diff --git a/pyproject.toml b/pyproject.toml index f41cf23..df270f0 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [tool.poetry] name = "workflowai" -version = "0.6.1" +version = "0.6.2-dev1" description = "Python SDK for WorkflowAI" authors = ["Guillaume Aquilina "] readme = "README.md" @@ -63,7 +63,7 @@ unfixable = [] [tool.ruff.lint.per-file-ignores] # in bin we use rich.print "bin/*" = ["T201"] -"*_test.py" = ["S101"] +"*_test.py" = ["S101", "S106"] "conftest.py" = ["S101"] "examples/*" = ["INP001", "T201", "ERA001"] diff --git a/workflowai/__init__.py b/workflowai/__init__.py index 0c4f5a5..35568a6 100644 --- a/workflowai/__init__.py +++ b/workflowai/__init__.py @@ -1,5 +1,5 @@ from collections.abc import Callable, Iterable -from typing import Any, Optional +from typing import Any, Literal, Optional from typing_extensions import deprecated @@ -82,3 +82,12 @@ def agent( model=model, tools=tools, ) + + +def send_feedback( + feedback_token: str, + outcome: Literal["positive", "negative"], + comment: Optional[str] = None, + user_id: Optional[str] = None, +): + return shared_client.send_feedback(feedback_token, outcome, comment, user_id) diff --git a/workflowai/core/_common_types.py b/workflowai/core/_common_types.py index ea49c85..cde0fa2 100644 --- a/workflowai/core/_common_types.py +++ b/workflowai/core/_common_types.py @@ -40,14 +40,18 @@ class VersionRunParams(TypedDict): temperature: NotRequired[Optional[float]] -class BaseRunParams(VersionRunParams): +class OtherRunParams(TypedDict): use_cache: NotRequired["CacheUsage"] - metadata: NotRequired[Optional[dict[str, Any]]] - labels: NotRequired[Optional[set[str]]] + max_retry_delay: NotRequired[float] max_retry_count: NotRequired[float] - max_tool_iterations: NotRequired[int] # 10 by default + max_turns: NotRequired[int] # 10 by default + max_turns_raises: NotRequired[bool] # True by default + + +class BaseRunParams(VersionRunParams, OtherRunParams): + metadata: NotRequired[Optional[dict[str, Any]]] class RunParams(BaseRunParams, Generic[AgentOutput]): diff --git a/workflowai/core/client/_fn_utils.py b/workflowai/core/client/_fn_utils.py index 3930669..cc25906 100644 --- a/workflowai/core/client/_fn_utils.py +++ b/workflowai/core/client/_fn_utils.py @@ -15,6 +15,7 @@ from pydantic import BaseModel, ValidationError from typing_extensions import Unpack +from workflowai.core._common_types import OtherRunParams from workflowai.core.client._api import APIClient from workflowai.core.client._models import RunResponse from workflowai.core.client._types import ( @@ -131,7 +132,7 @@ async def __call__(self, input: AgentInput, **kwargs: Unpack[RunParams[AgentOutp milliseconds. Defaults to 60000. max_retry_count (Optional[float], optional): The maximum number of retry attempts. Defaults to 1. - max_tool_iterations (Optional[int], optional): Maximum number of tool iteration cycles. + max_turns (Optional[int], optional): Maximum number of tool iteration cycles. Defaults to 10. validator (Optional[OutputValidator[AgentOutput]], optional): Custom validator for the output. @@ -194,7 +195,7 @@ async def __call__(self, input: AgentInput, **kwargs: Unpack[RunParams[AgentOutp milliseconds. Defaults to 60000. max_retry_count (Optional[float], optional): The maximum number of retry attempts. Defaults to 1. - max_tool_iterations (Optional[int], optional): Maximum number of tool iteration cycles. + max_turns (Optional[int], optional): Maximum number of tool iteration cycles. Defaults to 10. validator (Optional[OutputValidator[AgentOutput]], optional): Custom validator for the output. @@ -234,7 +235,7 @@ def __call__(self, input: AgentInput, **kwargs: Unpack[RunParams[AgentOutput]]): milliseconds. Defaults to 60000. max_retry_count (Optional[float], optional): The maximum number of retry attempts. Defaults to 1. - max_tool_iterations (Optional[int], optional): Maximum number of tool iteration cycles. + max_turns (Optional[int], optional): Maximum number of tool iteration cycles. Defaults to 10. validator (Optional[OutputValidator[AgentOutput]], optional): Custom validator for the output. @@ -276,7 +277,7 @@ async def __call__(self, input: AgentInput, **kwargs: Unpack[RunParams[AgentOutp milliseconds. Defaults to 60000. max_retry_count (Optional[float], optional): The maximum number of retry attempts. Defaults to 1. - max_tool_iterations (Optional[int], optional): Maximum number of tool iteration cycles. + max_turns (Optional[int], optional): Maximum number of tool iteration cycles. Defaults to 10. validator (Optional[OutputValidator[AgentOutput]], optional): Custom validator for the output. @@ -318,6 +319,7 @@ def wrap_run_template( model: Optional[ModelOrStr], fn: RunTemplate[AgentInput, AgentOutput], tools: Optional[Iterable[Callable[..., Any]]] = None, + run_params: Optional[OtherRunParams] = None, ) -> Union[ _RunnableAgent[AgentInput, AgentOutput], _RunnableOutputOnlyAgent[AgentInput, AgentOutput], @@ -344,6 +346,7 @@ def wrap_run_template( schema_id=schema_id, version=version, tools=tools, + **(run_params or {}), ) @@ -358,13 +361,14 @@ def agent_wrapper( version: Optional[VersionReference] = None, model: Optional[ModelOrStr] = None, tools: Optional[Iterable[Callable[..., Any]]] = None, + **kwargs: Unpack[OtherRunParams], ) -> AgentDecorator: def wrap(fn: RunTemplate[AgentInput, AgentOutput]): tid = agent_id or agent_id_from_fn_name(fn) # TODO[types]: Not sure why a cast is needed here agent = cast( FinalRunTemplate[AgentInput, AgentOutput], - wrap_run_template(client, tid, schema_id, version, model, fn, tools), + wrap_run_template(client, tid, schema_id, version, model, fn, tools, kwargs), ) agent.__doc__ = """A class representing an AI agent that can process inputs and generate outputs. diff --git a/workflowai/core/client/_fn_utils_test.py b/workflowai/core/client/_fn_utils_test.py index f04cfab..c776dac 100644 --- a/workflowai/core/client/_fn_utils_test.py +++ b/workflowai/core/client/_fn_utils_test.py @@ -19,7 +19,7 @@ get_generic_args, is_async_iterator, ) -from workflowai.core.client._models import RunResponse +from workflowai.core.client._models import RunRequest, RunResponse from workflowai.core.domain.run import Run @@ -80,6 +80,29 @@ async def test_fn_run(self, mock_api_client: Mock): assert run.id == "1" assert run.output == HelloTaskOutput(message="Hello, World!") + async def test_fn_run_with_default_cache(self, mock_api_client: Mock): + wrapped = agent_wrapper(lambda: mock_api_client, schema_id=1, agent_id="hello", use_cache="never")(self.fn_run) + assert isinstance(wrapped, _RunnableAgent) + + mock_api_client.post.return_value = RunResponse(id="1", task_output={"message": "Hello, World!"}) + run = await wrapped(HelloTaskInput(name="World")) + assert isinstance(run, Run) + + mock_api_client.post.assert_called_once() + req = mock_api_client.post.call_args.args[1] + assert isinstance(req, RunRequest) + assert req.use_cache == "never" + + mock_api_client.post.reset_mock() + + # Check that it can be overridden + _ = await wrapped(HelloTaskInput(name="World"), use_cache="always") + + mock_api_client.post.assert_called_once() + req = mock_api_client.post.call_args.args[1] + assert isinstance(req, RunRequest) + assert req.use_cache == "always" + def fn_stream(self, task_input: HelloTaskInput) -> AsyncIterator[Run[HelloTaskOutput]]: ... async def test_fn_stream(self, mock_api_client: Mock): diff --git a/workflowai/core/client/_models.py b/workflowai/core/client/_models.py index 75e61cf..411c49c 100644 --- a/workflowai/core/client/_models.py +++ b/workflowai/core/client/_models.py @@ -103,13 +103,13 @@ def tool_call_to_domain(tool_call: ToolCall) -> DToolCall: ) -class ToolCallRequest(TypedDict): +class ToolCallRequestDict(TypedDict): id: str name: str input: dict[str, Any] -def tool_call_request_to_domain(tool_call_request: ToolCallRequest) -> DToolCallRequest: +def tool_call_request_to_domain(tool_call_request: ToolCallRequestDict) -> DToolCallRequest: return DToolCallRequest( id=tool_call_request["id"], name=tool_call_request["name"], @@ -119,7 +119,7 @@ def tool_call_request_to_domain(tool_call_request: ToolCallRequest) -> DToolCall class RunResponse(BaseModel): id: str - task_output: dict[str, Any] + task_output: Optional[dict[str, Any]] = None version: Optional[Version] = None duration_seconds: Optional[float] = None @@ -127,7 +127,7 @@ class RunResponse(BaseModel): metadata: Optional[dict[str, Any]] = None tool_calls: Optional[list[ToolCall]] = None - tool_call_requests: Optional[list[ToolCallRequest]] = None + tool_call_requests: Optional[list[ToolCallRequestDict]] = None feedback_token: Optional[str] = None @@ -147,7 +147,7 @@ def to_domain( id=self.id, agent_id=task_id, schema_id=task_schema_id, - output=validator(self.task_output, partial), + output=validator(self.task_output or {}, partial), version=self.version and self.version.to_domain(), duration_seconds=self.duration_seconds, cost_usd=self.cost_usd, @@ -220,3 +220,10 @@ class CompletionsResponse(BaseModel): """Response from the completions API endpoint.""" completions: list[Completion] + + +class CreateFeedbackRequest(BaseModel): + feedback_token: str + outcome: Literal["positive", "negative"] + comment: Optional[str] + user_id: Optional[str] diff --git a/workflowai/core/client/agent.py b/workflowai/core/client/agent.py index 20a0311..651c4a2 100644 --- a/workflowai/core/client/agent.py +++ b/workflowai/core/client/agent.py @@ -6,7 +6,12 @@ from pydantic import BaseModel, ValidationError from typing_extensions import Unpack -from workflowai.core._common_types import BaseRunParams, OutputValidator, VersionRunParams +from workflowai.core._common_types import ( + BaseRunParams, + OtherRunParams, + OutputValidator, + VersionRunParams, +) from workflowai.core.client._api import APIClient from workflowai.core.client._models import ( CompletionsResponse, @@ -27,7 +32,7 @@ global_default_version_reference, ) from workflowai.core.domain.completion import Completion -from workflowai.core.domain.errors import BaseError, WorkflowAIError +from workflowai.core.domain.errors import BaseError, MaxTurnsReachedError, WorkflowAIError from workflowai.core.domain.run import Run from workflowai.core.domain.task import AgentInput, AgentOutput from workflowai.core.domain.tool import Tool @@ -83,7 +88,7 @@ class MyOutput(BaseModel): ``` """ - _DEFAULT_MAX_ITERATIONS = 10 + _DEFAULT_MAX_TURNS = 10 def __init__( self, @@ -94,6 +99,7 @@ def __init__( schema_id: Optional[int] = None, version: Optional[VersionReference] = None, tools: Optional[Iterable[Callable[..., Any]]] = None, + **kwargs: Unpack[OtherRunParams], ): self.agent_id = agent_id self.schema_id = schema_id @@ -104,6 +110,7 @@ def __init__( self._tools = self.build_tools(tools) if tools else None self._default_validator = default_validator(output_cls) + self._other_run_params = kwargs @classmethod def build_tools(cls, tools: Iterable[Callable[..., Any]]): @@ -180,6 +187,13 @@ def _sanitize_version(self, params: VersionRunParams) -> Union[str, int, dict[st dumped["temperature"] = combined.temperature return dumped + def _get_run_param(self, key: str, params: OtherRunParams, default: Any = None) -> Any: + if key in params: + return params[key] # pyright: ignore [reportUnknownVariableType] + if key in self._other_run_params: + return self._other_run_params[key] # pyright: ignore [reportUnknownVariableType] + return default + async def _prepare_run(self, agent_input: AgentInput, stream: bool, **kwargs: Unpack[RunParams[AgentOutput]]): schema_id = self.schema_id if not schema_id: @@ -192,15 +206,14 @@ async def _prepare_run(self, agent_input: AgentInput, stream: bool, **kwargs: Un task_input=agent_input.model_dump(by_alias=True), version=version, stream=stream, - use_cache=kwargs.get("use_cache"), + use_cache=self._get_run_param("use_cache", kwargs), metadata=kwargs.get("metadata"), - labels=kwargs.get("labels"), ) route = f"/v1/_/agents/{self.agent_id}/schemas/{self.schema_id}/run" should_retry, wait_for_exception = build_retryable_wait( - kwargs.get("max_retry_delay", 60), - kwargs.get("max_retry_count", 1), + self._get_run_param("max_retry_delay", kwargs, 60), + self._get_run_param("max_retry_count", kwargs, 1), ) return self._PreparedRun(request, route, should_retry, wait_for_exception, schema_id) @@ -227,8 +240,8 @@ async def _prepare_reply( ) route = f"/v1/_/agents/{self.agent_id}/runs/{run_id}/reply" should_retry, wait_for_exception = build_retryable_wait( - kwargs.get("max_retry_delay", 60), - kwargs.get("max_retry_count", 1), + self._get_run_param("max_retry_delay", kwargs, 60), + self._get_run_param("max_retry_count", kwargs, 1), ) return self._PreparedRun(request, route, should_retry, wait_for_exception, self.schema_id) @@ -324,8 +337,14 @@ async def _build_run( run = self._build_run_no_tools(chunk, schema_id, validator) if run.tool_call_requests: - if current_iteration >= kwargs.get("max_iterations", self._DEFAULT_MAX_ITERATIONS): - raise WorkflowAIError(error=BaseError(message="max tool iterations reached"), response=None) + if current_iteration >= self._get_run_param("max_turns", kwargs, self._DEFAULT_MAX_TURNS): + if self._get_run_param("max_turns_raises", kwargs, default=True): + raise MaxTurnsReachedError( + error=BaseError(message="max tool iterations reached"), + response=None, + tool_call_requests=run.tool_call_requests, + ) + return run with_reply = await self._execute_tools( run_id=run.id, tool_call_requests=run.tool_call_requests, @@ -368,7 +387,9 @@ async def run( max_retry_delay (Optional[float], optional): The maximum delay between retries in milliseconds. Defaults to 60000. max_retry_count (Optional[float], optional): The maximum number of retry attempts. Defaults to 1. - max_tool_iterations (Optional[int], optional): Maximum number of tool iteration cycles. Defaults to 10. + max_turns (Optional[int], optional): Maximum number of tool iteration cycles. Defaults to 10. + max_turns_raises (Optional[bool], optional): Whether to raise an error when the maximum number of turns is + reached. Defaults to True. validator (Optional[OutputValidator[AgentOutput]], optional): Custom validator for the output. Returns: @@ -385,7 +406,7 @@ async def run( res, prepared_run.schema_id, validator, - current_iteration=0, + current_iteration=1, # TODO[test]: add test with custom validator **new_kwargs, ) @@ -424,7 +445,7 @@ async def stream( max_retry_delay (Optional[float], optional): The maximum delay between retries in milliseconds. Defaults to 60000. max_retry_count (Optional[float], optional): The maximum number of retry attempts. Defaults to 1. - max_tool_iterations (Optional[int], optional): Maximum number of tool iteration cycles. Defaults to 10. + max_turns (Optional[int], optional): Maximum number of tool iteration cycles. Defaults to 10. validator (Optional[OutputValidator[AgentOutput]], optional): Custom validator for the output. Returns: diff --git a/workflowai/core/client/agent_test.py b/workflowai/core/client/agent_test.py index 925f044..8315f1e 100644 --- a/workflowai/core/client/agent_test.py +++ b/workflowai/core/client/agent_test.py @@ -21,8 +21,9 @@ WorkflowAI, ) from workflowai.core.domain.completion import Completion, CompletionUsage, Message -from workflowai.core.domain.errors import WorkflowAIError +from workflowai.core.domain.errors import MaxTurnsReachedError, WorkflowAIError from workflowai.core.domain.run import Run +from workflowai.core.domain.tool_call import ToolCallRequest from workflowai.core.domain.version_properties import VersionProperties @@ -50,8 +51,8 @@ def agent_with_instructions(api_client: APIClient): @pytest.fixture def agent_with_tools(api_client: APIClient): - def some_tool() -> str: - return "Hello, world!" + def some_tool(arg: int) -> str: + return f"Hello, world {arg}!" return Agent( agent_id="123", @@ -101,6 +102,20 @@ def agent_no_schema(api_client: APIClient): class TestRun: + def _mock_tool_call_requests(self, httpx_mock: HTTPXMock, arg: int = 1): + httpx_mock.add_response( + json={ + "id": "run_1", + "tool_call_requests": [ + { + "id": "1", + "name": "some_tool", + "input": {"arg": arg}, + }, + ], + }, + ) + async def test_success(self, httpx_mock: HTTPXMock, agent: Agent[HelloTaskInput, HelloTaskOutput]): httpx_mock.add_response(json=fixtures_json("task_run.json")) @@ -326,6 +341,135 @@ class AliasOutput(BaseModel): "aliased_val": "3", } + async def test_run_with_tools( + self, + httpx_mock: HTTPXMock, + agent_with_tools: Agent[HelloTaskInput, HelloTaskOutput], + ): + httpx_mock.add_response( + url="http://localhost:8000/v1/_/agents/123/schemas/1/run", + json={ + "id": "run_1", + "tool_call_requests": [ + { + "id": "1", + "name": "some_tool", + "input": {"arg": 1}, + }, + ], + }, + ) + httpx_mock.add_response( + url="http://localhost:8000/v1/_/agents/123/runs/run_1/reply", + json={ + "id": "run_2", + "task_output": {"message": "blibli!"}, + }, + ) + + out = await agent_with_tools.run(HelloTaskInput(name="Alice")) + assert out.output.message == "blibli!" + + reply_req = httpx_mock.get_request(url="http://localhost:8000/v1/_/agents/123/runs/run_1/reply") + assert reply_req + assert json.loads(reply_req.content)["tool_results"] == [ + { + "id": "1", + "output": "Hello, world 1!", + }, + ] + + async def test_max_turns_default( + self, + httpx_mock: HTTPXMock, + agent_with_tools: Agent[HelloTaskInput, HelloTaskOutput], + ): + for i in range(10): + httpx_mock.add_response( + json={ + "id": f"run_{i}", + "tool_call_requests": [ + { + "id": "1", + "name": "some_tool", + "input": {"arg": i}, + }, + ], + }, + ) + + with pytest.raises(MaxTurnsReachedError) as e: + await agent_with_tools.run(HelloTaskInput(name="Alice")) + + assert e.value.tool_call_requests == [ + ToolCallRequest( + id="1", + name="some_tool", + input={"arg": 9}, + ), + ] + + reqs = httpx_mock.get_requests() + assert len(reqs) == 10 + + async def test_max_turns_0(self, httpx_mock: HTTPXMock, agent_with_tools: Agent[HelloTaskInput, HelloTaskOutput]): + httpx_mock.add_response( + url="http://localhost:8000/v1/_/agents/123/schemas/1/run", + json={ + "id": "run_1", + "tool_call_requests": [ + { + "id": "1", + "name": "some_tool", + "input": {"arg": 1}, + }, + ], + }, + ) + + with pytest.raises(MaxTurnsReachedError) as e: + await agent_with_tools.run(HelloTaskInput(name="Alice"), max_turns=0) + + assert e.value.tool_call_requests == [ + ToolCallRequest( + id="1", + name="some_tool", + input={"arg": 1}, + ), + ] + + async def test_max_turns_0_not_raises( + self, + httpx_mock: HTTPXMock, + agent_with_tools: Agent[HelloTaskInput, HelloTaskOutput], + ): + self._mock_tool_call_requests(httpx_mock) + out = await agent_with_tools.run(HelloTaskInput(name="Alice"), max_turns=0, max_turns_raises=False) + assert out.tool_call_requests == [ + ToolCallRequest( + id="1", + name="some_tool", + input={"arg": 1}, + ), + ] + + async def test_max_turns_raises_default( + self, + httpx_mock: HTTPXMock, + agent_with_tools: Agent[HelloTaskInput, HelloTaskOutput], + ): + self._mock_tool_call_requests(httpx_mock) + agent_with_tools._other_run_params = {"max_turns": 0, "max_turns_raises": False} # pyright: ignore [reportPrivateUsage, reportAttributeAccessIssue] + + run = await agent_with_tools.run(HelloTaskInput(name="Alice")) + assert run.tool_call_requests == [ + ToolCallRequest( + id="1", + name="some_tool", + input={"arg": 1}, + ), + ] + class TestSanitizeVersion: def test_global_default(self, agent: Agent[HelloTaskInput, HelloTaskOutput]): diff --git a/workflowai/core/client/client.py b/workflowai/core/client/client.py index 2910f0b..1437e3b 100644 --- a/workflowai/core/client/client.py +++ b/workflowai/core/client/client.py @@ -1,10 +1,12 @@ import importlib.metadata from typing import ( + Literal, Optional, ) from workflowai.core.client._api import APIClient from workflowai.core.client._fn_utils import agent_wrapper +from workflowai.core.client._models import CreateFeedbackRequest from workflowai.core.client._utils import global_default_version_reference from workflowai.core.domain.version_reference import VersionReference @@ -39,3 +41,15 @@ def agent( version: Optional[VersionReference] = None, ): return agent_wrapper(lambda: self.api, schema_id=schema_id, agent_id=id, version=version) + + async def send_feedback( + self, + feedback_token: str, + outcome: Literal["positive", "negative"], + comment: Optional[str] = None, + user_id: Optional[str] = None, + ): + await self.api.post( + "/v1/feedback", + CreateFeedbackRequest(feedback_token=feedback_token, outcome=outcome, comment=comment, user_id=user_id), + ) diff --git a/workflowai/core/client/client_test.py b/workflowai/core/client/client_test.py index 96c6f67..f0d15b6 100644 --- a/workflowai/core/client/client_test.py +++ b/workflowai/core/client/client_test.py @@ -1,3 +1,4 @@ +import json from collections.abc import AsyncIterator from typing import Any from unittest.mock import Mock, patch @@ -10,12 +11,13 @@ from workflowai.core.domain.run import Run -class TestAgentDecorator: - @pytest.fixture - def workflowai(self): - # using httpx_mock to make sure we don't actually call the api - return WorkflowAI(api_key="test") +@pytest.fixture +def workflowai(): + # using httpx_mock to make sure we don't actually call the api + return WorkflowAI(api_key="test") + +class TestAgentDecorator: @pytest.fixture def mock_run_fn(self): with patch("workflowai.core.client.agent.Agent.run", autospec=True) as run_mock: @@ -120,3 +122,24 @@ def fn(task_input: HelloTaskInput) -> AsyncIterator[HelloTaskOutput]: ... HelloTaskOutput(message="hello"), HelloTaskOutput(message="hello"), ] + + +class TestSendFeedback: + async def test_send_feedback(self, workflowai: WorkflowAI, httpx_mock: HTTPXMock): + httpx_mock.add_response( + url="https://api.workflowai.com/v1/feedback", + method="POST", + status_code=200, + ) + + await workflowai.send_feedback(feedback_token="bliblu", outcome="positive") + + reqs = httpx_mock.get_requests() + assert len(reqs) == 1 + req = reqs[0] + assert req.method == "POST" + assert req.url == "https://api.workflowai.com/v1/feedback" + assert json.loads(req.content) == { + "feedback_token": "bliblu", + "outcome": "positive", + } diff --git a/workflowai/core/domain/errors.py b/workflowai/core/domain/errors.py index 33c126d..6f6e2a2 100644 --- a/workflowai/core/domain/errors.py +++ b/workflowai/core/domain/errors.py @@ -6,6 +6,8 @@ from httpx import Response from pydantic import BaseModel +from workflowai.core.domain import tool_call + ProviderErrorCode = Literal[ # Max number of tokens were exceeded in the prompt "max_tokens_exceeded", @@ -112,12 +114,14 @@ def __init__( run_id: Optional[str] = None, retry_after_delay_seconds: Optional[float] = None, partial_output: Optional[dict[str, Any]] = None, + tool_call_requests: Optional[list["tool_call.ToolCallRequest"]] = None, ): self.error = error self.run_id = run_id self.response = response self._retry_after_delay_seconds = retry_after_delay_seconds self.partial_output = partial_output + self.tool_call_requests = tool_call_requests def __str__(self): return f"WorkflowAIError : [{self.error.code}] ({self.error.status_code}): [{self.error.message}]" @@ -187,3 +191,6 @@ def details(self) -> Optional[dict[str, Any]]: class InvalidGenerationError(WorkflowAIError): ... + + +class MaxTurnsReachedError(WorkflowAIError): ...