diff --git a/README.md b/README.md index 0cb7c285..5fe325f9 100644 --- a/README.md +++ b/README.md @@ -14,6 +14,8 @@ Install with `npm install replicate` Set your API token as an environment variable called `REPLICATE_API_TOKEN`. +### Making preedictions + To run a prediction and return its output: ```js @@ -24,11 +26,11 @@ const prediction = await replicate "stability-ai/stable-diffusion:db21e45d3f7023abc2a46ee38a23973f6dce16bb082a930b0c49861f96d1e5bf" ) .predict({ - prompt: "painting of a cat by andy warhol", + prompt: "an astronaut riding on a horse", }); console.log(prediction.output); -// "https://replicate.delivery/pbxt/oeJLu7D1Y7UWESpzerfINqgwZgONSCubSjSw0msf8i4AP2BCB/out-0.png" +// "https://replicate.delivery/pbxt/nSREat5H54rxGJo1kk2xLLG2fpr0NBE0HBD5L0jszLoy8oSIA/out-0.png" ``` If you want to do something like updating progress while the prediction is @@ -43,7 +45,7 @@ await replicate ) .predict( { - prompt: "painting of a cat by andy warhol", + prompt: "an astronaut riding on a horse", }, { onUpdate: (prediction) => { @@ -64,7 +66,7 @@ const prediction = await replicate "stability-ai/stable-diffusion:db21e45d3f7023abc2a46ee38a23973f6dce16bb082a930b0c49861f96d1e5bf" ) .createPrediction({ - prompt: "painting of a cat by andy warhol", + prompt: "an astronaut riding on a horse", }); console.log(prediction.status); // "starting" @@ -73,6 +75,32 @@ console.log(prediction.status); // "starting" From there, you can fetch the current status of the prediction using `await prediction.load()` or `await replicate.prediction(prediction.id).load()`. +#### Webhooks + +You can also provide webhook configuration to have Replicate send POST requests +to your service when certain events occur: + +```js +import replicate from "replicate"; + +await replicate + .model( + "stability-ai/stable-diffusion:db21e45d3f7023abc2a46ee38a23973f6dce16bb082a930b0c49861f96d1e5bf" + ) + .createPrediction( + { + prompt: "an astronaut riding on a horse", + }, + { + // See https://replicate.com/docs/reference/http#create-prediction--webhook + webhook: "https://your.host/webhook", + + // See https://replicate.com/docs/reference/http#create-prediction--webhook_events_filter + webhookEventsFilter: ["output", "completed"], + } + ); +``` + ## Contributing While we'd love to accept contributions to this library, please open an issue diff --git a/lib/Model.js b/lib/Model.js index 3296dfe5..ed27db36 100644 --- a/lib/Model.js +++ b/lib/Model.js @@ -122,7 +122,7 @@ export default class Model extends ReplicateObject { return prediction; } - async createPrediction(input) { + async createPrediction(input, { webhook, webhookEventsFilter } = {}) { // This is here and not on `Prediction` because conceptually, a prediction // from a model "belongs" to the model. It's an odd feature of the API that // the prediction creation isn't an action on the model (or that it doesn't @@ -131,6 +131,8 @@ export default class Model extends ReplicateObject { const predictionData = await this.client.request("POST /v1/predictions", { version: this.version, input, + webhook, + webhook_events_filter: webhookEventsFilter, }); return new Prediction(predictionData, this); diff --git a/lib/Model.test.js b/lib/Model.test.js index aabbe47e..a4a7486e 100644 --- a/lib/Model.test.js +++ b/lib/Model.test.js @@ -276,4 +276,44 @@ describe("createPrediction()", () => { input: { text: "test text" }, }); }); + + it("supports webhook URL", async () => { + jest.spyOn(client, "request").mockResolvedValue({ + id: "testprediction", + status: PredictionStatus.SUCCEEDED, + }); + + await model.createPrediction( + { text: "test text" }, + { webhook: "http://test.host/webhook" } + ); + + expect(client.request).toHaveBeenCalledWith("POST /v1/predictions", { + version: "testversion", + input: { text: "test text" }, + webhook: "http://test.host/webhook", + }); + }); + + it("supports webhook events filter", async () => { + jest.spyOn(client, "request").mockResolvedValue({ + id: "testprediction", + status: PredictionStatus.SUCCEEDED, + }); + + await model.createPrediction( + { text: "test text" }, + { + webhook: "http://test.host/webhook", + webhookEventsFilter: ["output", "completed"], + } + ); + + expect(client.request).toHaveBeenCalledWith("POST /v1/predictions", { + version: "testversion", + input: { text: "test text" }, + webhook: "http://test.host/webhook", + webhook_events_filter: ["output", "completed"], + }); + }); });