diff --git a/.npmignore b/.npmignore new file mode 100644 index 00000000..e69de29b diff --git a/README.md b/README.md index 1f61bd4d..938a7e3b 100644 --- a/README.md +++ b/README.md @@ -26,8 +26,11 @@ const replicate = new Replicate({ Run a model and await the result: ```js -const model = "stability-ai/stable-diffusion:27b93a2413e7f36cd83da926f3656280b2931564ff050bf9575f1fdf9bcd7478"; -const input = { prompt: "a 19th century portrait of a raccoon gentleman wearing a suit" }; +const model = + "stability-ai/stable-diffusion:27b93a2413e7f36cd83da926f3656280b2931564ff050bf9575f1fdf9bcd7478"; +const input = { + prompt: "a 19th century portrait of a raccoon gentleman wearing a suit", +}; const output = await replicate.run(model, { input }); // ['https://replicate.delivery/pbxt/GtQb3Sgve42ZZyVnt8xjquFk9EX5LP0fF68NTIWlgBMUpguQA/out-0.png'] ``` @@ -53,7 +56,7 @@ Or wait for the prediction to finish: ```js prediction = await replicate.wait(prediction); -console.log(prediction.output) +console.log(prediction.output); // ['https://replicate.delivery/pbxt/RoaxeXqhL0xaYyLm6w3bpGwF5RaNBjADukfFnMbhOyeoWBdhA/out-0.png'] ``` @@ -280,12 +283,14 @@ which you can use in a for loop or iterate over manually. ```js // iterate over paginated results in a for loop -for await (const page of replicate.paginate(replicate.predictions.list)) { +for await (const page of replicate.paginate( + replicate.predictions.list.bind(replicate) +)) { /* do something with page of results */ } // iterate over paginated results one at a time -let paginator = replicate.paginate(replicate.predictions.list); +let paginator = replicate.paginate(replicate.predictions.list.bind(replicate)); const page1 = await paginator.next(); const page2 = await paginator.next(); // etc. diff --git a/index.d.ts b/index.d.ts deleted file mode 100644 index bc5be670..00000000 --- a/index.d.ts +++ /dev/null @@ -1,133 +0,0 @@ -declare module 'replicate' { - type Status = 'starting' | 'processing' | 'succeeded' | 'failed' | 'canceled'; - type WebhookEventType = 'start' | 'output' | 'logs' | 'completed'; - - interface Page { - previous?: string; - next?: string; - results: T[]; - } - - export interface Collection { - name: string; - slug: string; - description: string; - models: Model[]; - } - - export interface Model { - url: string; - owner: string; - name: string; - description: string; - visibility: 'public' | 'private'; - github_url: string; - paper_url: string; - license_url: string; - run_count: number; - cover_image_url: string; - default_example?: Prediction; - latest_version?: ModelVersion; - } - - export interface ModelVersion { - id: string; - created_at: string; - cog_version: string; - openapi_schema: object; - } - - export interface Prediction { - id: string; - status: Status; - version: string; - input: object; - output: any; - source: 'api' | 'web'; - error?: any; - logs?: string; - metrics?: { - predicti_time?: number; - } - webhook?: string; - webhook_events_filter?: WebhookEventType[]; - created_at: string; - updated_at: string; - completed_at?: string; - } - - export default class Replicate { - constructor(options: { - auth: string; - userAgent?: string; - baseUrl?: string; - }); - - auth: string; - userAgent?: string; - baseUrl?: string; - fetch: Function; - - run( - identifier: `${string}/${string}:${string}`, - options: { - input: object; - wait?: boolean | { interval?: number; maxAttempts?: number }; - webhook?: string; - webhook_events_filter?: WebhookEventType[]; - } - ): Promise; - request(route: string, parameters: any): Promise; - paginate(endpoint: () => Promise>): AsyncGenerator<[ T ]>; - wait( - prediction: Prediction, - options: { - interval?: number; - maxAttempts?: number; - } - ): Promise; - - collections: { - get(collection_slug: string): Promise; - }; - - models: { - get(model_owner: string, model_name: string): Promise; - versions: { - list(model_owner: string, model_name: string): Promise; - get( - model_owner: string, - model_name: string, - version_id: string - ): Promise; - }; - }; - - predictions: { - create(options: { - version: string; - input: object; - webhook?: string; - webhook_events_filter?: WebhookEventType[]; - }): Promise; - get(prediction_id: string): Promise; - list(): Promise>; - }; - - trainings: { - create( - model_owner: string, - model_name: string, - version_id: string, - options: { - destination: `${string}/${string}`; - input: object; - webhook?: string; - webhook_events_filter?: WebhookEventType[]; - } - ): Promise; - get(options: TrainingsGetOptions): Promise; - cancel(options: TrainingsGetOptions): Promise; - }; - } -} diff --git a/index.js b/index.js deleted file mode 100644 index 6cb513fb..00000000 --- a/index.js +++ /dev/null @@ -1,234 +0,0 @@ -const collections = require('./lib/collections'); -const models = require('./lib/models'); -const predictions = require('./lib/predictions'); -const trainings = require('./lib/trainings'); -const packageJSON = require('./package.json'); - -/** - * Replicate API client library - * - * @see https://replicate.com/docs/reference/http - * @example - * // Create a new Replicate API client instance - * const Replicate = require("replicate"); - * const replicate = new Replicate({ - * // get your token from https://replicate.com/account - * auth: process.env.REPLICATE_API_TOKEN, - * userAgent: "my-app/1.2.3" - * }); - * - * // Run a model and await the result: - * const model = 'owner/model:version-id' - * const input = {text: 'Hello, world!'} - * const output = await replicate.run(model, { input }); - */ -class Replicate { - /** - * Create a new Replicate API client instance. - * - * @param {object} options - Configuration options for the client - * @param {string} options.auth - Required. API access token - * @param {string} options.userAgent - Identifier of your app - * @param {string} [options.baseUrl] - Defaults to https://api.replicate.com/v1 - * @param {Function} [options.fetch] - Defaults to native fetch - */ - constructor(options) { - this.auth = options.auth; - this.userAgent = - options.userAgent || `replicate-javascript/${packageJSON.version}`; - this.baseUrl = options.baseUrl || 'https://api.replicate.com/v1'; - this.fetch = fetch; - - this.collections = { - get: collections.get.bind(this), - }; - - this.models = { - get: models.get.bind(this), - versions: { - list: models.versions.list.bind(this), - get: models.versions.get.bind(this), - }, - }; - - this.predictions = { - create: predictions.create.bind(this), - get: predictions.get.bind(this), - list: predictions.list.bind(this), - }; - - this.trainings = { - create: trainings.create.bind(this), - get: trainings.get.bind(this), - cancel: trainings.cancel.bind(this), - }; - } - - /** - * Run a model and wait for its output. - * - * @param {string} identifier - Required. The model version identifier in the format "{owner}/{name}:{version}" - * @param {object} options - * @param {object} options.input - Required. An object with the model inputs - * @param {boolean|object} [options.wait] - Whether to wait for the prediction to finish. Defaults to false - * @param {number} [options.wait.interval] - Polling interval in milliseconds. Defaults to 250 - * @param {number} [options.wait.maxAttempts] - Maximum number of polling attempts. Defaults to no limit - * @param {string} [options.webhook] - An HTTPS URL for receiving a webhook when the prediction has new output - * @param {string[]} [options.webhook_events_filter] - You can change which events trigger webhook requests by specifying webhook events (`start`|`output`|`logs`|`completed`) - * @throws {Error} If the prediction failed - * @returns {Promise} - Resolves with the output of running the model - */ - async run(identifier, options) { - const pattern = - /^(?[a-zA-Z0-9-_]+?)\/(?[a-zA-Z0-9-_]+?):(?[0-9a-fA-F]+)$/; - const match = identifier.match(pattern); - - if (!match) { - throw new Error( - 'Invalid version. It must be in the format "owner/name:version"' - ); - } - - const { version } = match.groups; - const prediction = await this.predictions.create({ - wait: true, - ...options, - version, - }); - - if (prediction.status === 'failed') { - throw new Error(`Prediction failed: ${prediction.error}`); - } - - return prediction.output; - } - - /** - * Make a request to the Replicate API. - * - * @param {string} route - REST API endpoint path - * @param {object} parameters - Request parameters - * @param {string} [parameters.method] - HTTP method. Defaults to GET - * @param {object} [parameters.params] - Query parameters - * @param {object} [parameters.data] - Body parameters - * @returns {Promise} - Resolves with the API response data - */ - async request(route, parameters) { - const { auth, baseUrl, userAgent } = this; - - const url = new URL( - route.startsWith('/') ? route.slice(1) : route, - baseUrl.endsWith('/') ? baseUrl : `${baseUrl}/` - ); - - const { method = 'GET', params = {}, data } = parameters; - - Object.entries(params).forEach(([key, value]) => { - url.searchParams.append(key, value); - }); - - const headers = { - Authorization: `Token ${auth}`, - 'Content-Type': 'application/json', - 'User-Agent': userAgent, - }; - - const response = await this.fetch(url, { - method, - headers, - body: data ? JSON.stringify(data) : undefined, - }); - - if (!response.ok) { - throw new Error(`API request failed: ${response.statusText}`); - } - - return response.json(); - } - - /** - * Paginate through a list of results. - * - * @generator - * @example - * for await (const page of replicate.paginate(replicate.predictions.list) { - * console.log(page); - * } - * @param {Function} endpoint - Function that returns a promise for the next page of results - * @yields {object[]} Each page of results - */ - async *paginate(endpoint) { - const response = await endpoint(); - yield response.results; - if (response.next) { - const nextPage = () => this.request(response.next, { method: 'GET' }); - yield* this.paginate(nextPage); - } - } - - /** - * Wait for a prediction to finish. - * - * If the prediction has already finished, - * this function returns immediately. - * Otherwise, it polls the API until the prediction finishes. - * - * @async - * @param {object} prediction - Prediction object - * @param {object} options - Options - * @param {number} [options.interval] - Polling interval in milliseconds. Defaults to 250 - * @param {number} [options.maxAttempts] - Maximum number of polling attempts. Defaults to no limit - * @throws {Error} If the prediction doesn't complete within the maximum number of attempts - * @throws {Error} If the prediction failed - * @returns {Promise} Resolves with the completed prediction object - */ - async wait(prediction, options) { - const { id } = prediction; - if (!id) { - throw new Error('Invalid prediction'); - } - - if ( - prediction.status === 'succeeded' || - prediction.status === 'failed' || - prediction.status === 'canceled' - ) { - return prediction; - } - - let updatedPrediction = await this.predictions.get(id); - - // eslint-disable-next-line no-promise-executor-return - const sleep = (ms) => new Promise((resolve) => setTimeout(resolve, ms)); - - let attempts = 0; - const interval = options.interval || 250; - const maxAttempts = options.maxAttempts || null; - - while ( - updatedPrediction.status !== 'succeeded' && - updatedPrediction.status !== 'failed' && - updatedPrediction.status !== 'canceled' - ) { - attempts += 1; - if (maxAttempts && attempts > maxAttempts) { - throw new Error( - `Prediction ${id} did not finish after ${maxAttempts} attempts` - ); - } - - /* eslint-disable no-await-in-loop */ - await sleep(interval); - updatedPrediction = await this.predictions.get(prediction.id); - /* eslint-enable no-await-in-loop */ - } - - if (updatedPrediction.status === 'failed') { - throw new Error(`Prediction failed: ${updatedPrediction.error}`); - } - - return updatedPrediction; - } -} - -module.exports = Replicate; diff --git a/index.test.ts b/index.test.ts index 7da943d5..37344b12 100644 --- a/index.test.ts +++ b/index.test.ts @@ -1,6 +1,5 @@ - import { expect, jest, test } from '@jest/globals'; -import Replicate, { Prediction } from 'replicate'; +import Replicate, { Prediction } from '.'; import nock from 'nock'; import fetch from 'cross-fetch'; @@ -10,8 +9,7 @@ describe('Replicate client', () => { const BASE_URL = 'https://api.replicate.com/v1'; beforeEach(() => { - client = new Replicate({ auth: 'test-token' }); - client.fetch = fetch; + client = new Replicate({ auth: 'test-token', fetch }); }); describe('constructor', () => { @@ -38,15 +36,13 @@ describe('Replicate client', () => { describe('collections.get', () => { test('Calls the correct API route', async () => { - nock(BASE_URL) - .get('/collections/super-resolution') - .reply(200, { - name: 'Super resolution', - slug: 'super-resolution', - description: - 'Upscaling models that create high-quality images from low-quality images.', - models: [], - }); + nock(BASE_URL).get('/collections/super-resolution').reply(200, { + name: 'Super resolution', + slug: 'super-resolution', + description: + 'Upscaling models that create high-quality images from low-quality images.', + models: [], + }); const collection = await client.collections.get('super-resolution'); expect(collection.name).toBe('Super resolution'); @@ -56,29 +52,26 @@ describe('Replicate client', () => { describe('models.get', () => { test('Calls the correct API route', async () => { - nock(BASE_URL) - .get('/models/replicate/hello-world') - .reply(200, { - url: 'https://replicate.com/replicate/hello-world', - owner: 'replicate', - name: 'hello-world', - description: 'A tiny model that says hello', - visibility: 'public', - github_url: 'https://github.com/replicate/cog-examples', - paper_url: null, - license_url: null, - run_count: 12345, - cover_image_url: '', - default_example: {}, - latest_version: {}, - }); + nock(BASE_URL).get('/models/replicate/hello-world').reply(200, { + url: 'https://replicate.com/replicate/hello-world', + owner: 'replicate', + name: 'hello-world', + description: 'A tiny model that says hello', + visibility: 'public', + github_url: 'https://github.com/replicate/cog-examples', + paper_url: null, + license_url: null, + run_count: 12345, + cover_image_url: '', + default_example: {}, + latest_version: {}, + }); await client.models.get('replicate', 'hello-world'); }); // Add more tests for error handling, edge cases, etc. }); - describe('predictions.create', () => { test('Calls the correct API route with the correct payload', async () => { nock(BASE_URL) @@ -111,7 +104,7 @@ describe('Replicate client', () => { text: 'Alice', }, webhook: 'http://test.host/webhook', - webhook_events_filter: [ 'output', 'completed' ], + webhook_events_filter: ['output', 'completed'], }); expect(prediction.id).toBe('ufawqhfynnddngldkgtslldrkq'); }); @@ -184,24 +177,28 @@ describe('Replicate client', () => { const predictions = await client.predictions.list(); expect(predictions.results.length).toBe(1); - expect(predictions.results[ 0 ].id).toBe('jpzd7hm5gfcapbfyt4mqytarku'); + expect(predictions.results[0].id).toBe('jpzd7hm5gfcapbfyt4mqytarku'); }); test('Paginates results', async () => { nock(BASE_URL) .get('/predictions') .reply(200, { - results: [ { id: 'ufawqhfynnddngldkgtslldrkq' } ], + results: [{ id: 'ufawqhfynnddngldkgtslldrkq' }], next: 'https://api.replicate.com/v1/predictions?cursor=cD0yMDIyLTAxLTIxKzIzJTNBMTglM0EyNC41MzAzNTclMkIwMCUzQTAw', }) - .get('/predictions?cursor=cD0yMDIyLTAxLTIxKzIzJTNBMTglM0EyNC41MzAzNTclMkIwMCUzQTAw') + .get( + '/predictions?cursor=cD0yMDIyLTAxLTIxKzIzJTNBMTglM0EyNC41MzAzNTclMkIwMCUzQTAw' + ) .reply(200, { - results: [ { id: 'rrr4z55ocneqzikepnug6xezpe' } ], + results: [{ id: 'rrr4z55ocneqzikepnug6xezpe' }], next: null, }); const results: Prediction[] = []; - for await (const batch of client.paginate(client.predictions.list)) { + for await (const batch of client.paginate( + client.predictions.list.bind(client.predictions) + )) { results.push(...batch); } expect(results).toEqual([ @@ -215,7 +212,9 @@ describe('Replicate client', () => { describe('trainings.create', () => { test('Calls the correct API route with the correct payload', async () => { nock(BASE_URL) - .post('/models/owner/model/versions/632231d0d49d34d5c4633bd838aee3d81d936e59a886fbf28524702003b4c532/trainings') + .post( + '/models/owner/model/versions/632231d0d49d34d5c4633bd838aee3d81d936e59a886fbf28524702003b4c532/trainings' + ) .reply(200, { id: 'zz4ibbonubfz7carwiefibzgga', version: '{version}', @@ -231,7 +230,6 @@ describe('Replicate client', () => { completed_at: null, }); - const training = await client.trainings.create( 'owner', 'model', @@ -272,7 +270,9 @@ describe('Replicate client', () => { completed_at: null, }); - const training = await client.trainings.get('zz4ibbonubfz7carwiefibzgga'); + const training = await client.trainings.get( + 'zz4ibbonubfz7carwiefibzgga' + ); expect(training.status).toBe('succeeded'); }); @@ -302,8 +302,9 @@ describe('Replicate client', () => { completed_at: null, }); - - const training = await client.trainings.cancel('zz4ibbonubfz7carwiefibzgga'); + const training = await client.trainings.cancel( + 'zz4ibbonubfz7carwiefibzgga' + ); expect(training.status).toBe('canceled'); }); diff --git a/index.ts b/index.ts new file mode 100644 index 00000000..4918219a --- /dev/null +++ b/index.ts @@ -0,0 +1,149 @@ +import APIClient, { Page } from './lib/apiClient'; +import Collections from './lib/collections'; +import Predictions, { WaitOptions } from './lib/predictions'; +import Trainings from './lib/trainings'; +import Models from './lib/models'; + +export { Prediction, PredictionOptions } from './lib/predictions'; +export { Page } from './lib/apiClient'; +export { Model, ModelVersion } from './lib/models'; +export { Collection } from './lib/collections'; +export { Training, TrainingOptions } from './lib/trainings'; + +export interface ReplicateOptions { + auth: string; + userAgent?: string; + baseUrl?: string; + fetch?: typeof fetch; +} + +export interface RunOptions { + input: object; + wait?: boolean | WaitOptions; + webhook?: string; + webhook_events_filter?: string[]; +} + +/** + * Replicate API client library + * + * @see https://replicate.com/docs/reference/http + * @example + * // Create a new Replicate API client instance + * const Replicate = require("replicate"); + * const replicate = new Replicate({ + * // get your token from https://replicate.com/account + * auth: process.env.REPLICATE_API_TOKEN, + * userAgent: "my-app/1.2.3" + * }); + * + * // Run a model and await the result: + * const model = 'owner/model:version-id' + * const input = {text: 'Hello, world!'} + * const output = await replicate.run(model, { input }); + */ +class Replicate { + client: APIClient; + public collections: Collections; + public models: Models; + public predictions: Predictions; + public trainings: Trainings; + + /** + * Create a new Replicate API client instance. + * + * @param {object} options - Configuration options for the client + * @param {string} options.auth - Required. API access token + * @param {string} options.userAgent - Identifier of your app + * @param {string} [options.baseUrl] - Defaults to https://api.replicate.com/v1 + * @param {Function} [options.fetch] - Defaults to native fetch + */ + constructor(options: ReplicateOptions) { + this.client = new APIClient( + options.auth, + options.userAgent, + options.baseUrl, + options.fetch + ); + + this.predictions = new Predictions(this.client); + this.collections = new Collections(this.client); + this.models = new Models(this.client); + this.trainings = new Trainings(this.client); + } + + /** + * Run a model and wait for its output. + * + * @param {string} identifier - Required. The model version identifier in the format "{owner}/{name}:{version}" + * @param {object} options + * @param {object} options.input - Required. An object with the model inputs + * @param {boolean|object} [options.wait] - Whether to wait for the prediction to finish. Defaults to false + * @param {number} [options.wait.interval] - Polling interval in milliseconds. Defaults to 250 + * @param {number} [options.wait.maxAttempts] - Maximum number of polling attempts. Defaults to no limit + * @param {string} [options.webhook] - An HTTPS URL for receiving a webhook when the prediction has new output + * @param {string[]} [options.webhook_events_filter] - You can change which events trigger webhook requests by specifying webhook events (`start`|`output`|`logs`|`completed`) + * @throws {Error} If the prediction failed + * @returns {Promise} - Resolves with the output of running the model + */ + async run(identifier: string, options: RunOptions): Promise { + const pattern = + /^(?[a-zA-Z0-9-_]+?)\/(?[a-zA-Z0-9-_]+?):(?[0-9a-fA-F]+)$/; + const match = identifier.match(pattern); + + if (!match) { + throw new Error( + 'Invalid version. It must be in the format "owner/name:version"' + ); + } + + const { version } = match.groups as { + version: string; + owner: string; + name: string; + }; + + const prediction = await this.predictions.create({ + wait: true, + ...options, + version, + }); + + if (prediction.status === 'failed') { + throw new Error(`Prediction failed: ${prediction.error}`); + } + + return prediction.output; + } + + /** + * Paginate through a list of results. + * + * @generator + * @example + * for await (const page of replicate.paginate(replicate.predictions.list) { + * console.log(page); + * } + * @param {Function} endpoint - Function that returns a promise for the next page of results + * @yields {object[]} Each page of results + */ + async *paginate(endpoint: () => Promise>): AsyncGenerator { + yield* this.client.paginate(endpoint); + } + + /** + * The base URL being used by the API client. + */ + get baseUrl() { + return this.client.baseUrl; + } + + /** + * The user agent being used by the API client + */ + get userAgent() { + return this.client.userAgent; + } +} + +export default Replicate; diff --git a/lib/apiClient.ts b/lib/apiClient.ts new file mode 100644 index 00000000..a131b6e1 --- /dev/null +++ b/lib/apiClient.ts @@ -0,0 +1,90 @@ +import { FetchLike } from './types'; +import packageJSON from '../package.json'; + +export interface RequestOptions { + method?: string; + params?: object; + data?: object; +} + +export interface Page { + previous?: string; + next?: string; + results: T[]; +} + +const BASE_URL = 'https://api.replicate.com/v1'; + +export default class APIClient { + constructor( + readonly auth: string, + readonly userAgent: string = `replicate-javascript/${packageJSON.version}`, + readonly baseUrl: string = BASE_URL, + readonly customFetch?: FetchLike + ) {} + + async request(route: string, parameters: RequestOptions): Promise { + const url = new URL( + route.startsWith('/') ? route.slice(1) : route, + this.baseUrl.endsWith('/') ? this.baseUrl : `${this.baseUrl}/` + ); + + const { method = 'GET', params = {}, data } = parameters; + + Object.entries(params).forEach(([key, value]) => { + url.searchParams.append(key, value); + }); + + const headers = { + Authorization: `Token ${this.auth}`, + 'Content-Type': 'application/json', + 'User-Agent': this.userAgent, + }; + + let localFetch = this.customFetch; + if (!localFetch) { + localFetch = fetch; + } + + const response = await localFetch(url, { + method, + headers, + body: data ? JSON.stringify(data) : undefined, + }); + + if (!response.ok) { + throw new Error(`API request failed: ${response.statusText}`); + } + + return response.json(); + } + + /** + * Paginate through a list of results. + * + * @generator + * @example + * for await (const page of replicate.paginate(replicate.predictions.list) { + * console.log(page); + * } + * @param {Function} endpoint - Function that returns a promise for the next page of results + * @yields {object[]} Each page of results + */ + async *paginate(endpoint: () => Promise>): AsyncGenerator { + const response = await endpoint(); + + let next = response.next; + + yield response.results; + + if (response.next) { + console.log(response.next); + + const nextPage = (): Promise> => { + return this.request(response.next as string, { method: 'GET' }); + }; + + yield* this.paginate(nextPage); + } + } +} diff --git a/lib/collections.js b/lib/collections.js deleted file mode 100644 index 668262e3..00000000 --- a/lib/collections.js +++ /dev/null @@ -1,13 +0,0 @@ -/** - * Fetch a model collection - * - * @param {string} collection_slug - Required. The slug of the collection. See http://replicate.com/collections - * @returns {Promise} - Resolves with the collection data - */ -async function getCollection(collection_slug) { - return this.request(`/collections/${collection_slug}`, { - method: 'GET', - }); -} - -module.exports = { get: getCollection }; diff --git a/lib/collections.ts b/lib/collections.ts new file mode 100644 index 00000000..45241cfe --- /dev/null +++ b/lib/collections.ts @@ -0,0 +1,24 @@ +import APIClient from './apiClient'; +import { Model } from './models'; + +export interface Collection { + name: string; + slug: string; + description: string; + models: Model[]; +} + +export default class Collections { + constructor(private client: APIClient) {} + /** + * Fetch a model collection + * + * @param {string} collection_slug - Required. The slug of the collection. See http://replicate.com/collections + * @returns {Promise} - Resolves with the collection data + */ + async get(collection_slug: string): Promise { + return this.client.request(`/collections/${collection_slug}`, { + method: 'GET', + }); + } +} diff --git a/lib/models.js b/lib/models.js deleted file mode 100644 index 55c9b43b..00000000 --- a/lib/models.js +++ /dev/null @@ -1,47 +0,0 @@ -/** - * Get information about a model - * - * @param {string} model_owner - Required. The name of the user or organization that owns the model - * @param {string} model_name - Required. The name of the model - * @returns {Promise} Resolves with the model data - */ -async function getModel(model_owner, model_name) { - return this.request(`/models/${model_owner}/${model_name}`, { - method: 'GET', - }); -} - -/** - * List model versions - * - * @param {string} model_owner - Required. The name of the user or organization that owns the model - * @param {string} model_name - Required. The name of the model - * @returns {Promise} Resolves with the list of model versions - */ -async function listModelVersions(model_owner, model_name) { - return this.request(`/models/${model_owner}/${model_name}/versions`, { - method: 'GET', - }); -} - -/** - * Get a specific model version - * - * @param {string} model_owner - Required. The name of the user or organization that owns the model - * @param {string} model_name - Required. The name of the model - * @param {string} version_id - Required. The model version - * @returns {Promise} Resolves with the model version data - */ -async function getModelVersion(model_owner, model_name, version_id) { - return this.request( - `/models/${model_owner}/${model_name}/versions/${version_id}`, - { - method: 'GET', - } - ); -} - -module.exports = { - get: getModel, - versions: { list: listModelVersions, get: getModelVersion }, -}; diff --git a/lib/models.ts b/lib/models.ts new file mode 100644 index 00000000..2090557b --- /dev/null +++ b/lib/models.ts @@ -0,0 +1,85 @@ +import APIClient from './apiClient'; +import { Prediction } from './predictions'; + +export interface Model { + url: string; + owner: string; + name: string; + description: string; + visibility: 'public' | 'private'; + github_url: string; + paper_url: string; + license_url: string; + run_count: number; + cover_image_url: string; + default_example?: Prediction; + latest_version?: ModelVersion; +} + +export interface ModelVersion { + id: string; + created_at: string; + cog_version: string; + openapi_schema: object; +} + +export default class Models { + public versions: ModelVersions; + constructor(private client: APIClient) { + this.versions = new ModelVersions(client); + } + + /** + * Get information about a model + * + * @param {string} model_owner - Required. The name of the user or organization that owns the model + * @param {string} model_name - Required. The name of the model + * @returns {Promise} Resolves with the model data + */ + async get(model_owner: string, model_name: string): Promise { + return this.client.request(`/models/${model_owner}/${model_name}`, { + method: 'GET', + }); + } +} + +export class ModelVersions { + constructor(private client: APIClient) {} + + /** + * List model versions + * + * @param {string} model_owner - Required. The name of the user or organization that owns the model + * @param {string} model_name - Required. The name of the model + * @returns {Promise} Resolves with the list of model versions + */ + async list(model_owner: string, model_name: string): Promise { + return this.client.request( + `/models/${model_owner}/${model_name}/versions`, + { + method: 'GET', + } + ); + } + + /** + * Get a specific model version + * + * @param {string} model_owner - Required. The name of the user or organization that owns the model + * @param {string} model_name - Required. The name of the model + * @param {string} version_id - Required. The model version + * @returns {Promise} Resolves with the model version data + */ + async get( + model_owner: string, + model_name: string, + version_id: string + ): Promise { + return this.client.request( + `/models/${model_owner}/${model_name}/versions/${version_id}`, + { + method: 'GET', + } + ); + } +} diff --git a/lib/predictions.js b/lib/predictions.js deleted file mode 100644 index b688feeb..00000000 --- a/lib/predictions.js +++ /dev/null @@ -1,57 +0,0 @@ -/** - * Create a new prediction - * - * @param {object} options - * @param {string} options.version - Required. The model version - * @param {object} options.input - Required. An object with the model inputs - * @param {boolean|object} [options.wait] - Whether to wait for the prediction to finish. Defaults to false - * @param {number} [options.wait.interval] - Polling interval in milliseconds. Defaults to 250 - * @param {number} [options.wait.maxAttempts] - Maximum number of polling attempts. Defaults to no limit - * @param {string} [options.webhook] - An HTTPS URL for receiving a webhook when the prediction has new output - * @param {string[]} [options.webhook_events_filter] - You can change which events trigger webhook requests by specifying webhook events (`start`|`output`|`logs`|`completed`) - * @returns {Promise} Resolves with the created prediction data - */ -async function createPrediction(options) { - const { wait, ...data } = options; - - const prediction = this.request('/predictions', { - method: 'POST', - data, - }); - - if (wait) { - const { maxAttempts, interval } = options.wait; - return this.wait(await prediction, { maxAttempts, interval }); - } - - return prediction; -} - -/** - * Fetch a prediction by ID - * - * @param {number} prediction_id - Required. The prediction ID - * @returns {Promise} Resolves with the prediction data - */ -async function getPrediction(prediction_id) { - return this.request(`/predictions/${prediction_id}`, { - method: 'GET', - }); -} - -/** - * List all predictions - * - * @returns {Promise} - Resolves with a page of predictions - */ -async function listPredictions() { - return this.request('/predictions', { - method: 'GET', - }); -} - -module.exports = { - create: createPrediction, - get: getPrediction, - list: listPredictions, -}; diff --git a/lib/predictions.ts b/lib/predictions.ts new file mode 100644 index 00000000..493563c6 --- /dev/null +++ b/lib/predictions.ts @@ -0,0 +1,156 @@ +import APIClient, { Page } from './apiClient'; +import { Status, WebhookEventType } from './types'; + +export interface WaitOptions { + interval?: number; + maxAttempts?: number; +} + +export interface PredictionOptions { + version: string; + input: object; + wait?: WaitOptions | boolean; + webhook?: string; + webhook_events_filter?: string[]; +} + +export interface Prediction { + id: string; + status: Status; + version: string; + input: object; + output: any; + source: 'api' | 'web'; + error?: any; + logs?: string; + metrics?: { + predicti_time?: number; + }; + webhook?: string; + webhook_events_filter?: WebhookEventType[]; + created_at: string; + updated_at: string; + completed_at?: string; +} + +export default class Predictions { + constructor(private client: APIClient) {} + + /** + * Create a new prediction + * + * @param {object} options + * @param {string} options.version - Required. The model version + * @param {object} options.input - Required. An object with the model inputs + * @param {boolean|object} [options.wait] - Whether to wait for the prediction to finish. Defaults to false + * @param {number} [options.wait.interval] - Polling interval in milliseconds. Defaults to 250 + * @param {number} [options.wait.maxAttempts] - Maximum number of polling attempts. Defaults to no limit + * @param {string} [options.webhook] - An HTTPS URL for receiving a webhook when the prediction has new output + * @param {string[]} [options.webhook_events_filter] - You can change which events trigger webhook requests by specifying webhook events (`start`|`output`|`logs`|`completed`) + * @returns {Promise} Resolves with the created prediction data + */ + async create(options: PredictionOptions): Promise { + const { wait, ...data } = options; + + const prediction: Prediction = await this.client.request('/predictions', { + method: 'POST', + data, + }); + + if (wait) { + let waitOptions = wait === true ? {} : wait; + return this.wait(prediction, waitOptions); + } + + return prediction; + } + + /** + * Fetch a prediction by ID + * @param predictionId + * @returns + */ + async get(predictionId: string): Promise { + return this.client.request(`/predictions/${predictionId}`, { + method: 'GET', + }); + } + + /** + * List all predictions + * + * @returns {Promise} - Resolves with a page of predictions + */ + async list(): Promise> { + return this.client.request('/predictions', { + method: 'GET', + }); + } + + /** + * Wait for a prediction to finish. + * + * If the prediction has already finished, + * this function returns immediately. + * Otherwise, it polls the API until the prediction finishes. + * + * @async + * @param {object} prediction - Prediction object + * @param {object} options - Options + * @param {number} [options.interval] - Polling interval in milliseconds. Defaults to 250 + * @param {number} [options.maxAttempts] - Maximum number of polling attempts. Defaults to no limit + * @throws {Error} If the prediction doesn't complete within the maximum number of attempts + * @throws {Error} If the prediction failed + * @returns {Promise} Resolves with the completed prediction object + */ + async wait( + prediction: Prediction, + options: WaitOptions + ): Promise { + const { id } = prediction; + if (!id) { + throw new Error('Invalid prediction'); + } + + if ( + prediction.status === 'succeeded' || + prediction.status === 'failed' || + prediction.status === 'canceled' + ) { + return prediction; + } + + let updatedPrediction: Prediction = await this.get(id); + + const sleep = (ms: number) => + new Promise((resolve) => setTimeout(resolve, ms)); + + let attempts = 0; + const interval = options.interval || 250; + const maxAttempts = options.maxAttempts || null; + + while ( + updatedPrediction.status !== 'succeeded' && + updatedPrediction.status !== 'failed' && + updatedPrediction.status !== 'canceled' + ) { + attempts += 1; + if (maxAttempts && attempts > maxAttempts) { + throw new Error( + `Prediction ${id} did not finish after ${maxAttempts} attempts` + ); + } + + /* eslint-disable no-await-in-loop */ + await sleep(interval); + updatedPrediction = await this.get(prediction.id); + /* eslint-enable no-await-in-loop */ + } + + if (updatedPrediction.status === 'failed') { + throw new Error(`Prediction failed: ${updatedPrediction.error}`); + } + + return updatedPrediction; + } +} diff --git a/lib/trainings.js b/lib/trainings.js deleted file mode 100644 index c2512947..00000000 --- a/lib/trainings.js +++ /dev/null @@ -1,53 +0,0 @@ -/** - * Create a new training - * - * @param {string} model_owner - Required. The username of the user or organization who owns the model - * @param {string} model_name - Required. The name of the model - * @param {string} version_id - Required. The version ID - * @param {object} options - * @param {string} options.destination - Required. The destination for the trained version in the form "{username}/{model_name}" - * @param {object} options.input - Required. An object with the model inputs - * @param {string} [options.webhook] - An HTTPS URL for receiving a webhook when the training updates - * @param {string[]} [options.webhook_events_filter] - You can change which events trigger webhook requests by specifying webhook events (`start`|`output`|`logs`|`completed`) - * @returns {Promise} Resolves with the data for the created training - */ -async function createTraining(model_owner, model_name, version_id, options) { - const { ...data } = options; - - const training = this.request(`/models/${model_owner}/${model_name}/versions/${version_id}/trainings`, { - method: 'POST', - data, - }); - - return training; -} - -/** - * Fetch a training by ID - * - * @param {string} training_id - Required. The training ID - * @returns {Promise} Resolves with the data for the training - */ -async function getTraining(training_id) { - return this.request(`/trainings/${training_id}`, { - method: 'GET', - }); -} - -/** - * Cancel a training by ID - * - * @param {string} training_id - Required. The training ID - * @returns {Promise} Resolves with the data for the training - */ -async function cancelTraining(training_id) { - return this.request(`/trainings/${training_id}/cancel`, { - method: 'POST', - }); -} - -module.exports = { - create: createTraining, - get: getTraining, - cancel: cancelTraining, -}; diff --git a/lib/trainings.ts b/lib/trainings.ts new file mode 100644 index 00000000..27c31b65 --- /dev/null +++ b/lib/trainings.ts @@ -0,0 +1,80 @@ +import APIClient from './apiClient'; +import { Status, WebhookEventType } from './types'; + +export interface TrainingOptions { + destination: `${string}/${string}`; + input: object; + webhook?: string; + webhook_events_filter?: WebhookEventType[]; +} + +export interface Training { + id: string; + version: string; + status: Status; + input: object; + output: object | null; + error: any; + logs: any; + started_at: string | null; + created_at: string | null; + completed_at: string | null; +} + +export default class Trainings { + constructor(private client: APIClient) {} + + /** + * Create a new training + * + * @param {string} model_owner - Required. The username of the user or organization who owns the model + * @param {string} model_name - Required. The name of the model + * @param {string} version_id - Required. The version ID + * @param {object} options + * @param {string} options.destination - Required. The destination for the trained version in the form "{username}/{model_name}" + * @param {object} options.input - Required. An object with the model inputs + * @param {string} [options.webhook] - An HTTPS URL for receiving a webhook when the training updates + * @param {string[]} [options.webhook_events_filter] - You can change which events trigger webhook requests by specifying webhook events (`start`|`output`|`logs`|`completed`) + * @returns {Promise} Resolves with the data for the created training + */ + async create( + model_owner: string, + model_name: string, + version_id: string, + options: TrainingOptions + ): Promise { + const { ...data } = options; + + return await this.client.request( + `/models/${model_owner}/${model_name}/versions/${version_id}/trainings`, + { + method: 'POST', + data, + } + ); + } + + /** + * Fetch a training by ID + * + * @param {string} training_id - Required. The training ID + * @returns {Promise} Resolves with the data for the training + */ + async get(training_id: string): Promise { + return this.client.request(`/trainings/${training_id}`, { + method: 'GET', + }); + } + + /** + * Cancel a training by ID + * + * @param {string} training_id - Required. The training ID + * @returns {Promise} Resolves with the data for the training + */ + async cancel(training_id: string): Promise { + return this.client.request(`/trainings/${training_id}/cancel`, { + method: 'POST', + }); + } +} diff --git a/lib/types.ts b/lib/types.ts new file mode 100644 index 00000000..1b67ad2a --- /dev/null +++ b/lib/types.ts @@ -0,0 +1,15 @@ +export type WebhookEventType = 'start' | 'output' | 'logs' | 'completed'; +export type Status = + | 'starting' + | 'processing' + | 'succeeded' + | 'failed' + | 'canceled'; + +export { Request, Response, Headers } from 'cross-fetch'; //Note: not used in build. + +type RequestInfo = Request | string; +export type FetchLike = ( + input: RequestInfo | URL, + init?: RequestInit +) => Promise; diff --git a/package.json b/package.json index 071b9f48..c1ea9bee 100644 --- a/package.json +++ b/package.json @@ -1,16 +1,20 @@ { "name": "replicate", - "version": "0.9.0", + "version": "0.9.1", "description": "JavaScript client for Replicate", "repository": "github:replicate/replicate-javascript", "homepage": "https://github.com/replicate/replicate-javascript#readme", "bugs": "https://github.com/replicate/replicate-javascript/issues", "license": "Apache-2.0", - "main": "index.js", + "main": "dist/index.js", + "types": "dist/index.d.ts", "scripts": { "lint": "eslint .", - "test": "jest" + "test": "jest", + "build": "tsc", + "clean": "rm -rf ./dist" }, + "prepublish": "tsc", "devDependencies": { "@tsconfig/recommended": "^1.0.2", "@types/jest": "^29.5.0", diff --git a/tsconfig.json b/tsconfig.json index 97cb33b3..1297ec05 100644 --- a/tsconfig.json +++ b/tsconfig.json @@ -1,3 +1,14 @@ { - "extends": "@tsconfig/recommended/tsconfig.json" + "extends": "@tsconfig/recommended/tsconfig.json", + "compilerOptions": { + "target": "es5", + "module": "commonjs", + "lib": ["es2017", "es7", "es6", "dom"], + "declaration": true, + "outDir": "dist", + "strict": true, + "esModuleInterop": true, + "resolveJsonModule": true + }, + "exclude": ["node_modules", "dist"] }