From 7f72b2b6ff5958a4586d2939f642c7e619d94f65 Mon Sep 17 00:00:00 2001 From: Mattt Zmuda Date: Mon, 19 Feb 2024 06:06:42 -0800 Subject: [PATCH 1/6] Add parsePredictionProgress helper function --- index.d.ts | 6 +++ index.js | 7 +++- index.test.ts | 110 +++++++++++++++++++++++++++++++++++++++----------- lib/util.js | 36 ++++++++++++++++- 4 files changed, 133 insertions(+), 26 deletions(-) diff --git a/index.d.ts b/index.d.ts index 69e651b1..0ce87f3f 100644 --- a/index.d.ts +++ b/index.d.ts @@ -280,4 +280,10 @@ declare module "replicate" { }, secret: string ): boolean; + + export function parsePredictionProgress(logs: Prediction | string): { + percentage: number; + current: number; + total: number; + }; } diff --git a/index.js b/index.js index 83b98887..b03af1af 100644 --- a/index.js +++ b/index.js @@ -1,7 +1,11 @@ const ApiError = require("./lib/error"); const ModelVersionIdentifier = require("./lib/identifier"); const { Stream } = require("./lib/stream"); -const { withAutomaticRetries, validateWebhook } = require("./lib/util"); +const { + withAutomaticRetries, + validateWebhook, + parsePredictionProgress, +} = require("./lib/util"); const accounts = require("./lib/accounts"); const collections = require("./lib/collections"); @@ -375,3 +379,4 @@ class Replicate { module.exports = Replicate; module.exports.validateWebhook = validateWebhook; +module.exports.parsePredictionProgress = parsePredictionProgress; diff --git a/index.test.ts b/index.test.ts index f00a7e68..9d366fac 100644 --- a/index.test.ts +++ b/index.test.ts @@ -4,6 +4,7 @@ import Replicate, { Model, Prediction, validateWebhook, + parsePredictionProgress, } from "replicate"; import nock from "nock"; import fetch from "cross-fetch"; @@ -888,29 +889,55 @@ describe("Replicate client", () => { }); describe("run", () => { - test("Calls the correct API routes for a version", async () => { - const firstPollingRequest = true; - + test("Calls the correct API routes", async () => { nock(BASE_URL) .post("/predictions") .reply(201, { id: "ufawqhfynnddngldkgtslldrkq", status: "starting", + logs: null, }) .get("/predictions/ufawqhfynnddngldkgtslldrkq") - .twice() .reply(200, { id: "ufawqhfynnddngldkgtslldrkq", status: "processing", + logs: [ + "Using seed: 12345", + "0%| | 0/5 [00:00 { input: { text: "Hello, world!" }, wait: { interval: 1 }, }, - progress + (prediction) => { + const progress = parsePredictionProgress(prediction); + callback(prediction, progress); + } ); expect(output).toBe("Goodbye!"); - expect(progress).toHaveBeenNthCalledWith(1, { - id: "ufawqhfynnddngldkgtslldrkq", - status: "starting", - }); + expect(callback).toHaveBeenNthCalledWith( + 1, + { + id: "ufawqhfynnddngldkgtslldrkq", + status: "starting", + logs: null, + }, + null + ); - expect(progress).toHaveBeenNthCalledWith(2, { - id: "ufawqhfynnddngldkgtslldrkq", - status: "processing", - }); + expect(callback).toHaveBeenNthCalledWith( + 2, + { + id: "ufawqhfynnddngldkgtslldrkq", + status: "processing", + logs: expect.any(String), + }, + { + percentage: 0.4, + current: 2, + total: 5, + } + ); - expect(progress).toHaveBeenNthCalledWith(3, { - id: "ufawqhfynnddngldkgtslldrkq", - status: "processing", - }); + expect(callback).toHaveBeenNthCalledWith( + 3, + { + id: "ufawqhfynnddngldkgtslldrkq", + status: "processing", + logs: expect.any(String), + }, + { + percentage: 0.8, + current: 4, + total: 5, + } + ); - expect(progress).toHaveBeenNthCalledWith(4, { - id: "ufawqhfynnddngldkgtslldrkq", - status: "succeeded", - output: "Goodbye!", - }); + expect(callback).toHaveBeenNthCalledWith( + 4, + { + id: "ufawqhfynnddngldkgtslldrkq", + status: "succeeded", + logs: expect.any(String), + output: "Goodbye!", + }, + { + percentage: 1.0, + current: 5, + total: 5, + } + ); - expect(progress).toHaveBeenCalledTimes(4); + expect(callback).toHaveBeenCalledTimes(4); }); test("Calls the correct API routes for a model", async () => { diff --git a/lib/util.js b/lib/util.js index 48d7563c..369b0bce 100644 --- a/lib/util.js +++ b/lib/util.js @@ -246,4 +246,38 @@ function isPlainObject(value) { ); } -module.exports = { transformFileInputs, validateWebhook, withAutomaticRetries }; +/** + * Parse prediction progress from logs. + * @param {object|string} input - A prediction object or string. + * @returns {object} - An object with the percentage, current, and total. + */ +function parsePredictionProgress(input) { + const logs = typeof input === "object" && input.logs ? input.logs : input; + if (!logs || typeof logs !== "string") { + return null; + } + + const pattern = /^\s*(\d+)%\s*\|.+?\|\s*(\d+)\/(\d+)/; + const lines = logs.split("\n").reverse(); + + for (const line of lines) { + const matches = line.match(pattern); + + if (matches && matches.length === 4) { + return { + percentage: parseInt(matches[1], 10) / 100, + current: parseInt(matches[2], 10), + total: parseInt(matches[3], 10), + }; + } + } + + return null; +} + +module.exports = { + transformFileInputs, + validateWebhook, + withAutomaticRetries, + parsePredictionProgress, +}; From 4e2802d271e09847e58f33658a0711bcc6af38de Mon Sep 17 00:00:00 2001 From: Mattt Zmuda Date: Wed, 21 Feb 2024 03:37:16 -0800 Subject: [PATCH 2/6] Rename parsePredictionProgress to parseProgress --- index.d.ts | 2 +- index.js | 4 ++-- index.test.ts | 4 ++-- lib/util.js | 6 +++--- 4 files changed, 8 insertions(+), 8 deletions(-) diff --git a/index.d.ts b/index.d.ts index 0ce87f3f..74286383 100644 --- a/index.d.ts +++ b/index.d.ts @@ -281,7 +281,7 @@ declare module "replicate" { secret: string ): boolean; - export function parsePredictionProgress(logs: Prediction | string): { + export function parseProgress(logs: Prediction | string): { percentage: number; current: number; total: number; diff --git a/index.js b/index.js index b03af1af..40ff6eb6 100644 --- a/index.js +++ b/index.js @@ -4,7 +4,7 @@ const { Stream } = require("./lib/stream"); const { withAutomaticRetries, validateWebhook, - parsePredictionProgress, + parseProgress, } = require("./lib/util"); const accounts = require("./lib/accounts"); @@ -379,4 +379,4 @@ class Replicate { module.exports = Replicate; module.exports.validateWebhook = validateWebhook; -module.exports.parsePredictionProgress = parsePredictionProgress; +module.exports.parseProgress = parseProgress; diff --git a/index.test.ts b/index.test.ts index 9d366fac..f0ed7f94 100644 --- a/index.test.ts +++ b/index.test.ts @@ -4,7 +4,7 @@ import Replicate, { Model, Prediction, validateWebhook, - parsePredictionProgress, + parseProgress, } from "replicate"; import nock from "nock"; import fetch from "cross-fetch"; @@ -946,7 +946,7 @@ describe("Replicate client", () => { wait: { interval: 1 }, }, (prediction) => { - const progress = parsePredictionProgress(prediction); + const progress = parseProgress(prediction); callback(prediction, progress); } ); diff --git a/lib/util.js b/lib/util.js index 369b0bce..84678837 100644 --- a/lib/util.js +++ b/lib/util.js @@ -247,11 +247,11 @@ function isPlainObject(value) { } /** - * Parse prediction progress from logs. + * Parse progress from prediction logs. * @param {object|string} input - A prediction object or string. * @returns {object} - An object with the percentage, current, and total. */ -function parsePredictionProgress(input) { +function parseProgress(input) { const logs = typeof input === "object" && input.logs ? input.logs : input; if (!logs || typeof logs !== "string") { return null; @@ -279,5 +279,5 @@ module.exports = { transformFileInputs, validateWebhook, withAutomaticRetries, - parsePredictionProgress, + parseProgress, }; From 051d00ad093d53d04f8a5d36b602e42221d8f12f Mon Sep 17 00:00:00 2001 From: Mattt Zmuda Date: Fri, 23 Feb 2024 06:03:48 -0800 Subject: [PATCH 3/6] Annotate possible null return value in jsdoc --- lib/util.js | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/lib/util.js b/lib/util.js index 84678837..a65c5474 100644 --- a/lib/util.js +++ b/lib/util.js @@ -249,7 +249,7 @@ function isPlainObject(value) { /** * Parse progress from prediction logs. * @param {object|string} input - A prediction object or string. - * @returns {object} - An object with the percentage, current, and total. + * @returns {(object|null)} - An object with the percentage, current, and total, or null if no progress can be parsed. */ function parseProgress(input) { const logs = typeof input === "object" && input.logs ? input.logs : input; From ac0a8d84b43db8411d2426a2f4b04af26cbe7d20 Mon Sep 17 00:00:00 2001 From: Mattt Zmuda Date: Fri, 23 Feb 2024 06:04:53 -0800 Subject: [PATCH 4/6] Annotate possible null return value in type definition --- index.d.ts | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/index.d.ts b/index.d.ts index 74286383..3a5a3514 100644 --- a/index.d.ts +++ b/index.d.ts @@ -285,5 +285,5 @@ declare module "replicate" { percentage: number; current: number; total: number; - }; + } | null; } From 8e7935ca8ab567cc03e0a3a44cb5625bea2efbc1 Mon Sep 17 00:00:00 2001 From: Mattt Zmuda Date: Fri, 23 Feb 2024 06:05:19 -0800 Subject: [PATCH 5/6] Rename parseProgress to parseProgressFromLogs --- index.d.ts | 2 +- index.js | 4 ++-- index.test.ts | 4 ++-- lib/util.js | 4 ++-- 4 files changed, 7 insertions(+), 7 deletions(-) diff --git a/index.d.ts b/index.d.ts index 3a5a3514..8dc998a8 100644 --- a/index.d.ts +++ b/index.d.ts @@ -281,7 +281,7 @@ declare module "replicate" { secret: string ): boolean; - export function parseProgress(logs: Prediction | string): { + export function parseProgressFromLogs(logs: Prediction | string): { percentage: number; current: number; total: number; diff --git a/index.js b/index.js index 40ff6eb6..24376fe1 100644 --- a/index.js +++ b/index.js @@ -4,7 +4,7 @@ const { Stream } = require("./lib/stream"); const { withAutomaticRetries, validateWebhook, - parseProgress, + parseProgressFromLogs, } = require("./lib/util"); const accounts = require("./lib/accounts"); @@ -379,4 +379,4 @@ class Replicate { module.exports = Replicate; module.exports.validateWebhook = validateWebhook; -module.exports.parseProgress = parseProgress; +module.exports.parseProgressFromLogs = parseProgressFromLogs; diff --git a/index.test.ts b/index.test.ts index f0ed7f94..97abc6fb 100644 --- a/index.test.ts +++ b/index.test.ts @@ -4,7 +4,7 @@ import Replicate, { Model, Prediction, validateWebhook, - parseProgress, + parseProgressFromLogs, } from "replicate"; import nock from "nock"; import fetch from "cross-fetch"; @@ -946,7 +946,7 @@ describe("Replicate client", () => { wait: { interval: 1 }, }, (prediction) => { - const progress = parseProgress(prediction); + const progress = parseProgressFromLogs(prediction); callback(prediction, progress); } ); diff --git a/lib/util.js b/lib/util.js index a65c5474..bb0c0920 100644 --- a/lib/util.js +++ b/lib/util.js @@ -251,7 +251,7 @@ function isPlainObject(value) { * @param {object|string} input - A prediction object or string. * @returns {(object|null)} - An object with the percentage, current, and total, or null if no progress can be parsed. */ -function parseProgress(input) { +function parseProgressFromLogs(input) { const logs = typeof input === "object" && input.logs ? input.logs : input; if (!logs || typeof logs !== "string") { return null; @@ -279,5 +279,5 @@ module.exports = { transformFileInputs, validateWebhook, withAutomaticRetries, - parseProgress, + parseProgressFromLogs, }; From 51ce4a7b7dcc63307b056e77565ffebc41cdd639 Mon Sep 17 00:00:00 2001 From: Mattt Zmuda Date: Fri, 23 Feb 2024 06:13:48 -0800 Subject: [PATCH 6/6] Expand documentation of parseProgressFromLogs --- lib/util.js | 17 +++++++++++++++++ 1 file changed, 17 insertions(+) diff --git a/lib/util.js b/lib/util.js index bb0c0920..949bafa4 100644 --- a/lib/util.js +++ b/lib/util.js @@ -248,6 +248,23 @@ function isPlainObject(value) { /** * Parse progress from prediction logs. + * + * This function supports log statements in the following format, + * which are generated by https://github.com/tqdm/tqdm and similar libraries: + * + * ``` + * 76%|████████████████████████████ | 7568/10000 [00:33<00:10, 229.00it/s] + * ``` + * + * @example + * const progress = parseProgressFromLogs("76%|████████████████████████████ | 7568/10000 [00:33<00:10, 229.00it/s]"); + * console.log(progress); + * // { + * // percentage: 0.76, + * // current: 7568, + * // total: 10000, + * // } + * * @param {object|string} input - A prediction object or string. * @returns {(object|null)} - An object with the percentage, current, and total, or null if no progress can be parsed. */