|
| 1 | +import type { |
| 2 | + BirpcReturn, |
| 3 | + RpcDefinitionsToFunctions, |
| 4 | + RpcDumpClientOptions, |
| 5 | + RpcDumpCollectionOptions, |
| 6 | + RpcDumpDefinition, |
| 7 | + RpcDumpStore, |
| 8 | + RpcFunctionDefinitionAny, |
| 9 | +} from '../types' |
| 10 | +import { hash } from 'devframe/utils/hash' |
| 11 | +import pLimit from 'p-limit' |
| 12 | +import { logger } from '../diagnostics' |
| 13 | +import { validateDefinitions } from '../validation' |
| 14 | +import { reviveDumpError, serializeDumpError } from './error' |
| 15 | + |
| 16 | +function getDumpRecordKey(functionName: string, args: any[]): string { |
| 17 | + const argsHash = hash(args) |
| 18 | + return `${functionName}---${argsHash}` |
| 19 | +} |
| 20 | + |
| 21 | +function getDumpFallbackKey(functionName: string): string { |
| 22 | + return `${functionName}---fallback` |
| 23 | +} |
| 24 | + |
| 25 | +async function resolveGetter<T>(valueOrGetter: T | (() => Promise<T>)): Promise<T> { |
| 26 | + return typeof valueOrGetter === 'function' |
| 27 | + ? await (valueOrGetter as () => Promise<T>)() |
| 28 | + : valueOrGetter |
| 29 | +} |
| 30 | + |
| 31 | +/** |
| 32 | + * Collects pre-computed dumps by executing functions with their defined input combinations. |
| 33 | + * Static functions without dump config automatically get `{ inputs: [[]] }`. |
| 34 | + * |
| 35 | + * @example |
| 36 | + * ```ts |
| 37 | + * const store = await dumpFunctions([greet], context, { concurrency: 10 }) |
| 38 | + * ``` |
| 39 | + */ |
| 40 | +export async function dumpFunctions< |
| 41 | + T extends readonly RpcFunctionDefinitionAny[], |
| 42 | +>( |
| 43 | + definitions: T, |
| 44 | + context?: any, |
| 45 | + options: RpcDumpCollectionOptions = {}, |
| 46 | +): Promise<RpcDumpStore<RpcDefinitionsToFunctions<T>>> { |
| 47 | + validateDefinitions(definitions) |
| 48 | + const concurrency = options.concurrency === true |
| 49 | + ? 5 |
| 50 | + : options.concurrency === false || options.concurrency == null |
| 51 | + ? 1 |
| 52 | + : options.concurrency |
| 53 | + |
| 54 | + const store: RpcDumpStore = { |
| 55 | + definitions: {}, |
| 56 | + records: {}, |
| 57 | + } |
| 58 | + |
| 59 | + // #region Definition resolution |
| 60 | + interface TaskResolution { |
| 61 | + handler: (...args: any[]) => any |
| 62 | + dump: RpcDumpDefinition |
| 63 | + definition: RpcFunctionDefinitionAny |
| 64 | + } |
| 65 | + |
| 66 | + const tasksResolutions: (() => Promise<undefined | TaskResolution>)[] = definitions.map(definition => async () => { |
| 67 | + if (definition.type === 'event' || definition.type === 'action') { |
| 68 | + return undefined |
| 69 | + } |
| 70 | + |
| 71 | + // Fresh setup results for each context to avoid caching issues |
| 72 | + const setupResult = definition.setup |
| 73 | + ? await Promise.resolve(definition.setup(context)) |
| 74 | + : {} |
| 75 | + |
| 76 | + const handler = setupResult.handler || definition.handler |
| 77 | + if (!handler) { |
| 78 | + throw logger.DF0024({ name: definition.name }).throw() |
| 79 | + } |
| 80 | + |
| 81 | + let dump = setupResult.dump ?? definition.dump |
| 82 | + if (!dump && definition.type === 'static') { |
| 83 | + dump = { inputs: [[]] } |
| 84 | + } |
| 85 | + if (!dump && definition.snapshot) { |
| 86 | + // Sugar: run the handler once with no args, store the result as |
| 87 | + // both the no-args record and the fallback. Any client call then |
| 88 | + // resolves to the same snapshot — matching NMI's "getPayload() |
| 89 | + // always returns the baked dump" shape. |
| 90 | + dump = async (_ctx, h) => { |
| 91 | + const output = await Promise.resolve(h(...([] as unknown as any[]))) |
| 92 | + return { |
| 93 | + records: [{ inputs: [] as any, output }], |
| 94 | + fallback: output, |
| 95 | + } |
| 96 | + } |
| 97 | + } |
| 98 | + |
| 99 | + if (!dump) { |
| 100 | + return undefined |
| 101 | + } |
| 102 | + |
| 103 | + if (typeof dump === 'function') { |
| 104 | + dump = await Promise.resolve(dump(context, handler)) |
| 105 | + } |
| 106 | + |
| 107 | + // Only add to definitions if it has a dump |
| 108 | + store.definitions[definition.name] = { |
| 109 | + name: definition.name, |
| 110 | + type: definition.type, |
| 111 | + } |
| 112 | + |
| 113 | + return { |
| 114 | + handler, |
| 115 | + dump, |
| 116 | + definition, |
| 117 | + } |
| 118 | + }) |
| 119 | + |
| 120 | + let functionsToDump: TaskResolution[] = [] |
| 121 | + if (concurrency <= 1) { |
| 122 | + for (const task of tasksResolutions) { |
| 123 | + const resolution = await task() |
| 124 | + if (resolution) { |
| 125 | + functionsToDump.push(resolution) |
| 126 | + } |
| 127 | + } |
| 128 | + } |
| 129 | + else { |
| 130 | + const limit = pLimit(concurrency) |
| 131 | + functionsToDump = (await Promise.all(tasksResolutions.map(task => limit(task)))).filter(x => !!x) |
| 132 | + } |
| 133 | + // #endregion |
| 134 | + |
| 135 | + // #region Dump execution |
| 136 | + const dumpTasks: Array<() => Promise<void>> = [] |
| 137 | + for (const { definition, handler, dump } of functionsToDump) { |
| 138 | + const { inputs, records, fallback } = dump |
| 139 | + |
| 140 | + // Add pre-defined records |
| 141 | + if (records) { |
| 142 | + for (const record of records) { |
| 143 | + const recordKey = getDumpRecordKey(definition.name, record.inputs) |
| 144 | + store.records[recordKey] = record |
| 145 | + } |
| 146 | + } |
| 147 | + |
| 148 | + // Add fallback record |
| 149 | + if ('fallback' in dump) { |
| 150 | + const fallbackKey = getDumpFallbackKey(definition.name) |
| 151 | + store.records[fallbackKey] = { |
| 152 | + inputs: [], |
| 153 | + output: fallback, |
| 154 | + } |
| 155 | + } |
| 156 | + |
| 157 | + // Add input records execution tasks |
| 158 | + if (inputs) { |
| 159 | + for (const input of inputs) { |
| 160 | + dumpTasks.push(async () => { |
| 161 | + const recordKey = getDumpRecordKey(definition.name, input) |
| 162 | + |
| 163 | + try { |
| 164 | + const output = await Promise.resolve(handler(...input)) |
| 165 | + store.records[recordKey] = { |
| 166 | + inputs: input, |
| 167 | + output, |
| 168 | + } |
| 169 | + } |
| 170 | + catch (error: unknown) { |
| 171 | + store.records[recordKey] = { |
| 172 | + inputs: input, |
| 173 | + error: serializeDumpError(error), |
| 174 | + } |
| 175 | + } |
| 176 | + }) |
| 177 | + } |
| 178 | + } |
| 179 | + } |
| 180 | + |
| 181 | + if (concurrency <= 1) { |
| 182 | + for (const task of dumpTasks) { |
| 183 | + await task() |
| 184 | + } |
| 185 | + } |
| 186 | + else { |
| 187 | + const limit = pLimit(concurrency) |
| 188 | + await Promise.all(dumpTasks.map(task => limit(task))) |
| 189 | + } |
| 190 | + // #endregion |
| 191 | + |
| 192 | + return store |
| 193 | +} |
| 194 | + |
| 195 | +/** |
| 196 | + * Creates a client that serves pre-computed results from a dump store. |
| 197 | + * Uses argument hashing to match calls to stored records. |
| 198 | + * |
| 199 | + * @example |
| 200 | + * ```ts |
| 201 | + * const client = createClientFromDump(store) |
| 202 | + * await client.greet('Alice') |
| 203 | + * ``` |
| 204 | + */ |
| 205 | +export function createClientFromDump<T extends Record<string, any>>( |
| 206 | + store: RpcDumpStore<T>, |
| 207 | + options: RpcDumpClientOptions = {}, |
| 208 | +): BirpcReturn<T> { |
| 209 | + const { onMiss } = options |
| 210 | + |
| 211 | + const client = new Proxy({} as T, { |
| 212 | + get(_, functionName: string) { |
| 213 | + if (!(functionName in store.definitions)) { |
| 214 | + throw logger.DF0025({ name: functionName }).throw() |
| 215 | + } |
| 216 | + |
| 217 | + return async (...args: any[]) => { |
| 218 | + const recordKey = getDumpRecordKey(functionName, args) |
| 219 | + |
| 220 | + const recordOrGetter = store.records[recordKey] |
| 221 | + |
| 222 | + if (recordOrGetter) { |
| 223 | + const record = await resolveGetter(recordOrGetter) |
| 224 | + |
| 225 | + if (record.error) { |
| 226 | + throw reviveDumpError(record.error) |
| 227 | + } |
| 228 | + |
| 229 | + if (typeof record.output === 'function') { |
| 230 | + return await record.output() |
| 231 | + } |
| 232 | + |
| 233 | + return record.output |
| 234 | + } |
| 235 | + |
| 236 | + onMiss?.(functionName, args) |
| 237 | + |
| 238 | + const fallbackKey = getDumpFallbackKey(functionName) |
| 239 | + if (fallbackKey in store.records) { |
| 240 | + const fallbackOrGetter = store.records[fallbackKey] |
| 241 | + |
| 242 | + const fallbackRecord = await resolveGetter(fallbackOrGetter) |
| 243 | + |
| 244 | + if (fallbackRecord && typeof fallbackRecord.output === 'function') { |
| 245 | + return await fallbackRecord.output() |
| 246 | + } |
| 247 | + if (fallbackRecord) |
| 248 | + return fallbackRecord.output |
| 249 | + } |
| 250 | + |
| 251 | + throw logger.DF0026({ name: functionName, args: JSON.stringify(args) }).throw() |
| 252 | + } |
| 253 | + }, |
| 254 | + has(_, functionName: string) { |
| 255 | + return functionName in store.definitions |
| 256 | + }, |
| 257 | + ownKeys() { |
| 258 | + return Object.keys(store.definitions) |
| 259 | + }, |
| 260 | + getOwnPropertyDescriptor(_, functionName: string) { |
| 261 | + return functionName in store.definitions |
| 262 | + ? { configurable: true, enumerable: true, value: undefined } |
| 263 | + : undefined |
| 264 | + }, |
| 265 | + }) |
| 266 | + |
| 267 | + return client as any as BirpcReturn<T> |
| 268 | +} |
| 269 | + |
| 270 | +/** |
| 271 | + * Filters function definitions to only those with dump definitions. |
| 272 | + * Note: Only checks the definition itself, not setup results. |
| 273 | + */ |
| 274 | +export function getDefinitionsWithDumps<T extends readonly RpcFunctionDefinitionAny[]>( |
| 275 | + definitions: T, |
| 276 | +): RpcFunctionDefinitionAny[] { |
| 277 | + return definitions.filter(def => def.dump !== undefined) |
| 278 | +} |
0 commit comments