From e7358c41fa9d81cbc7c3593b757ce194816fced9 Mon Sep 17 00:00:00 2001 From: Mattt Zmuda Date: Sun, 17 Sep 2023 08:33:43 -0700 Subject: [PATCH 1/3] Add progress property to Prediction Signed-off-by: Mattt Zmuda --- replicate/prediction.py | 33 +++++++++++++++++++++ tests/test_prediction.py | 62 ++++++++++++++++++++++++++++++++++++++++ 2 files changed, 95 insertions(+) diff --git a/replicate/prediction.py b/replicate/prediction.py index 298a4655..8dba7fda 100644 --- a/replicate/prediction.py +++ b/replicate/prediction.py @@ -1,4 +1,6 @@ +import re import time +from dataclasses import dataclass from typing import Any, Dict, Iterator, List, Optional from replicate.base_model import BaseModel @@ -56,6 +58,37 @@ class Prediction(BaseModel): - `cancel`: A URL to cancel the prediction. """ + @dataclass + class Progress: + percentage: float + """The percentage of the prediction that has completed.""" + + current: int + """The number of items that have been processed.""" + + total: int + """The total number of items to process.""" + + @property + def progress(self) -> Optional[Progress]: + if self.logs is None or self.logs == "": + return None + + pattern = ( + r"^\s*(?P\d+)%\s*\|.+?\|\s*(?P\d+)\/(?P\d+)" + ) + re_compiled = re.compile(pattern) + + lines = self.logs.split("\n") + for i in reversed(range(len(lines))): + line = lines[i].strip() + if re_compiled.match(line): + matches = re_compiled.findall(line) + if len(matches) == 1: + percentage, current, total = map(int, matches[0]) + return Prediction.Progress(percentage / 100.0, current, total) + return None + def wait(self) -> None: """ Wait for prediction to finish. diff --git a/tests/test_prediction.py b/tests/test_prediction.py index ad6ccba6..4b330015 100644 --- a/tests/test_prediction.py +++ b/tests/test_prediction.py @@ -1,6 +1,8 @@ import responses from responses import matchers +from replicate.prediction import Prediction + from .factories import create_client, create_version @@ -214,3 +216,63 @@ def test_async_timings(): assert prediction.completed_at == "2022-04-26T20:02:27.648305Z" assert prediction.output == "hello world" assert prediction.metrics["predict_time"] == 1.2345 + + +def test_prediction_progress(): + client = create_client() + version = create_version(client) + prediction = Prediction( + id="ufawqhfynnddngldkgtslldrkq", version=version, status="starting" + ) + + lines = [ + "Using seed: 12345", + "0%| | 0/5 [00:00 Date: Sun, 17 Sep 2023 09:05:20 -0700 Subject: [PATCH 2/3] Refactor progress log parsing into Progress dataclass Signed-off-by: Mattt Zmuda --- replicate/prediction.py | 35 +++++++++++++++++++++-------------- 1 file changed, 21 insertions(+), 14 deletions(-) diff --git a/replicate/prediction.py b/replicate/prediction.py index 8dba7fda..e868de7e 100644 --- a/replicate/prediction.py +++ b/replicate/prediction.py @@ -69,25 +69,32 @@ class Progress: total: int """The total number of items to process.""" + _pattern = re.compile( + r"^\s*(?P\d+)%\s*\|.+?\|\s*(?P\d+)\/(?P\d+)" + ) + + @classmethod + def parse(cls, logs: str) -> Optional["Prediction.Progress"]: + lines = logs.split("\n") + for i in reversed(range(len(lines))): + line = lines[i].strip() + if cls._pattern.match(line): + matches = cls._pattern.findall(line) + if len(matches) == 1: + percentage, current, total = map(int, matches[0]) + return cls(percentage / 100.0, current, total) + + return None + @property def progress(self) -> Optional[Progress]: + """ + The progress of the prediction, if available. + """ if self.logs is None or self.logs == "": return None - pattern = ( - r"^\s*(?P\d+)%\s*\|.+?\|\s*(?P\d+)\/(?P\d+)" - ) - re_compiled = re.compile(pattern) - - lines = self.logs.split("\n") - for i in reversed(range(len(lines))): - line = lines[i].strip() - if re_compiled.match(line): - matches = re_compiled.findall(line) - if len(matches) == 1: - percentage, current, total = map(int, matches[0]) - return Prediction.Progress(percentage / 100.0, current, total) - return None + return Prediction.Progress.parse(self.logs) def wait(self) -> None: """ From 06cd1edcbb4aad1ee453ad6069bd4765ff76666f Mon Sep 17 00:00:00 2001 From: Mattt Zmuda Date: Sun, 17 Sep 2023 09:07:13 -0700 Subject: [PATCH 3/3] Add doc string Signed-off-by: Mattt Zmuda --- replicate/prediction.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/replicate/prediction.py b/replicate/prediction.py index e868de7e..f40a587a 100644 --- a/replicate/prediction.py +++ b/replicate/prediction.py @@ -75,6 +75,8 @@ class Progress: @classmethod def parse(cls, logs: str) -> Optional["Prediction.Progress"]: + """Parse the progress from the logs of a prediction.""" + lines = logs.split("\n") for i in reversed(range(len(lines))): line = lines[i].strip()