Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[tool.poetry]
name = "workflowai"
version = "0.5.1"
version = "0.5.2"
description = ""
authors = ["Guillaume Aquilina <guillaume@workflowai.com>"]
readme = "README.md"
Expand Down
18 changes: 14 additions & 4 deletions workflowai/core/client/_api.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import logging
from contextlib import asynccontextmanager
from typing import Any, AsyncIterator, Literal, Optional, TypeVar, Union, overload

import httpx
Expand All @@ -19,17 +20,26 @@ def __init__(self, endpoint: str, api_key: str, source_headers: Optional[dict[st
self.api_key = api_key
self.source_headers = source_headers or {}

def _client(self) -> httpx.AsyncClient:
@asynccontextmanager
async def _client(self):
source_headers = self.source_headers or {}
client = httpx.AsyncClient(
async with httpx.AsyncClient(
base_url=self.endpoint,
headers={
"Authorization": f"Bearer {self.api_key}",
**source_headers,
},
timeout=120.0,
)
return client
) as client:
try:
yield client
except (httpx.ReadError, httpx.ConnectError) as e:
raise WorkflowAIError(
response=None,
error=BaseError(message="Could not read response", code="connection_error"),
# We can retry after 10ms
retry_after_delay_seconds=0.010,
) from e

async def get(self, path: str, returns: type[_R], query: Union[dict[str, Any], None] = None) -> _R:
async with self._client() as client:
Expand Down
35 changes: 35 additions & 0 deletions workflowai/core/client/_api_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -154,3 +154,38 @@ async def test_stream_with_multiple_chunks(
httpx_mock.add_response(stream=IteratorStream(streamed_chunks))
chunks = await stream_fn()
assert chunks == [TestOutputModel(a="test"), TestOutputModel(a="test2")]


class TestReadAndConnectError:
@pytest.mark.parametrize("exception", [httpx.ReadError("arg"), httpx.ConnectError("arg")])
async def test_get(self, httpx_mock: HTTPXMock, client: APIClient, exception: Exception):
httpx_mock.add_exception(exception)

with pytest.raises(WorkflowAIError) as e:
await client.get(path="test_path", returns=TestOutputModel)

assert e.value.error.code == "connection_error"

@pytest.mark.parametrize("exception", [httpx.ReadError("arg"), httpx.ConnectError("arg")])
async def test_post(self, httpx_mock: HTTPXMock, client: APIClient, exception: Exception):
httpx_mock.add_exception(exception)

with pytest.raises(WorkflowAIError) as e:
await client.post(path="test_path", data=TestInputModel(), returns=TestOutputModel)

assert e.value.error.code == "connection_error"

@pytest.mark.parametrize("exception", [httpx.ReadError("arg"), httpx.ConnectError("arg")])
async def test_stream(self, httpx_mock: HTTPXMock, client: APIClient, exception: Exception):
httpx_mock.add_exception(exception)

with pytest.raises(WorkflowAIError) as e: # noqa: PT012
async for _ in client.stream(
method="GET",
path="test_path",
data=TestInputModel(),
returns=TestOutputModel,
):
pass

assert e.value.error.code == "connection_error"
25 changes: 5 additions & 20 deletions workflowai/core/client/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,8 @@
# the delimiter is not withing a quoted string
import asyncio
import re
from email.utils import parsedate_to_datetime
from json import JSONDecodeError
from time import time
from typing import Any, Optional

from workflowai.core.client._types import OutputValidator
from workflowai.core.domain.errors import BaseError, WorkflowAIError
Expand All @@ -24,22 +22,6 @@ def split_chunks(chunk: bytes):
yield chunk_str[start:]


def retry_after_to_delay_seconds(retry_after: Any) -> Optional[float]:
if retry_after is None:
return None

try:
return float(retry_after)
except ValueError:
pass
try:
retry_after_date = parsedate_to_datetime(retry_after)
current_time = time()
return retry_after_date.timestamp() - current_time
except (TypeError, ValueError, OverflowError):
return None


# Returns two functions:
# - _should_retry: returns True if we should retry
# - _wait_for_exception: waits after an exception only if we should retry, otherwise raises
Expand All @@ -60,13 +42,16 @@ def _should_retry():
return retry_count < max_retry_count and _leftover_delay() >= 0

async def _wait_for_exception(e: WorkflowAIError):
if not e.response:
retry_after = e.retry_after_delay_seconds
if retry_after is None:
raise e

nonlocal retry_count
retry_after = retry_after_to_delay_seconds(e.response.headers.get("Retry-After"))
leftover_delay = _leftover_delay()
if not retry_after or leftover_delay < 0 or retry_count >= max_retry_count:
if not e.response:
raise e

# Convert error to WorkflowAIError
try:
response_json = e.response.json()
Expand Down
22 changes: 3 additions & 19 deletions workflowai/core/client/_utils_test.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,9 @@
from typing import Optional
from unittest.mock import Mock

import pytest
from freezegun import freeze_time
from httpx import HTTPStatusError

from workflowai.core.client._utils import build_retryable_wait, retry_after_to_delay_seconds, split_chunks
from workflowai.core.domain.errors import WorkflowAIError
from workflowai.core.client._utils import build_retryable_wait, split_chunks
from workflowai.core.domain.errors import BaseError, WorkflowAIError


@pytest.mark.parametrize(
Expand All @@ -23,25 +20,12 @@ def test_split_chunks(chunk: bytes, expected: list[bytes]):
assert list(split_chunks(chunk)) == expected


@freeze_time("2024-01-01T00:00:00Z")
@pytest.mark.parametrize(
("retry_after", "expected"),
[
(None, None),
("10", 10),
("Wed, 01 Jan 2024 00:00:10 UTC", 10),
],
)
def test_retry_after_to_delay_seconds(retry_after: Optional[str], expected: Optional[float]):
assert retry_after_to_delay_seconds(retry_after) == expected


class TestBuildRetryableWait:
@pytest.fixture
def request_error(self):
response = Mock()
response.headers = {"Retry-After": "0.01"}
return HTTPStatusError(message="", request=Mock(), response=response)
return WorkflowAIError(response=response, error=BaseError(message=""))

async def test_should_retry_count(self, request_error: WorkflowAIError):
should_retry, wait_for_exception = build_retryable_wait(60, 1)
Expand Down
22 changes: 22 additions & 0 deletions workflowai/core/client/client_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from typing import Any, AsyncIterator
from unittest.mock import AsyncMock, patch

import httpx
import pytest
from pytest_httpx import HTTPXMock, IteratorStream

Expand All @@ -13,6 +14,7 @@
WorkflowAIClient,
_compute_default_version_reference, # pyright: ignore [reportPrivateUsage]
)
from workflowai.core.domain.errors import WorkflowAIError
from workflowai.core.domain.run import Run


Expand Down Expand Up @@ -178,6 +180,26 @@ async def test_run_retries_on_too_many_requests(self, httpx_mock: HTTPXMock, cli
assert reqs[0].url == "http://localhost:8000/v1/_/tasks/123/schemas/1/run"
assert reqs[1].url == "http://localhost:8000/v1/_/tasks/123/schemas/1/run"

async def test_run_retries_on_connection_error(self, httpx_mock: HTTPXMock, client: Client):
task = HelloTask(id="123", schema_id=1)

httpx_mock.add_exception(httpx.ConnectError("arg"))
httpx_mock.add_response(json=fixtures_json("task_run.json"))

task_run = await client.run(task, task_input=HelloTaskInput(name="Alice"), max_retry_count=5)
assert task_run.id == "8f635b73-f403-47ee-bff9-18320616c6cc"

async def test_max_retries(self, httpx_mock: HTTPXMock, client: Client):
task = HelloTask(id="123", schema_id=1)

httpx_mock.add_exception(httpx.ConnectError("arg"), is_reusable=True)

with pytest.raises(WorkflowAIError):
await client.run(task, task_input=HelloTaskInput(name="Alice"), max_retry_count=5)

reqs = httpx_mock.get_requests()
assert len(reqs) == 5


class TestTask:
@pytest.fixture
Expand Down
38 changes: 37 additions & 1 deletion workflowai/core/domain/errors.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
from email.utils import parsedate_to_datetime
from json import JSONDecodeError
from time import time
from typing import Any, Literal, Optional, Union

from httpx import Response
Expand Down Expand Up @@ -53,6 +55,7 @@
# The request was invalid
"bad_request",
"invalid_file",
"connection_error",
],
str, # Using as a fallback to avoid validation error if an error code is added to the API
]
Expand All @@ -70,11 +73,34 @@ class ErrorResponse(BaseModel):
task_run_id: Optional[str] = None


def _retry_after_to_delay_seconds(retry_after: Any) -> Optional[float]:
if retry_after is None:
return None

try:
return float(retry_after)
except ValueError:
pass
try:
retry_after_date = parsedate_to_datetime(retry_after)
current_time = time()
return retry_after_date.timestamp() - current_time
except (TypeError, ValueError, OverflowError):
return None


class WorkflowAIError(Exception):
def __init__(self, response: Optional[Response], error: BaseError, task_run_id: Optional[str] = None):
def __init__(
self,
response: Optional[Response],
error: BaseError,
task_run_id: Optional[str] = None,
retry_after_delay_seconds: Optional[float] = None,
):
self.error = error
self.task_run_id = task_run_id
self.response = response
self._retry_after_delay_seconds = retry_after_delay_seconds

def __str__(self):
return f"WorkflowAIError : [{self.error.code}] ({self.error.status_code}): [{self.error.message}]"
Expand Down Expand Up @@ -106,3 +132,13 @@ def from_response(cls, response: Response):
),
task_run_id=task_run_id,
)

@property
def retry_after_delay_seconds(self) -> Optional[float]:
if self._retry_after_delay_seconds:
return self._retry_after_delay_seconds

if self.response:
return _retry_after_to_delay_seconds(self.response.headers.get("Retry-After"))

return None
23 changes: 21 additions & 2 deletions workflowai/core/domain/errors_test.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,13 @@

from typing import Optional
from unittest.mock import Mock

from workflowai.core.domain.errors import WorkflowAIError
import pytest
from freezegun import freeze_time

from workflowai.core.domain.errors import (
WorkflowAIError,
_retry_after_to_delay_seconds, # pyright: ignore [reportPrivateUsage]
)


def test_workflow_ai_error_404():
Expand All @@ -14,3 +20,16 @@ def test_workflow_ai_error_404():
assert error.error.message == "None"
assert error.error.status_code == 404
assert error.error.code == "object_not_found"


@freeze_time("2024-01-01T00:00:00Z")
@pytest.mark.parametrize(
("retry_after", "expected"),
[
(None, None),
("10", 10),
("Wed, 01 Jan 2024 00:00:10 UTC", 10),
],
)
def test_retry_after_to_delay_seconds(retry_after: Optional[str], expected: Optional[float]):
assert _retry_after_to_delay_seconds(retry_after) == expected
Loading