From 10146e2b2635295b7661308fc6f0cdfe2bed48b2 Mon Sep 17 00:00:00 2001 From: Mattt Zmuda Date: Wed, 29 May 2024 10:37:30 -0700 Subject: [PATCH 1/2] Fix regression in how array input values are transformed --- index.test.ts | 44 ++++++++++++++++++++++++++++++++++++++++++++ lib/util.js | 5 +++-- 2 files changed, 47 insertions(+), 2 deletions(-) diff --git a/index.test.ts b/index.test.ts index 7502969d..7e5eefa9 100644 --- a/index.test.ts +++ b/index.test.ts @@ -114,6 +114,50 @@ describe("Replicate client", () => { const collections = await client.collections.list(); expect(collections.results.length).toBe(2); }); + + describe("predictions.create", () => { + test("Handles array input correctly", async () => { + const inputArray = ["Alice", "Bob", "Charlie"]; + + nock(BASE_URL) + .post("/predictions", { + version: + "5c7d5dc6dd8bf75c1acaa8565735e7986bc5b66206b55cca93cb72c9bf15ccaa", + input: { + text: inputArray, + }, + }) + .reply(200, { + id: "ufawqhfynnddngldkgtslldrkq", + model: "replicate/hello-world", + version: + "5c7d5dc6dd8bf75c1acaa8565735e7986bc5b66206b55cca93cb72c9bf15ccaa", + urls: { + get: "https://api.replicate.com/v1/predictions/ufawqhfynnddngldkgtslldrkq", + cancel: + "https://api.replicate.com/v1/predictions/ufawqhfynnddngldkgtslldrkq/cancel", + }, + created_at: "2022-04-26T22:13:06.224088Z", + started_at: null, + completed_at: null, + status: "starting", + input: { + text: inputArray, + }, + }); + + const response = await client.predictions.create({ + version: + "5c7d5dc6dd8bf75c1acaa8565735e7986bc5b66206b55cca93cb72c9bf15ccaa", + input: { + text: inputArray, + }, + }); + + expect(response.input).toEqual({ text: inputArray }); + expect(response.status).toBe("starting"); + }); + }); // Add more tests for error handling, edge cases, etc. }); diff --git a/lib/util.js b/lib/util.js index e164899e..3745d9f0 100644 --- a/lib/util.js +++ b/lib/util.js @@ -310,9 +310,10 @@ async function transformFileInputsToBase64EncodedDataURIs(inputs) { // Walk a JavaScript object and transform the leaf values. async function transform(value, mapper) { if (Array.isArray(value)) { - let copy = []; + const copy = []; for (const val of value) { - copy = await transform(val, mapper); + const transformed = await transform(val, mapper); + copy.push(transformed); } return copy; } From ab69c56ae0d1d9233cce9d0bf36edb0050876632 Mon Sep 17 00:00:00 2001 From: Mattt Zmuda Date: Wed, 29 May 2024 10:53:27 -0700 Subject: [PATCH 2/2] Refactor tests --- index.test.ts | 150 +++++++++++++++++++++++++------------------------- 1 file changed, 74 insertions(+), 76 deletions(-) diff --git a/index.test.ts b/index.test.ts index 7e5eefa9..834b786f 100644 --- a/index.test.ts +++ b/index.test.ts @@ -114,50 +114,6 @@ describe("Replicate client", () => { const collections = await client.collections.list(); expect(collections.results.length).toBe(2); }); - - describe("predictions.create", () => { - test("Handles array input correctly", async () => { - const inputArray = ["Alice", "Bob", "Charlie"]; - - nock(BASE_URL) - .post("/predictions", { - version: - "5c7d5dc6dd8bf75c1acaa8565735e7986bc5b66206b55cca93cb72c9bf15ccaa", - input: { - text: inputArray, - }, - }) - .reply(200, { - id: "ufawqhfynnddngldkgtslldrkq", - model: "replicate/hello-world", - version: - "5c7d5dc6dd8bf75c1acaa8565735e7986bc5b66206b55cca93cb72c9bf15ccaa", - urls: { - get: "https://api.replicate.com/v1/predictions/ufawqhfynnddngldkgtslldrkq", - cancel: - "https://api.replicate.com/v1/predictions/ufawqhfynnddngldkgtslldrkq/cancel", - }, - created_at: "2022-04-26T22:13:06.224088Z", - started_at: null, - completed_at: null, - status: "starting", - input: { - text: inputArray, - }, - }); - - const response = await client.predictions.create({ - version: - "5c7d5dc6dd8bf75c1acaa8565735e7986bc5b66206b55cca93cb72c9bf15ccaa", - input: { - text: inputArray, - }, - }); - - expect(response.input).toEqual({ text: inputArray }); - expect(response.status).toBe("starting"); - }); - }); // Add more tests for error handling, edge cases, etc. }); @@ -229,42 +185,84 @@ describe("Replicate client", () => { }); describe("predictions.create", () => { - test("Calls the correct API route with the correct payload", async () => { - nock(BASE_URL) - .post("/predictions") - .reply(200, { - id: "ufawqhfynnddngldkgtslldrkq", - model: "replicate/hello-world", - version: - "5c7d5dc6dd8bf75c1acaa8565735e7986bc5b66206b55cca93cb72c9bf15ccaa", - urls: { - get: "https://api.replicate.com/v1/predictions/ufawqhfynnddngldkgtslldrkq", - cancel: - "https://api.replicate.com/v1/predictions/ufawqhfynnddngldkgtslldrkq/cancel", - }, - created_at: "2022-04-26T22:13:06.224088Z", - started_at: null, - completed_at: null, - status: "starting", - input: { - text: "Alice", + const predictionTestCases = [ + { + description: "String input", + input: { + text: "Alice", + }, + }, + { + description: "Number input", + input: { + text: 123, + }, + }, + { + description: "Boolean input", + input: { + text: true, + }, + }, + { + description: "Array input", + input: { + text: ["Alice", "Bob", "Charlie"], + }, + }, + { + description: "Object input", + input: { + text: { + name: "Alice", }, - output: null, - error: null, - logs: null, - metrics: {}, - }); - const prediction = await client.predictions.create({ + }, + }, + ].map((testCase) => ({ + ...testCase, + expectedResponse: { + id: "ufawqhfynnddngldkgtslldrkq", + model: "replicate/hello-world", version: "5c7d5dc6dd8bf75c1acaa8565735e7986bc5b66206b55cca93cb72c9bf15ccaa", - input: { - text: "Alice", + urls: { + get: "https://api.replicate.com/v1/predictions/ufawqhfynnddngldkgtslldrkq", + cancel: + "https://api.replicate.com/v1/predictions/ufawqhfynnddngldkgtslldrkq/cancel", }, - webhook: "http://test.host/webhook", - webhook_events_filter: ["output", "completed"], - }); - expect(prediction.id).toBe("ufawqhfynnddngldkgtslldrkq"); - }); + input: testCase.input, + created_at: "2022-04-26T22:13:06.224088Z", + started_at: null, + completed_at: null, + status: "starting", + }, + })); + + test.each(predictionTestCases)( + "$description", + async ({ input, expectedResponse }) => { + nock(BASE_URL) + .post("/predictions", { + version: + "5c7d5dc6dd8bf75c1acaa8565735e7986bc5b66206b55cca93cb72c9bf15ccaa", + input: input as Record, + webhook: "http://test.host/webhook", + webhook_events_filter: ["output", "completed"], + }) + .reply(200, expectedResponse); + + const response = await client.predictions.create({ + version: + "5c7d5dc6dd8bf75c1acaa8565735e7986bc5b66206b55cca93cb72c9bf15ccaa", + input: input as Record, + webhook: "http://test.host/webhook", + webhook_events_filter: ["output", "completed"], + }); + + expect(response.input).toEqual(input); + expect(response.status).toBe(expectedResponse.status); + } + ); const fileTestCases = [ // Skip test case if File type is not available