From bad98a26c21d2eec17b3626681cf7e5e5fb1cb2f Mon Sep 17 00:00:00 2001 From: Mattt Zmuda Date: Mon, 18 Mar 2024 06:02:45 -0700 Subject: [PATCH 1/5] Improve client HTTP errors Signed-off-by: Mattt Zmuda --- replicate/client.py | 2 +- replicate/exceptions.py | 72 ++++++++++++++++++++++++++++++++++++++++- tests/test_client.py | 54 +++++++++++++++++++++++++++++++ 3 files changed, 126 insertions(+), 2 deletions(-) diff --git a/replicate/client.py b/replicate/client.py index 3656d826..b267e1c1 100644 --- a/replicate/client.py +++ b/replicate/client.py @@ -365,4 +365,4 @@ def _build_httpx_client( def _raise_for_status(resp: httpx.Response) -> None: if 400 <= resp.status_code < 600: - raise ReplicateError(resp.json()["detail"]) + raise ReplicateError.from_response(resp) diff --git a/replicate/exceptions.py b/replicate/exceptions.py index e1aa51c4..29ec776e 100644 --- a/replicate/exceptions.py +++ b/replicate/exceptions.py @@ -1,3 +1,8 @@ +from typing import Optional + +import httpx + + class ReplicateException(Exception): """A base class for all Replicate exceptions.""" @@ -7,4 +12,69 @@ class ModelError(ReplicateException): class ReplicateError(ReplicateException): - """An error from Replicate.""" + """ + An error from Replicate's API. + + This class represents a problem details response as defined in RFC 7807. + """ + + type: Optional[str] + """A URI that identifies the error type.""" + + title: Optional[str] + """A short, human-readable summary of the error.""" + + status: Optional[int] + """The HTTP status code.""" + + detail: Optional[str] + """A human-readable explanation specific to this occurrence of the error.""" + + instance: Optional[str] + """A URI that identifies the specific occurrence of the error.""" + + def __init__( + self, + type: Optional[str] = None, + title: Optional[str] = None, + status: Optional[int] = None, + detail: Optional[str] = None, + instance: Optional[str] = None, + ) -> None: + self.type = type + self.title = title + self.status = status + self.detail = detail + self.instance = instance + + @classmethod + def from_response(cls, response: httpx.Response) -> "ReplicateError": + """Create a ReplicateError from a requests.Response.""" + try: + data = response.json() + except ValueError: + data = {} + + return cls( + type=data.get("type"), + title=data.get("title"), + detail=data.get("detail"), + status=response.status_code, + instance=data.get("instance"), + ) + + def to_dict(self) -> dict: + return { + key: value + for key, value in { + "type": self.type, + "title": self.title, + "status": self.status, + "detail": self.detail, + "instance": self.instance, + }.items() + if value is not None + } + + def __str__(self) -> str: + return f"ReplicateError: {self.to_dict()}" diff --git a/tests/test_client.py b/tests/test_client.py index 95636771..a95711a9 100644 --- a/tests/test_client.py +++ b/tests/test_client.py @@ -31,3 +31,57 @@ async def test_authorization_when_setting_environ_after_import(): client = replicate.Client(transport=httpx.MockTransport(router.handler)) resp = client._request("GET", "/") assert resp.status_code == 200 + + +@pytest.mark.asyncio +async def test_client_error_handling(): + import replicate + from replicate.exceptions import ReplicateError + + router = respx.Router() + router.route( + method="GET", + url="https://api.replicate.com/", + headers={"Authorization": "Token test-client-error"}, + ).mock( + return_value=httpx.Response( + 400, + json={"detail": "Client error occurred"}, + ) + ) + + token = "test-client-error" # noqa: S105 + + with mock.patch.dict(os.environ, {"REPLICATE_API_TOKEN": token}): + client = replicate.Client(transport=httpx.MockTransport(router.handler)) + with pytest.raises(ReplicateError) as exc_info: + client._request("GET", "/") + assert "'status': 400" in str(exc_info.value) + assert "'detail': 'Client error occurred'" in str(exc_info.value) + + +@pytest.mark.asyncio +async def test_server_error_handling(): + import replicate + from replicate.exceptions import ReplicateError + + router = respx.Router() + router.route( + method="GET", + url="https://api.replicate.com/", + headers={"Authorization": "Token test-server-error"}, + ).mock( + return_value=httpx.Response( + 500, + json={"detail": "Server error occurred"}, + ) + ) + + token = "test-server-error" # noqa: S105 + + with mock.patch.dict(os.environ, {"REPLICATE_API_TOKEN": token}): + client = replicate.Client(transport=httpx.MockTransport(router.handler)) + with pytest.raises(ReplicateError) as exc_info: + client._request("GET", "/") + assert "'status': 500" in str(exc_info.value) + assert "'detail': 'Server error occurred'" in str(exc_info.value) From 487a9a7c72f07cd71bba45ca2b8ad4c5d5ee7905 Mon Sep 17 00:00:00 2001 From: Mattt Zmuda Date: Tue, 19 Mar 2024 03:04:04 -0700 Subject: [PATCH 2/5] Fix docstring Signed-off-by: Mattt Zmuda --- replicate/exceptions.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/replicate/exceptions.py b/replicate/exceptions.py index 29ec776e..6e948c89 100644 --- a/replicate/exceptions.py +++ b/replicate/exceptions.py @@ -49,7 +49,7 @@ def __init__( @classmethod def from_response(cls, response: httpx.Response) -> "ReplicateError": - """Create a ReplicateError from a requests.Response.""" + """Create a ReplicateError from an HTTP response.""" try: data = response.json() except ValueError: From 716e1a1bfba816a1dca798d795667f2a098f374d Mon Sep 17 00:00:00 2001 From: Mattt Zmuda Date: Tue, 19 Mar 2024 03:05:50 -0700 Subject: [PATCH 3/5] Improve implementation of __str__ Signed-off-by: Mattt Zmuda --- replicate/exceptions.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/replicate/exceptions.py b/replicate/exceptions.py index 6e948c89..ed6d198f 100644 --- a/replicate/exceptions.py +++ b/replicate/exceptions.py @@ -77,4 +77,6 @@ def to_dict(self) -> dict: } def __str__(self) -> str: - return f"ReplicateError: {self.to_dict()}" + return "ReplicateError Details:\n" + "\n".join( + [f"{key}: {value}" for key, value in self.to_dict().items()] + ) From a3806f05cab14da46ec289b9b5ee62d7ea7735c7 Mon Sep 17 00:00:00 2001 From: Mattt Zmuda Date: Tue, 19 Mar 2024 03:06:01 -0700 Subject: [PATCH 4/5] Implement __repr__ Signed-off-by: Mattt Zmuda --- replicate/exceptions.py | 13 +++++++++++++ 1 file changed, 13 insertions(+) diff --git a/replicate/exceptions.py b/replicate/exceptions.py index ed6d198f..4ac839c0 100644 --- a/replicate/exceptions.py +++ b/replicate/exceptions.py @@ -80,3 +80,16 @@ def __str__(self) -> str: return "ReplicateError Details:\n" + "\n".join( [f"{key}: {value}" for key, value in self.to_dict().items()] ) + + def __repr__(self) -> str: + class_name = self.__class__.__name__ + params = ", ".join( + [ + f"type={repr(self.type)}", + f"title={repr(self.title)}", + f"status={repr(self.status)}", + f"detail={repr(self.detail)}", + f"instance={repr(self.instance)}", + ] + ) + return f"{class_name}({params})" From 5a4172ac463cf52ee2e3f5387605f7bad7c6c467 Mon Sep 17 00:00:00 2001 From: Mattt Zmuda Date: Tue, 19 Mar 2024 03:24:41 -0700 Subject: [PATCH 5/5] Fix expected test outputs for error strings Signed-off-by: Mattt Zmuda --- tests/test_client.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/tests/test_client.py b/tests/test_client.py index a95711a9..163b185e 100644 --- a/tests/test_client.py +++ b/tests/test_client.py @@ -56,8 +56,8 @@ async def test_client_error_handling(): client = replicate.Client(transport=httpx.MockTransport(router.handler)) with pytest.raises(ReplicateError) as exc_info: client._request("GET", "/") - assert "'status': 400" in str(exc_info.value) - assert "'detail': 'Client error occurred'" in str(exc_info.value) + assert "status: 400" in str(exc_info.value) + assert "detail: Client error occurred" in str(exc_info.value) @pytest.mark.asyncio @@ -83,5 +83,5 @@ async def test_server_error_handling(): client = replicate.Client(transport=httpx.MockTransport(router.handler)) with pytest.raises(ReplicateError) as exc_info: client._request("GET", "/") - assert "'status': 500" in str(exc_info.value) - assert "'detail': 'Server error occurred'" in str(exc_info.value) + assert "status: 500" in str(exc_info.value) + assert "detail: Server error occurred" in str(exc_info.value)