diff --git a/index.d.ts b/index.d.ts index 27dcdcb8..455ea28e 100644 --- a/index.d.ts +++ b/index.d.ts @@ -113,5 +113,21 @@ declare module 'replicate' { get(prediction_id: string): Promise; list(): Promise>; }; + + trainings: { + create( + model_owner: string, + model_name: string, + version_id: string, + options: { + destination: `${string}/${string}`; + input: object; + webhook?: string; + webhook_events_filter?: WebhookEventType[]; + } + ): Promise; + get(options: TrainingsGetOptions): Promise; + cancel(options: TrainingsGetOptions): Promise; + }; } } diff --git a/index.js b/index.js index cd9468df..cecc45a8 100644 --- a/index.js +++ b/index.js @@ -3,6 +3,7 @@ const axios = require('axios'); const collections = require('./lib/collections'); const models = require('./lib/models'); const predictions = require('./lib/predictions'); +const trainings = require('./lib/trainings'); const packageJSON = require('./package.json'); /** @@ -63,6 +64,12 @@ class Replicate { get: predictions.get.bind(this), list: predictions.list.bind(this), }; + + this.trainings = { + create: trainings.create.bind(this), + get: trainings.get.bind(this), + cancel: trainings.cancel.bind(this), + }; } /** diff --git a/index.test.ts b/index.test.ts index 70e644f8..97e1f029 100644 --- a/index.test.ts +++ b/index.test.ts @@ -9,7 +9,7 @@ describe('Replicate client', () => { beforeEach(() => { client = new Replicate({ auth: 'test-token' }); - client['instance'] = jest.fn(); + client[ 'instance' ] = jest.fn(); }); describe('constructor', () => { @@ -36,7 +36,7 @@ describe('Replicate client', () => { describe('collections.get', () => { test('Calls the correct API route', async () => { - client['instance'].mockResolvedValueOnce({ + client[ 'instance' ].mockResolvedValueOnce({ data: { name: 'Super resolution', slug: 'super-resolution', @@ -46,7 +46,7 @@ describe('Replicate client', () => { }, }); const collection = await client.collections.get('super-resolution'); - expect(client['instance']).toHaveBeenCalledWith( + expect(client[ 'instance' ]).toHaveBeenCalledWith( '/collections/super-resolution', { method: 'GET', @@ -60,7 +60,7 @@ describe('Replicate client', () => { describe('models.get', () => { test('Calls the correct API route', async () => { - client['instance'].mockResolvedValueOnce({ + client[ 'instance' ].mockResolvedValueOnce({ data: { url: 'https://replicate.com/replicate/hello-world', owner: 'replicate', @@ -77,7 +77,7 @@ describe('Replicate client', () => { }, }); await client.models.get('replicate', 'hello-world'); - expect(client['instance']).toHaveBeenCalledWith( + expect(client[ 'instance' ]).toHaveBeenCalledWith( '/models/replicate/hello-world', { method: 'GET', @@ -90,7 +90,7 @@ describe('Replicate client', () => { describe('predictions.create', () => { test('Calls the correct API route with the correct payload', async () => { - client['instance'].mockResolvedValueOnce({ + client[ 'instance' ].mockResolvedValueOnce({ data: { id: 'ufawqhfynnddngldkgtslldrkq', version: @@ -121,11 +121,11 @@ describe('Replicate client', () => { text: 'Alice', }, webhook: 'http://test.host/webhook', - webhook_events_filter: ['output', 'completed'], + webhook_events_filter: [ 'output', 'completed' ], }); expect(prediction.id).toBe('ufawqhfynnddngldkgtslldrkq'); - expect(client['instance']).toHaveBeenCalledWith('/predictions', { + expect(client[ 'instance' ]).toHaveBeenCalledWith('/predictions', { method: 'POST', data: { version: @@ -134,7 +134,7 @@ describe('Replicate client', () => { text: 'Alice', }, webhook: 'http://test.host/webhook', - webhook_events_filter: ['output', 'completed'], + webhook_events_filter: [ 'output', 'completed' ], }, }); }); @@ -144,7 +144,7 @@ describe('Replicate client', () => { describe('predictions.get', () => { test('Calls the correct API route with the correct payload', async () => { - client['instance'].mockResolvedValueOnce({ + client[ 'instance' ].mockResolvedValueOnce({ data: { id: 'rrr4z55ocneqzikepnug6xezpe', version: @@ -178,7 +178,7 @@ describe('Replicate client', () => { ); expect(prediction.id).toBe('rrr4z55ocneqzikepnug6xezpe'); - expect(client['instance']).toHaveBeenCalledWith( + expect(client[ 'instance' ]).toHaveBeenCalledWith( '/predictions/rrr4z55ocneqzikepnug6xezpe', { method: 'GET', @@ -191,7 +191,7 @@ describe('Replicate client', () => { describe('predictions.list', () => { test('Calls the correct API route with the correct payload', async () => { - client['instance'].mockResolvedValueOnce({ + client[ 'instance' ].mockResolvedValueOnce({ data: { next: 'https://api.replicate.com/v1/predictions?cursor=cD0yMDIyLTAxLTIxKzIzJTNBMTglM0EyNC41MzAzNTclMkIwMCUzQTAw', previous: null, @@ -217,23 +217,23 @@ describe('Replicate client', () => { const predictions = await client.predictions.list(); expect(predictions.results.length).toBe(1); - expect(predictions.results[0].id).toBe('jpzd7hm5gfcapbfyt4mqytarku'); + expect(predictions.results[ 0 ].id).toBe('jpzd7hm5gfcapbfyt4mqytarku'); - expect(client['instance']).toHaveBeenCalledWith('/predictions', { + expect(client[ 'instance' ]).toHaveBeenCalledWith('/predictions', { method: 'GET', }); }); test('Paginates results', async () => { - client['instance'].mockResolvedValueOnce({ + client[ 'instance' ].mockResolvedValueOnce({ data: { - results: [{ id: 'ufawqhfynnddngldkgtslldrkq' }], + results: [ { id: 'ufawqhfynnddngldkgtslldrkq' } ], next: 'https://api.replicate.com/v1/predictions?cursor=cD0yMDIyLTAxLTIxKzIzJTNBMTglM0EyNC41MzAzNTclMkIwMCUzQTAw', }, }); - client['instance'].mockResolvedValueOnce({ + client[ 'instance' ].mockResolvedValueOnce({ data: { - results: [{ id: 'rrr4z55ocneqzikepnug6xezpe' }], + results: [ { id: 'rrr4z55ocneqzikepnug6xezpe' } ], next: null, }, }); @@ -248,10 +248,10 @@ describe('Replicate client', () => { { id: 'rrr4z55ocneqzikepnug6xezpe' }, ]); - expect(client['instance']).toHaveBeenCalledWith('/predictions', { + expect(client[ 'instance' ]).toHaveBeenCalledWith('/predictions', { method: 'GET', }); - expect(client['instance']).toHaveBeenCalledWith( + expect(client[ 'instance' ]).toHaveBeenCalledWith( 'https://api.replicate.com/v1/predictions?cursor=cD0yMDIyLTAxLTIxKzIzJTNBMTglM0EyNC41MzAzNTclMkIwMCUzQTAw', { method: 'GET', @@ -262,15 +262,129 @@ describe('Replicate client', () => { // Add more tests for error handling, edge cases, etc. }); + describe('trainings.create', () => { + test('Calls the correct API route with the correct payload', async () => { + client[ 'instance' ].mockResolvedValueOnce({ + data: { + "id": "zz4ibbonubfz7carwiefibzgga", + "version": "{version}", + "status": "starting", + "input": { + "text": "..." + }, + "output": null, + "error": null, + "logs": null, + "started_at": null, + "created_at": "2023-03-28T21:47:58.566434Z", + "completed_at": null + } + }); + + const training = await client.trainings.create( + 'owner', + 'model', + '632231d0d49d34d5c4633bd838aee3d81d936e59a886fbf28524702003b4c532', + { + destination: 'new_owner/new_model', + input: { + text: '...' + } + } + ); + expect(training.id).toBe('zz4ibbonubfz7carwiefibzgga'); + + expect(client[ 'instance' ]).toHaveBeenCalledWith('/models/owner/model/versions/632231d0d49d34d5c4633bd838aee3d81d936e59a886fbf28524702003b4c532/trainings', { + method: 'POST', + data: { + destination: 'new_owner/new_model', + input: { + text: '...' + }, + } + }); + }); + + // Add more tests for error handling, edge cases, etc. + }); + + describe('trainings.get', () => { + test('Calls the correct API route with the correct payload', async () => { + client[ 'instance' ].mockResolvedValueOnce({ + data: { + "id": "zz4ibbonubfz7carwiefibzgga", + "version": "{version}", + "status": "succeeded", + "input": { + "data": "...", + "param1": "..." + }, + "output": { + "version": "..." + }, + "error": null, + "logs": null, + "webhook_completed": null, + "started_at": null, + "created_at": "2023-03-28T21:47:58.566434Z", + "completed_at": null + } + }); + + const training = await client.trainings.get('zz4ibbonubfz7carwiefibzgga'); + expect(training.status).toBe('succeeded'); + + expect(client[ 'instance' ]).toHaveBeenCalledWith('/trainings/zz4ibbonubfz7carwiefibzgga', { + method: 'GET', + }); + }); + + // Add more tests for error handling, edge cases, etc. + }); + + describe('trainings.cancel', () => { + test('Calls the correct API route with the correct payload', async () => { + client[ 'instance' ].mockResolvedValueOnce({ + data: { + "id": "zz4ibbonubfz7carwiefibzgga", + "version": "{version}", + "status": "canceled", + "input": { + "data": "...", + "param1": "..." + }, + "output": { + "version": "..." + }, + "error": null, + "logs": null, + "webhook_completed": null, + "started_at": null, + "created_at": "2023-03-28T21:47:58.566434Z", + "completed_at": null + } + }); + + const training = await client.trainings.cancel("zz4ibbonubfz7carwiefibzgga"); + expect(training.status).toBe('canceled'); + + expect(client[ 'instance' ]).toHaveBeenCalledWith('/trainings/zz4ibbonubfz7carwiefibzgga/cancel', { + method: 'POST', + }); + }); + + // Add more tests for error handling, edge cases, etc. + }); + describe('run', () => { test('Calls the correct API routes', async () => { - client['instance'].mockResolvedValueOnce({ + client[ 'instance' ].mockResolvedValueOnce({ data: { id: 'ufawqhfynnddngldkgtslldrkq', status: 'processing', }, }); - client['instance'].mockResolvedValueOnce({ + client[ 'instance' ].mockResolvedValueOnce({ data: { id: 'ufawqhfynnddngldkgtslldrkq', status: 'succeeded', @@ -283,7 +397,7 @@ describe('Replicate client', () => { input: { text: 'Hello, world!' }, } ); - expect(client['instance']).toHaveBeenCalledWith('/predictions', { + expect(client[ 'instance' ]).toHaveBeenCalledWith('/predictions', { method: 'POST', data: { version: @@ -293,7 +407,7 @@ describe('Replicate client', () => { }, }, }); - expect(client['instance']).toHaveBeenCalledWith( + expect(client[ 'instance' ]).toHaveBeenCalledWith( '/predictions/ufawqhfynnddngldkgtslldrkq', { method: 'GET', diff --git a/lib/trainings.js b/lib/trainings.js new file mode 100644 index 00000000..c2512947 --- /dev/null +++ b/lib/trainings.js @@ -0,0 +1,53 @@ +/** + * Create a new training + * + * @param {string} model_owner - Required. The username of the user or organization who owns the model + * @param {string} model_name - Required. The name of the model + * @param {string} version_id - Required. The version ID + * @param {object} options + * @param {string} options.destination - Required. The destination for the trained version in the form "{username}/{model_name}" + * @param {object} options.input - Required. An object with the model inputs + * @param {string} [options.webhook] - An HTTPS URL for receiving a webhook when the training updates + * @param {string[]} [options.webhook_events_filter] - You can change which events trigger webhook requests by specifying webhook events (`start`|`output`|`logs`|`completed`) + * @returns {Promise} Resolves with the data for the created training + */ +async function createTraining(model_owner, model_name, version_id, options) { + const { ...data } = options; + + const training = this.request(`/models/${model_owner}/${model_name}/versions/${version_id}/trainings`, { + method: 'POST', + data, + }); + + return training; +} + +/** + * Fetch a training by ID + * + * @param {string} training_id - Required. The training ID + * @returns {Promise} Resolves with the data for the training + */ +async function getTraining(training_id) { + return this.request(`/trainings/${training_id}`, { + method: 'GET', + }); +} + +/** + * Cancel a training by ID + * + * @param {string} training_id - Required. The training ID + * @returns {Promise} Resolves with the data for the training + */ +async function cancelTraining(training_id) { + return this.request(`/trainings/${training_id}/cancel`, { + method: 'POST', + }); +} + +module.exports = { + create: createTraining, + get: getTraining, + cancel: cancelTraining, +};