diff --git a/workflowai/core/client/agent.py b/workflowai/core/client/agent.py index e0c29e4..c480d4d 100644 --- a/workflowai/core/client/agent.py +++ b/workflowai/core/client/agent.py @@ -472,6 +472,7 @@ async def reply( user_message: Optional[str] = None, tool_results: Optional[Iterable[ToolCallResult]] = None, current_iteration: int = 0, + max_retries: int = 2, **kwargs: Unpack[RunParams[AgentOutput]], ): """Reply to a run to provide additional information or context. @@ -489,7 +490,18 @@ async def reply( prepared_run = await self._prepare_reply(run_id, user_message, tool_results, stream=False, **kwargs) validator, new_kwargs = self._sanitize_validator(kwargs, self._default_validator) - res = await self.api.post(prepared_run.route, prepared_run.request, returns=RunResponse, run=True) + async def _with_retries(): + err: Optional[WorkflowAIError] = None + for _ in range(max_retries): + try: + return await self.api.post(prepared_run.route, prepared_run.request, returns=RunResponse, run=True) + except WorkflowAIError as e: # noqa: PERF203 + if e.code != "object_not_found": + raise e + err = e + raise err or RuntimeError("This should never raise") + + res = await _with_retries() return await self._build_run( res, prepared_run.schema_id, diff --git a/workflowai/core/client/agent_test.py b/workflowai/core/client/agent_test.py index 8779b47..ca18480 100644 --- a/workflowai/core/client/agent_test.py +++ b/workflowai/core/client/agent_test.py @@ -1078,3 +1078,75 @@ async def test_stream_validation_final_error( assert e.value.partial_output == {"message": 1} assert e.value.run_id == "1" + + +class TestReply: + async def test_reply_success(self, httpx_mock: HTTPXMock, agent: Agent[HelloTaskInput, HelloTaskOutput]): + httpx_mock.add_response( + url="http://localhost:8000/v1/_/agents/123/runs/1/reply", + json=fixtures_json("task_run.json"), + ) + reply = await agent.reply(run_id="1", user_message="test message") + assert reply.output.message == "Austin" + + assert len(httpx_mock.get_requests()) == 1 + + async def test_reply_first_404(self, httpx_mock: HTTPXMock, agent: Agent[HelloTaskInput, HelloTaskOutput]): + """Check that we retry once if the run is not found""" + + httpx_mock.add_response( + url="http://localhost:8000/v1/_/agents/123/runs/1/reply", + status_code=404, + json={ + "error": { + "code": "object_not_found", + }, + }, + ) + + httpx_mock.add_response( + url="http://localhost:8000/v1/_/agents/123/runs/1/reply", + json=fixtures_json("task_run.json"), + ) + + reply = await agent.reply(run_id="1", user_message="test message") + assert reply.output.message == "Austin" + + assert len(httpx_mock.get_requests()) == 2 + + async def test_reply_not_not_found_error( + self, + httpx_mock: HTTPXMock, + agent: Agent[HelloTaskInput, HelloTaskOutput], + ): + """Check that we raise the error if it's not a 404""" + httpx_mock.add_response( + url="http://localhost:8000/v1/_/agents/123/runs/1/reply", + status_code=400, + json={ + "error": { + "code": "whatever", + }, + }, + ) + with pytest.raises(WorkflowAIError) as e: + await agent.reply(run_id="1", user_message="test message") + assert e.value.code == "whatever" + assert len(httpx_mock.get_requests()) == 1 + + async def test_reply_multiple_retries(self, httpx_mock: HTTPXMock, agent: Agent[HelloTaskInput, HelloTaskOutput]): + """Check that we retry once if the run is not found""" + httpx_mock.add_response( + url="http://localhost:8000/v1/_/agents/123/runs/1/reply", + status_code=404, + json={ + "error": { + "code": "object_not_found", + }, + }, + is_reusable=True, + ) + with pytest.raises(WorkflowAIError) as e: + await agent.reply(run_id="1", user_message="test message") + assert e.value.code == "object_not_found" + assert len(httpx_mock.get_requests()) == 2 diff --git a/workflowai/core/utils/_pydantic.py b/workflowai/core/utils/_pydantic.py index f8c733f..c611fc4 100644 --- a/workflowai/core/utils/_pydantic.py +++ b/workflowai/core/utils/_pydantic.py @@ -1,3 +1,4 @@ +import warnings from collections.abc import Mapping, Sequence from typing import Any, TypeVar, get_args, get_origin @@ -25,9 +26,11 @@ def _copy_field_info(field_info: FieldInfo, **overrides: Any): certain values. """ + _excluded = {"annotation", "required"} + kwargs = overrides for k, v in field_info.__repr_args__(): - if k in kwargs or not k: + if k in kwargs or not k or k in _excluded: continue kwargs[k] = v @@ -79,7 +82,6 @@ def partial_model(base: type[BM]) -> type[BM]: overrides: dict[str, Any] = {} try: annotation = _optional_annotation(field.annotation) - overrides["annotation"] = annotation overrides["default"] = _default_value_from_annotation(annotation) except Exception: # noqa: BLE001 logger.debug("Failed to make annotation optional", exc_info=True) @@ -95,10 +97,12 @@ def custom_eq(o1: BM, o2: Any): return False return o1.model_dump() == o2.model_dump() - return create_model( # pyright: ignore [reportCallIssue, reportUnknownVariableType] - f"Partial{base.__name__}", - __base__=base, - __eq__=custom_eq, - __hash__=base.__hash__, - **default_fields, # pyright: ignore [reportArgumentType] - ) + with warnings.catch_warnings(): + warnings.filterwarnings("ignore", category=RuntimeWarning, message="fields may not start with an underscore") + return create_model( # pyright: ignore [reportCallIssue, reportUnknownVariableType] + f"Partial{base.__name__}", + __base__=base, + __eq__=custom_eq, + __hash__=base.__hash__, + **default_fields, # pyright: ignore [reportArgumentType] + ) diff --git a/workflowai/core/utils/_pydantic_test.py b/workflowai/core/utils/_pydantic_test.py index 858a607..db8c9e9 100644 --- a/workflowai/core/utils/_pydantic_test.py +++ b/workflowai/core/utils/_pydantic_test.py @@ -7,7 +7,7 @@ class TestPartialModel: - def test_partial_model_equals(self): + def test_partial_model_equals(self, recwarn: pytest.WarningsRecorder): class SimpleModel(BaseModel): name: str @@ -16,6 +16,8 @@ class SimpleModel(BaseModel): assert SimpleModel(name="John") == partial.model_validate({"name": "John"}) + assert len(recwarn.list) == 0 + def test_simple_model(self): class SimpleModel(BaseModel): name1: str