From 27ee41e521928a613ebc54a8162e9c6cb7c57ef5 Mon Sep 17 00:00:00 2001 From: Guillaume Aquilina Date: Thu, 10 Apr 2025 11:33:34 +0200 Subject: [PATCH] feat: sanitize workflowai error message --- .github/workflows/quality.yml | 2 +- .pre-commit-config.yaml | 2 +- Makefile | 2 +- examples/01_basic_agent.py | 8 ++- examples/18_flight_info_extraction.py | 4 ++ pytest.ini | 6 +-- workflowai/core/client/_api.py | 24 +-------- workflowai/core/client/_api_test.py | 77 --------------------------- workflowai/core/domain/errors.py | 66 ++++++++++++----------- workflowai/core/domain/errors_test.py | 77 +++++++++++++++++++++++---- 10 files changed, 120 insertions(+), 148 deletions(-) diff --git a/.github/workflows/quality.yml b/.github/workflows/quality.yml index d9d003d..e18c8de 100644 --- a/.github/workflows/quality.yml +++ b/.github/workflows/quality.yml @@ -47,4 +47,4 @@ jobs: poetry install --all-extras - name: Run tests - run: poetry run pytest --ignore=tests/e2e + run: poetry run pytest --ignore=tests/e2e --ignore-glob="examples/*.py" diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index fd7d244..a3908e5 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -22,6 +22,6 @@ repos: types: [file, python] - id: pytest name: testing (pytest) - entry: pytest . --ignore tests/e2e + entry: make test language: system pass_filenames: false diff --git a/Makefile b/Makefile index b8295b2..9c2bd9f 100644 --- a/Makefile +++ b/Makefile @@ -26,7 +26,7 @@ lint: .PHONY: test test: - pytest --ignore=tests/e2e + pytest --ignore=tests/e2e --ignore-glob="examples/*.py" .PHONY: lock lock: diff --git a/examples/01_basic_agent.py b/examples/01_basic_agent.py index 8d9f266..e83bddc 100644 --- a/examples/01_basic_agent.py +++ b/examples/01_basic_agent.py @@ -80,8 +80,12 @@ async def main(): # Example 1: Basic usage with Paris print("\nExample 1: Basic usage with Paris") print("-" * 50) - run = await get_capital_info.run(CityInput(city="Paris")) - print(run) + try: + run = await get_capital_info.run(CityInput(city="Paris")) + print(run) + except workflowai.WorkflowAIError as e: + print(e) + return # Example 2: Using Tokyo print("\nExample 2: Using Tokyo") diff --git a/examples/18_flight_info_extraction.py b/examples/18_flight_info_extraction.py index 58a7614..75eb768 100644 --- a/examples/18_flight_info_extraction.py +++ b/examples/18_flight_info_extraction.py @@ -19,13 +19,16 @@ class EmailInput(BaseModel): """Raw email content containing flight booking details. This could be a confirmation email, itinerary update, or e-ticket from any airline.""" + email_content: str class FlightInfo(BaseModel): """Model for extracted flight information.""" + class Status(str, Enum): """Possible statuses for a flight booking.""" + CONFIRMED = "Confirmed" PENDING = "Pending" CANCELLED = "Cancelled" @@ -41,6 +44,7 @@ class Status(str, Enum): arrival: datetime status: Status + @workflowai.agent( id="flight-info-extractor", model=Model.GEMINI_2_0_FLASH_LATEST, diff --git a/pytest.ini b/pytest.ini index 1f94d62..f756c03 100644 --- a/pytest.ini +++ b/pytest.ini @@ -1,7 +1,7 @@ [pytest] -python_files = *_test.py -python_functions = test_* -testpaths = tests workflowai +python_files = *_test.py examples/*.py +python_functions = test_* main +testpaths = tests workflowai examples filterwarnings = ignore::pydantic.warnings.PydanticDeprecatedSince20 asyncio_mode = auto diff --git a/workflowai/core/client/_api.py b/workflowai/core/client/_api.py index 8228ab5..1c793cd 100644 --- a/workflowai/core/client/_api.py +++ b/workflowai/core/client/_api.py @@ -6,7 +6,7 @@ from pydantic import BaseModel, TypeAdapter, ValidationError from workflowai.core._logger import logger -from workflowai.core.domain.errors import BaseError, ErrorResponse, WorkflowAIError +from workflowai.core.domain.errors import BaseError, WorkflowAIError # A type for return values _R = TypeVar("_R") @@ -103,26 +103,6 @@ async def delete(self, path: str) -> None: response = await client.delete(path) await self.raise_for_status(response) - def _extract_error( - self, - response: httpx.Response, - data: Union[bytes, str], - exception: Optional[Exception] = None, - ) -> WorkflowAIError: - try: - res = ErrorResponse.model_validate_json(data) - return WorkflowAIError(error=res.error, run_id=res.id, response=response, partial_output=res.task_output) - except ValidationError: - raise WorkflowAIError( - error=BaseError( - message="Unknown error" if exception is None else str(exception), - details={ - "raw": str(data), - }, - ), - response=response, - ) from exception - async def _wrap_sse(self, raw: AsyncIterator[bytes], termination_chars: bytes = b"\n\n"): data = b"" in_data = False @@ -181,7 +161,7 @@ async def stream( try: yield returns.model_validate_json(chunk) except ValidationError as e: - raise self._extract_error(response, chunk, e) from None + raise WorkflowAIError.from_response(response, chunk) from e async def raise_for_status(self, response: httpx.Response): if response.status_code < 200 or response.status_code >= 300: diff --git a/workflowai/core/client/_api_test.py b/workflowai/core/client/_api_test.py index 8c4ebd4..ded633e 100644 --- a/workflowai/core/client/_api_test.py +++ b/workflowai/core/client/_api_test.py @@ -9,83 +9,6 @@ from workflowai.core.domain.errors import WorkflowAIError -class TestAPIClientExtractError: - def test_extract_error(self): - client = APIClient(url="test_url", api_key="test_api_key") - - # Test valid JSON error response - response = httpx.Response( - status_code=400, - json={ - "error": { - "message": "Test error message", - "details": {"key": "value"}, - }, - "id": "test_task_123", - }, - ) - - error = client._extract_error(response, response.content) # pyright:ignore[reportPrivateUsage] - assert isinstance(error, WorkflowAIError) - assert error.error.message == "Test error message" - assert error.error.details == {"key": "value"} - assert error.run_id == "test_task_123" - assert error.response == response - - def test_extract_partial_output(self): - client = APIClient(url="test_url", api_key="test_api_key") - - # Test valid JSON error response - response = httpx.Response( - status_code=400, - json={ - "error": { - "message": "Test error message", - "details": {"key": "value"}, - }, - "id": "test_task_123", - "task_output": {"key": "value"}, - }, - ) - - error = client._extract_error(response, response.content) # pyright:ignore[reportPrivateUsage] - assert isinstance(error, WorkflowAIError) - assert error.error.message == "Test error message" - assert error.error.details == {"key": "value"} - assert error.run_id == "test_task_123" - assert error.partial_output == {"key": "value"} - assert error.response == response - - def test_extract_error_invalid_json(self): - client = APIClient(url="test_url", api_key="test_api_key") - - # Test invalid JSON response - invalid_data = b"Invalid JSON data" - response = httpx.Response(status_code=400, content=invalid_data) - - with pytest.raises(WorkflowAIError) as e: - client._extract_error(response, invalid_data) # pyright:ignore[reportPrivateUsage] - assert isinstance(e.value, WorkflowAIError) - assert e.value.error.message == "Unknown error" - assert e.value.error.details == {"raw": "b'Invalid JSON data'"} - assert e.value.response == response - - def test_extract_error_with_custom_error(self): - client = APIClient(url="test_url", api_key="test_api_key") - - # Test with provided exception - invalid_data = "{'detail': 'Not Found'}" - response = httpx.Response(status_code=404, content=invalid_data) - exception = ValueError("Custom error") - - with pytest.raises(WorkflowAIError) as e: - client._extract_error(response, invalid_data, exception) # pyright:ignore[reportPrivateUsage] - assert isinstance(e.value, WorkflowAIError) - assert e.value.error.message == "Custom error" - assert e.value.error.details == {"raw": "{'detail': 'Not Found'}"} - assert e.value.response == response - - @pytest.fixture def client() -> APIClient: return APIClient(url="https://blabla.com", api_key="test_api_key") diff --git a/workflowai/core/domain/errors.py b/workflowai/core/domain/errors.py index 6f6e2a2..bbf37e7 100644 --- a/workflowai/core/domain/errors.py +++ b/workflowai/core/domain/errors.py @@ -1,10 +1,10 @@ from email.utils import parsedate_to_datetime -from json import JSONDecodeError from time import time from typing import Any, Literal, Optional, Union from httpx import Response -from pydantic import BaseModel +from pydantic import BaseModel, ValidationError +from typing_extensions import override from workflowai.core.domain import tool_call @@ -79,7 +79,7 @@ class BaseError(BaseModel): details: Optional[dict[str, Any]] = None - message: str + message: str = "Unknown error" status_code: Optional[int] = None code: Optional[ErrorCode] = None @@ -127,41 +127,29 @@ def __str__(self): return f"WorkflowAIError : [{self.error.code}] ({self.error.status_code}): [{self.error.message}]" @classmethod - def error_cls(cls, code: str): + def error_cls(cls, status_code: int, code: Optional[str] = None): + if status_code == 401: + return InvalidAPIKeyError if code == "invalid_generation" or code == "failed_generation" or code == "agent_run_failed": return InvalidGenerationError return cls @classmethod - def from_response(cls, response: Response): + def from_response(cls, response: Response, data: Union[bytes, str, None] = None): try: - response_json = response.json() - r_error = response_json.get("error", {}) - error_message = response_json.get("detail", {}) or r_error.get("message", "Unknown Error") - details = r_error.get("details", {}) - error_code = r_error.get("code", "unknown_error") - status_code = response.status_code - run_id = response_json.get("id", None) - partial_output = response_json.get("task_output", None) - except JSONDecodeError: - error_message = "Unknown error" - details = {"raw": response.content.decode()} - error_code = "unknown_error" - status_code = response.status_code - run_id = None - partial_output = None - - return cls.error_cls(error_code)( - response=response, - error=BaseError( - message=error_message, - details=details, - status_code=status_code, - code=error_code, - ), - run_id=run_id, - partial_output=partial_output, - ) + res = ErrorResponse.model_validate_json(data or response.content) + error_cls = cls.error_cls(response.status_code, res.error.code) + return error_cls(error=res.error, run_id=res.id, response=response, partial_output=res.task_output) + except ValidationError: + return cls.error_cls(response.status_code)( + error=BaseError( + message="Unknown error", + details={ + "raw": str(data), + }, + ), + response=response, + ) @property def retry_after_delay_seconds(self) -> Optional[float]: @@ -194,3 +182,17 @@ class InvalidGenerationError(WorkflowAIError): ... class MaxTurnsReachedError(WorkflowAIError): ... + + +class InvalidAPIKeyError(WorkflowAIError): + @property + @override + def message(self) -> str: + return ( + "❌ Your API key is invalid. Please double-check your API key, " + "or create a new one at https://workflowai.com/organization/settings/api-keys " + "or from your self-hosted WorkflowAI instance." + ) + + def __str__(self) -> str: + return self.message diff --git a/workflowai/core/domain/errors_test.py b/workflowai/core/domain/errors_test.py index 243a5da..168c093 100644 --- a/workflowai/core/domain/errors_test.py +++ b/workflowai/core/domain/errors_test.py @@ -1,26 +1,73 @@ from typing import Optional from unittest.mock import Mock +import httpx import pytest from freezegun import freeze_time from workflowai.core.domain.errors import ( BaseError, + InvalidAPIKeyError, WorkflowAIError, _retry_after_to_delay_seconds, # pyright: ignore [reportPrivateUsage] ) -def test_workflow_ai_error_404(): - response = Mock() - response.status_code = 404 - response.json = Mock() - response.json.return_value = {"error": {"message": "None", "status_code": 404, "code": "object_not_found"}} +class TestErrorFromResponse: + def test_extract_error(self): + # Test valid JSON error response + response = httpx.Response( + status_code=400, + json={ + "error": { + "message": "Test error message", + "details": {"key": "value"}, + "code": "object_not_found", + }, + "id": "test_task_123", + }, + ) - error = WorkflowAIError.from_response(response) - assert error.error.message == "None" - assert error.error.status_code == 404 - assert error.error.code == "object_not_found" + error = WorkflowAIError.from_response(response, response.content) + assert isinstance(error, WorkflowAIError) + assert error.error.message == "Test error message" + assert error.error.details == {"key": "value"} + assert error.run_id == "test_task_123" + assert error.response == response + assert error.code == "object_not_found" + + def test_extract_partial_output(self): + # Test valid JSON error response + response = httpx.Response( + status_code=400, + json={ + "error": { + "message": "Test error message", + "details": {"key": "value"}, + }, + "id": "test_task_123", + "task_output": {"key": "value"}, + }, + ) + + error = WorkflowAIError.from_response(response, response.content) + assert isinstance(error, WorkflowAIError) + assert error.error.message == "Test error message" + assert error.error.details == {"key": "value"} + assert error.run_id == "test_task_123" + assert error.partial_output == {"key": "value"} + assert error.response == response + + def test_extract_error_invalid_json(self): + # Test invalid JSON response + invalid_data = b"Invalid JSON data" + response = httpx.Response(status_code=400, content=invalid_data) + + error = WorkflowAIError.from_response(response, invalid_data) + assert isinstance(error, WorkflowAIError) + assert error.error.message == "Unknown error" + assert error.error.details == {"raw": "b'Invalid JSON data'"} + assert error.response == response @freeze_time("2024-01-01T00:00:00Z") @@ -83,3 +130,15 @@ def test_workflow_ai_error_details(): ), ) assert error.details == {"test": "test"} + + +def test_invalid_api_key_error(): + error = InvalidAPIKeyError( + response=Mock(), + error=BaseError( + message="test", + status_code=404, + code="object_not_found", + ), + ) + assert str(error).startswith("❌ Your API key is invalid")