Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 13 additions & 1 deletion workflowai/core/client/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -472,6 +472,7 @@ async def reply(
user_message: Optional[str] = None,
tool_results: Optional[Iterable[ToolCallResult]] = None,
current_iteration: int = 0,
max_retries: int = 2,
**kwargs: Unpack[RunParams[AgentOutput]],
):
"""Reply to a run to provide additional information or context.
Expand All @@ -489,7 +490,18 @@ async def reply(
prepared_run = await self._prepare_reply(run_id, user_message, tool_results, stream=False, **kwargs)
validator, new_kwargs = self._sanitize_validator(kwargs, self._default_validator)

res = await self.api.post(prepared_run.route, prepared_run.request, returns=RunResponse, run=True)
async def _with_retries():
err: Optional[WorkflowAIError] = None
for _ in range(max_retries):
try:
return await self.api.post(prepared_run.route, prepared_run.request, returns=RunResponse, run=True)
except WorkflowAIError as e: # noqa: PERF203
if e.code != "object_not_found":
raise e
err = e
raise err or RuntimeError("This should never raise")

res = await _with_retries()
return await self._build_run(
res,
prepared_run.schema_id,
Expand Down
72 changes: 72 additions & 0 deletions workflowai/core/client/agent_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -1078,3 +1078,75 @@ async def test_stream_validation_final_error(

assert e.value.partial_output == {"message": 1}
assert e.value.run_id == "1"


class TestReply:
async def test_reply_success(self, httpx_mock: HTTPXMock, agent: Agent[HelloTaskInput, HelloTaskOutput]):
httpx_mock.add_response(
url="http://localhost:8000/v1/_/agents/123/runs/1/reply",
json=fixtures_json("task_run.json"),
)
reply = await agent.reply(run_id="1", user_message="test message")
assert reply.output.message == "Austin"

assert len(httpx_mock.get_requests()) == 1

async def test_reply_first_404(self, httpx_mock: HTTPXMock, agent: Agent[HelloTaskInput, HelloTaskOutput]):
"""Check that we retry once if the run is not found"""

httpx_mock.add_response(
url="http://localhost:8000/v1/_/agents/123/runs/1/reply",
status_code=404,
json={
"error": {
"code": "object_not_found",
},
},
)

httpx_mock.add_response(
url="http://localhost:8000/v1/_/agents/123/runs/1/reply",
json=fixtures_json("task_run.json"),
)

reply = await agent.reply(run_id="1", user_message="test message")
assert reply.output.message == "Austin"

assert len(httpx_mock.get_requests()) == 2

async def test_reply_not_not_found_error(
self,
httpx_mock: HTTPXMock,
agent: Agent[HelloTaskInput, HelloTaskOutput],
):
"""Check that we raise the error if it's not a 404"""
httpx_mock.add_response(
url="http://localhost:8000/v1/_/agents/123/runs/1/reply",
status_code=400,
json={
"error": {
"code": "whatever",
},
},
)
with pytest.raises(WorkflowAIError) as e:
await agent.reply(run_id="1", user_message="test message")
assert e.value.code == "whatever"
assert len(httpx_mock.get_requests()) == 1

async def test_reply_multiple_retries(self, httpx_mock: HTTPXMock, agent: Agent[HelloTaskInput, HelloTaskOutput]):
"""Check that we retry once if the run is not found"""
httpx_mock.add_response(
url="http://localhost:8000/v1/_/agents/123/runs/1/reply",
status_code=404,
json={
"error": {
"code": "object_not_found",
},
},
is_reusable=True,
)
with pytest.raises(WorkflowAIError) as e:
await agent.reply(run_id="1", user_message="test message")
assert e.value.code == "object_not_found"
assert len(httpx_mock.get_requests()) == 2
22 changes: 13 additions & 9 deletions workflowai/core/utils/_pydantic.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import warnings
from collections.abc import Mapping, Sequence
from typing import Any, TypeVar, get_args, get_origin

Expand Down Expand Up @@ -25,9 +26,11 @@ def _copy_field_info(field_info: FieldInfo, **overrides: Any):
certain values.
"""

_excluded = {"annotation", "required"}

kwargs = overrides
for k, v in field_info.__repr_args__():
if k in kwargs or not k:
if k in kwargs or not k or k in _excluded:
continue
kwargs[k] = v

Expand Down Expand Up @@ -79,7 +82,6 @@ def partial_model(base: type[BM]) -> type[BM]:
overrides: dict[str, Any] = {}
try:
annotation = _optional_annotation(field.annotation)
overrides["annotation"] = annotation
overrides["default"] = _default_value_from_annotation(annotation)
except Exception: # noqa: BLE001
logger.debug("Failed to make annotation optional", exc_info=True)
Expand All @@ -95,10 +97,12 @@ def custom_eq(o1: BM, o2: Any):
return False
return o1.model_dump() == o2.model_dump()

return create_model( # pyright: ignore [reportCallIssue, reportUnknownVariableType]
f"Partial{base.__name__}",
__base__=base,
__eq__=custom_eq,
__hash__=base.__hash__,
**default_fields, # pyright: ignore [reportArgumentType]
)
with warnings.catch_warnings():
warnings.filterwarnings("ignore", category=RuntimeWarning, message="fields may not start with an underscore")
return create_model( # pyright: ignore [reportCallIssue, reportUnknownVariableType]
f"Partial{base.__name__}",
__base__=base,
__eq__=custom_eq,
__hash__=base.__hash__,
**default_fields, # pyright: ignore [reportArgumentType]
)
4 changes: 3 additions & 1 deletion workflowai/core/utils/_pydantic_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@


class TestPartialModel:
def test_partial_model_equals(self):
def test_partial_model_equals(self, recwarn: pytest.WarningsRecorder):
class SimpleModel(BaseModel):
name: str

Expand All @@ -16,6 +16,8 @@ class SimpleModel(BaseModel):

assert SimpleModel(name="John") == partial.model_validate({"name": "John"})

assert len(recwarn.list) == 0

def test_simple_model(self):
class SimpleModel(BaseModel):
name1: str
Expand Down