From 4ecaa027e8eaf94f064b2299b9a7f3c5c6e82b7c Mon Sep 17 00:00:00 2001 From: Rohan Mehta Date: Wed, 10 Jul 2024 10:05:16 -0400 Subject: [PATCH 1/4] Add prediction ID to ModelError exception Signed-off-by: Rohan Mehta --- replicate/exceptions.py | 6 ++++++ replicate/prediction.py | 4 ++-- replicate/run.py | 4 ++-- 3 files changed, 10 insertions(+), 4 deletions(-) diff --git a/replicate/exceptions.py b/replicate/exceptions.py index 6302d10f..497f522d 100644 --- a/replicate/exceptions.py +++ b/replicate/exceptions.py @@ -10,6 +10,12 @@ class ReplicateException(Exception): class ModelError(ReplicateException): """An error from user's code in a model.""" + prediction_id: str + + def __init__(self, error: Optional[str], prediction_id: str) -> None: + self.prediction_id = prediction_id + super().__init__(error) + class ReplicateError(ReplicateException): """ diff --git a/replicate/prediction.py b/replicate/prediction.py index 871566d7..b5590682 100644 --- a/replicate/prediction.py +++ b/replicate/prediction.py @@ -249,7 +249,7 @@ def output_iterator(self) -> Iterator[Any]: self.reload() if self.status == "failed": - raise ModelError(self.error) + raise ModelError(self.error, self.id) output = self.output or [] new_output = output[len(previous_output) :] @@ -272,7 +272,7 @@ async def async_output_iterator(self) -> AsyncIterator[Any]: await self.async_reload() if self.status == "failed": - raise ModelError(self.error) + raise ModelError(self.error, self.id) output = self.output or [] new_output = output[len(previous_output) :] diff --git a/replicate/run.py b/replicate/run.py index 975cc4dc..7fcbbf06 100644 --- a/replicate/run.py +++ b/replicate/run.py @@ -58,7 +58,7 @@ def run( prediction.wait() if prediction.status == "failed": - raise ModelError(prediction.error) + raise ModelError(prediction.error, prediction.id) return prediction.output @@ -97,7 +97,7 @@ async def async_run( await prediction.async_wait() if prediction.status == "failed": - raise ModelError(prediction.error) + raise ModelError(prediction.error, prediction.id) return prediction.output From 544b756b86b2df0e78c218774de09e7f05d8feca Mon Sep 17 00:00:00 2001 From: Mattt Zmuda Date: Thu, 18 Jul 2024 03:56:12 -0700 Subject: [PATCH 2/4] Pass prediction object to ModelError initializer Signed-off-by: Mattt Zmuda --- replicate/exceptions.py | 13 ++++++++----- replicate/prediction.py | 4 ++-- replicate/run.py | 4 ++-- 3 files changed, 12 insertions(+), 9 deletions(-) diff --git a/replicate/exceptions.py b/replicate/exceptions.py index 497f522d..f52f9fb4 100644 --- a/replicate/exceptions.py +++ b/replicate/exceptions.py @@ -1,7 +1,10 @@ -from typing import Optional +from typing import TYPE_CHECKING, Optional import httpx +if TYPE_CHECKING: + from replicate.prediction import Prediction + class ReplicateException(Exception): """A base class for all Replicate exceptions.""" @@ -10,11 +13,11 @@ class ReplicateException(Exception): class ModelError(ReplicateException): """An error from user's code in a model.""" - prediction_id: str + prediction: "Prediction" - def __init__(self, error: Optional[str], prediction_id: str) -> None: - self.prediction_id = prediction_id - super().__init__(error) + def __init__(self, prediction: "Prediction") -> None: + self.prediction = prediction + super().__init__(prediction.error) class ReplicateError(ReplicateException): diff --git a/replicate/prediction.py b/replicate/prediction.py index b5590682..74c1946e 100644 --- a/replicate/prediction.py +++ b/replicate/prediction.py @@ -249,7 +249,7 @@ def output_iterator(self) -> Iterator[Any]: self.reload() if self.status == "failed": - raise ModelError(self.error, self.id) + raise ModelError(self) output = self.output or [] new_output = output[len(previous_output) :] @@ -272,7 +272,7 @@ async def async_output_iterator(self) -> AsyncIterator[Any]: await self.async_reload() if self.status == "failed": - raise ModelError(self.error, self.id) + raise ModelError(self) output = self.output or [] new_output = output[len(previous_output) :] diff --git a/replicate/run.py b/replicate/run.py index 7fcbbf06..ae1ca7e5 100644 --- a/replicate/run.py +++ b/replicate/run.py @@ -58,7 +58,7 @@ def run( prediction.wait() if prediction.status == "failed": - raise ModelError(prediction.error, prediction.id) + raise ModelError(prediction) return prediction.output @@ -97,7 +97,7 @@ async def async_run( await prediction.async_wait() if prediction.status == "failed": - raise ModelError(prediction.error, prediction.id) + raise ModelError(prediction) return prediction.output From 1a5f65f9df2b4e10f80c3976fe9a28d1b6dcf204 Mon Sep 17 00:00:00 2001 From: Mattt Zmuda Date: Thu, 18 Jul 2024 03:56:34 -0700 Subject: [PATCH 3/4] Add test coverage for ModelError Signed-off-by: Mattt Zmuda --- tests/test_run.py | 71 ++++++++++++++++++++++++++++++++++++++++++++++- 1 file changed, 70 insertions(+), 1 deletion(-) diff --git a/tests/test_run.py b/tests/test_run.py index 84c8f3ab..d117eb32 100644 --- a/tests/test_run.py +++ b/tests/test_run.py @@ -7,7 +7,7 @@ import replicate from replicate.client import Client -from replicate.exceptions import ReplicateError +from replicate.exceptions import ModelError, ReplicateError @pytest.mark.vcr("run.yaml") @@ -184,3 +184,72 @@ def prediction_with_status(status: str) -> dict: ) assert output == "Hello, world!" + + +@pytest.mark.asyncio +async def test_run_with_model_error(mock_replicate_api_token): + def prediction_with_status(status: str) -> dict: + return { + "id": "p1", + "model": "test/example", + "version": "v1", + "urls": { + "get": "https://api.replicate.com/v1/predictions/p1", + "cancel": "https://api.replicate.com/v1/predictions/p1/cancel", + }, + "created_at": "2023-10-05T12:00:00.000000Z", + "source": "api", + "status": status, + "input": {"text": "world"}, + "output": None, + "error": "OOM" if status == "failed" else None, + "logs": "", + } + + router = respx.Router(base_url="https://api.replicate.com/v1") + router.route(method="POST", path="/predictions").mock( + return_value=httpx.Response( + 201, + json=prediction_with_status("processing"), + ) + ) + router.route(method="GET", path="/predictions/p1").mock( + return_value=httpx.Response( + 200, + json=prediction_with_status("failed"), + ) + ) + router.route( + method="GET", + path="/models/test/example/versions/v1", + ).mock( + return_value=httpx.Response( + 201, + json={ + "id": "f2d6b24e6002f25f77ae89c2b0a5987daa6d0bf751b858b94b8416e8542434d1", + "created_at": "2024-07-18T00:35:56.210272Z", + "cog_version": "0.9.10", + "openapi_schema": { + "openapi": "3.0.2", + }, + }, + ) + ) + router.route(host="api.replicate.com").pass_through() + + client = Client( + api_token="test-token", transport=httpx.MockTransport(router.handler) + ) + client.poll_interval = 0.001 + + with pytest.raises(ModelError) as excinfo: + client.run( + "test/example:v1", + input={ + "text": "Hello, world!", + }, + ) + + assert str(excinfo.value) == "OOM" + assert excinfo.value.prediction.error == "OOM" + assert excinfo.value.prediction.status == "failed" From b2779fb5bd301a3db603b62283ddb6d423b8c226 Mon Sep 17 00:00:00 2001 From: Mattt Zmuda Date: Thu, 18 Jul 2024 04:21:17 -0700 Subject: [PATCH 4/4] Document ModelError in README Signed-off-by: Mattt Zmuda --- README.md | 18 ++++++++++++++++++ 1 file changed, 18 insertions(+) diff --git a/README.md b/README.md index a4e80c54..eb411cf7 100644 --- a/README.md +++ b/README.md @@ -82,6 +82,24 @@ or a handle to a file on your local device. "an astronaut riding a horse" ``` +`replicate.run` raises `ModelError` if the prediction fails. +You can access the exception's `prediction` property +to get more information about the failure. + +```python +import replicate +from replicate.exceptions import ModelError + +try: + output = replicate.run("stability-ai/stable-diffusion-3", { "prompt": "An astronaut riding a rainbow unicorn" }) +except ModelError as e + if "(some known issue)" in e.logs: + pass + + print("Failed prediction: " + e.prediction.id) +``` + + ## Run a model and stream its output Replicate’s API supports server-sent event streams (SSEs) for language models.