diff --git a/README.md b/README.md index 36fbcb0f..1eb7a2b5 100644 --- a/README.md +++ b/README.md @@ -236,6 +236,43 @@ const response = await replicate.models.list(); } ``` +### `replicate.models.create` + +Create a new public or private model. + +```js +const response = await replicate.models.create(model_owner, model_name, options); +``` + +| name | type | description | +| ------------------------- | ------ | --------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- | +| `model_owner` | string | **Required**. The name of the user or organization that will own the model. This must be the same as the user or organization that is making the API request. In other words, the API token used in the request must belong to this user or organization. | +| `model_name` | string | **Required**. The name of the model. This must be unique among all models owned by the user or organization. | +| `options.visibility` | string | **Required**. Whether the model should be public or private. A public model can be viewed and run by anyone, whereas a private model can be viewed and run only by the user or organization members that own the model. | +| `options.hardware` | string | **Required**. The SKU for the hardware used to run the model. Possible values can be found by calling [`replicate.hardware.list()](#replicatehardwarelist)`. | +| `options.description` | string | A description of the model. | +| `options.github_url` | string | A URL for the model's source code on GitHub. | +| `options.paper_url` | string | A URL for the model's paper. | +| `options.license_url` | string | A URL for the model's license. | +| `options.cover_image_url` | string | A URL for the model's cover image. This should be an image file. | + +### `replicate.hardware.list` + +List available hardware for running models on Replicate. + +```js +const response = await replicate.hardware.list() +``` + +```jsonc +[ + {"name": "CPU", "sku": "cpu" }, + {"name": "Nvidia T4 GPU", "sku": "gpu-t4" }, + {"name": "Nvidia A40 GPU", "sku": "gpu-a40-small" }, + {"name": "Nvidia A40 (Large) GPU", "sku": "gpu-a40-large" }, +] +``` + ### `replicate.models.versions.list` Get a list of all published versions of a model, including input and output schemas for each version. diff --git a/index.d.ts b/index.d.ts index 601e15b8..a3e2ee04 100644 --- a/index.d.ts +++ b/index.d.ts @@ -1,5 +1,6 @@ declare module 'replicate' { type Status = 'starting' | 'processing' | 'succeeded' | 'failed' | 'canceled'; + type Visibility = 'public' | 'private'; type WebhookEventType = 'start' | 'output' | 'logs' | 'completed'; export interface ApiError extends Error { @@ -14,6 +15,11 @@ declare module 'replicate' { models?: Model[]; } + export interface Hardware { + sku: string; + name: string + } + export interface Model { url: string; owner: string; @@ -115,9 +121,40 @@ declare module 'replicate' { get(collection_slug: string): Promise; }; + deployments: { + predictions: { + create( + deployment_owner: string, + deployment_name: string, + options: { + input: object; + stream?: boolean; + webhook?: string; + webhook_events_filter?: WebhookEventType[]; + } + ): Promise; + }; + }; + + hardware: { + list(): Promise + } + models: { get(model_owner: string, model_name: string): Promise; list(): Promise>; + create( + model_owner: string, + model_name: string, + options: { + visibility: Visibility; + hardware: string; + description?: string; + github_url?: string; + paper_url?: string; + license_url?: string; + cover_image_url?: string; + }): Promise; versions: { list(model_owner: string, model_name: string): Promise; get( @@ -157,20 +194,5 @@ declare module 'replicate' { cancel(training_id: string): Promise; list(): Promise>; }; - - deployments: { - predictions: { - create( - deployment_owner: string, - deployment_name: string, - options: { - input: object; - stream?: boolean; - webhook?: string; - webhook_events_filter?: WebhookEventType[]; - } - ): Promise; - }; - }; } } diff --git a/index.js b/index.js index 902ba572..acb07eb7 100644 --- a/index.js +++ b/index.js @@ -3,6 +3,7 @@ const { withAutomaticRetries } = require('./lib/util'); const collections = require('./lib/collections'); const deployments = require('./lib/deployments'); +const hardware = require('./lib/hardware'); const models = require('./lib/models'); const predictions = require('./lib/predictions'); const trainings = require('./lib/trainings'); @@ -49,9 +50,20 @@ class Replicate { get: collections.get.bind(this), }; + this.deployments = { + predictions: { + create: deployments.predictions.create.bind(this), + } + }; + + this.hardware = { + list: hardware.list.bind(this), + }; + this.models = { get: models.get.bind(this), list: models.list.bind(this), + create: models.create.bind(this), versions: { list: models.versions.list.bind(this), get: models.versions.get.bind(this), @@ -71,12 +83,6 @@ class Replicate { cancel: trainings.cancel.bind(this), list: trainings.list.bind(this), }; - - this.deployments = { - predictions: { - create: deployments.predictions.create.bind(this), - } - }; } /** diff --git a/index.test.ts b/index.test.ts index ab4e9d66..377357b1 100644 --- a/index.test.ts +++ b/index.test.ts @@ -136,12 +136,12 @@ describe('Replicate client', () => { nock(BASE_URL) .get('/models') .reply(200, { - results: [{ url: 'https://replicate.com/some-user/model-1' }], + results: [ { url: 'https://replicate.com/some-user/model-1' } ], next: 'https://api.replicate.com/v1/models?cursor=cD0yMDIyLTAxLTIxKzIzJTNBMTglM0EyNC41MzAzNTclMkIwMCUzQTAw', }) .get('/models?cursor=cD0yMDIyLTAxLTIxKzIzJTNBMTglM0EyNC41MzAzNTclMkIwMCUzQTAw') .reply(200, { - results: [{ url: 'https://replicate.com/some-user/model-2' }], + results: [ { url: 'https://replicate.com/some-user/model-2' } ], next: null, }); @@ -149,7 +149,7 @@ describe('Replicate client', () => { for await (const batch of client.paginate(client.models.list)) { results.push(...batch); } - expect(results).toEqual([{ url: 'https://replicate.com/some-user/model-1' }, { url: 'https://replicate.com/some-user/model-2' }]); + expect(results).toEqual([ { url: 'https://replicate.com/some-user/model-1' }, { url: 'https://replicate.com/some-user/model-2' } ]); // Add more tests for error handling, edge cases, etc. }); @@ -662,6 +662,54 @@ describe('Replicate client', () => { // Add more tests for error handling, edge cases, etc. }); + describe('hardware.list', () => { + test('Calls the correct API route', async () => { + nock(BASE_URL) + .get('/hardware') + .reply(200, [ + { name: "CPU", sku: "cpu" }, + { name: "Nvidia T4 GPU", sku: "gpu-t4" }, + { name: "Nvidia A40 GPU", sku: "gpu-a40-small" }, + { name: "Nvidia A40 (Large) GPU", sku: "gpu-a40-large" }, + ]); + + const hardware = await client.hardware.list(); + expect(hardware.length).toBe(4); + expect(hardware[ 0 ].name).toBe('CPU'); + expect(hardware[ 0 ].sku).toBe('cpu'); + }); + // Add more tests for error handling, edge cases, etc. + }); + + describe('models.create', () => { + test('Calls the correct API route with the correct payload', async () => { + nock(BASE_URL) + .post('/models') + .reply(200, { + owner: 'test-owner', + name: 'test-model', + visibility: 'public', + hardware: 'cpu', + description: 'A test model', + }); + + const model = await client.models.create( + 'test-owner', + 'test-model', + { + visibility: 'public', + hardware: 'cpu', + description: 'A test model', + }); + + expect(model.owner).toBe('test-owner'); + expect(model.name).toBe('test-model'); + expect(model.visibility).toBe('public'); + // expect(model.hardware).toBe('cpu'); + expect(model.description).toBe('A test model'); + }); + }); + describe('run', () => { test('Calls the correct API routes', async () => { let firstPollingRequest = true; diff --git a/lib/hardware.js b/lib/hardware.js new file mode 100644 index 00000000..487f3b88 --- /dev/null +++ b/lib/hardware.js @@ -0,0 +1,16 @@ +/** + * List hardware + * + * @returns {Promise} Resolves with the array of hardware + */ +async function listHardware() { + const response = await this.request('/hardware', { + method: 'GET', + }); + + return response.json(); +} + +module.exports = { + list: listHardware, +}; diff --git a/lib/models.js b/lib/models.js index be057503..3c4e5b1d 100644 --- a/lib/models.js +++ b/lib/models.js @@ -57,8 +57,35 @@ async function listModels() { return response.json(); } +/** + * Create a new model + * + * @param {string} model_owner - Required. The name of the user or organization that will own the model. This must be the same as the user or organization that is making the API request. In other words, the API token used in the request must belong to this user or organization. + * @param {string} model_name - Required. The name of the model. This must be unique among all models owned by the user or organization. + * @param {object} options + * @param {("public"|"private")} options.visibility - Required. Whether the model should be public or private. A public model can be viewed and run by anyone, whereas a private model can be viewed and run only by the user or organization members that own the model. + * @param {string} options.hardware - Required. The SKU for the hardware used to run the model. Possible values can be found by calling `Replicate.hardware.list()`. + * @param {string} options.description - A description of the model. + * @param {string} options.github_url - A URL for the model's source code on GitHub. + * @param {string} options.paper_url - A URL for the model's paper. + * @param {string} options.license_url - A URL for the model's license. + * @param {string} options.cover_image_url - A URL for the model's cover image. This should be an image file. + * @returns {Promise} Resolves with the model version data + */ +async function createModel(model_owner, model_name, options) { + const data = { owner: model_owner, name: model_name, ...options }; + + const response = await this.request('/models', { + method: 'POST', + data, + }); + + return response.json(); +} + module.exports = { get: getModel, list: listModels, + create: createModel, versions: { list: listModelVersions, get: getModelVersion }, };