From 4b57e3b17f8b395c24a03d49b1a0351983a942e1 Mon Sep 17 00:00:00 2001 From: Guillaume Aquilina Date: Mon, 9 Dec 2024 15:05:24 -0600 Subject: [PATCH] fix: retry on connect error --- pyproject.toml | 2 +- workflowai/core/client/_api.py | 18 ++++++++++--- workflowai/core/client/_api_test.py | 35 ++++++++++++++++++++++++ workflowai/core/client/_utils.py | 25 ++++-------------- workflowai/core/client/_utils_test.py | 22 +++------------- workflowai/core/client/client_test.py | 22 ++++++++++++++++ workflowai/core/domain/errors.py | 38 ++++++++++++++++++++++++++- workflowai/core/domain/errors_test.py | 23 ++++++++++++++-- 8 files changed, 138 insertions(+), 47 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index e7fce42..a614722 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [tool.poetry] name = "workflowai" -version = "0.5.1" +version = "0.5.2" description = "" authors = ["Guillaume Aquilina "] readme = "README.md" diff --git a/workflowai/core/client/_api.py b/workflowai/core/client/_api.py index edb0789..c831f47 100644 --- a/workflowai/core/client/_api.py +++ b/workflowai/core/client/_api.py @@ -1,4 +1,5 @@ import logging +from contextlib import asynccontextmanager from typing import Any, AsyncIterator, Literal, Optional, TypeVar, Union, overload import httpx @@ -19,17 +20,26 @@ def __init__(self, endpoint: str, api_key: str, source_headers: Optional[dict[st self.api_key = api_key self.source_headers = source_headers or {} - def _client(self) -> httpx.AsyncClient: + @asynccontextmanager + async def _client(self): source_headers = self.source_headers or {} - client = httpx.AsyncClient( + async with httpx.AsyncClient( base_url=self.endpoint, headers={ "Authorization": f"Bearer {self.api_key}", **source_headers, }, timeout=120.0, - ) - return client + ) as client: + try: + yield client + except (httpx.ReadError, httpx.ConnectError) as e: + raise WorkflowAIError( + response=None, + error=BaseError(message="Could not read response", code="connection_error"), + # We can retry after 10ms + retry_after_delay_seconds=0.010, + ) from e async def get(self, path: str, returns: type[_R], query: Union[dict[str, Any], None] = None) -> _R: async with self._client() as client: diff --git a/workflowai/core/client/_api_test.py b/workflowai/core/client/_api_test.py index 88ded0f..d5aeaa5 100644 --- a/workflowai/core/client/_api_test.py +++ b/workflowai/core/client/_api_test.py @@ -154,3 +154,38 @@ async def test_stream_with_multiple_chunks( httpx_mock.add_response(stream=IteratorStream(streamed_chunks)) chunks = await stream_fn() assert chunks == [TestOutputModel(a="test"), TestOutputModel(a="test2")] + + +class TestReadAndConnectError: + @pytest.mark.parametrize("exception", [httpx.ReadError("arg"), httpx.ConnectError("arg")]) + async def test_get(self, httpx_mock: HTTPXMock, client: APIClient, exception: Exception): + httpx_mock.add_exception(exception) + + with pytest.raises(WorkflowAIError) as e: + await client.get(path="test_path", returns=TestOutputModel) + + assert e.value.error.code == "connection_error" + + @pytest.mark.parametrize("exception", [httpx.ReadError("arg"), httpx.ConnectError("arg")]) + async def test_post(self, httpx_mock: HTTPXMock, client: APIClient, exception: Exception): + httpx_mock.add_exception(exception) + + with pytest.raises(WorkflowAIError) as e: + await client.post(path="test_path", data=TestInputModel(), returns=TestOutputModel) + + assert e.value.error.code == "connection_error" + + @pytest.mark.parametrize("exception", [httpx.ReadError("arg"), httpx.ConnectError("arg")]) + async def test_stream(self, httpx_mock: HTTPXMock, client: APIClient, exception: Exception): + httpx_mock.add_exception(exception) + + with pytest.raises(WorkflowAIError) as e: # noqa: PT012 + async for _ in client.stream( + method="GET", + path="test_path", + data=TestInputModel(), + returns=TestOutputModel, + ): + pass + + assert e.value.error.code == "connection_error" diff --git a/workflowai/core/client/_utils.py b/workflowai/core/client/_utils.py index 08a16e5..4abfb94 100644 --- a/workflowai/core/client/_utils.py +++ b/workflowai/core/client/_utils.py @@ -3,10 +3,8 @@ # the delimiter is not withing a quoted string import asyncio import re -from email.utils import parsedate_to_datetime from json import JSONDecodeError from time import time -from typing import Any, Optional from workflowai.core.client._types import OutputValidator from workflowai.core.domain.errors import BaseError, WorkflowAIError @@ -24,22 +22,6 @@ def split_chunks(chunk: bytes): yield chunk_str[start:] -def retry_after_to_delay_seconds(retry_after: Any) -> Optional[float]: - if retry_after is None: - return None - - try: - return float(retry_after) - except ValueError: - pass - try: - retry_after_date = parsedate_to_datetime(retry_after) - current_time = time() - return retry_after_date.timestamp() - current_time - except (TypeError, ValueError, OverflowError): - return None - - # Returns two functions: # - _should_retry: returns True if we should retry # - _wait_for_exception: waits after an exception only if we should retry, otherwise raises @@ -60,13 +42,16 @@ def _should_retry(): return retry_count < max_retry_count and _leftover_delay() >= 0 async def _wait_for_exception(e: WorkflowAIError): - if not e.response: + retry_after = e.retry_after_delay_seconds + if retry_after is None: raise e nonlocal retry_count - retry_after = retry_after_to_delay_seconds(e.response.headers.get("Retry-After")) leftover_delay = _leftover_delay() if not retry_after or leftover_delay < 0 or retry_count >= max_retry_count: + if not e.response: + raise e + # Convert error to WorkflowAIError try: response_json = e.response.json() diff --git a/workflowai/core/client/_utils_test.py b/workflowai/core/client/_utils_test.py index e399971..91d565d 100644 --- a/workflowai/core/client/_utils_test.py +++ b/workflowai/core/client/_utils_test.py @@ -1,12 +1,9 @@ -from typing import Optional from unittest.mock import Mock import pytest -from freezegun import freeze_time -from httpx import HTTPStatusError -from workflowai.core.client._utils import build_retryable_wait, retry_after_to_delay_seconds, split_chunks -from workflowai.core.domain.errors import WorkflowAIError +from workflowai.core.client._utils import build_retryable_wait, split_chunks +from workflowai.core.domain.errors import BaseError, WorkflowAIError @pytest.mark.parametrize( @@ -23,25 +20,12 @@ def test_split_chunks(chunk: bytes, expected: list[bytes]): assert list(split_chunks(chunk)) == expected -@freeze_time("2024-01-01T00:00:00Z") -@pytest.mark.parametrize( - ("retry_after", "expected"), - [ - (None, None), - ("10", 10), - ("Wed, 01 Jan 2024 00:00:10 UTC", 10), - ], -) -def test_retry_after_to_delay_seconds(retry_after: Optional[str], expected: Optional[float]): - assert retry_after_to_delay_seconds(retry_after) == expected - - class TestBuildRetryableWait: @pytest.fixture def request_error(self): response = Mock() response.headers = {"Retry-After": "0.01"} - return HTTPStatusError(message="", request=Mock(), response=response) + return WorkflowAIError(response=response, error=BaseError(message="")) async def test_should_retry_count(self, request_error: WorkflowAIError): should_retry, wait_for_exception = build_retryable_wait(60, 1) diff --git a/workflowai/core/client/client_test.py b/workflowai/core/client/client_test.py index a0cc734..023aa36 100644 --- a/workflowai/core/client/client_test.py +++ b/workflowai/core/client/client_test.py @@ -3,6 +3,7 @@ from typing import Any, AsyncIterator from unittest.mock import AsyncMock, patch +import httpx import pytest from pytest_httpx import HTTPXMock, IteratorStream @@ -13,6 +14,7 @@ WorkflowAIClient, _compute_default_version_reference, # pyright: ignore [reportPrivateUsage] ) +from workflowai.core.domain.errors import WorkflowAIError from workflowai.core.domain.run import Run @@ -178,6 +180,26 @@ async def test_run_retries_on_too_many_requests(self, httpx_mock: HTTPXMock, cli assert reqs[0].url == "http://localhost:8000/v1/_/tasks/123/schemas/1/run" assert reqs[1].url == "http://localhost:8000/v1/_/tasks/123/schemas/1/run" + async def test_run_retries_on_connection_error(self, httpx_mock: HTTPXMock, client: Client): + task = HelloTask(id="123", schema_id=1) + + httpx_mock.add_exception(httpx.ConnectError("arg")) + httpx_mock.add_response(json=fixtures_json("task_run.json")) + + task_run = await client.run(task, task_input=HelloTaskInput(name="Alice"), max_retry_count=5) + assert task_run.id == "8f635b73-f403-47ee-bff9-18320616c6cc" + + async def test_max_retries(self, httpx_mock: HTTPXMock, client: Client): + task = HelloTask(id="123", schema_id=1) + + httpx_mock.add_exception(httpx.ConnectError("arg"), is_reusable=True) + + with pytest.raises(WorkflowAIError): + await client.run(task, task_input=HelloTaskInput(name="Alice"), max_retry_count=5) + + reqs = httpx_mock.get_requests() + assert len(reqs) == 5 + class TestTask: @pytest.fixture diff --git a/workflowai/core/domain/errors.py b/workflowai/core/domain/errors.py index f44b728..5c58606 100644 --- a/workflowai/core/domain/errors.py +++ b/workflowai/core/domain/errors.py @@ -1,4 +1,6 @@ +from email.utils import parsedate_to_datetime from json import JSONDecodeError +from time import time from typing import Any, Literal, Optional, Union from httpx import Response @@ -53,6 +55,7 @@ # The request was invalid "bad_request", "invalid_file", + "connection_error", ], str, # Using as a fallback to avoid validation error if an error code is added to the API ] @@ -70,11 +73,34 @@ class ErrorResponse(BaseModel): task_run_id: Optional[str] = None +def _retry_after_to_delay_seconds(retry_after: Any) -> Optional[float]: + if retry_after is None: + return None + + try: + return float(retry_after) + except ValueError: + pass + try: + retry_after_date = parsedate_to_datetime(retry_after) + current_time = time() + return retry_after_date.timestamp() - current_time + except (TypeError, ValueError, OverflowError): + return None + + class WorkflowAIError(Exception): - def __init__(self, response: Optional[Response], error: BaseError, task_run_id: Optional[str] = None): + def __init__( + self, + response: Optional[Response], + error: BaseError, + task_run_id: Optional[str] = None, + retry_after_delay_seconds: Optional[float] = None, + ): self.error = error self.task_run_id = task_run_id self.response = response + self._retry_after_delay_seconds = retry_after_delay_seconds def __str__(self): return f"WorkflowAIError : [{self.error.code}] ({self.error.status_code}): [{self.error.message}]" @@ -106,3 +132,13 @@ def from_response(cls, response: Response): ), task_run_id=task_run_id, ) + + @property + def retry_after_delay_seconds(self) -> Optional[float]: + if self._retry_after_delay_seconds: + return self._retry_after_delay_seconds + + if self.response: + return _retry_after_to_delay_seconds(self.response.headers.get("Retry-After")) + + return None diff --git a/workflowai/core/domain/errors_test.py b/workflowai/core/domain/errors_test.py index e670b62..9d62d9d 100644 --- a/workflowai/core/domain/errors_test.py +++ b/workflowai/core/domain/errors_test.py @@ -1,7 +1,13 @@ - +from typing import Optional from unittest.mock import Mock -from workflowai.core.domain.errors import WorkflowAIError +import pytest +from freezegun import freeze_time + +from workflowai.core.domain.errors import ( + WorkflowAIError, + _retry_after_to_delay_seconds, # pyright: ignore [reportPrivateUsage] +) def test_workflow_ai_error_404(): @@ -14,3 +20,16 @@ def test_workflow_ai_error_404(): assert error.error.message == "None" assert error.error.status_code == 404 assert error.error.code == "object_not_found" + + +@freeze_time("2024-01-01T00:00:00Z") +@pytest.mark.parametrize( + ("retry_after", "expected"), + [ + (None, None), + ("10", 10), + ("Wed, 01 Jan 2024 00:00:10 UTC", 10), + ], +) +def test_retry_after_to_delay_seconds(retry_after: Optional[str], expected: Optional[float]): + assert _retry_after_to_delay_seconds(retry_after) == expected