Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[tool.poetry]
name = "workflowai"
version = "0.6.1"
version = "0.6.2-dev1"
description = "Python SDK for WorkflowAI"
authors = ["Guillaume Aquilina <guillaume@workflowai.com>"]
readme = "README.md"
Expand Down Expand Up @@ -63,7 +63,7 @@ unfixable = []
[tool.ruff.lint.per-file-ignores]
# in bin we use rich.print
"bin/*" = ["T201"]
"*_test.py" = ["S101"]
"*_test.py" = ["S101", "S106"]
"conftest.py" = ["S101"]
"examples/*" = ["INP001", "T201", "ERA001"]

Expand Down
11 changes: 10 additions & 1 deletion workflowai/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from collections.abc import Callable, Iterable
from typing import Any, Optional
from typing import Any, Literal, Optional

from typing_extensions import deprecated

Expand Down Expand Up @@ -82,3 +82,12 @@ def agent(
model=model,
tools=tools,
)


def send_feedback(
feedback_token: str,
outcome: Literal["positive", "negative"],
comment: Optional[str] = None,
user_id: Optional[str] = None,
):
return shared_client.send_feedback(feedback_token, outcome, comment, user_id)
12 changes: 8 additions & 4 deletions workflowai/core/_common_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,14 +40,18 @@ class VersionRunParams(TypedDict):
temperature: NotRequired[Optional[float]]


class BaseRunParams(VersionRunParams):
class OtherRunParams(TypedDict):
use_cache: NotRequired["CacheUsage"]
metadata: NotRequired[Optional[dict[str, Any]]]
labels: NotRequired[Optional[set[str]]]

max_retry_delay: NotRequired[float]
max_retry_count: NotRequired[float]

max_tool_iterations: NotRequired[int] # 10 by default
max_turns: NotRequired[int] # 10 by default
max_turns_raises: NotRequired[bool] # True by default


class BaseRunParams(VersionRunParams, OtherRunParams):
metadata: NotRequired[Optional[dict[str, Any]]]


class RunParams(BaseRunParams, Generic[AgentOutput]):
Expand Down
14 changes: 9 additions & 5 deletions workflowai/core/client/_fn_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
from pydantic import BaseModel, ValidationError
from typing_extensions import Unpack

from workflowai.core._common_types import OtherRunParams
from workflowai.core.client._api import APIClient
from workflowai.core.client._models import RunResponse
from workflowai.core.client._types import (
Expand Down Expand Up @@ -131,7 +132,7 @@ async def __call__(self, input: AgentInput, **kwargs: Unpack[RunParams[AgentOutp
milliseconds. Defaults to 60000.
max_retry_count (Optional[float], optional): The maximum number of retry attempts.
Defaults to 1.
max_tool_iterations (Optional[int], optional): Maximum number of tool iteration cycles.
max_turns (Optional[int], optional): Maximum number of tool iteration cycles.
Defaults to 10.
validator (Optional[OutputValidator[AgentOutput]], optional): Custom validator for the
output.
Expand Down Expand Up @@ -194,7 +195,7 @@ async def __call__(self, input: AgentInput, **kwargs: Unpack[RunParams[AgentOutp
milliseconds. Defaults to 60000.
max_retry_count (Optional[float], optional): The maximum number of retry attempts.
Defaults to 1.
max_tool_iterations (Optional[int], optional): Maximum number of tool iteration cycles.
max_turns (Optional[int], optional): Maximum number of tool iteration cycles.
Defaults to 10.
validator (Optional[OutputValidator[AgentOutput]], optional): Custom validator for the
output.
Expand Down Expand Up @@ -234,7 +235,7 @@ def __call__(self, input: AgentInput, **kwargs: Unpack[RunParams[AgentOutput]]):
milliseconds. Defaults to 60000.
max_retry_count (Optional[float], optional): The maximum number of retry attempts.
Defaults to 1.
max_tool_iterations (Optional[int], optional): Maximum number of tool iteration cycles.
max_turns (Optional[int], optional): Maximum number of tool iteration cycles.
Defaults to 10.
validator (Optional[OutputValidator[AgentOutput]], optional): Custom validator for the
output.
Expand Down Expand Up @@ -276,7 +277,7 @@ async def __call__(self, input: AgentInput, **kwargs: Unpack[RunParams[AgentOutp
milliseconds. Defaults to 60000.
max_retry_count (Optional[float], optional): The maximum number of retry attempts.
Defaults to 1.
max_tool_iterations (Optional[int], optional): Maximum number of tool iteration cycles.
max_turns (Optional[int], optional): Maximum number of tool iteration cycles.
Defaults to 10.
validator (Optional[OutputValidator[AgentOutput]], optional): Custom validator for the
output.
Expand Down Expand Up @@ -318,6 +319,7 @@ def wrap_run_template(
model: Optional[ModelOrStr],
fn: RunTemplate[AgentInput, AgentOutput],
tools: Optional[Iterable[Callable[..., Any]]] = None,
run_params: Optional[OtherRunParams] = None,
) -> Union[
_RunnableAgent[AgentInput, AgentOutput],
_RunnableOutputOnlyAgent[AgentInput, AgentOutput],
Expand All @@ -344,6 +346,7 @@ def wrap_run_template(
schema_id=schema_id,
version=version,
tools=tools,
**(run_params or {}),
)


Expand All @@ -358,13 +361,14 @@ def agent_wrapper(
version: Optional[VersionReference] = None,
model: Optional[ModelOrStr] = None,
tools: Optional[Iterable[Callable[..., Any]]] = None,
**kwargs: Unpack[OtherRunParams],
) -> AgentDecorator:
def wrap(fn: RunTemplate[AgentInput, AgentOutput]):
tid = agent_id or agent_id_from_fn_name(fn)
# TODO[types]: Not sure why a cast is needed here
agent = cast(
FinalRunTemplate[AgentInput, AgentOutput],
wrap_run_template(client, tid, schema_id, version, model, fn, tools),
wrap_run_template(client, tid, schema_id, version, model, fn, tools, kwargs),
)

agent.__doc__ = """A class representing an AI agent that can process inputs and generate outputs.
Expand Down
25 changes: 24 additions & 1 deletion workflowai/core/client/_fn_utils_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
get_generic_args,
is_async_iterator,
)
from workflowai.core.client._models import RunResponse
from workflowai.core.client._models import RunRequest, RunResponse
from workflowai.core.domain.run import Run


Expand Down Expand Up @@ -80,6 +80,29 @@ async def test_fn_run(self, mock_api_client: Mock):
assert run.id == "1"
assert run.output == HelloTaskOutput(message="Hello, World!")

async def test_fn_run_with_default_cache(self, mock_api_client: Mock):
wrapped = agent_wrapper(lambda: mock_api_client, schema_id=1, agent_id="hello", use_cache="never")(self.fn_run)
assert isinstance(wrapped, _RunnableAgent)

mock_api_client.post.return_value = RunResponse(id="1", task_output={"message": "Hello, World!"})
run = await wrapped(HelloTaskInput(name="World"))
assert isinstance(run, Run)

mock_api_client.post.assert_called_once()
req = mock_api_client.post.call_args.args[1]
assert isinstance(req, RunRequest)
assert req.use_cache == "never"

mock_api_client.post.reset_mock()

# Check that it can be overridden
_ = await wrapped(HelloTaskInput(name="World"), use_cache="always")

mock_api_client.post.assert_called_once()
req = mock_api_client.post.call_args.args[1]
assert isinstance(req, RunRequest)
assert req.use_cache == "always"

def fn_stream(self, task_input: HelloTaskInput) -> AsyncIterator[Run[HelloTaskOutput]]: ...

async def test_fn_stream(self, mock_api_client: Mock):
Expand Down
17 changes: 12 additions & 5 deletions workflowai/core/client/_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,13 +103,13 @@ def tool_call_to_domain(tool_call: ToolCall) -> DToolCall:
)


class ToolCallRequest(TypedDict):
class ToolCallRequestDict(TypedDict):
id: str
name: str
input: dict[str, Any]


def tool_call_request_to_domain(tool_call_request: ToolCallRequest) -> DToolCallRequest:
def tool_call_request_to_domain(tool_call_request: ToolCallRequestDict) -> DToolCallRequest:
return DToolCallRequest(
id=tool_call_request["id"],
name=tool_call_request["name"],
Expand All @@ -119,15 +119,15 @@ def tool_call_request_to_domain(tool_call_request: ToolCallRequest) -> DToolCall

class RunResponse(BaseModel):
id: str
task_output: dict[str, Any]
task_output: Optional[dict[str, Any]] = None

version: Optional[Version] = None
duration_seconds: Optional[float] = None
cost_usd: Optional[float] = None
metadata: Optional[dict[str, Any]] = None

tool_calls: Optional[list[ToolCall]] = None
tool_call_requests: Optional[list[ToolCallRequest]] = None
tool_call_requests: Optional[list[ToolCallRequestDict]] = None

feedback_token: Optional[str] = None

Expand All @@ -147,7 +147,7 @@ def to_domain(
id=self.id,
agent_id=task_id,
schema_id=task_schema_id,
output=validator(self.task_output, partial),
output=validator(self.task_output or {}, partial),
version=self.version and self.version.to_domain(),
duration_seconds=self.duration_seconds,
cost_usd=self.cost_usd,
Expand Down Expand Up @@ -220,3 +220,10 @@ class CompletionsResponse(BaseModel):
"""Response from the completions API endpoint."""

completions: list[Completion]


class CreateFeedbackRequest(BaseModel):
feedback_token: str
outcome: Literal["positive", "negative"]
comment: Optional[str]
user_id: Optional[str]
49 changes: 35 additions & 14 deletions workflowai/core/client/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,12 @@
from pydantic import BaseModel, ValidationError
from typing_extensions import Unpack

from workflowai.core._common_types import BaseRunParams, OutputValidator, VersionRunParams
from workflowai.core._common_types import (
BaseRunParams,
OtherRunParams,
OutputValidator,
VersionRunParams,
)
from workflowai.core.client._api import APIClient
from workflowai.core.client._models import (
CompletionsResponse,
Expand All @@ -27,7 +32,7 @@
global_default_version_reference,
)
from workflowai.core.domain.completion import Completion
from workflowai.core.domain.errors import BaseError, WorkflowAIError
from workflowai.core.domain.errors import BaseError, MaxTurnsReachedError, WorkflowAIError
from workflowai.core.domain.run import Run
from workflowai.core.domain.task import AgentInput, AgentOutput
from workflowai.core.domain.tool import Tool
Expand Down Expand Up @@ -83,7 +88,7 @@ class MyOutput(BaseModel):
```
"""

_DEFAULT_MAX_ITERATIONS = 10
_DEFAULT_MAX_TURNS = 10

def __init__(
self,
Expand All @@ -94,6 +99,7 @@ def __init__(
schema_id: Optional[int] = None,
version: Optional[VersionReference] = None,
tools: Optional[Iterable[Callable[..., Any]]] = None,
**kwargs: Unpack[OtherRunParams],
):
self.agent_id = agent_id
self.schema_id = schema_id
Expand All @@ -104,6 +110,7 @@ def __init__(
self._tools = self.build_tools(tools) if tools else None

self._default_validator = default_validator(output_cls)
self._other_run_params = kwargs

@classmethod
def build_tools(cls, tools: Iterable[Callable[..., Any]]):
Expand Down Expand Up @@ -180,6 +187,13 @@ def _sanitize_version(self, params: VersionRunParams) -> Union[str, int, dict[st
dumped["temperature"] = combined.temperature
return dumped

def _get_run_param(self, key: str, params: OtherRunParams, default: Any = None) -> Any:
if key in params:
return params[key] # pyright: ignore [reportUnknownVariableType]
if key in self._other_run_params:
return self._other_run_params[key] # pyright: ignore [reportUnknownVariableType]
return default

async def _prepare_run(self, agent_input: AgentInput, stream: bool, **kwargs: Unpack[RunParams[AgentOutput]]):
schema_id = self.schema_id
if not schema_id:
Expand All @@ -192,15 +206,14 @@ async def _prepare_run(self, agent_input: AgentInput, stream: bool, **kwargs: Un
task_input=agent_input.model_dump(by_alias=True),
version=version,
stream=stream,
use_cache=kwargs.get("use_cache"),
use_cache=self._get_run_param("use_cache", kwargs),
metadata=kwargs.get("metadata"),
labels=kwargs.get("labels"),
)

route = f"/v1/_/agents/{self.agent_id}/schemas/{self.schema_id}/run"
should_retry, wait_for_exception = build_retryable_wait(
kwargs.get("max_retry_delay", 60),
kwargs.get("max_retry_count", 1),
self._get_run_param("max_retry_delay", kwargs, 60),
self._get_run_param("max_retry_count", kwargs, 1),
)
return self._PreparedRun(request, route, should_retry, wait_for_exception, schema_id)

Expand All @@ -227,8 +240,8 @@ async def _prepare_reply(
)
route = f"/v1/_/agents/{self.agent_id}/runs/{run_id}/reply"
should_retry, wait_for_exception = build_retryable_wait(
kwargs.get("max_retry_delay", 60),
kwargs.get("max_retry_count", 1),
self._get_run_param("max_retry_delay", kwargs, 60),
self._get_run_param("max_retry_count", kwargs, 1),
)

return self._PreparedRun(request, route, should_retry, wait_for_exception, self.schema_id)
Expand Down Expand Up @@ -324,8 +337,14 @@ async def _build_run(
run = self._build_run_no_tools(chunk, schema_id, validator)

if run.tool_call_requests:
if current_iteration >= kwargs.get("max_iterations", self._DEFAULT_MAX_ITERATIONS):
raise WorkflowAIError(error=BaseError(message="max tool iterations reached"), response=None)
if current_iteration >= self._get_run_param("max_turns", kwargs, self._DEFAULT_MAX_TURNS):
if self._get_run_param("max_turns_raises", kwargs, default=True):
raise MaxTurnsReachedError(
error=BaseError(message="max tool iterations reached"),
response=None,
tool_call_requests=run.tool_call_requests,
)
return run
with_reply = await self._execute_tools(
run_id=run.id,
tool_call_requests=run.tool_call_requests,
Expand Down Expand Up @@ -368,7 +387,9 @@ async def run(
max_retry_delay (Optional[float], optional): The maximum delay between retries in milliseconds.
Defaults to 60000.
max_retry_count (Optional[float], optional): The maximum number of retry attempts. Defaults to 1.
max_tool_iterations (Optional[int], optional): Maximum number of tool iteration cycles. Defaults to 10.
max_turns (Optional[int], optional): Maximum number of tool iteration cycles. Defaults to 10.
max_turns_raises (Optional[bool], optional): Whether to raise an error when the maximum number of turns is
reached. Defaults to True.
validator (Optional[OutputValidator[AgentOutput]], optional): Custom validator for the output.

Returns:
Expand All @@ -385,7 +406,7 @@ async def run(
res,
prepared_run.schema_id,
validator,
current_iteration=0,
current_iteration=1,
# TODO[test]: add test with custom validator
**new_kwargs,
)
Expand Down Expand Up @@ -424,7 +445,7 @@ async def stream(
max_retry_delay (Optional[float], optional): The maximum delay between retries in milliseconds.
Defaults to 60000.
max_retry_count (Optional[float], optional): The maximum number of retry attempts. Defaults to 1.
max_tool_iterations (Optional[int], optional): Maximum number of tool iteration cycles. Defaults to 10.
max_turns (Optional[int], optional): Maximum number of tool iteration cycles. Defaults to 10.
validator (Optional[OutputValidator[AgentOutput]], optional): Custom validator for the output.

Returns:
Expand Down
Loading