From 8f9408feaea382d5eb01c1420a1c7eb910ae20a2 Mon Sep 17 00:00:00 2001 From: Mattt Zmuda Date: Wed, 1 Nov 2023 14:22:15 -0700 Subject: [PATCH 1/4] Add support for hardware.list endpoint --- index.d.ts | 39 ++++++++++++++++++++++++--------------- index.js | 17 +++++++++++------ index.test.ts | 25 ++++++++++++++++++++++--- lib/hardware.js | 16 ++++++++++++++++ 4 files changed, 73 insertions(+), 24 deletions(-) create mode 100644 lib/hardware.js diff --git a/index.d.ts b/index.d.ts index 601e15b8..f606e9bc 100644 --- a/index.d.ts +++ b/index.d.ts @@ -14,6 +14,11 @@ declare module 'replicate' { models?: Model[]; } + export interface Hardware { + sku: string; + name: string + } + export interface Model { url: string; owner: string; @@ -115,6 +120,25 @@ declare module 'replicate' { get(collection_slug: string): Promise; }; + deployments: { + predictions: { + create( + deployment_owner: string, + deployment_name: string, + options: { + input: object; + stream?: boolean; + webhook?: string; + webhook_events_filter?: WebhookEventType[]; + } + ): Promise; + }; + }; + + hardware: { + list(): Promise + } + models: { get(model_owner: string, model_name: string): Promise; list(): Promise>; @@ -157,20 +181,5 @@ declare module 'replicate' { cancel(training_id: string): Promise; list(): Promise>; }; - - deployments: { - predictions: { - create( - deployment_owner: string, - deployment_name: string, - options: { - input: object; - stream?: boolean; - webhook?: string; - webhook_events_filter?: WebhookEventType[]; - } - ): Promise; - }; - }; } } diff --git a/index.js b/index.js index 902ba572..ca7fb890 100644 --- a/index.js +++ b/index.js @@ -3,6 +3,7 @@ const { withAutomaticRetries } = require('./lib/util'); const collections = require('./lib/collections'); const deployments = require('./lib/deployments'); +const hardware = require('./lib/hardware'); const models = require('./lib/models'); const predictions = require('./lib/predictions'); const trainings = require('./lib/trainings'); @@ -49,6 +50,16 @@ class Replicate { get: collections.get.bind(this), }; + this.deployments = { + predictions: { + create: deployments.predictions.create.bind(this), + } + }; + + this.hardware = { + list: hardware.list.bind(this), + }; + this.models = { get: models.get.bind(this), list: models.list.bind(this), @@ -71,12 +82,6 @@ class Replicate { cancel: trainings.cancel.bind(this), list: trainings.list.bind(this), }; - - this.deployments = { - predictions: { - create: deployments.predictions.create.bind(this), - } - }; } /** diff --git a/index.test.ts b/index.test.ts index ab4e9d66..63815d7b 100644 --- a/index.test.ts +++ b/index.test.ts @@ -136,12 +136,12 @@ describe('Replicate client', () => { nock(BASE_URL) .get('/models') .reply(200, { - results: [{ url: 'https://replicate.com/some-user/model-1' }], + results: [ { url: 'https://replicate.com/some-user/model-1' } ], next: 'https://api.replicate.com/v1/models?cursor=cD0yMDIyLTAxLTIxKzIzJTNBMTglM0EyNC41MzAzNTclMkIwMCUzQTAw', }) .get('/models?cursor=cD0yMDIyLTAxLTIxKzIzJTNBMTglM0EyNC41MzAzNTclMkIwMCUzQTAw') .reply(200, { - results: [{ url: 'https://replicate.com/some-user/model-2' }], + results: [ { url: 'https://replicate.com/some-user/model-2' } ], next: null, }); @@ -149,7 +149,7 @@ describe('Replicate client', () => { for await (const batch of client.paginate(client.models.list)) { results.push(...batch); } - expect(results).toEqual([{ url: 'https://replicate.com/some-user/model-1' }, { url: 'https://replicate.com/some-user/model-2' }]); + expect(results).toEqual([ { url: 'https://replicate.com/some-user/model-1' }, { url: 'https://replicate.com/some-user/model-2' } ]); // Add more tests for error handling, edge cases, etc. }); @@ -662,6 +662,25 @@ describe('Replicate client', () => { // Add more tests for error handling, edge cases, etc. }); + describe('hardware.list', () => { + test('Calls the correct API route', async () => { + nock(BASE_URL) + .get('/hardware') + .reply(200, [ + { name: "CPU", sku: "cpu" }, + { name: "Nvidia T4 GPU", sku: "gpu-t4" }, + { name: "Nvidia A40 GPU", sku: "gpu-a40-small" }, + { name: "Nvidia A40 (Large) GPU", sku: "gpu-a40-large" }, + ]); + + const hardware = await client.hardware.list(); + expect(hardware.length).toBe(4); + expect(hardware[ 0 ].name).toBe('CPU'); + expect(hardware[ 0 ].sku).toBe('cpu'); + }); + // Add more tests for error handling, edge cases, etc. + }); + describe('run', () => { test('Calls the correct API routes', async () => { let firstPollingRequest = true; diff --git a/lib/hardware.js b/lib/hardware.js new file mode 100644 index 00000000..487f3b88 --- /dev/null +++ b/lib/hardware.js @@ -0,0 +1,16 @@ +/** + * List hardware + * + * @returns {Promise} Resolves with the array of hardware + */ +async function listHardware() { + const response = await this.request('/hardware', { + method: 'GET', + }); + + return response.json(); +} + +module.exports = { + list: listHardware, +}; From 00a5963ae6aa02cdc0a4bb3057590b1fc7f7aed8 Mon Sep 17 00:00:00 2001 From: Mattt Zmuda Date: Thu, 2 Nov 2023 04:00:01 -0700 Subject: [PATCH 2/4] Add support for models.create endpoint --- index.d.ts | 13 +++++++++++++ index.js | 1 + index.test.ts | 29 +++++++++++++++++++++++++++++ lib/models.js | 27 +++++++++++++++++++++++++++ 4 files changed, 70 insertions(+) diff --git a/index.d.ts b/index.d.ts index f606e9bc..a3e2ee04 100644 --- a/index.d.ts +++ b/index.d.ts @@ -1,5 +1,6 @@ declare module 'replicate' { type Status = 'starting' | 'processing' | 'succeeded' | 'failed' | 'canceled'; + type Visibility = 'public' | 'private'; type WebhookEventType = 'start' | 'output' | 'logs' | 'completed'; export interface ApiError extends Error { @@ -142,6 +143,18 @@ declare module 'replicate' { models: { get(model_owner: string, model_name: string): Promise; list(): Promise>; + create( + model_owner: string, + model_name: string, + options: { + visibility: Visibility; + hardware: string; + description?: string; + github_url?: string; + paper_url?: string; + license_url?: string; + cover_image_url?: string; + }): Promise; versions: { list(model_owner: string, model_name: string): Promise; get( diff --git a/index.js b/index.js index ca7fb890..acb07eb7 100644 --- a/index.js +++ b/index.js @@ -63,6 +63,7 @@ class Replicate { this.models = { get: models.get.bind(this), list: models.list.bind(this), + create: models.create.bind(this), versions: { list: models.versions.list.bind(this), get: models.versions.get.bind(this), diff --git a/index.test.ts b/index.test.ts index 63815d7b..afba8ca2 100644 --- a/index.test.ts +++ b/index.test.ts @@ -681,6 +681,35 @@ describe('Replicate client', () => { // Add more tests for error handling, edge cases, etc. }); + describe('models.create', () => { + test('Calls the correct API route with the correct payload', async () => { + nock(BASE_URL) + .post('/models') + .reply(200, { + owner: 'test-owner', + name: 'test-model', + visibility: 'public', + hardware: 'cpu', + description: 'A test model', + }); + + const model = await client.models.create( + 'test-owner', + 'test-model', + { + visibility: 'public', + hardware: 'cpu', + description: 'A test model', + }); + + expect(model.owner).toBe('test-owner'); + expect(model.name).toBe('test-model'); + expect(model.visibility).toBe('public'); + // expect(model.hardware).toBe('test-hardware'); + expect(model.description).toBe('A test model'); + }); + }); + describe('run', () => { test('Calls the correct API routes', async () => { let firstPollingRequest = true; diff --git a/lib/models.js b/lib/models.js index be057503..3c4e5b1d 100644 --- a/lib/models.js +++ b/lib/models.js @@ -57,8 +57,35 @@ async function listModels() { return response.json(); } +/** + * Create a new model + * + * @param {string} model_owner - Required. The name of the user or organization that will own the model. This must be the same as the user or organization that is making the API request. In other words, the API token used in the request must belong to this user or organization. + * @param {string} model_name - Required. The name of the model. This must be unique among all models owned by the user or organization. + * @param {object} options + * @param {("public"|"private")} options.visibility - Required. Whether the model should be public or private. A public model can be viewed and run by anyone, whereas a private model can be viewed and run only by the user or organization members that own the model. + * @param {string} options.hardware - Required. The SKU for the hardware used to run the model. Possible values can be found by calling `Replicate.hardware.list()`. + * @param {string} options.description - A description of the model. + * @param {string} options.github_url - A URL for the model's source code on GitHub. + * @param {string} options.paper_url - A URL for the model's paper. + * @param {string} options.license_url - A URL for the model's license. + * @param {string} options.cover_image_url - A URL for the model's cover image. This should be an image file. + * @returns {Promise} Resolves with the model version data + */ +async function createModel(model_owner, model_name, options) { + const data = { owner: model_owner, name: model_name, ...options }; + + const response = await this.request('/models', { + method: 'POST', + data, + }); + + return response.json(); +} + module.exports = { get: getModel, list: listModels, + create: createModel, versions: { list: listModelVersions, get: getModelVersion }, }; From e473ac4d4dfa768cb820a4da9be8de756d68c1bd Mon Sep 17 00:00:00 2001 From: Mattt Zmuda Date: Sun, 5 Nov 2023 05:27:55 -0800 Subject: [PATCH 3/4] Update README --- README.md | 37 +++++++++++++++++++++++++++++++++++++ 1 file changed, 37 insertions(+) diff --git a/README.md b/README.md index 36fbcb0f..1eb7a2b5 100644 --- a/README.md +++ b/README.md @@ -236,6 +236,43 @@ const response = await replicate.models.list(); } ``` +### `replicate.models.create` + +Create a new public or private model. + +```js +const response = await replicate.models.create(model_owner, model_name, options); +``` + +| name | type | description | +| ------------------------- | ------ | --------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- | +| `model_owner` | string | **Required**. The name of the user or organization that will own the model. This must be the same as the user or organization that is making the API request. In other words, the API token used in the request must belong to this user or organization. | +| `model_name` | string | **Required**. The name of the model. This must be unique among all models owned by the user or organization. | +| `options.visibility` | string | **Required**. Whether the model should be public or private. A public model can be viewed and run by anyone, whereas a private model can be viewed and run only by the user or organization members that own the model. | +| `options.hardware` | string | **Required**. The SKU for the hardware used to run the model. Possible values can be found by calling [`replicate.hardware.list()](#replicatehardwarelist)`. | +| `options.description` | string | A description of the model. | +| `options.github_url` | string | A URL for the model's source code on GitHub. | +| `options.paper_url` | string | A URL for the model's paper. | +| `options.license_url` | string | A URL for the model's license. | +| `options.cover_image_url` | string | A URL for the model's cover image. This should be an image file. | + +### `replicate.hardware.list` + +List available hardware for running models on Replicate. + +```js +const response = await replicate.hardware.list() +``` + +```jsonc +[ + {"name": "CPU", "sku": "cpu" }, + {"name": "Nvidia T4 GPU", "sku": "gpu-t4" }, + {"name": "Nvidia A40 GPU", "sku": "gpu-a40-small" }, + {"name": "Nvidia A40 (Large) GPU", "sku": "gpu-a40-large" }, +] +``` + ### `replicate.models.versions.list` Get a list of all published versions of a model, including input and output schemas for each version. From 295786380b69a2a6e8fe73d5404f6bab9525948f Mon Sep 17 00:00:00 2001 From: Mattt Zmuda Date: Mon, 6 Nov 2023 03:51:57 -0800 Subject: [PATCH 4/4] Update test expectation --- index.test.ts | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/index.test.ts b/index.test.ts index afba8ca2..377357b1 100644 --- a/index.test.ts +++ b/index.test.ts @@ -705,7 +705,7 @@ describe('Replicate client', () => { expect(model.owner).toBe('test-owner'); expect(model.name).toBe('test-model'); expect(model.visibility).toBe('public'); - // expect(model.hardware).toBe('test-hardware'); + // expect(model.hardware).toBe('cpu'); expect(model.description).toBe('A test model'); }); });