diff --git a/README.md b/README.md index 3becbddd..392ceeda 100644 --- a/README.md +++ b/README.md @@ -26,7 +26,9 @@ const prediction = await replicate "stability-ai/stable-diffusion:db21e45d3f7023abc2a46ee38a23973f6dce16bb082a930b0c49861f96d1e5bf" ) .predict({ - prompt: "an astronaut riding on a horse", + input: { + prompt: "an astronaut riding on a horse", + }, }); console.log(prediction.output); @@ -45,7 +47,9 @@ await replicate ) .predict( { - prompt: "an astronaut riding on a horse", + input: { + prompt: "an astronaut riding on a horse", + }, }, { onUpdate: (prediction) => { @@ -66,7 +70,9 @@ const prediction = await replicate "stability-ai/stable-diffusion:db21e45d3f7023abc2a46ee38a23973f6dce16bb082a930b0c49861f96d1e5bf" ) .createPrediction({ - prompt: "an astronaut riding on a horse", + input: { + prompt: "an astronaut riding on a horse", + }, }); console.log(prediction.status); // "starting" @@ -89,7 +95,9 @@ await replicate ) .createPrediction( { - prompt: "an astronaut riding on a horse", + input: { + prompt: "an astronaut riding on a horse", + }, }, { // See https://replicate.com/docs/reference/http#create-prediction--webhook diff --git a/lib/Model.js b/lib/Model.js index ed27db36..1c0bcec5 100644 --- a/lib/Model.js +++ b/lib/Model.js @@ -47,7 +47,7 @@ export default class Model extends ReplicateObject { } async predict( - input, + { input }, { onUpdate = noop, onTemporaryError = noop, @@ -122,12 +122,16 @@ export default class Model extends ReplicateObject { return prediction; } - async createPrediction(input, { webhook, webhookEventsFilter } = {}) { + async createPrediction({ input }, { webhook, webhookEventsFilter } = {}) { // This is here and not on `Prediction` because conceptually, a prediction // from a model "belongs" to the model. It's an odd feature of the API that // the prediction creation isn't an action on the model (or that it doesn't // actually use the model information, only the version), but we don't need // to expose that to users of this library. + if (!input) { + throw new ReplicateError("input is required"); + } + const predictionData = await this.client.request("POST /v1/predictions", { version: this.version, input, diff --git a/lib/Model.test.js b/lib/Model.test.js index 057bd75e..a21c0060 100644 --- a/lib/Model.test.js +++ b/lib/Model.test.js @@ -86,7 +86,7 @@ describe("predict()", () => { ); await model.predict( - { text: "test text" }, + { input: { text: "test text" } }, {}, { defaultPollingInterval: 0 } ); @@ -128,7 +128,7 @@ describe("predict()", () => { .mockImplementation((action) => requestMockReturnValues[action]); await model.predict( - { text: "test text" }, + { input: { text: "test text" } }, {}, { defaultPollingInterval: 0 } ); @@ -182,7 +182,7 @@ describe("predict()", () => { }); const prediction = await model.predict( - { text: "test text" }, + { input: { text: "test text" } }, {}, { defaultPollingInterval: 0 } ); @@ -237,7 +237,7 @@ describe("predict()", () => { const backoffFn = jest.fn(() => 0); const prediction = await model.predict( - { text: "test text" }, + { input: { text: "test text" } }, {}, { defaultPollingInterval: 0, backoffFn } ); @@ -255,7 +255,7 @@ describe("createPrediction()", () => { status: PredictionStatus.SUCCEEDED, }); - await model.createPrediction({ text: "test text" }); + await model.createPrediction({ input: { text: "test text" } }); expect(client.request).toHaveBeenCalledWith("POST /v1/predictions", { version: "testversion", @@ -270,7 +270,7 @@ describe("createPrediction()", () => { }); await model.createPrediction( - { text: "test text" }, + { input: { text: "test text" } }, { webhook: "http://test.host/webhook" } ); @@ -288,7 +288,7 @@ describe("createPrediction()", () => { }); await model.createPrediction( - { text: "test text" }, + { input: { text: "test text" } }, { webhook: "http://test.host/webhook", webhookEventsFilter: ["output", "completed"],