diff --git a/README.md b/README.md index bd1217f..6f7671f 100644 --- a/README.md +++ b/README.md @@ -116,12 +116,12 @@ run will be created. By default: #### Using different models -WorkflowAI supports a long list of models. The source of truth for models we support is on [workflowai.com](https://workflowai.com). The [Model](./workflowai/core/domain/model.py) type is a good indication of what models are supported at the time of the sdk release, although it may be missing some models since new ones are added all the time. +WorkflowAI supports a long list of models. The source of truth for models we support is on [workflowai.com](https://workflowai.com). The [Model enum](./workflowai/core/domain/model.py) is a good indication of what models are supported at the time of the sdk release, although it may be missing some models since new ones are added all the time. You can set the model explicitly in the agent decorator: ```python -@workflowai.agent(model="gpt-4o") +@workflowai.agent(model=Model.GPT_4O_LATEST) def say_hello(input: Input) -> Output: ... ``` @@ -174,11 +174,20 @@ def say_hello(input: Input) -> AsyncIterator[Run[Output]]: ### Tools -WorkflowAI has a few tools that can be used to enhance the agent's capabilities: +Tools allow enhancing an agent's capabilities by allowing it to call external functions. + +#### WorkflowAI Hosted tools + +WorkflowAI hosts a few tools: - `@browser-text` allows fetching the content of a web page - `@search` allows performing a web search +Hosted tools tend to be faster because there is no back and forth between the client and the WorkflowAI API. Instead, +if a tool call is needed, the WorkflowAI API will call it within a single request. + +A single run will be created for all tool iterations. + To use a tool, simply add it's handles to the instructions (the function docstring): ```python @@ -190,6 +199,47 @@ def say_hello(input: Input) -> Output: ... ``` +#### Custom tools + +Custom tools allow using most functions within a single agent call. If an agent has custom tools, and the model +deems that tools are needed for a particular run, the agent will: + +- call all tools in parallel +- wait for all tools to complete +- reply to the run with the tool outputs +- continue with the next step of the run, and re-execute tools if needed +- ... +- until either no tool calls are requested, the max iteration (10 by default) or the agent has run to completion + +Tools are defined as regular python functions, and can be async or sync. Examples for tools are available in the [tools end 2 end test file](./tests/e2e/tools_test.py). + +> **Important**: It must be possible to determine the schema of a tool from the function signature. This means that +> the function must have type annotations and use standard types or `BaseModel` only for now. + +```python +# Annotations for parameters are passed as property descriptions in the tool schema +def get_current_time(timezone: Annotated[str, "The timezone to get the current time in. e-g Europe/Paris"]) -> str: + """Return the current time in the given timezone in iso format""" + return datetime.now(ZoneInfo(timezone)).isoformat() + +@agent( + id="answer-question", + tools=[get_current_time], + version=VersionProperties(model=Model.GPT_4O_LATEST), +) +async def answer_question(_: AnswerQuestionInput) -> Run[AnswerQuestionOutput]: ... + +run = await answer_question(AnswerQuestionInput(question="What is the current time in Paris?")) +assert run.output.answer +``` + +> It's important to understand that there are actually two runs in a single agent call: +> +> - the first run returns an empty output with a tool call request with a timezone +> - the second run returns the current time in the given timezone +> +> Only the last run is returned to the caller. + ### Error handling Agents can raise errors, for example when the underlying model fails to generate a response or when @@ -210,3 +260,111 @@ except WorkflowAIError as e: print(e.code) print(e.message) ``` + +### Definining input and output types + +There are some important subtleties when defining input and output types. + +#### Descriptions and examples + +Field description and examples are passed to the model and can help stir the output in the right direction. A good +use case is to describe a format or style for a string field + +```python +# summary has no examples or description so the model will likely return a block of text +class SummaryOutput(BaseModel): + summary: str + +# passing the description will help the model return a summary formatted as bullet points +class SummaryOutput(BaseModel): + summary: str = Field(description="A summary, formatted as bullet points") + +# passing examples can help as well +class SummaryOutput(BaseModel): + summary: str = Field(examples=["- Paris is a city in France\n- London is a city in England"]) +``` + +Some notes: + +- there are very little use cases for descriptions and examples in the **input** type. The model will most of the + infer from the value that is passed. An example use case is to use the description for fields that can be missing. +- adding examples that are too numerous or too specific can push the model to restrict the output value + +#### Required versus optional fields + +In short, we recommend using default values for most output fields. + +Pydantic is by default rather strict on model validation. If there is no default value, the field must be provided. +Although the fact that a field is required is passed to the model, the generation can sometimes omit null or empty +values. + +```python +class Input(BaseModel): + name: str + +class OutputStrict(BaseModel): + greeting: str + +@workflowai.agent() +async def say_hello_strict(_: Input) -> OutputStrict: + ... + +try: + run = await say_hello(Input(name="John")) + print(run.output.greeting) # "Hello, John!" +except WorkflowAIError as e: + print(e.code) # "invalid_generation" error code means that the generation did not match the schema + +class OutputTolerant(BaseModel): + greeting: str = "" + +@workflowai.agent() +async def say_hello_tolerant(_: Input) -> OutputTolerant: + ... + +# The invalid_generation is less likely +run = await say_hello_tolerant(Input(name="John")) +if not run.output.greeting: + print("No greeting was generated !") +print(run.output.greeting) # "Hello, John!" + +``` + +> WorkflowAI automatically retries invalid generations once. If a model outputs an object that does not match the +> schema, a new generation is triggered with the previous response and the error message as context. + +Another reason to prefer optional fields in the output is for streaming. Partial outputs are constructed using +`BaseModel.model_construct` when streaming. If a default value is not provided for a field, fields that are +absent will cause `AttributeError` when queried. + +```python +class Input(BaseModel): + name: str + +class OutputStrict(BaseModel): + greeting1: str + greeting2: str + +@workflowai.agent() +def say_hello_strict(_: Input) -> AsyncIterator[Output]: + ... + +async for run in say_hello(Input(name="John")): + try: + print(run.output.greeting1) + except AttributeError: + # run.output.greeting1 has not been generated yet + + +class OutputTolerant(BaseModel): + greeting1: str = "" + greeting2: str = "" + +@workflowai.agent() +def say_hello_tolerant(_: Input) -> AsyncIterator[OutputTolerant]: + ... + +async for run in say_hello(Input(name="John")): + print(run.output.greeting1) # will be empty if the model has not generated it yet + +``` diff --git a/pyproject.toml b/pyproject.toml index 1cce866..77ac7c5 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [tool.poetry] name = "workflowai" -version = "0.6.0.dev2" +version = "0.6.0.dev3" description = "" authors = ["Guillaume Aquilina "] readme = "README.md" diff --git a/tests/e2e/tools_test.py b/tests/e2e/tools_test.py new file mode 100644 index 0000000..82869a5 --- /dev/null +++ b/tests/e2e/tools_test.py @@ -0,0 +1,64 @@ +from datetime import datetime +from typing import Annotated + +from pydantic import BaseModel +from zoneinfo import ZoneInfo + +from workflowai import Run, agent +from workflowai.core.domain.model import Model +from workflowai.core.domain.tool import Tool +from workflowai.core.domain.tool_call import ToolCallResult +from workflowai.core.domain.version_properties import VersionProperties + + +class AnswerQuestionInput(BaseModel): + question: str + + +class AnswerQuestionOutput(BaseModel): + answer: str = "" + + +async def test_manual_tool(): + get_current_time_tool = Tool( + name="get_current_time", + description="Get the current time", + input_schema={}, + output_schema={ + "properties": { + "time": {"type": "string", "description": "The current time"}, + }, + }, + ) + + @agent( + id="answer-question", + version=VersionProperties(model=Model.GPT_4O_LATEST, enabled_tools=[get_current_time_tool]), + ) + async def answer_question(_: AnswerQuestionInput) -> Run[AnswerQuestionOutput]: ... + + run = await answer_question(AnswerQuestionInput(question="What is the current time spelled out in French?")) + assert not run.output.answer + + assert run.tool_call_requests + assert len(run.tool_call_requests) == 1 + assert run.tool_call_requests[0].name == "get_current_time" + + replied = await run.reply(tool_results=[ToolCallResult(id=run.tool_call_requests[0].id, output={"time": "12:00"})]) + assert replied.output.answer + + +async def test_auto_tool(): + def get_current_time(timezone: Annotated[str, "The timezone to get the current time in. e-g Europe/Paris"]) -> str: + """Return the current time in the given timezone in iso format""" + return datetime.now(ZoneInfo(timezone)).isoformat() + + @agent( + id="answer-question", + tools=[get_current_time], + version=VersionProperties(model=Model.GPT_4O_LATEST), + ) + async def answer_question(_: AnswerQuestionInput) -> Run[AnswerQuestionOutput]: ... + + run = await answer_question(AnswerQuestionInput(question="What is the current time in Paris?")) + assert run.output.answer diff --git a/workflowai/__init__.py b/workflowai/__init__.py index eac4b5a..2906edb 100644 --- a/workflowai/__init__.py +++ b/workflowai/__init__.py @@ -1,10 +1,12 @@ import os -from typing import Optional +from collections.abc import Callable, Iterable +from typing import Any, Optional from typing_extensions import deprecated from workflowai.core.client._types import AgentDecorator from workflowai.core.client.client import WorkflowAI as WorkflowAI +from workflowai.core.domain import model from workflowai.core.domain.cache_usage import CacheUsage as CacheUsage from workflowai.core.domain.errors import WorkflowAIError as WorkflowAIError from workflowai.core.domain.model import Model as Model @@ -31,7 +33,7 @@ def _build_client( shared_client: WorkflowAI = _build_client() # The default model to use when running agents without a deployment -DEFAULT_MODEL: Model = os.getenv("WORKFLOWAI_DEFAULT_MODEL", "gemini-1.5-pro-latest") +DEFAULT_MODEL: "model.ModelOrStr" = os.getenv("WORKFLOWAI_DEFAULT_MODEL", "gemini-1.5-pro-latest") def init(api_key: Optional[str] = None, url: Optional[str] = None, default_version: Optional[VersionReference] = None): @@ -66,7 +68,8 @@ def agent( id: Optional[str] = None, # noqa: A002 schema_id: Optional[int] = None, version: Optional[VersionReference] = None, - model: Optional[Model] = None, + model: Optional["model.ModelOrStr"] = None, + tools: Optional[Iterable[Callable[..., Any]]] = None, ) -> AgentDecorator: from workflowai.core.client._fn_utils import agent_wrapper @@ -76,4 +79,5 @@ def agent( agent_id=id, version=version, model=model, + tools=tools, ) diff --git a/workflowai/core/_common_types.py b/workflowai/core/_common_types.py new file mode 100644 index 0000000..af3f595 --- /dev/null +++ b/workflowai/core/_common_types.py @@ -0,0 +1,36 @@ +from typing import ( + Any, + Generic, + Optional, + Protocol, + TypeVar, +) + +from pydantic import BaseModel +from typing_extensions import NotRequired, TypedDict + +from workflowai.core.domain.cache_usage import CacheUsage +from workflowai.core.domain.task import AgentOutput +from workflowai.core.domain.version_reference import VersionReference + +AgentInputContra = TypeVar("AgentInputContra", bound=BaseModel, contravariant=True) +AgentOutputCov = TypeVar("AgentOutputCov", bound=BaseModel, covariant=True) + + +class OutputValidator(Protocol, Generic[AgentOutputCov]): + def __call__(self, data: dict[str, Any], has_tool_call_requests: bool) -> AgentOutputCov: ... + + +class BaseRunParams(TypedDict): + version: NotRequired[Optional["VersionReference"]] + use_cache: NotRequired["CacheUsage"] + metadata: NotRequired[Optional[dict[str, Any]]] + labels: NotRequired[Optional[set[str]]] + max_retry_delay: NotRequired[float] + max_retry_count: NotRequired[float] + + max_tool_iterations: NotRequired[int] # 10 by default + + +class RunParams(BaseRunParams, Generic[AgentOutput]): + validator: NotRequired[OutputValidator["AgentOutput"]] diff --git a/workflowai/core/logger.py b/workflowai/core/_logger.py similarity index 100% rename from workflowai/core/logger.py rename to workflowai/core/_logger.py diff --git a/workflowai/core/client/_fn_utils.py b/workflowai/core/client/_fn_utils.py index 57aee97..0f097dd 100644 --- a/workflowai/core/client/_fn_utils.py +++ b/workflowai/core/client/_fn_utils.py @@ -1,6 +1,6 @@ import functools import inspect -from collections.abc import Callable +from collections.abc import Callable, Iterable from typing import ( Any, AsyncIterator, @@ -26,7 +26,7 @@ RunTemplate, ) from workflowai.core.client.agent import Agent -from workflowai.core.domain.model import Model +from workflowai.core.domain.model import ModelOrStr from workflowai.core.domain.run import Run from workflowai.core.domain.task import AgentInput, AgentOutput from workflowai.core.domain.version_properties import VersionProperties @@ -128,8 +128,9 @@ def wrap_run_template( agent_id: str, schema_id: Optional[int], version: Optional[VersionReference], - model: Optional[Model], + model: Optional[ModelOrStr], fn: RunTemplate[AgentInput, AgentOutput], + tools: Optional[Iterable[Callable[..., Any]]] = None, ) -> Union[ _RunnableAgent[AgentInput, AgentOutput], _RunnableOutputOnlyAgent[AgentInput, AgentOutput], @@ -155,6 +156,7 @@ def wrap_run_template( api=client, schema_id=schema_id, version=version, + tools=tools, ) @@ -167,11 +169,12 @@ def agent_wrapper( schema_id: Optional[int] = None, agent_id: Optional[str] = None, version: Optional[VersionReference] = None, - model: Optional[Model] = None, + model: Optional[ModelOrStr] = None, + tools: Optional[Iterable[Callable[..., Any]]] = None, ) -> AgentDecorator: def wrap(fn: RunTemplate[AgentInput, AgentOutput]) -> FinalRunTemplate[AgentInput, AgentOutput]: tid = agent_id or agent_id_from_fn_name(fn) - return functools.wraps(fn)(wrap_run_template(client, tid, schema_id, version, model, fn)) # pyright: ignore [reportReturnType] + return functools.wraps(fn)(wrap_run_template(client, tid, schema_id, version, model, fn, tools)) # pyright: ignore [reportReturnType] # pyright is unhappy with generics return wrap # pyright: ignore [reportReturnType] diff --git a/workflowai/core/client/_models.py b/workflowai/core/client/_models.py index 37f44e9..369b9fe 100644 --- a/workflowai/core/client/_models.py +++ b/workflowai/core/client/_models.py @@ -1,14 +1,21 @@ -from typing import Any, Optional, Union +from typing import Any, Literal, Optional, Union from pydantic import BaseModel, Field from typing_extensions import NotRequired, TypedDict -from workflowai.core.client._types import OutputValidator +from workflowai.core._common_types import OutputValidator from workflowai.core.domain.cache_usage import CacheUsage from workflowai.core.domain.run import Run from workflowai.core.domain.task import AgentOutput +from workflowai.core.domain.tool_call import ToolCall as DToolCall +from workflowai.core.domain.tool_call import ToolCallRequest as DToolCallRequest +from workflowai.core.domain.tool_call import ToolCallResult as DToolCallResult from workflowai.core.domain.version import Version as DVersion from workflowai.core.domain.version_properties import VersionProperties as DVersionProperties +from workflowai.core.utils._iter import safe_map_list + +# TODO: we should likely only use typed dicts here to avoid validation issues +# We have some typed dicts but pydantic also validates them class RunRequest(BaseModel): @@ -27,7 +34,29 @@ class RunRequest(BaseModel): stream: Optional[bool] = None -# Not using a base model to avoid validation +class ReplyRequest(BaseModel): + user_response: Optional[str] = None + version: Union[str, int, dict[str, Any]] + metadata: Optional[dict[str, Any]] = None + + class ToolResult(BaseModel): + id: str + output: Optional[Any] + error: Optional[str] + + @classmethod + def from_domain(cls, tool_result: DToolCallResult): + return cls( + id=tool_result.id, + output=tool_result.output, + error=tool_result.error, + ) + + tool_results: Optional[list[ToolResult]] = None + + stream: Optional[bool] = None + + class VersionProperties(TypedDict): model: NotRequired[Optional[str]] provider: NotRequired[Optional[str]] @@ -35,9 +64,55 @@ class VersionProperties(TypedDict): instructions: NotRequired[Optional[str]] +def version_properties_to_domain(properties: VersionProperties) -> DVersionProperties: + return DVersionProperties.model_construct( + None, + **properties, + ) + + class Version(BaseModel): properties: VersionProperties + def to_domain(self) -> DVersion: + return DVersion( + properties=version_properties_to_domain(self.properties), + ) + + +class ToolCall(TypedDict): + id: str + name: str + input_preview: str + output_preview: NotRequired[Optional[str]] + error: NotRequired[Optional[str]] + status: NotRequired[Optional[Literal["success", "failed", "in_progress"]]] + + +def tool_call_to_domain(tool_call: ToolCall) -> DToolCall: + return DToolCall( + id=tool_call["id"], + name=tool_call["name"], + input_preview=tool_call["input_preview"], + output_preview=tool_call.get("output_preview"), + error=tool_call.get("error"), + status=tool_call.get("status"), + ) + + +class ToolCallRequest(TypedDict): + id: str + name: str + input: dict[str, Any] + + +def tool_call_request_to_domain(tool_call_request: ToolCallRequest) -> DToolCallRequest: + return DToolCallRequest( + id=tool_call_request["id"], + name=tool_call_request["name"], + input=tool_call_request["input"], + ) + class RunResponse(BaseModel): id: str @@ -46,22 +121,22 @@ class RunResponse(BaseModel): version: Optional[Version] = None duration_seconds: Optional[float] = None cost_usd: Optional[float] = None + metadata: Optional[dict[str, Any]] = None + + tool_calls: Optional[list[ToolCall]] = None + tool_call_requests: Optional[list[ToolCallRequest]] = None def to_domain(self, task_id: str, task_schema_id: int, validator: OutputValidator[AgentOutput]) -> Run[AgentOutput]: return Run( id=self.id, agent_id=task_id, schema_id=task_schema_id, - output=validator(self.task_output), - version=self.version - and DVersion( - properties=DVersionProperties.model_construct( - None, - **self.version.properties, - ), - ), + output=validator(self.task_output, self.tool_call_requests is not None), + version=self.version and self.version.to_domain(), duration_seconds=self.duration_seconds, cost_usd=self.cost_usd, + tool_calls=safe_map_list(self.tool_calls, tool_call_to_domain), + tool_call_requests=safe_map_list(self.tool_call_requests, tool_call_request_to_domain), ) diff --git a/workflowai/core/client/_models_test.py b/workflowai/core/client/_models_test.py index 46b8c72..e12bdbf 100644 --- a/workflowai/core/client/_models_test.py +++ b/workflowai/core/client/_models_test.py @@ -5,8 +5,9 @@ from tests.utils import fixture_text from workflowai.core.client._models import RunResponse -from workflowai.core.client._utils import tolerant_validator +from workflowai.core.client._utils import intolerant_validator, tolerant_validator from workflowai.core.domain.run import Run +from workflowai.core.domain.tool_call import ToolCallRequest @pytest.mark.parametrize( @@ -75,4 +76,14 @@ def test_with_version_validation_fails(self): '{"id": "1", "task_output": {"a": 1}, "version": {"properties": {"a": 1, "b": "test"}}}', ) with pytest.raises(ValidationError): - chunk.to_domain(task_id="1", task_schema_id=1, validator=_TaskOutput.model_validate) + chunk.to_domain(task_id="1", task_schema_id=1, validator=intolerant_validator(_TaskOutput)) + + def test_with_tool_calls(self): + chunk = RunResponse.model_validate_json( + '{"id": "1", "task_output": {}, "tool_call_requests": [{"id": "1", "name": "test", "input": {"a": 1}}]}', + ) + assert chunk + + parsed = chunk.to_domain(task_id="1", task_schema_id=1, validator=tolerant_validator(_TaskOutput)) + assert isinstance(parsed, Run) + assert parsed.tool_call_requests == [ToolCallRequest(id="1", name="test", input={"a": 1})] diff --git a/workflowai/core/client/_types.py b/workflowai/core/client/_types.py index 7dd3bdc..96cf75a 100644 --- a/workflowai/core/client/_types.py +++ b/workflowai/core/client/_types.py @@ -1,37 +1,18 @@ -from collections.abc import Callable from typing import ( Any, AsyncIterator, Generic, Optional, Protocol, - TypeVar, Union, overload, ) -from pydantic import BaseModel -from typing_extensions import NotRequired, TypedDict, Unpack +from typing_extensions import Unpack -from workflowai.core.domain.cache_usage import CacheUsage +from workflowai.core._common_types import AgentInputContra, AgentOutputCov, RunParams from workflowai.core.domain.run import Run from workflowai.core.domain.task import AgentInput, AgentOutput -from workflowai.core.domain.version_reference import VersionReference - -AgentInputContra = TypeVar("AgentInputContra", bound=BaseModel, contravariant=True) -AgentOutputCov = TypeVar("AgentOutputCov", bound=BaseModel, covariant=True) - -OutputValidator = Callable[[dict[str, Any]], AgentOutput] - - -class RunParams(TypedDict, Generic[AgentOutput]): - version: NotRequired[Optional[VersionReference]] - use_cache: NotRequired[CacheUsage] - metadata: NotRequired[Optional[dict[str, Any]]] - labels: NotRequired[Optional[set[str]]] - max_retry_delay: NotRequired[float] - max_retry_count: NotRequired[float] - validator: NotRequired[OutputValidator[AgentOutput]] class RunFn(Protocol, Generic[AgentInputContra, AgentOutput]): diff --git a/workflowai/core/client/_utils.py b/workflowai/core/client/_utils.py index 3d3d181..c846ef8 100644 --- a/workflowai/core/client/_utils.py +++ b/workflowai/core/client/_utils.py @@ -6,12 +6,13 @@ import re from json import JSONDecodeError from time import time +from typing import Any -from workflowai.core.client._types import OutputValidator +from workflowai.core._common_types import OutputValidator +from workflowai.core._logger import logger from workflowai.core.domain.errors import BaseError, WorkflowAIError from workflowai.core.domain.task import AgentOutput from workflowai.core.domain.version_reference import VersionReference -from workflowai.core.logger import logger delimiter = re.compile(r'\}\n\ndata: \{"') @@ -86,7 +87,21 @@ async def _wait_for_exception(e: WorkflowAIError): def tolerant_validator(m: type[AgentOutput]) -> OutputValidator[AgentOutput]: - return lambda payload: m.model_construct(None, **payload) + def _validator(data: dict[str, Any], has_tool_call_requests: bool) -> AgentOutput: # noqa: ARG001 + return m.model_construct(None, **data) + + return _validator + + +def intolerant_validator(m: type[AgentOutput]) -> OutputValidator[AgentOutput]: + def _validator(data: dict[str, Any], has_tool_call_requests: bool) -> AgentOutput: + # When we have tool call requests, the output can be empty + if has_tool_call_requests: + return m.model_construct(None, **data) + + return m.model_validate(data) + + return _validator def global_default_version_reference() -> VersionReference: diff --git a/workflowai/core/client/agent.py b/workflowai/core/client/agent.py index 9793163..cbe06f5 100644 --- a/workflowai/core/client/agent.py +++ b/workflowai/core/client/agent.py @@ -1,17 +1,34 @@ -from collections.abc import Awaitable, Callable -from typing import Any, Generic, NamedTuple, Optional, Union +import asyncio +from collections.abc import Awaitable, Callable, Iterable +from typing import Any, Generic, NamedTuple, Optional, Union, cast +from pydantic import BaseModel from typing_extensions import Unpack +from workflowai.core._common_types import BaseRunParams, OutputValidator from workflowai.core.client._api import APIClient -from workflowai.core.client._models import CreateAgentRequest, CreateAgentResponse, RunRequest, RunResponse +from workflowai.core.client._models import ( + CreateAgentRequest, + CreateAgentResponse, + ReplyRequest, + RunRequest, + RunResponse, +) from workflowai.core.client._types import RunParams -from workflowai.core.client._utils import build_retryable_wait, global_default_version_reference, tolerant_validator +from workflowai.core.client._utils import ( + build_retryable_wait, + global_default_version_reference, + intolerant_validator, + tolerant_validator, +) from workflowai.core.domain.errors import BaseError, WorkflowAIError from workflowai.core.domain.run import Run from workflowai.core.domain.task import AgentInput, AgentOutput +from workflowai.core.domain.tool import Tool +from workflowai.core.domain.tool_call import ToolCallRequest, ToolCallResult from workflowai.core.domain.version_properties import VersionProperties from workflowai.core.domain.version_reference import VersionReference +from workflowai.core.utils._tools import tool_schema class Agent(Generic[AgentInput, AgentOutput]): @@ -23,6 +40,7 @@ def __init__( api: Union[APIClient, Callable[[], APIClient]], schema_id: Optional[int] = None, version: Optional[VersionReference] = None, + tools: Optional[Iterable[Callable[..., Any]]] = None, ): self.agent_id = agent_id self.schema_id = schema_id @@ -30,13 +48,20 @@ def __init__( self.output_cls = output_cls self.version: VersionReference = version or global_default_version_reference() self._api = (lambda: api) if isinstance(api, APIClient) else api + self._tools = self.build_tools(tools) if tools else None + + @classmethod + def build_tools(cls, tools: Iterable[Callable[..., Any]]) -> dict[str, tuple[Tool, Callable[..., Any]]]: + # TODO: we should be more tolerant with errors ? + return {tool.__name__: (tool_schema(tool), tool) for tool in tools} @property def api(self) -> APIClient: return self._api() class _PreparedRun(NamedTuple): - request: RunRequest + # would be nice to use a generic here, but python 3.9 does not support generic NamedTuple + request: BaseModel route: str should_retry: Callable[[], bool] wait_for_exception: Callable[[WorkflowAIError], Awaitable[None]] @@ -53,6 +78,16 @@ def _sanitize_version(self, version: Optional[VersionReference]) -> Union[str, i import workflowai dumped["model"] = workflowai.DEFAULT_MODEL + if self._tools: + dumped["enabled_tools"] = [ + { + "name": tool.name, + "description": tool.description, + "input_schema": tool.input_schema, + "output_schema": tool.output_schema, + } + for tool, _ in self._tools.values() + ] return dumped async def _prepare_run(self, task_input: AgentInput, stream: bool, **kwargs: Unpack[RunParams[AgentOutput]]): @@ -78,6 +113,35 @@ async def _prepare_run(self, task_input: AgentInput, stream: bool, **kwargs: Unp ) return self._PreparedRun(request, route, should_retry, wait_for_exception, schema_id) + async def _prepare_reply( + self, + run_id: str, + user_response: Optional[str], + tool_results: Optional[Iterable[ToolCallResult]], + stream: bool, + **kwargs: Unpack[RunParams[AgentOutput]], + ): + if not self.schema_id: + raise ValueError("schema_id is required") + version = self._sanitize_version(kwargs.get("version")) + + request = ReplyRequest( + user_response=user_response, + version=version, + stream=stream, + metadata=kwargs.get("metadata"), + tool_results=[ReplyRequest.ToolResult.from_domain(tool_result) for tool_result in tool_results] + if tool_results + else None, + ) + route = f"/v1/_/agents/{self.agent_id}/runs/{run_id}/reply" + should_retry, wait_for_exception = build_retryable_wait( + kwargs.get("max_retry_delay", 60), + kwargs.get("max_retry_count", 1), + ) + + return self._PreparedRun(request, route, should_retry, wait_for_exception, self.schema_id) + async def register(self): """Registers the agent and returns the schema id""" res = await self.api.post( @@ -92,6 +156,79 @@ async def register(self): self.schema_id = res.schema_id return res.schema_id + @classmethod + async def _safe_execute_tool(cls, tool_call_request: ToolCallRequest, tool_func: Callable[..., Any]): + try: + output: Any = tool_func(**tool_call_request.input) + if isinstance(output, Awaitable): + output = await output + return ToolCallResult( + id=tool_call_request.id, + output=output, + ) + except Exception as e: # noqa: BLE001 + return ToolCallResult( + id=tool_call_request.id, + error=str(e), + ) + + async def _execute_tools( + self, + run_id: str, + tool_call_requests: Iterable[ToolCallRequest], + current_iteration: int, + **kwargs: Unpack[RunParams[AgentOutput]], + ): + if not self._tools: + return None + + executions: list[tuple[ToolCallRequest, Callable[..., Any]]] = [] + for tool_call_request in tool_call_requests: + if tool_call_request.name not in self._tools: + continue + + _, tool_func = self._tools[tool_call_request.name] + executions.append((tool_call_request, tool_func)) + + if not executions: + return None + + # Executing all tools in parallel + results = await asyncio.gather( + *[self._safe_execute_tool(tool_call_request, tool_func) for tool_call_request, tool_func in executions], + ) + return await self.reply( + run_id=run_id, + tool_results=results, + current_iteration=current_iteration + 1, + **kwargs, + ) + + async def _build_run( + self, + chunk: RunResponse, + schema_id: int, + validator: OutputValidator[AgentOutput], + current_iteration: int, + **kwargs: Unpack[BaseRunParams], + ) -> Run[AgentOutput]: + run = chunk.to_domain(self.agent_id, schema_id, validator) + run._agent = self # pyright: ignore [reportPrivateUsage] + + if run.tool_call_requests: + with_reply = await self._execute_tools( + run_id=run.id, + tool_call_requests=run.tool_call_requests, + current_iteration=current_iteration, + validator=validator, + **kwargs, + ) + # Execute tools return None if there are actually no available tools to execute + if with_reply: + return with_reply + + return run + async def run( self, task_input: AgentInput, @@ -122,13 +259,21 @@ async def run( or an async iterator of output objects """ prepared_run = await self._prepare_run(task_input, stream=False, **kwargs) - validator = kwargs.get("validator") or self.output_cls.model_validate + validator, new_kwargs = self._sanitize_validator(kwargs, intolerant_validator(self.output_cls)) last_error = None while prepared_run.should_retry(): try: res = await self.api.post(prepared_run.route, prepared_run.request, returns=RunResponse, run=True) - return res.to_domain(self.agent_id, prepared_run.schema_id, validator) + return await self._build_run( + res, + prepared_run.schema_id, + validator, + current_iteration=0, + # TODO[test]: add test with custom validator + # We popped validator above + **new_kwargs, + ) except WorkflowAIError as e: # noqa: PERF203 last_error = e await prepared_run.wait_for_exception(e) @@ -165,7 +310,7 @@ async def stream( or an async iterator of output objects """ prepared_run = await self._prepare_run(task_input, stream=True, **kwargs) - validator = kwargs.get("validator") or tolerant_validator(self.output_cls) + validator, new_kwargs = self._sanitize_validator(kwargs, tolerant_validator(self.output_cls)) while prepared_run.should_retry(): try: @@ -176,7 +321,38 @@ async def stream( returns=RunResponse, run=True, ): - yield chunk.to_domain(self.agent_id, prepared_run.schema_id, validator) + yield await self._build_run( + chunk, + prepared_run.schema_id, + validator, + current_iteration=0, + **new_kwargs, + ) return except WorkflowAIError as e: # noqa: PERF203 await prepared_run.wait_for_exception(e) + + async def reply( + self, + run_id: str, + user_response: Optional[str] = None, + tool_results: Optional[Iterable[ToolCallResult]] = None, + current_iteration: int = 0, + **kwargs: Unpack[RunParams[AgentOutput]], + ): + prepared_run = await self._prepare_reply(run_id, user_response, tool_results, stream=False, **kwargs) + validator, new_kwargs = self._sanitize_validator(kwargs, intolerant_validator(self.output_cls)) + + res = await self.api.post(prepared_run.route, prepared_run.request, returns=RunResponse, run=True) + return await self._build_run( + res, + prepared_run.schema_id, + validator, + current_iteration=current_iteration, + **new_kwargs, + ) + + @classmethod + def _sanitize_validator(cls, kwargs: RunParams[AgentOutput], default: OutputValidator[AgentOutput]): + validator = kwargs.pop("validator", default) + return validator, cast(BaseRunParams, kwargs) diff --git a/workflowai/core/domain/model.py b/workflowai/core/domain/model.py index 56d8e7a..ea881e6 100644 --- a/workflowai/core/domain/model.py +++ b/workflowai/core/domain/model.py @@ -1,105 +1,122 @@ -from typing import Literal, Union +from enum import Enum +from typing import Union -Model = Union[ - Literal[ - # -------------------------------------------------------------------------- - # OpenAI Models - # -------------------------------------------------------------------------- - "gpt-4o-latest", - "gpt-4o-2024-11-20", - "gpt-4o-2024-08-06", - "gpt-4o-2024-05-13", - "gpt-4o-mini-latest", - "gpt-4o-mini-2024-07-18", - "o1-2024-12-17-high", - "o1-2024-12-17", - "o1-2024-12-17-low", - "o1-preview-2024-09-12", - "o1-mini-latest", - "o1-mini-2024-09-12", - "gpt-4o-audio-preview-2024-12-17", - "gpt-4o-audio-preview-2024-10-01", - "gpt-4-turbo-2024-04-09", - "gpt-4-0125-preview", - "gpt-4-1106-preview", - "gpt-4-1106-vision-preview", - "gpt-3.5-turbo-0125", - "gpt-3.5-turbo-1106", - # -------------------------------------------------------------------------- - # Gemini Models - # -------------------------------------------------------------------------- - "gemini-2.0-flash-exp", - "gemini-2.0-flash-thinking-exp-1219", - "gemini-1.5-pro-latest", - "gemini-1.5-pro-002", - "gemini-1.5-pro-001", - "gemini-1.5-pro-preview-0514", - "gemini-1.5-pro-preview-0409", - "gemini-1.5-flash-latest", - "gemini-1.5-flash-002", - "gemini-1.5-flash-001", - "gemini-1.5-flash-8b", - "gemini-1.5-flash-preview-0514", - "gemini-exp-1206", - "gemini-exp-1121", - "gemini-1.0-pro-002", - "gemini-1.0-pro-001", - "gemini-1.0-pro-vision-001", - # -------------------------------------------------------------------------- - # Claude Models - # -------------------------------------------------------------------------- - "claude-3-5-sonnet-latest", - "claude-3-5-sonnet-20241022", - "claude-3-5-sonnet-20240620", - "claude-3-5-haiku-latest", - "claude-3-5-haiku-20241022", - "claude-3-opus-20240229", - "claude-3-sonnet-20240229", - "claude-3-haiku-20240307", - # -------------------------------------------------------------------------- - # Llama Models - # -------------------------------------------------------------------------- - "llama-3.3-70b", - "llama-3.2-90b", - "llama-3.2-11b", - "llama-3.2-11b-vision", - "llama-3.2-3b", - "llama-3.2-1b", - "llama-3.2-90b-vision-preview", - "llama-3.2-90b-text-preview", - "llama-3.2-11b-text-preview", - "llama-3.2-3b-preview", - "llama-3.2-1b-preview", - "llama-3.1-405b", - "llama-3.1-70b", - "llama-3.1-8b", - "llama3-70b-8192", - "llama3-8b-8192", - # -------------------------------------------------------------------------- - # Mistral AI Models - # -------------------------------------------------------------------------- - "mixtral-8x7b-32768", - "mistral-large-2-latest", - "mistral-large-2-2407", - "mistral-large-latest", - "mistral-large-2411", - "pixtral-large-latest", - "pixtral-large-2411", - "pixtral-12b-2409", - "ministral-3b-2410", - "ministral-8b-2410", - "mistral-small-2409", - "codestral-mamba-2407", - # -------------------------------------------------------------------------- - # Qwen Models - # -------------------------------------------------------------------------- - "qwen-v3p2-32b-instruct", - # -------------------------------------------------------------------------- - # DeepSeek Models - # -------------------------------------------------------------------------- - "deepseek-v3-2412", - "deepseek-r1-2501", - ], - # Adding string to allow for any model not currently in the SDK but supported by the API - str, -] + +# All models that were supported at one point by WorkflowAI +# Some are deprecated an remapped. +# +# Notes: +# - DO NOT remove models from this list, only add new ones. If needed, add a replacement model +# in the model datas +# - the order is the same as will be displayed in the UI. If you need a model to be displayed +# higher, comment out the line where it should be in "natural" order, and add another one wherever +# needed for the order +class Model(str, Enum): + # -------------------------------------------------------------------------- + # OpenAI Models + # -------------------------------------------------------------------------- + + GPT_4O_LATEST = "gpt-4o-latest" + GPT_4O_2024_11_20 = "gpt-4o-2024-11-20" + GPT_4O_2024_08_06 = "gpt-4o-2024-08-06" + GPT_4O_2024_05_13 = "gpt-4o-2024-05-13" + GPT_4O_MINI_LATEST = "gpt-4o-mini-latest" + GPT_4O_MINI_2024_07_18 = "gpt-4o-mini-2024-07-18" + O1_2024_12_17_HIGH_REASONING_EFFORT = "o1-2024-12-17-high" + O1_2024_12_17_MEDIUM_REASONING_EFFORT = "o1-2024-12-17" + O1_2024_12_17_LOW_REASONING_EFFORT = "o1-2024-12-17-low" + O1_PREVIEW_2024_09_12 = "o1-preview-2024-09-12" + O1_MINI_LATEST = "o1-mini-latest" + O1_MINI_2024_09_12 = "o1-mini-2024-09-12" + GPT_4O_AUDIO_PREVIEW_2024_12_17 = "gpt-4o-audio-preview-2024-12-17" + GPT_40_AUDIO_PREVIEW_2024_10_01 = "gpt-4o-audio-preview-2024-10-01" + GPT_4_TURBO_2024_04_09 = "gpt-4-turbo-2024-04-09" + GPT_4_0125_PREVIEW = "gpt-4-0125-preview" + GPT_4_1106_PREVIEW = "gpt-4-1106-preview" + GPT_4_1106_VISION_PREVIEW = "gpt-4-1106-vision-preview" + GPT_3_5_TURBO_0125 = "gpt-3.5-turbo-0125" + GPT_3_5_TURBO_1106 = "gpt-3.5-turbo-1106" + + # -------------------------------------------------------------------------- + # Gemini Models + # -------------------------------------------------------------------------- + GEMINI_2_0_FLASH_EXP = "gemini-2.0-flash-exp" + GEMINI_2_0_FLASH_THINKING_EXP_1219 = "gemini-2.0-flash-thinking-exp-1219" + GEMINI_2_0_FLASH_THINKING_EXP_0121 = "gemini-2.0-flash-thinking-exp-01-21" + GEMINI_1_5_PRO_LATEST = "gemini-1.5-pro-latest" + GEMINI_1_5_PRO_002 = "gemini-1.5-pro-002" + GEMINI_1_5_PRO_001 = "gemini-1.5-pro-001" + GEMINI_1_5_PRO_PREVIEW_0514 = "gemini-1.5-pro-preview-0514" + GEMINI_1_5_PRO_PREVIEW_0409 = "gemini-1.5-pro-preview-0409" + GEMINI_1_5_FLASH_LATEST = "gemini-1.5-flash-latest" + GEMINI_1_5_FLASH_002 = "gemini-1.5-flash-002" + GEMINI_1_5_FLASH_001 = "gemini-1.5-flash-001" + GEMINI_1_5_FLASH_8B = "gemini-1.5-flash-8b" + GEMINI_1_5_FLASH_PREVIEW_0514 = "gemini-1.5-flash-preview-0514" + GEMINI_EXP_1206 = "gemini-exp-1206" + GEMINI_EXP_1121 = "gemini-exp-1121" + GEMINI_1_0_PRO_002 = "gemini-1.0-pro-002" + GEMINI_1_0_PRO_001 = "gemini-1.0-pro-001" + GEMINI_1_0_PRO_VISION_001 = "gemini-1.0-pro-vision-001" + + # -------------------------------------------------------------------------- + # Claude Models + # -------------------------------------------------------------------------- + CLAUDE_3_5_SONNET_LATEST = "claude-3-5-sonnet-latest" + CLAUDE_3_5_SONNET_20241022 = "claude-3-5-sonnet-20241022" + CLAUDE_3_5_SONNET_20240620 = "claude-3-5-sonnet-20240620" + CLAUDE_3_5_HAIKU_LATEST = "claude-3-5-haiku-latest" + CLAUDE_3_5_HAIKU_20241022 = "claude-3-5-haiku-20241022" + CLAUDE_3_OPUS_20240229 = "claude-3-opus-20240229" + CLAUDE_3_SONNET_20240229 = "claude-3-sonnet-20240229" + CLAUDE_3_HAIKU_20240307 = "claude-3-haiku-20240307" + + # -------------------------------------------------------------------------- + # Llama Models + # -------------------------------------------------------------------------- + LLAMA_3_3_70B = "llama-3.3-70b" + LLAMA_3_2_90B = "llama-3.2-90b" + LLAMA_3_2_11B = "llama-3.2-11b" + LLAMA_3_2_11B_VISION = "llama-3.2-11b-vision" + LLAMA_3_2_3B = "llama-3.2-3b" + LLAMA_3_2_1B = "llama-3.2-1b" + LLAMA_3_2_90B_VISION_PREVIEW = "llama-3.2-90b-vision-preview" + LLAMA_3_2_90B_TEXT_PREVIEW = "llama-3.2-90b-text-preview" + LLAMA_3_2_11B_TEXT_PREVIEW = "llama-3.2-11b-text-preview" + LLAMA_3_2_3B_PREVIEW = "llama-3.2-3b-preview" + LLAMA_3_2_1B_PREVIEW = "llama-3.2-1b-preview" + LLAMA_3_1_405B = "llama-3.1-405b" + LLAMA_3_1_70B = "llama-3.1-70b" + LLAMA_3_1_8B = "llama-3.1-8b" + LLAMA3_70B_8192 = "llama3-70b-8192" + LLAMA3_8B_8192 = "llama3-8b-8192" + + # -------------------------------------------------------------------------- + # Mistral AI Models + # -------------------------------------------------------------------------- + MIXTRAL_8X7B_32768 = "mixtral-8x7b-32768" + MISTRAL_LARGE_2_LATEST = "mistral-large-2-latest" + MISTRAL_LARGE_2_2407 = "mistral-large-2-2407" + MISTRAL_LARGE_LATEST = "mistral-large-latest" + MISTRAL_LARGE_2411 = "mistral-large-2411" + PIXTRAL_LARGE_LATEST = "pixtral-large-latest" + PIXTRAL_LARGE_2411 = "pixtral-large-2411" + PIXTRAL_12B_2409 = "pixtral-12b-2409" + MINISTRAL_3B_2410 = "ministral-3b-2410" + MINISTRAL_8B_2410 = "ministral-8b-2410" + MISTRAL_SMALL_2409 = "mistral-small-2409" + CODESTRAL_MAMBA_2407 = "codestral-mamba-2407" + + # -------------------------------------------------------------------------- + # Qwen Models + # -------------------------------------------------------------------------- + QWEN_QWQ_32B_PREVIEW = "qwen-v3p2-32b-instruct" + + # -------------------------------------------------------------------------- + # DeepSeek Models + # -------------------------------------------------------------------------- + DEEPSEEK_V3_2412 = "deepseek-v3-2412" + DEEPSEEK_R1_2501 = "deepseek-r1-2501" + + +ModelOrStr = Union[Model, str] diff --git a/workflowai/core/domain/run.py b/workflowai/core/domain/run.py index 08ce2fb..9affd7e 100644 --- a/workflowai/core/domain/run.py +++ b/workflowai/core/domain/run.py @@ -1,9 +1,14 @@ import uuid -from typing import Any, Generic, Optional +from collections.abc import Iterable +from typing import Any, Generic, Optional, Protocol from pydantic import BaseModel, Field # pyright: ignore [reportUnknownVariableType] +from typing_extensions import Unpack +from workflowai.core import _common_types +from workflowai.core.client import _types from workflowai.core.domain.task import AgentOutput +from workflowai.core.domain.tool_call import ToolCall, ToolCallRequest, ToolCallResult from workflowai.core.domain.version import Version @@ -33,3 +38,50 @@ class Run(BaseModel, Generic[AgentOutput]): ) metadata: Optional[dict[str, Any]] = None + + tool_calls: Optional[list[ToolCall]] = None + tool_call_requests: Optional[list[ToolCallRequest]] = None + + _agent: Optional["_AgentBase[AgentOutput]"] = None + + def __eq__(self, other: object) -> bool: + # Probably over simplistic but the object is not crazy complicated + # We just need a way to ignore the agent object + if not isinstance(other, Run): + return False + if self.__dict__ == other.__dict__: + return True + # Otherwise we check without the agent + for field, value in self.__dict__.items(): + if field == "_agent": + continue + if not value == other.__dict__.get(field): + return False + return True + + async def reply( + self, + user_response: Optional[str] = None, + tool_results: Optional[Iterable[ToolCallResult]] = None, + **kwargs: Unpack["_common_types.RunParams[AgentOutput]"], + ): + if not self._agent: + raise ValueError("Agent is not set") + return await self._agent.reply( + run_id=self.id, + user_response=user_response, + tool_results=tool_results, + **kwargs, + ) + + +class _AgentBase(Protocol, Generic[AgentOutput]): + async def reply( + self, + run_id: str, + user_response: Optional[str] = None, + tool_results: Optional[Iterable[ToolCallResult]] = None, + **kwargs: Unpack["_types.RunParams[AgentOutput]"], + ) -> "Run[AgentOutput]": + """Reply to a run. Either a user_response or tool_results must be provided.""" + ... diff --git a/workflowai/core/domain/run_test.py b/workflowai/core/domain/run_test.py new file mode 100644 index 0000000..31aa347 --- /dev/null +++ b/workflowai/core/domain/run_test.py @@ -0,0 +1,47 @@ +from unittest.mock import Mock + +import pytest +from pydantic import BaseModel + +from workflowai.core.domain.run import Run +from workflowai.core.domain.version import Version +from workflowai.core.domain.version_properties import VersionProperties + + +class _TestOutput(BaseModel): + message: str + + +@pytest.fixture +def run1() -> Run[_TestOutput]: + return Run[_TestOutput]( + id="test-id", + agent_id="agent-1", + schema_id=1, + output=_TestOutput(message="test output"), + duration_seconds=1.0, + cost_usd=0.1, + version=Version(properties=VersionProperties()), + metadata={"test": "data"}, + tool_calls=[], + tool_call_requests=[], + ) + + +@pytest.fixture +def run2(run1: Run[_TestOutput]) -> Run[_TestOutput]: + return run1.model_copy(deep=True) + + +class TestRunEquality: + def test_identical(self, run1: Run[_TestOutput], run2: Run[_TestOutput]): + assert run1 == run2 + + def test_different_output(self, run1: Run[_TestOutput], run2: Run[_TestOutput]): + run2.output.message = "different output" + assert run1 != run2 + + def test_different_agents(self, run1: Run[_TestOutput], run2: Run[_TestOutput]): + run2._agent = Mock() # pyright: ignore [reportPrivateUsage] + assert run1._agent != run2._agent, "sanity check" # pyright: ignore [reportPrivateUsage] + assert run1 == run2 diff --git a/workflowai/core/domain/tool.py b/workflowai/core/domain/tool.py new file mode 100644 index 0000000..13b02bc --- /dev/null +++ b/workflowai/core/domain/tool.py @@ -0,0 +1,11 @@ +from typing import Any + +from pydantic import BaseModel, Field + + +class Tool(BaseModel): + name: str = Field(description="The name of the tool") + description: str = Field(default="", description="The description of the tool") + + input_schema: dict[str, Any] = Field(description="The input class of the tool") + output_schema: dict[str, Any] = Field(description="The output class of the tool") diff --git a/workflowai/core/domain/tool_call.py b/workflowai/core/domain/tool_call.py new file mode 100644 index 0000000..0ca1565 --- /dev/null +++ b/workflowai/core/domain/tool_call.py @@ -0,0 +1,30 @@ +from typing import Any, Literal, Optional + +from pydantic import BaseModel + + +class ToolCall(BaseModel): + """A tool call that has already been executed, either by workflowai or by the user""" + + id: str + name: str + input_preview: str + output_preview: Optional[str] + error: Optional[str] + status: Optional[Literal["success", "failed", "in_progress"]] + + +class ToolCallRequest(BaseModel): + """A request to execute a tool call""" + + id: str + name: str + input: dict[str, Any] + + +class ToolCallResult(BaseModel): + """The output of a tool call""" + + id: str + output: Optional[Any] = None + error: Optional[str] = None diff --git a/workflowai/core/domain/version_properties.py b/workflowai/core/domain/version_properties.py index 6cb6570..8ba1736 100644 --- a/workflowai/core/domain/version_properties.py +++ b/workflowai/core/domain/version_properties.py @@ -1,8 +1,9 @@ -from typing import Optional +from typing import Optional, Union from pydantic import BaseModel, ConfigDict, Field -from workflowai.core.domain.model import Model +from workflowai.core.domain.model import ModelOrStr +from workflowai.core.domain.tool import Tool class VersionProperties(BaseModel): @@ -12,7 +13,7 @@ class VersionProperties(BaseModel): # Allow extra fields to support custom options model_config = ConfigDict(extra="allow") - model: Optional[Model] = Field( + model: Optional[ModelOrStr] = Field( default=None, description="The LLM model used for the run", ) @@ -42,3 +43,8 @@ class VersionProperties(BaseModel): default=None, description="The version of the runner used", ) + + enabled_tools: Optional[list[Union[str, Tool]]] = Field( + default=None, + description="The tools enabled for the run. A string can be used to refer to a tool hosted by WorkflowAI", + ) diff --git a/workflowai/core/utils/__init__.py b/workflowai/core/utils/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/workflowai/core/utils/_iter.py b/workflowai/core/utils/_iter.py new file mode 100644 index 0000000..b35f4c9 --- /dev/null +++ b/workflowai/core/utils/_iter.py @@ -0,0 +1,22 @@ +from collections.abc import Iterator +from typing import Callable, Iterable, Optional + +from workflowai.core._logger import logger +from workflowai.core.utils._vars import T, U + + +def safe_map(iterable: Iterable[T], func: Callable[[T], U]) -> Iterator[U]: + """Map 'iterable' with 'func' and return a list of results, ignoring any errors.""" + + for item in iterable: + try: + yield func(item) + except Exception as e: # noqa: PERF203, BLE001 + logger.exception(e) + + +def safe_map_list(iterable: Optional[Iterable[T]], func: Callable[[T], U]) -> Optional[list[U]]: + if not iterable: + return None + + return list(safe_map(iterable, func)) diff --git a/workflowai/core/utils/_iter_test.py b/workflowai/core/utils/_iter_test.py new file mode 100644 index 0000000..c2528c5 --- /dev/null +++ b/workflowai/core/utils/_iter_test.py @@ -0,0 +1,21 @@ +from workflowai.core.utils._iter import safe_map + + +class TestSafeMap: + def test_safe_map_success(self): + # Test normal mapping without errors + input_list = [1, 2, 3] + result = list(safe_map(input_list, lambda x: x * 2)) + assert result == [2, 4, 6] + + def test_safe_map_with_errors(self): + # Test mapping with some operations that will raise exceptions + def problematic_function(x: int) -> int: + if x == 2: + raise ValueError("Error for number 2") + return x * 2 + + input_list = [1, 2, 3] + result = list(safe_map(input_list, problematic_function)) + # Should skip the error for 2 and continue processing + assert result == [2, 6] diff --git a/workflowai/core/utils/_tools.py b/workflowai/core/utils/_tools.py new file mode 100644 index 0000000..e7be95d --- /dev/null +++ b/workflowai/core/utils/_tools.py @@ -0,0 +1,99 @@ +import inspect +from enum import Enum +from typing import Any, Callable, get_type_hints + +from pydantic import BaseModel + +ToolFunction = Callable[..., Any] + + +def tool_schema(func: ToolFunction): + """Creates JSON schemas for function input parameters and return type. + + Args: + func (Callable[[Any], Any]): a Python callable with annotated types + + Returns: + FunctionJsonSchema: a FunctionJsonSchema object containing the function input/output JSON schemas + """ + from workflowai.core.domain.tool import Tool + + sig = inspect.signature(func) + type_hints = get_type_hints(func, include_extras=True) + + input_schema = _build_input_schema(sig, type_hints) + output_schema = _build_output_schema(type_hints) + + tool_description = inspect.getdoc(func) + + return Tool( + name=func.__name__, + description=tool_description or "", + input_schema=input_schema, + output_schema=output_schema, + ) + + +def _get_type_schema(param_type: type) -> dict[str, Any]: + """Convert a Python type to its corresponding JSON schema type. + + Args: + param_type: The Python type to convert + + Returns: + A dictionary containing the JSON schema type definition + """ + if issubclass(param_type, Enum): + if not issubclass(param_type, str): + raise ValueError(f"Non string enums are not supported: {param_type}") + return {"type": "string", "enum": [e.value for e in param_type]} + + if param_type is str: + return {"type": "string"} + + if param_type in (int, float): + return {"type": "number"} + + if param_type is bool: + return {"type": "boolean"} + + if isinstance(param_type, BaseModel): + return param_type.model_json_schema() + + raise ValueError(f"Unsupported type: {param_type}") + + +def _build_input_schema(sig: inspect.Signature, type_hints: dict[str, Any]) -> dict[str, Any]: + input_schema: dict[str, Any] = {"type": "object", "properties": {}, "required": []} + + for param_name, param in sig.parameters.items(): + if param_name == "self": + continue + + param_type_hint = type_hints[param_name] + param_type = param_type_hint.__origin__ if hasattr(param_type_hint, "__origin__") else param_type_hint + param_description = param_type_hint.__metadata__[0] if hasattr(param_type_hint, "__metadata__") else None + + param_schema = _get_type_schema(param_type) if isinstance(param_type, type) else {"type": "string"} + if param_description is not None: + param_schema["description"] = param_description + + if param.default is inspect.Parameter.empty: + input_schema["required"].append(param_name) + + input_schema["properties"][param_name] = param_schema + + return input_schema + + +def _build_output_schema(type_hints: dict[str, Any]) -> dict[str, Any]: + return_type = type_hints.get("return") + if not return_type: + raise ValueError("Return type annotation is required") + + return_type_base = return_type.__origin__ if hasattr(return_type, "__origin__") else return_type + + if not isinstance(return_type_base, type): + raise ValueError(f"Unsupported return type: {return_type_base}") + + return _get_type_schema(return_type_base) diff --git a/workflowai/core/utils/_tools_test.py b/workflowai/core/utils/_tools_test.py new file mode 100644 index 0000000..6dc7628 --- /dev/null +++ b/workflowai/core/utils/_tools_test.py @@ -0,0 +1,72 @@ +from enum import Enum +from typing import Annotated + +from workflowai.core.utils._tools import tool_schema + + +class TestToolSchema: + def test_function_with_basic_types(self): + class TestMode(str, Enum): + FAST = "fast" + SLOW = "slow" + + def sample_func( + name: Annotated[str, "The name parameter"], + age: int, + height: float, + is_active: bool, + mode: TestMode = TestMode.FAST, + ) -> bool: + """Sample function for testing""" + ... + + schema = tool_schema(sample_func) + + assert schema.name == "sample_func" + assert schema.input_schema == { + "type": "object", + "properties": { + "name": { + "type": "string", + "description": "The name parameter", + }, + "age": { + "type": "number", + }, + "height": { + "type": "number", + }, + "is_active": { + "type": "boolean", + }, + "mode": { + "type": "string", + "enum": ["fast", "slow"], + }, + }, + "required": ["name", "age", "height", "is_active"], # 'mode' is not required + } + assert schema.output_schema == { + "type": "boolean", + } + assert schema.description == "Sample function for testing" + + def test_method_with_self(self): + class TestClass: + def sample_method(self, value: int) -> str: + return str(value) + + schema = tool_schema(TestClass.sample_method) + + assert schema.input_schema == { + "type": "object", + "properties": { + "value": { + "type": "number", + }, + }, + "required": ["value"], + } + assert schema.output_schema == { + "type": "string", + } diff --git a/workflowai/core/utils/_vars.py b/workflowai/core/utils/_vars.py new file mode 100644 index 0000000..0db5692 --- /dev/null +++ b/workflowai/core/utils/_vars.py @@ -0,0 +1,8 @@ +from typing import TypeVar + +from pydantic import BaseModel + +T = TypeVar("T") +U = TypeVar("U") + +BM = TypeVar("BM", bound=BaseModel)