From 8469af776539ee7b3d35da2c251e45abe8217b1c Mon Sep 17 00:00:00 2001 From: Aron Carroll Date: Fri, 5 Jul 2024 16:40:53 +0100 Subject: [PATCH] Re-raise 4xx responses from server when uploading files Fixes #270 We were falling back to base64 encoding the file data when requests to upload the file failed. However if the client makes an invalid request such as failing to include Authorization headers by forgetting to auth then we should surface these errors. This commit now re-raises any 4xx errors returned while attempting to upload a file and adds tests to verify the behavior. --- index.test.ts | 427 +++++++++++++++++++++----------------------------- lib/util.js | 35 ++--- 2 files changed, 189 insertions(+), 273 deletions(-) diff --git a/index.test.ts b/index.test.ts index 2645ca4d..c4d7e067 100644 --- a/index.test.ts +++ b/index.test.ts @@ -1,11 +1,5 @@ import { expect, jest, test } from "@jest/globals"; -import Replicate, { - ApiError, - Model, - Prediction, - validateWebhook, - parseProgressFromLogs, -} from "replicate"; +import Replicate, { ApiError, Model, Prediction, validateWebhook, parseProgressFromLogs } from "replicate"; import nock from "nock"; import { Readable } from "node:stream"; import { createReadableStream } from "./lib/stream"; @@ -42,8 +36,7 @@ const fileTestCases = [ describe("Replicate client", () => { let unmatched: any[] = []; - const handleNoMatch = (req: unknown, options: any, body: string) => - unmatched.push({ req, options, body }); + const handleNoMatch = (req: unknown, options: any, body: string) => unmatched.push({ req, options, body }); beforeEach(() => { client = new Replicate({ auth: "test-token" }); @@ -123,8 +116,7 @@ describe("Replicate client", () => { { name: "Super resolution", slug: "super-resolution", - description: - "Upscaling models that create high-quality images from low-quality images.", + description: "Upscaling models that create high-quality images from low-quality images.", }, { name: "Image classification", @@ -147,8 +139,7 @@ describe("Replicate client", () => { nock(BASE_URL).get("/collections/super-resolution").reply(200, { name: "Super resolution", slug: "super-resolution", - description: - "Upscaling models that create high-quality images from low-quality images.", + description: "Upscaling models that create high-quality images from low-quality images.", models: [], }); @@ -188,9 +179,7 @@ describe("Replicate client", () => { results: [{ url: "https://replicate.com/some-user/model-1" }], next: "https://api.replicate.com/v1/models?cursor=cD0yMDIyLTAxLTIxKzIzJTNBMTglM0EyNC41MzAzNTclMkIwMCUzQTAw", }) - .get( - "/models?cursor=cD0yMDIyLTAxLTIxKzIzJTNBMTglM0EyNC41MzAzNTclMkIwMCUzQTAw" - ) + .get("/models?cursor=cD0yMDIyLTAxLTIxKzIzJTNBMTglM0EyNC41MzAzNTclMkIwMCUzQTAw") .reply(200, { results: [{ url: "https://replicate.com/some-user/model-2" }], next: null, @@ -248,12 +237,10 @@ describe("Replicate client", () => { expectedResponse: { id: "ufawqhfynnddngldkgtslldrkq", model: "replicate/hello-world", - version: - "5c7d5dc6dd8bf75c1acaa8565735e7986bc5b66206b55cca93cb72c9bf15ccaa", + version: "5c7d5dc6dd8bf75c1acaa8565735e7986bc5b66206b55cca93cb72c9bf15ccaa", urls: { get: "https://api.replicate.com/v1/predictions/ufawqhfynnddngldkgtslldrkq", - cancel: - "https://api.replicate.com/v1/predictions/ufawqhfynnddngldkgtslldrkq/cancel", + cancel: "https://api.replicate.com/v1/predictions/ufawqhfynnddngldkgtslldrkq/cancel", }, input: testCase.input, created_at: "2022-04-26T22:13:06.224088Z", @@ -263,79 +250,64 @@ describe("Replicate client", () => { }, })); - test.each(predictionTestCases)( - "$description", - async ({ input, expectedResponse }) => { - nock(BASE_URL) - .post("/predictions", { - version: - "5c7d5dc6dd8bf75c1acaa8565735e7986bc5b66206b55cca93cb72c9bf15ccaa", - input: input as Record, - webhook: "http://test.host/webhook", - webhook_events_filter: ["output", "completed"], - }) - .reply(200, expectedResponse); - - const response = await client.predictions.create({ - version: - "5c7d5dc6dd8bf75c1acaa8565735e7986bc5b66206b55cca93cb72c9bf15ccaa", + test.each(predictionTestCases)("$description", async ({ input, expectedResponse }) => { + nock(BASE_URL) + .post("/predictions", { + version: "5c7d5dc6dd8bf75c1acaa8565735e7986bc5b66206b55cca93cb72c9bf15ccaa", input: input as Record, webhook: "http://test.host/webhook", webhook_events_filter: ["output", "completed"], - }); + }) + .reply(200, expectedResponse); - expect(response.input).toEqual(input); - expect(response.status).toBe(expectedResponse.status); - } - ); + const response = await client.predictions.create({ + version: "5c7d5dc6dd8bf75c1acaa8565735e7986bc5b66206b55cca93cb72c9bf15ccaa", + input: input as Record, + webhook: "http://test.host/webhook", + webhook_events_filter: ["output", "completed"], + }); - test.each(fileTestCases)( - "converts a $type input into a Replicate file URL", - async ({ value: data, type }) => { - const mockedFetch = jest.spyOn(client, "fetch"); + expect(response.input).toEqual(input); + expect(response.status).toBe(expectedResponse.status); + }); - nock(BASE_URL) - .post("/files") - .reply(201, { - urls: { - get: "https://replicate.com/api/files/123", - }, - }) - .post( - "/predictions", - (body) => body.input.data === "https://replicate.com/api/files/123" - ) - .reply(201, (_uri: string, body: Record) => { - return body; - }); + test.each(fileTestCases)("converts a $type input into a Replicate file URL", async ({ value: data, type }) => { + const mockedFetch = jest.spyOn(client, "fetch"); - const prediction = await client.predictions.create({ - version: - "5c7d5dc6dd8bf75c1acaa8565735e7986bc5b66206b55cca93cb72c9bf15ccaa", - input: { - prompt: "Tell me a story", - data, + nock(BASE_URL) + .post("/files") + .reply(201, { + urls: { + get: "https://replicate.com/api/files/123", }, + }) + .post("/predictions", (body) => body.input.data === "https://replicate.com/api/files/123") + .reply(201, (_uri: string, body: Record) => { + return body; }); - expect(client.fetch).toHaveBeenCalledWith( - new URL("https://api.replicate.com/v1/files"), - { - method: "POST", - body: expect.any(FormData), - headers: expect.any(Object), - } - ); - const form = mockedFetch.mock.calls[0][1]?.body as FormData; - // @ts-ignore - expect(form?.get("content")?.name).toMatch(new RegExp(`^${type}_`)); - - expect(prediction.input).toEqual({ + const prediction = await client.predictions.create({ + version: "5c7d5dc6dd8bf75c1acaa8565735e7986bc5b66206b55cca93cb72c9bf15ccaa", + input: { prompt: "Tell me a story", - data: "https://replicate.com/api/files/123", - }); - } - ); + data, + }, + }); + + expect(client.fetch).toHaveBeenCalledWith(new URL("https://api.replicate.com/v1/files"), { + method: "POST", + body: expect.any(FormData), + headers: expect.any(Object), + }); + const form = mockedFetch.mock.calls[0][1]?.body as FormData; + // @ts-ignore + expect(form?.get("content")?.name).toMatch(new RegExp(`^${type}_`)); + + expect(prediction.input).toEqual({ + prompt: "Tell me a story", + data: "https://replicate.com/api/files/123", + }); + }); test.each(fileTestCases)( "converts a $type input into a base64 encoded string", @@ -351,8 +323,7 @@ describe("Replicate client", () => { }); await client.predictions.create({ - version: - "5c7d5dc6dd8bf75c1acaa8565735e7986bc5b66206b55cca93cb72c9bf15ccaa", + version: "5c7d5dc6dd8bf75c1acaa8565735e7986bc5b66206b55cca93cb72c9bf15ccaa", input: { prompt: "Tell me a story", data, @@ -361,7 +332,38 @@ describe("Replicate client", () => { }); expect(actual?.input.data).toEqual(expected); - } + }, + ); + + test.each(fileTestCases)( + "raises an error when the file upload fails with 4xx error for a $type input", + async ({ value: data, expected }) => { + let actual: Record | undefined; + nock(BASE_URL) + .post("/files") + .reply(401, "Unauthorized") + .post("/predictions") + .reply(201, (_uri: string, body: Record) => { + actual = body; + return body; + }); + + await expect(async () => { + await client.predictions.create({ + version: "5c7d5dc6dd8bf75c1acaa8565735e7986bc5b66206b55cca93cb72c9bf15ccaa", + input: { + prompt: "Tell me a story", + data, + }, + stream: true, + }); + }).rejects.toThrowError( + expect.objectContaining({ + name: "ApiError", + message: expect.stringContaining("401"), + }), + ); + }, ); test("Passes stream parameter to API endpoint", async () => { @@ -373,8 +375,7 @@ describe("Replicate client", () => { }); await client.predictions.create({ - version: - "5c7d5dc6dd8bf75c1acaa8565735e7986bc5b66206b55cca93cb72c9bf15ccaa", + version: "5c7d5dc6dd8bf75c1acaa8565735e7986bc5b66206b55cca93cb72c9bf15ccaa", input: { prompt: "Tell me a story", }, @@ -385,8 +386,7 @@ describe("Replicate client", () => { test("Throws an error if webhook URL is invalid", async () => { await expect(async () => { await client.predictions.create({ - version: - "5c7d5dc6dd8bf75c1acaa8565735e7986bc5b66206b55cca93cb72c9bf15ccaa", + version: "5c7d5dc6dd8bf75c1acaa8565735e7986bc5b66206b55cca93cb72c9bf15ccaa", input: { text: "Alice", }, @@ -402,15 +402,14 @@ describe("Replicate client", () => { status: 400, detail: "Invalid input", }, - { "Content-Type": "application/json" } + { "Content-Type": "application/json" }, ); try { expect.hasAssertions(); await client.predictions.create({ - version: - "5c7d5dc6dd8bf75c1acaa8565735e7986bc5b66206b55cca93cb72c9bf15ccaa", + version: "5c7d5dc6dd8bf75c1acaa8565735e7986bc5b66206b55cca93cb72c9bf15ccaa", input: { text: null, }, @@ -429,15 +428,14 @@ describe("Replicate client", () => { { detail: "Too many requests", }, - { "Content-Type": "application/json", "Retry-After": "1" } + { "Content-Type": "application/json", "Retry-After": "1" }, ) .post("/predictions") .reply(201, { id: "ufawqhfynnddngldkgtslldrkq", }); const prediction = await client.predictions.create({ - version: - "5c7d5dc6dd8bf75c1acaa8565735e7986bc5b66206b55cca93cb72c9bf15ccaa", + version: "5c7d5dc6dd8bf75c1acaa8565735e7986bc5b66206b55cca93cb72c9bf15ccaa", input: { text: "Alice", }, @@ -451,19 +449,18 @@ describe("Replicate client", () => { { detail: "Internal server error", }, - { "Content-Type": "application/json" } + { "Content-Type": "application/json" }, ); await expect( client.predictions.create({ - version: - "5c7d5dc6dd8bf75c1acaa8565735e7986bc5b66206b55cca93cb72c9bf15ccaa", + version: "5c7d5dc6dd8bf75c1acaa8565735e7986bc5b66206b55cca93cb72c9bf15ccaa", input: { text: "Alice", }, - }) + }), ).rejects.toThrow( - `Request to https://api.replicate.com/v1/predictions failed with status 500 Internal Server Error: {"detail":"Internal server error"}.` + `Request to https://api.replicate.com/v1/predictions failed with status 500 Internal Server Error: {"detail":"Internal server error"}.`, ); }); }); @@ -475,12 +472,10 @@ describe("Replicate client", () => { .reply(200, { id: "rrr4z55ocneqzikepnug6xezpe", model: "stability-ai/stable-diffusion", - version: - "be04660a5b93ef2aff61e3668dedb4cbeb14941e62a3fd5998364a32d613e35e", + version: "be04660a5b93ef2aff61e3668dedb4cbeb14941e62a3fd5998364a32d613e35e", urls: { get: "https://api.replicate.com/v1/predictions/rrr4z55ocneqzikepnug6xezpe", - cancel: - "https://api.replicate.com/v1/predictions/rrr4z55ocneqzikepnug6xezpe/cancel", + cancel: "https://api.replicate.com/v1/predictions/rrr4z55ocneqzikepnug6xezpe/cancel", }, created_at: "2022-09-13T22:54:18.578761Z", started_at: "2022-09-13T22:54:19.438525Z", @@ -499,9 +494,7 @@ describe("Replicate client", () => { predict_time: 4.484541, }, }); - const prediction = await client.predictions.get( - "rrr4z55ocneqzikepnug6xezpe" - ); + const prediction = await client.predictions.get("rrr4z55ocneqzikepnug6xezpe"); expect(prediction.id).toBe("rrr4z55ocneqzikepnug6xezpe"); }); @@ -513,16 +506,14 @@ describe("Replicate client", () => { { detail: "Too many requests", }, - { "Content-Type": "application/json", "Retry-After": "1" } + { "Content-Type": "application/json", "Retry-After": "1" }, ) .get("/predictions/rrr4z55ocneqzikepnug6xezpe") .reply(200, { id: "rrr4z55ocneqzikepnug6xezpe", }); - const prediction = await client.predictions.get( - "rrr4z55ocneqzikepnug6xezpe" - ); + const prediction = await client.predictions.get("rrr4z55ocneqzikepnug6xezpe"); expect(prediction.id).toBe("rrr4z55ocneqzikepnug6xezpe"); }); @@ -534,16 +525,14 @@ describe("Replicate client", () => { { detail: "Internal server error", }, - { "Content-Type": "application/json" } + { "Content-Type": "application/json" }, ) .get("/predictions/rrr4z55ocneqzikepnug6xezpe") .reply(200, { id: "rrr4z55ocneqzikepnug6xezpe", }); - const prediction = await client.predictions.get( - "rrr4z55ocneqzikepnug6xezpe" - ); + const prediction = await client.predictions.get("rrr4z55ocneqzikepnug6xezpe"); expect(prediction.id).toBe("rrr4z55ocneqzikepnug6xezpe"); }); }); @@ -555,12 +544,10 @@ describe("Replicate client", () => { .reply(200, { id: "ufawqhfynnddngldkgtslldrkq", model: "replicate/hello-world", - version: - "5c7d5dc6dd8bf75c1acaa8565735e7986bc5b66206b55cca93cb72c9bf15ccaa", + version: "5c7d5dc6dd8bf75c1acaa8565735e7986bc5b66206b55cca93cb72c9bf15ccaa", urls: { get: "https://api.replicate.com/v1/predictions/ufawqhfynnddngldkgtslldrkq", - cancel: - "https://api.replicate.com/v1/predictions/ufawqhfynnddngldkgtslldrkq/cancel", + cancel: "https://api.replicate.com/v1/predictions/ufawqhfynnddngldkgtslldrkq/cancel", }, created_at: "2022-04-26T22:13:06.224088Z", started_at: "2022-04-26T22:13:06.224088Z", @@ -575,9 +562,7 @@ describe("Replicate client", () => { metrics: {}, }); - const prediction = await client.predictions.cancel( - "ufawqhfynnddngldkgtslldrkq" - ); + const prediction = await client.predictions.cancel("ufawqhfynnddngldkgtslldrkq"); expect(prediction.status).toBe("canceled"); }); @@ -595,12 +580,10 @@ describe("Replicate client", () => { { id: "jpzd7hm5gfcapbfyt4mqytarku", model: "stability-ai/stable-diffusion", - version: - "b21cbe271e65c1718f2999b038c18b45e21e4fba961181fbfae9342fc53b9e05", + version: "b21cbe271e65c1718f2999b038c18b45e21e4fba961181fbfae9342fc53b9e05", urls: { get: "https://api.replicate.com/v1/predictions/jpzd7hm5gfcapbfyt4mqytarku", - cancel: - "https://api.replicate.com/v1/predictions/jpzd7hm5gfcapbfyt4mqytarku/cancel", + cancel: "https://api.replicate.com/v1/predictions/jpzd7hm5gfcapbfyt4mqytarku/cancel", }, created_at: "2022-04-26T20:00:40.658234Z", started_at: "2022-04-26T20:00:84.583803Z", @@ -623,9 +606,7 @@ describe("Replicate client", () => { results: [{ id: "ufawqhfynnddngldkgtslldrkq" }], next: "https://api.replicate.com/v1/predictions?cursor=cD0yMDIyLTAxLTIxKzIzJTNBMTglM0EyNC41MzAzNTclMkIwMCUzQTAw", }) - .get( - "/predictions?cursor=cD0yMDIyLTAxLTIxKzIzJTNBMTglM0EyNC41MzAzNTclMkIwMCUzQTAw" - ) + .get("/predictions?cursor=cD0yMDIyLTAxLTIxKzIzJTNBMTglM0EyNC41MzAzNTclMkIwMCUzQTAw") .reply(200, { results: [{ id: "rrr4z55ocneqzikepnug6xezpe" }], next: null, @@ -635,10 +616,7 @@ describe("Replicate client", () => { for await (const batch of client.paginate(client.predictions.list)) { results.push(...batch); } - expect(results).toEqual([ - { id: "ufawqhfynnddngldkgtslldrkq" }, - { id: "rrr4z55ocneqzikepnug6xezpe" }, - ]); + expect(results).toEqual([{ id: "ufawqhfynnddngldkgtslldrkq" }, { id: "rrr4z55ocneqzikepnug6xezpe" }]); // Add more tests for error handling, edge cases, etc. }); @@ -647,13 +625,10 @@ describe("Replicate client", () => { describe("trainings.create", () => { test("Calls the correct API route with the correct payload", async () => { nock(BASE_URL) - .post( - "/models/owner/model/versions/632231d0d49d34d5c4633bd838aee3d81d936e59a886fbf28524702003b4c532/trainings" - ) + .post("/models/owner/model/versions/632231d0d49d34d5c4633bd838aee3d81d936e59a886fbf28524702003b4c532/trainings") .reply(200, { id: "zz4ibbonubfz7carwiefibzgga", - version: - "632231d0d49d34d5c4633bd838aee3d81d936e59a886fbf28524702003b4c532", + version: "632231d0d49d34d5c4633bd838aee3d81d936e59a886fbf28524702003b4c532", status: "starting", input: { text: "...", @@ -675,25 +650,20 @@ describe("Replicate client", () => { input: { text: "...", }, - } + }, ); expect(training.id).toBe("zz4ibbonubfz7carwiefibzgga"); }); test("Throws an error if webhook is not a valid URL", async () => { await expect( - client.trainings.create( - "owner", - "model", - "632231d0d49d34d5c4633bd838aee3d81d936e59a886fbf28524702003b4c532", - { - destination: "new_owner/new_model", - input: { - text: "...", - }, - webhook: "invalid-url", - } - ) + client.trainings.create("owner", "model", "632231d0d49d34d5c4633bd838aee3d81d936e59a886fbf28524702003b4c532", { + destination: "new_owner/new_model", + input: { + text: "...", + }, + webhook: "invalid-url", + }), ).rejects.toThrow("Invalid webhook URL"); }); @@ -753,9 +723,7 @@ describe("Replicate client", () => { completed_at: null, }); - const training = await client.trainings.cancel( - "zz4ibbonubfz7carwiefibzgga" - ); + const training = await client.trainings.cancel("zz4ibbonubfz7carwiefibzgga"); expect(training.status).toBe("canceled"); }); @@ -773,12 +741,10 @@ describe("Replicate client", () => { { id: "jpzd7hm5gfcapbfyt4mqytarku", model: "stability-ai/sdxl", - version: - "b21cbe271e65c1718f2999b038c18b45e21e4fba961181fbfae9342fc53b9e05", + version: "b21cbe271e65c1718f2999b038c18b45e21e4fba961181fbfae9342fc53b9e05", urls: { get: "https://api.replicate.com/v1/trainings/jpzd7hm5gfcapbfyt4mqytarku", - cancel: - "https://api.replicate.com/v1/trainings/jpzd7hm5gfcapbfyt4mqytarku/cancel", + cancel: "https://api.replicate.com/v1/trainings/jpzd7hm5gfcapbfyt4mqytarku/cancel", }, created_at: "2022-04-26T20:00:40.658234Z", started_at: "2022-04-26T20:00:84.583803Z", @@ -801,9 +767,7 @@ describe("Replicate client", () => { results: [{ id: "ufawqhfynnddngldkgtslldrkq" }], next: "https://api.replicate.com/v1/trainings?cursor=cD0yMDIyLTAxLTIxKzIzJTNBMTglM0EyNC41MzAzNTclMkIwMCUzQTAw", }) - .get( - "/trainings?cursor=cD0yMDIyLTAxLTIxKzIzJTNBMTglM0EyNC41MzAzNTclMkIwMCUzQTAw" - ) + .get("/trainings?cursor=cD0yMDIyLTAxLTIxKzIzJTNBMTglM0EyNC41MzAzNTclMkIwMCUzQTAw") .reply(200, { results: [{ id: "rrr4z55ocneqzikepnug6xezpe" }], next: null, @@ -813,10 +777,7 @@ describe("Replicate client", () => { for await (const batch of client.paginate(client.trainings.list)) { results.push(...batch); } - expect(results).toEqual([ - { id: "ufawqhfynnddngldkgtslldrkq" }, - { id: "rrr4z55ocneqzikepnug6xezpe" }, - ]); + expect(results).toEqual([{ id: "ufawqhfynnddngldkgtslldrkq" }, { id: "rrr4z55ocneqzikepnug6xezpe" }]); // Add more tests for error handling, edge cases, etc. }); @@ -829,12 +790,10 @@ describe("Replicate client", () => { .reply(200, { id: "mfrgcyzzme2wkmbwgzrgmntcg", model: "replicate/hello-world", - version: - "5c7d5dc6dd8bf75c1acaa8565735e7986bc5b66206b55cca93cb72c9bf15ccaa", + version: "5c7d5dc6dd8bf75c1acaa8565735e7986bc5b66206b55cca93cb72c9bf15ccaa", urls: { get: "https://api.replicate.com/v1/predictions/ufawqhfynnddngldkgtslldrkq", - cancel: - "https://api.replicate.com/v1/predictions/ufawqhfynnddngldkgtslldrkq/cancel", + cancel: "https://api.replicate.com/v1/predictions/ufawqhfynnddngldkgtslldrkq/cancel", }, created_at: "2022-09-10T09:44:22.165836Z", started_at: null, @@ -848,17 +807,13 @@ describe("Replicate client", () => { logs: null, metrics: {}, }); - const prediction = await client.deployments.predictions.create( - "replicate", - "greeter", - { - input: { - text: "Alice", - }, - webhook: "http://test.host/webhook", - webhook_events_filter: ["output", "completed"], - } - ); + const prediction = await client.deployments.predictions.create("replicate", "greeter", { + input: { + text: "Alice", + }, + webhook: "http://test.host/webhook", + webhook_events_filter: ["output", "completed"], + }); expect(prediction.id).toBe("mfrgcyzzme2wkmbwgzrgmntcg"); }); // Add more tests for error handling, edge cases, etc. @@ -874,8 +829,7 @@ describe("Replicate client", () => { current_release: { number: 1, model: "stability-ai/sdxl", - version: - "da77bc59ee60423279fd632efb4795ab731d9e3ca9705ef3341091fb989b7eaf", + version: "da77bc59ee60423279fd632efb4795ab731d9e3ca9705ef3341091fb989b7eaf", created_at: "2024-02-15T16:32:57.018467Z", created_by: { type: "organization", @@ -891,10 +845,7 @@ describe("Replicate client", () => { }, }); - const deployment = await client.deployments.get( - "acme", - "my-app-image-generator" - ); + const deployment = await client.deployments.get("acme", "my-app-image-generator"); expect(deployment.owner).toBe("acme"); expect(deployment.name).toBe("my-app-image-generator"); @@ -913,8 +864,7 @@ describe("Replicate client", () => { current_release: { number: 1, model: "stability-ai/sdxl", - version: - "da77bc59ee60423279fd632efb4795ab731d9e3ca9705ef3341091fb989b7eaf", + version: "da77bc59ee60423279fd632efb4795ab731d9e3ca9705ef3341091fb989b7eaf", created_at: "2024-02-15T16:32:57.018467Z", created_by: { type: "organization", @@ -933,8 +883,7 @@ describe("Replicate client", () => { const deployment = await client.deployments.create({ name: "my-app-image-generator", model: "stability-ai/sdxl", - version: - "da77bc59ee60423279fd632efb4795ab731d9e3ca9705ef3341091fb989b7eaf", + version: "da77bc59ee60423279fd632efb4795ab731d9e3ca9705ef3341091fb989b7eaf", hardware: "gpu-t4", min_instances: 1, max_instances: 5, @@ -957,8 +906,7 @@ describe("Replicate client", () => { current_release: { number: 2, model: "stability-ai/sdxl", - version: - "632231d0d49d34d5c4633bd838aee3d81d936e59a886fbf28524702003b4c532", + version: "632231d0d49d34d5c4633bd838aee3d81d936e59a886fbf28524702003b4c532", created_at: "2024-02-16T08:14:22.345678Z", created_by: { type: "organization", @@ -974,25 +922,18 @@ describe("Replicate client", () => { }, }); - const deployment = await client.deployments.update( - "acme", - "my-app-image-generator", - { - version: - "632231d0d49d34d5c4633bd838aee3d81d936e59a886fbf28524702003b4c532", - hardware: "gpu-a40-large", - min_instances: 3, - max_instances: 10, - } - ); + const deployment = await client.deployments.update("acme", "my-app-image-generator", { + version: "632231d0d49d34d5c4633bd838aee3d81d936e59a886fbf28524702003b4c532", + hardware: "gpu-a40-large", + min_instances: 3, + max_instances: 10, + }); expect(deployment.current_release.number).toBe(2); expect(deployment.current_release.version).toBe( - "632231d0d49d34d5c4633bd838aee3d81d936e59a886fbf28524702003b4c532" - ); - expect(deployment.current_release.configuration.hardware).toBe( - "gpu-a40-large" + "632231d0d49d34d5c4633bd838aee3d81d936e59a886fbf28524702003b4c532", ); + expect(deployment.current_release.configuration.hardware).toBe("gpu-a40-large"); expect(deployment.current_release.configuration.min_instances).toBe(3); expect(deployment.current_release.configuration.max_instances).toBe(10); }); @@ -1001,14 +942,9 @@ describe("Replicate client", () => { describe("deployments.delete", () => { test("Calls the correct API route with the correct payload", async () => { - nock(BASE_URL) - .delete("/deployments/acme/my-app-image-generator") - .reply(204); + nock(BASE_URL).delete("/deployments/acme/my-app-image-generator").reply(204); - const success = await client.deployments.delete( - "acme", - "my-app-image-generator" - ); + const success = await client.deployments.delete("acme", "my-app-image-generator"); expect(success).toBe(true); }); }); @@ -1054,8 +990,7 @@ describe("Replicate client", () => { status: "starting", created_at: "2023-11-27T13:35:45.99397566Z", urls: { - cancel: - "https://api.replicate.com/v1/predictions/heat2o3bzn3ahtr6bjfftvbaci/cancel", + cancel: "https://api.replicate.com/v1/predictions/heat2o3bzn3ahtr6bjfftvbaci/cancel", get: "https://api.replicate.com/v1/predictions/heat2o3bzn3ahtr6bjfftvbaci", }, }); @@ -1265,7 +1200,7 @@ describe("Replicate client", () => { (prediction) => { const progress = parseProgressFromLogs(prediction); callback(prediction, progress); - } + }, ); expect(output).toBe("Goodbye!"); @@ -1277,7 +1212,7 @@ describe("Replicate client", () => { status: "starting", logs: null, }, - null + null, ); expect(callback).toHaveBeenNthCalledWith( @@ -1291,7 +1226,7 @@ describe("Replicate client", () => { percentage: 0.4, current: 2, total: 5, - } + }, ); expect(callback).toHaveBeenNthCalledWith( @@ -1305,7 +1240,7 @@ describe("Replicate client", () => { percentage: 0.8, current: 4, total: 5, - } + }, ); expect(callback).toHaveBeenNthCalledWith( @@ -1320,7 +1255,7 @@ describe("Replicate client", () => { percentage: 1.0, current: 5, total: 5, - } + }, ); expect(callback).toHaveBeenCalledTimes(4); @@ -1354,7 +1289,7 @@ describe("Replicate client", () => { input: { text: "Hello, world!" }, wait: { interval: 1 }, }, - progress + progress, ); expect(output).toBe("Goodbye!"); @@ -1397,9 +1332,7 @@ describe("Replicate client", () => { output: "foobar", }); - await expect( - client.run("a/b-1.0:abc123", { input: { text: "Hello, world!" } }) - ).resolves.not.toThrow(); + await expect(client.run("a/b-1.0:abc123", { input: { text: "Hello, world!" } })).resolves.not.toThrow(); }); test("Throws an error for invalid identifiers", async () => { @@ -1416,15 +1349,12 @@ describe("Replicate client", () => { test("Throws an error if webhook URL is invalid", async () => { await expect(async () => { - await client.run( - "owner/model:5c7d5dc6dd8bf75c1acaa8565735e7986bc5b66206b55cca93cb72c9bf15ccaa", - { - input: { - text: "Alice", - }, - webhook: "invalid-url", - } - ); + await client.run("owner/model:5c7d5dc6dd8bf75c1acaa8565735e7986bc5b66206b55cca93cb72c9bf15ccaa", { + input: { + text: "Alice", + }, + webhook: "invalid-url", + }); }).rejects.toThrow("Invalid webhook URL"); }); @@ -1463,7 +1393,7 @@ describe("Replicate client", () => { input: { text: "Hello, world!" }, signal, }, - onProgress + onProgress, ); expect(body).toBeDefined(); @@ -1475,19 +1405,19 @@ describe("Replicate client", () => { 1, expect.objectContaining({ status: "processing", - }) + }), ); expect(onProgress).toHaveBeenNthCalledWith( 2, expect.objectContaining({ status: "processing", - }) + }), ); expect(onProgress).toHaveBeenNthCalledWith( 3, expect.objectContaining({ status: "canceled", - }) + }), ); scope.done(); @@ -1512,8 +1442,7 @@ describe("Replicate client", () => { "Content-Type": "application/json", "Webhook-ID": "msg_p5jXN8AQM9LWM0D4loKWxJek", "Webhook-Timestamp": "1614265330", - "Webhook-Signature": - "v1,g0hM9SsE+OTPJTGt/tmIKtSyZlE3uFJELVlNIOLJ1OE=", + "Webhook-Signature": "v1,g0hM9SsE+OTPJTGt/tmIKtSyZlE3uFJELVlNIOLJ1OE=", }, body: `{"test": 2432232314}`, }); @@ -1556,7 +1485,7 @@ describe("Replicate client", () => { id: EVENT_2 data: {} - `.replace(/^[ ]+/gm, "") + `.replace(/^[ ]+/gm, ""), ); const iterator = stream[Symbol.asyncIterator](); @@ -1587,7 +1516,7 @@ describe("Replicate client", () => { id: EVENT_3 data: {} - `.replace(/^[ ]+/gm, "") + `.replace(/^[ ]+/gm, ""), ); const iterator = stream[Symbol.asyncIterator](); @@ -1621,7 +1550,7 @@ describe("Replicate client", () => { id: EVENT_2 data: {} - `.replace(/^[ ]+/gm, "") + `.replace(/^[ ]+/gm, ""), ); const iterator = stream[Symbol.asyncIterator](); @@ -1653,7 +1582,7 @@ describe("Replicate client", () => { id: EVENT_2 data: {} - `.replace(/^[ ]+/gm, "") + `.replace(/^[ ]+/gm, ""), ); const iterator = stream[Symbol.asyncIterator](); @@ -1774,7 +1703,7 @@ describe("Replicate client", () => { id: EVENT_1 data: hello world - `.replace(/^[ ]+/gm, "") + `.replace(/^[ ]+/gm, ""), ); const iterator = stream[Symbol.asyncIterator](); @@ -1796,7 +1725,7 @@ describe("Replicate client", () => { id: EVENT_2 data: An unexpected error occurred - `.replace(/^[ ]+/gm, "") + `.replace(/^[ ]+/gm, ""), ); const iterator = stream[Symbol.asyncIterator](); @@ -1804,9 +1733,7 @@ describe("Replicate client", () => { done: false, value: { event: "output", id: "EVENT_1", data: "hello world" }, }); - await expect(iterator.next()).rejects.toThrowError( - "An unexpected error occurred" - ); + await expect(iterator.next()).rejects.toThrowError("An unexpected error occurred"); expect(await iterator.next()).toEqual({ done: true }); }); @@ -1814,7 +1741,7 @@ describe("Replicate client", () => { const stream = createStream("{}", 500); const iterator = stream[Symbol.asyncIterator](); await expect(iterator.next()).rejects.toThrowError( - "Request to https://stream.replicate.com/fake_stream failed with status 500" + "Request to https://stream.replicate.com/fake_stream failed with status 500", ); expect(await iterator.next()).toEqual({ done: true }); }); diff --git a/lib/util.js b/lib/util.js index 3745d9f0..b4483aed 100644 --- a/lib/util.js +++ b/lib/util.js @@ -67,18 +67,11 @@ async function validateWebhook(requestData, secret) { const signedContent = `${id}.${timestamp}.${body}`; - const computedSignature = await createHMACSHA256( - signingSecret.split("_").pop(), - signedContent - ); + const computedSignature = await createHMACSHA256(signingSecret.split("_").pop(), signedContent); - const expectedSignatures = signature - .split(" ") - .map((sig) => sig.split(",")[1]); + const expectedSignatures = signature.split(" ").map((sig) => sig.split(",")[1]); - return expectedSignatures.some( - (expectedSignature) => expectedSignature === computedSignature - ); + return expectedSignatures.some((expectedSignature) => expectedSignature === computedSignature); } /** @@ -105,13 +98,9 @@ async function createHMACSHA256(secret, data) { crypto = require.call(null, "node:crypto").webcrypto; } - const key = await crypto.subtle.importKey( - "raw", - base64ToBytes(secret), - { name: "HMAC", hash: "SHA-256" }, - false, - ["sign"] - ); + const key = await crypto.subtle.importKey("raw", base64ToBytes(secret), { name: "HMAC", hash: "SHA-256" }, false, [ + "sign", + ]); const signature = await crypto.subtle.sign("HMAC", key, encoder.encode(data)); return bytesToBase64(signature); @@ -235,6 +224,9 @@ async function transformFileInputs(client, inputs, strategy) { try { return await transformFileInputsToReplicateFileURLs(client, inputs); } catch (error) { + if (error instanceof ApiError && error.response.status >= 400 && error.response.status < 500) { + throw error; + } return await transformFileInputsToBase64EncodedDataURIs(inputs); } default: @@ -296,7 +288,7 @@ async function transformFileInputsToBase64EncodedDataURIs(inputs) { totalBytes += buffer.byteLength; if (totalBytes > MAX_DATA_URI_SIZE) { throw new Error( - `Combined filesize of prediction ${totalBytes} bytes exceeds 10mb limit for inline encoding, please provide URLs instead` + `Combined filesize of prediction ${totalBytes} bytes exceeds 10mb limit for inline encoding, please provide URLs instead`, ); } @@ -354,14 +346,11 @@ function isPlainObject(value) { if (proto === null) { return true; } - const Ctor = - Object.prototype.hasOwnProperty.call(proto, "constructor") && - proto.constructor; + const Ctor = Object.prototype.hasOwnProperty.call(proto, "constructor") && proto.constructor; return ( typeof Ctor === "function" && Ctor instanceof Ctor && - Function.prototype.toString.call(Ctor) === - Function.prototype.toString.call(Object) + Function.prototype.toString.call(Ctor) === Function.prototype.toString.call(Object) ); }