From 0367b04f9b1a7a716853322c414d73f2590e1c4b Mon Sep 17 00:00:00 2001 From: Guillaume Aquilina Date: Wed, 29 Jan 2025 16:27:17 -0500 Subject: [PATCH 1/7] feat: add tool call models --- README.md | 2 +- workflowai/__init__.py | 5 +- workflowai/core/{logger.py => _logger.py} | 0 workflowai/core/client/_fn_utils.py | 6 +- workflowai/core/client/_models.py | 69 +++++- workflowai/core/client/_models_test.py | 11 + workflowai/core/client/_utils.py | 2 +- workflowai/core/domain/model.py | 225 ++++++++++--------- workflowai/core/domain/run.py | 4 + workflowai/core/domain/tool_call.py | 29 +++ workflowai/core/domain/version_properties.py | 4 +- workflowai/core/utils/__init__.py | 0 workflowai/core/utils/_iter.py | 22 ++ workflowai/core/utils/_iter_test.py | 21 ++ workflowai/core/utils/_vars.py | 4 + 15 files changed, 282 insertions(+), 122 deletions(-) rename workflowai/core/{logger.py => _logger.py} (100%) create mode 100644 workflowai/core/domain/tool_call.py create mode 100644 workflowai/core/utils/__init__.py create mode 100644 workflowai/core/utils/_iter.py create mode 100644 workflowai/core/utils/_iter_test.py create mode 100644 workflowai/core/utils/_vars.py diff --git a/README.md b/README.md index bd1217f..dc1a4d3 100644 --- a/README.md +++ b/README.md @@ -121,7 +121,7 @@ WorkflowAI supports a long list of models. The source of truth for models we sup You can set the model explicitly in the agent decorator: ```python -@workflowai.agent(model="gpt-4o") +@workflowai.agent(model=Model.GPT_4O_LATEST) def say_hello(input: Input) -> Output: ... ``` diff --git a/workflowai/__init__.py b/workflowai/__init__.py index eac4b5a..37e4255 100644 --- a/workflowai/__init__.py +++ b/workflowai/__init__.py @@ -5,6 +5,7 @@ from workflowai.core.client._types import AgentDecorator from workflowai.core.client.client import WorkflowAI as WorkflowAI +from workflowai.core.domain import model from workflowai.core.domain.cache_usage import CacheUsage as CacheUsage from workflowai.core.domain.errors import WorkflowAIError as WorkflowAIError from workflowai.core.domain.model import Model as Model @@ -31,7 +32,7 @@ def _build_client( shared_client: WorkflowAI = _build_client() # The default model to use when running agents without a deployment -DEFAULT_MODEL: Model = os.getenv("WORKFLOWAI_DEFAULT_MODEL", "gemini-1.5-pro-latest") +DEFAULT_MODEL: "model.ModelOrStr" = os.getenv("WORKFLOWAI_DEFAULT_MODEL", "gemini-1.5-pro-latest") def init(api_key: Optional[str] = None, url: Optional[str] = None, default_version: Optional[VersionReference] = None): @@ -66,7 +67,7 @@ def agent( id: Optional[str] = None, # noqa: A002 schema_id: Optional[int] = None, version: Optional[VersionReference] = None, - model: Optional[Model] = None, + model: Optional["model.ModelOrStr"] = None, ) -> AgentDecorator: from workflowai.core.client._fn_utils import agent_wrapper diff --git a/workflowai/core/logger.py b/workflowai/core/_logger.py similarity index 100% rename from workflowai/core/logger.py rename to workflowai/core/_logger.py diff --git a/workflowai/core/client/_fn_utils.py b/workflowai/core/client/_fn_utils.py index 57aee97..6204dca 100644 --- a/workflowai/core/client/_fn_utils.py +++ b/workflowai/core/client/_fn_utils.py @@ -26,7 +26,7 @@ RunTemplate, ) from workflowai.core.client.agent import Agent -from workflowai.core.domain.model import Model +from workflowai.core.domain.model import ModelOrStr from workflowai.core.domain.run import Run from workflowai.core.domain.task import AgentInput, AgentOutput from workflowai.core.domain.version_properties import VersionProperties @@ -128,7 +128,7 @@ def wrap_run_template( agent_id: str, schema_id: Optional[int], version: Optional[VersionReference], - model: Optional[Model], + model: Optional[ModelOrStr], fn: RunTemplate[AgentInput, AgentOutput], ) -> Union[ _RunnableAgent[AgentInput, AgentOutput], @@ -167,7 +167,7 @@ def agent_wrapper( schema_id: Optional[int] = None, agent_id: Optional[str] = None, version: Optional[VersionReference] = None, - model: Optional[Model] = None, + model: Optional[ModelOrStr] = None, ) -> AgentDecorator: def wrap(fn: RunTemplate[AgentInput, AgentOutput]) -> FinalRunTemplate[AgentInput, AgentOutput]: tid = agent_id or agent_id_from_fn_name(fn) diff --git a/workflowai/core/client/_models.py b/workflowai/core/client/_models.py index 37f44e9..1d9dea7 100644 --- a/workflowai/core/client/_models.py +++ b/workflowai/core/client/_models.py @@ -1,4 +1,4 @@ -from typing import Any, Optional, Union +from typing import Any, Literal, Optional, Union from pydantic import BaseModel, Field from typing_extensions import NotRequired, TypedDict @@ -7,8 +7,14 @@ from workflowai.core.domain.cache_usage import CacheUsage from workflowai.core.domain.run import Run from workflowai.core.domain.task import AgentOutput +from workflowai.core.domain.tool_call import ToolCall as DToolCall +from workflowai.core.domain.tool_call import ToolCallRequest as DToolCallRequest from workflowai.core.domain.version import Version as DVersion from workflowai.core.domain.version_properties import VersionProperties as DVersionProperties +from workflowai.core.utils._iter import safe_map_list + +# TODO: we should likely only use typed dicts here to avoid validation issues +# We have some typed dicts but pydantic also validates them class RunRequest(BaseModel): @@ -27,7 +33,6 @@ class RunRequest(BaseModel): stream: Optional[bool] = None -# Not using a base model to avoid validation class VersionProperties(TypedDict): model: NotRequired[Optional[str]] provider: NotRequired[Optional[str]] @@ -35,9 +40,55 @@ class VersionProperties(TypedDict): instructions: NotRequired[Optional[str]] +def version_properties_to_domain(properties: VersionProperties) -> DVersionProperties: + return DVersionProperties.model_construct( + None, + **properties, + ) + + class Version(BaseModel): properties: VersionProperties + def to_domain(self) -> DVersion: + return DVersion( + properties=version_properties_to_domain(self.properties), + ) + + +class ToolCall(TypedDict): + id: str + name: str + input_preview: str + output_preview: NotRequired[Optional[str]] + error: NotRequired[Optional[str]] + status: NotRequired[Optional[Literal["success", "failed", "in_progress"]]] + + +def tool_call_to_domain(tool_call: ToolCall) -> DToolCall: + return DToolCall( + id=tool_call["id"], + name=tool_call["name"], + input_preview=tool_call["input_preview"], + output_preview=tool_call.get("output_preview"), + error=tool_call.get("error"), + status=tool_call.get("status"), + ) + + +class ToolCallRequest(TypedDict): + id: str + name: str + input: dict[str, Any] + + +def tool_call_request_to_domain(tool_call_request: ToolCallRequest) -> DToolCallRequest: + return DToolCallRequest( + id=tool_call_request["id"], + name=tool_call_request["name"], + input=tool_call_request["input"], + ) + class RunResponse(BaseModel): id: str @@ -46,6 +97,10 @@ class RunResponse(BaseModel): version: Optional[Version] = None duration_seconds: Optional[float] = None cost_usd: Optional[float] = None + metadata: Optional[dict[str, Any]] = None + + tool_calls: Optional[list[ToolCall]] = None + tool_call_requests: Optional[list[ToolCallRequest]] = None def to_domain(self, task_id: str, task_schema_id: int, validator: OutputValidator[AgentOutput]) -> Run[AgentOutput]: return Run( @@ -53,15 +108,11 @@ def to_domain(self, task_id: str, task_schema_id: int, validator: OutputValidato agent_id=task_id, schema_id=task_schema_id, output=validator(self.task_output), - version=self.version - and DVersion( - properties=DVersionProperties.model_construct( - None, - **self.version.properties, - ), - ), + version=self.version and self.version.to_domain(), duration_seconds=self.duration_seconds, cost_usd=self.cost_usd, + tool_calls=safe_map_list(self.tool_calls, tool_call_to_domain), + tool_call_requests=safe_map_list(self.tool_call_requests, tool_call_request_to_domain), ) diff --git a/workflowai/core/client/_models_test.py b/workflowai/core/client/_models_test.py index 46b8c72..800366d 100644 --- a/workflowai/core/client/_models_test.py +++ b/workflowai/core/client/_models_test.py @@ -7,6 +7,7 @@ from workflowai.core.client._models import RunResponse from workflowai.core.client._utils import tolerant_validator from workflowai.core.domain.run import Run +from workflowai.core.domain.tool_call import ToolCallRequest @pytest.mark.parametrize( @@ -76,3 +77,13 @@ def test_with_version_validation_fails(self): ) with pytest.raises(ValidationError): chunk.to_domain(task_id="1", task_schema_id=1, validator=_TaskOutput.model_validate) + + def test_with_tool_calls(self): + chunk = RunResponse.model_validate_json( + '{"id": "1", "task_output": {}, "tool_call_requests": [{"id": "1", "name": "test", "input": {"a": 1}}]}', + ) + assert chunk + + parsed = chunk.to_domain(task_id="1", task_schema_id=1, validator=tolerant_validator(_TaskOutput)) + assert isinstance(parsed, Run) + assert parsed.tool_call_requests == [ToolCallRequest(id="1", name="test", input={"a": 1})] diff --git a/workflowai/core/client/_utils.py b/workflowai/core/client/_utils.py index 3d3d181..770014d 100644 --- a/workflowai/core/client/_utils.py +++ b/workflowai/core/client/_utils.py @@ -7,11 +7,11 @@ from json import JSONDecodeError from time import time +from workflowai.core._logger import logger from workflowai.core.client._types import OutputValidator from workflowai.core.domain.errors import BaseError, WorkflowAIError from workflowai.core.domain.task import AgentOutput from workflowai.core.domain.version_reference import VersionReference -from workflowai.core.logger import logger delimiter = re.compile(r'\}\n\ndata: \{"') diff --git a/workflowai/core/domain/model.py b/workflowai/core/domain/model.py index 56d8e7a..ea881e6 100644 --- a/workflowai/core/domain/model.py +++ b/workflowai/core/domain/model.py @@ -1,105 +1,122 @@ -from typing import Literal, Union +from enum import Enum +from typing import Union -Model = Union[ - Literal[ - # -------------------------------------------------------------------------- - # OpenAI Models - # -------------------------------------------------------------------------- - "gpt-4o-latest", - "gpt-4o-2024-11-20", - "gpt-4o-2024-08-06", - "gpt-4o-2024-05-13", - "gpt-4o-mini-latest", - "gpt-4o-mini-2024-07-18", - "o1-2024-12-17-high", - "o1-2024-12-17", - "o1-2024-12-17-low", - "o1-preview-2024-09-12", - "o1-mini-latest", - "o1-mini-2024-09-12", - "gpt-4o-audio-preview-2024-12-17", - "gpt-4o-audio-preview-2024-10-01", - "gpt-4-turbo-2024-04-09", - "gpt-4-0125-preview", - "gpt-4-1106-preview", - "gpt-4-1106-vision-preview", - "gpt-3.5-turbo-0125", - "gpt-3.5-turbo-1106", - # -------------------------------------------------------------------------- - # Gemini Models - # -------------------------------------------------------------------------- - "gemini-2.0-flash-exp", - "gemini-2.0-flash-thinking-exp-1219", - "gemini-1.5-pro-latest", - "gemini-1.5-pro-002", - "gemini-1.5-pro-001", - "gemini-1.5-pro-preview-0514", - "gemini-1.5-pro-preview-0409", - "gemini-1.5-flash-latest", - "gemini-1.5-flash-002", - "gemini-1.5-flash-001", - "gemini-1.5-flash-8b", - "gemini-1.5-flash-preview-0514", - "gemini-exp-1206", - "gemini-exp-1121", - "gemini-1.0-pro-002", - "gemini-1.0-pro-001", - "gemini-1.0-pro-vision-001", - # -------------------------------------------------------------------------- - # Claude Models - # -------------------------------------------------------------------------- - "claude-3-5-sonnet-latest", - "claude-3-5-sonnet-20241022", - "claude-3-5-sonnet-20240620", - "claude-3-5-haiku-latest", - "claude-3-5-haiku-20241022", - "claude-3-opus-20240229", - "claude-3-sonnet-20240229", - "claude-3-haiku-20240307", - # -------------------------------------------------------------------------- - # Llama Models - # -------------------------------------------------------------------------- - "llama-3.3-70b", - "llama-3.2-90b", - "llama-3.2-11b", - "llama-3.2-11b-vision", - "llama-3.2-3b", - "llama-3.2-1b", - "llama-3.2-90b-vision-preview", - "llama-3.2-90b-text-preview", - "llama-3.2-11b-text-preview", - "llama-3.2-3b-preview", - "llama-3.2-1b-preview", - "llama-3.1-405b", - "llama-3.1-70b", - "llama-3.1-8b", - "llama3-70b-8192", - "llama3-8b-8192", - # -------------------------------------------------------------------------- - # Mistral AI Models - # -------------------------------------------------------------------------- - "mixtral-8x7b-32768", - "mistral-large-2-latest", - "mistral-large-2-2407", - "mistral-large-latest", - "mistral-large-2411", - "pixtral-large-latest", - "pixtral-large-2411", - "pixtral-12b-2409", - "ministral-3b-2410", - "ministral-8b-2410", - "mistral-small-2409", - "codestral-mamba-2407", - # -------------------------------------------------------------------------- - # Qwen Models - # -------------------------------------------------------------------------- - "qwen-v3p2-32b-instruct", - # -------------------------------------------------------------------------- - # DeepSeek Models - # -------------------------------------------------------------------------- - "deepseek-v3-2412", - "deepseek-r1-2501", - ], - # Adding string to allow for any model not currently in the SDK but supported by the API - str, -] + +# All models that were supported at one point by WorkflowAI +# Some are deprecated an remapped. +# +# Notes: +# - DO NOT remove models from this list, only add new ones. If needed, add a replacement model +# in the model datas +# - the order is the same as will be displayed in the UI. If you need a model to be displayed +# higher, comment out the line where it should be in "natural" order, and add another one wherever +# needed for the order +class Model(str, Enum): + # -------------------------------------------------------------------------- + # OpenAI Models + # -------------------------------------------------------------------------- + + GPT_4O_LATEST = "gpt-4o-latest" + GPT_4O_2024_11_20 = "gpt-4o-2024-11-20" + GPT_4O_2024_08_06 = "gpt-4o-2024-08-06" + GPT_4O_2024_05_13 = "gpt-4o-2024-05-13" + GPT_4O_MINI_LATEST = "gpt-4o-mini-latest" + GPT_4O_MINI_2024_07_18 = "gpt-4o-mini-2024-07-18" + O1_2024_12_17_HIGH_REASONING_EFFORT = "o1-2024-12-17-high" + O1_2024_12_17_MEDIUM_REASONING_EFFORT = "o1-2024-12-17" + O1_2024_12_17_LOW_REASONING_EFFORT = "o1-2024-12-17-low" + O1_PREVIEW_2024_09_12 = "o1-preview-2024-09-12" + O1_MINI_LATEST = "o1-mini-latest" + O1_MINI_2024_09_12 = "o1-mini-2024-09-12" + GPT_4O_AUDIO_PREVIEW_2024_12_17 = "gpt-4o-audio-preview-2024-12-17" + GPT_40_AUDIO_PREVIEW_2024_10_01 = "gpt-4o-audio-preview-2024-10-01" + GPT_4_TURBO_2024_04_09 = "gpt-4-turbo-2024-04-09" + GPT_4_0125_PREVIEW = "gpt-4-0125-preview" + GPT_4_1106_PREVIEW = "gpt-4-1106-preview" + GPT_4_1106_VISION_PREVIEW = "gpt-4-1106-vision-preview" + GPT_3_5_TURBO_0125 = "gpt-3.5-turbo-0125" + GPT_3_5_TURBO_1106 = "gpt-3.5-turbo-1106" + + # -------------------------------------------------------------------------- + # Gemini Models + # -------------------------------------------------------------------------- + GEMINI_2_0_FLASH_EXP = "gemini-2.0-flash-exp" + GEMINI_2_0_FLASH_THINKING_EXP_1219 = "gemini-2.0-flash-thinking-exp-1219" + GEMINI_2_0_FLASH_THINKING_EXP_0121 = "gemini-2.0-flash-thinking-exp-01-21" + GEMINI_1_5_PRO_LATEST = "gemini-1.5-pro-latest" + GEMINI_1_5_PRO_002 = "gemini-1.5-pro-002" + GEMINI_1_5_PRO_001 = "gemini-1.5-pro-001" + GEMINI_1_5_PRO_PREVIEW_0514 = "gemini-1.5-pro-preview-0514" + GEMINI_1_5_PRO_PREVIEW_0409 = "gemini-1.5-pro-preview-0409" + GEMINI_1_5_FLASH_LATEST = "gemini-1.5-flash-latest" + GEMINI_1_5_FLASH_002 = "gemini-1.5-flash-002" + GEMINI_1_5_FLASH_001 = "gemini-1.5-flash-001" + GEMINI_1_5_FLASH_8B = "gemini-1.5-flash-8b" + GEMINI_1_5_FLASH_PREVIEW_0514 = "gemini-1.5-flash-preview-0514" + GEMINI_EXP_1206 = "gemini-exp-1206" + GEMINI_EXP_1121 = "gemini-exp-1121" + GEMINI_1_0_PRO_002 = "gemini-1.0-pro-002" + GEMINI_1_0_PRO_001 = "gemini-1.0-pro-001" + GEMINI_1_0_PRO_VISION_001 = "gemini-1.0-pro-vision-001" + + # -------------------------------------------------------------------------- + # Claude Models + # -------------------------------------------------------------------------- + CLAUDE_3_5_SONNET_LATEST = "claude-3-5-sonnet-latest" + CLAUDE_3_5_SONNET_20241022 = "claude-3-5-sonnet-20241022" + CLAUDE_3_5_SONNET_20240620 = "claude-3-5-sonnet-20240620" + CLAUDE_3_5_HAIKU_LATEST = "claude-3-5-haiku-latest" + CLAUDE_3_5_HAIKU_20241022 = "claude-3-5-haiku-20241022" + CLAUDE_3_OPUS_20240229 = "claude-3-opus-20240229" + CLAUDE_3_SONNET_20240229 = "claude-3-sonnet-20240229" + CLAUDE_3_HAIKU_20240307 = "claude-3-haiku-20240307" + + # -------------------------------------------------------------------------- + # Llama Models + # -------------------------------------------------------------------------- + LLAMA_3_3_70B = "llama-3.3-70b" + LLAMA_3_2_90B = "llama-3.2-90b" + LLAMA_3_2_11B = "llama-3.2-11b" + LLAMA_3_2_11B_VISION = "llama-3.2-11b-vision" + LLAMA_3_2_3B = "llama-3.2-3b" + LLAMA_3_2_1B = "llama-3.2-1b" + LLAMA_3_2_90B_VISION_PREVIEW = "llama-3.2-90b-vision-preview" + LLAMA_3_2_90B_TEXT_PREVIEW = "llama-3.2-90b-text-preview" + LLAMA_3_2_11B_TEXT_PREVIEW = "llama-3.2-11b-text-preview" + LLAMA_3_2_3B_PREVIEW = "llama-3.2-3b-preview" + LLAMA_3_2_1B_PREVIEW = "llama-3.2-1b-preview" + LLAMA_3_1_405B = "llama-3.1-405b" + LLAMA_3_1_70B = "llama-3.1-70b" + LLAMA_3_1_8B = "llama-3.1-8b" + LLAMA3_70B_8192 = "llama3-70b-8192" + LLAMA3_8B_8192 = "llama3-8b-8192" + + # -------------------------------------------------------------------------- + # Mistral AI Models + # -------------------------------------------------------------------------- + MIXTRAL_8X7B_32768 = "mixtral-8x7b-32768" + MISTRAL_LARGE_2_LATEST = "mistral-large-2-latest" + MISTRAL_LARGE_2_2407 = "mistral-large-2-2407" + MISTRAL_LARGE_LATEST = "mistral-large-latest" + MISTRAL_LARGE_2411 = "mistral-large-2411" + PIXTRAL_LARGE_LATEST = "pixtral-large-latest" + PIXTRAL_LARGE_2411 = "pixtral-large-2411" + PIXTRAL_12B_2409 = "pixtral-12b-2409" + MINISTRAL_3B_2410 = "ministral-3b-2410" + MINISTRAL_8B_2410 = "ministral-8b-2410" + MISTRAL_SMALL_2409 = "mistral-small-2409" + CODESTRAL_MAMBA_2407 = "codestral-mamba-2407" + + # -------------------------------------------------------------------------- + # Qwen Models + # -------------------------------------------------------------------------- + QWEN_QWQ_32B_PREVIEW = "qwen-v3p2-32b-instruct" + + # -------------------------------------------------------------------------- + # DeepSeek Models + # -------------------------------------------------------------------------- + DEEPSEEK_V3_2412 = "deepseek-v3-2412" + DEEPSEEK_R1_2501 = "deepseek-r1-2501" + + +ModelOrStr = Union[Model, str] diff --git a/workflowai/core/domain/run.py b/workflowai/core/domain/run.py index 08ce2fb..6046328 100644 --- a/workflowai/core/domain/run.py +++ b/workflowai/core/domain/run.py @@ -4,6 +4,7 @@ from pydantic import BaseModel, Field # pyright: ignore [reportUnknownVariableType] from workflowai.core.domain.task import AgentOutput +from workflowai.core.domain.tool_call import ToolCall, ToolCallRequest from workflowai.core.domain.version import Version @@ -33,3 +34,6 @@ class Run(BaseModel, Generic[AgentOutput]): ) metadata: Optional[dict[str, Any]] = None + + tool_calls: Optional[list[ToolCall]] = None + tool_call_requests: Optional[list[ToolCallRequest]] = None diff --git a/workflowai/core/domain/tool_call.py b/workflowai/core/domain/tool_call.py new file mode 100644 index 0000000..7e260ee --- /dev/null +++ b/workflowai/core/domain/tool_call.py @@ -0,0 +1,29 @@ +from typing import Any, Literal, Optional + +from pydantic import BaseModel + + +class ToolCall(BaseModel): + """A tool call that has already been executed, either by workflowai or by the user""" + + id: str + name: str + input_preview: str + output_preview: Optional[str] + error: Optional[str] + status: Optional[Literal["success", "failed", "in_progress"]] + + +class ToolCallRequest(BaseModel): + """A request to execute a tool call""" + + id: str + name: str + input: dict[str, Any] + + +class ToolCallOutput(BaseModel): + """The output of a tool call""" + + id: str + output: dict[str, Any] diff --git a/workflowai/core/domain/version_properties.py b/workflowai/core/domain/version_properties.py index 6cb6570..afebe8b 100644 --- a/workflowai/core/domain/version_properties.py +++ b/workflowai/core/domain/version_properties.py @@ -2,7 +2,7 @@ from pydantic import BaseModel, ConfigDict, Field -from workflowai.core.domain.model import Model +from workflowai.core.domain.model import ModelOrStr class VersionProperties(BaseModel): @@ -12,7 +12,7 @@ class VersionProperties(BaseModel): # Allow extra fields to support custom options model_config = ConfigDict(extra="allow") - model: Optional[Model] = Field( + model: Optional[ModelOrStr] = Field( default=None, description="The LLM model used for the run", ) diff --git a/workflowai/core/utils/__init__.py b/workflowai/core/utils/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/workflowai/core/utils/_iter.py b/workflowai/core/utils/_iter.py new file mode 100644 index 0000000..b35f4c9 --- /dev/null +++ b/workflowai/core/utils/_iter.py @@ -0,0 +1,22 @@ +from collections.abc import Iterator +from typing import Callable, Iterable, Optional + +from workflowai.core._logger import logger +from workflowai.core.utils._vars import T, U + + +def safe_map(iterable: Iterable[T], func: Callable[[T], U]) -> Iterator[U]: + """Map 'iterable' with 'func' and return a list of results, ignoring any errors.""" + + for item in iterable: + try: + yield func(item) + except Exception as e: # noqa: PERF203, BLE001 + logger.exception(e) + + +def safe_map_list(iterable: Optional[Iterable[T]], func: Callable[[T], U]) -> Optional[list[U]]: + if not iterable: + return None + + return list(safe_map(iterable, func)) diff --git a/workflowai/core/utils/_iter_test.py b/workflowai/core/utils/_iter_test.py new file mode 100644 index 0000000..c2528c5 --- /dev/null +++ b/workflowai/core/utils/_iter_test.py @@ -0,0 +1,21 @@ +from workflowai.core.utils._iter import safe_map + + +class TestSafeMap: + def test_safe_map_success(self): + # Test normal mapping without errors + input_list = [1, 2, 3] + result = list(safe_map(input_list, lambda x: x * 2)) + assert result == [2, 4, 6] + + def test_safe_map_with_errors(self): + # Test mapping with some operations that will raise exceptions + def problematic_function(x: int) -> int: + if x == 2: + raise ValueError("Error for number 2") + return x * 2 + + input_list = [1, 2, 3] + result = list(safe_map(input_list, problematic_function)) + # Should skip the error for 2 and continue processing + assert result == [2, 6] diff --git a/workflowai/core/utils/_vars.py b/workflowai/core/utils/_vars.py new file mode 100644 index 0000000..3fb8d7a --- /dev/null +++ b/workflowai/core/utils/_vars.py @@ -0,0 +1,4 @@ +from typing import TypeVar + +T = TypeVar("T") +U = TypeVar("U") From 35803ee9b26764eec2df3094cca179fdb1b5881b Mon Sep 17 00:00:00 2001 From: Guillaume Aquilina Date: Wed, 29 Jan 2025 17:00:07 -0500 Subject: [PATCH 2/7] feat: add basis for tool requests --- tests/e2e/tools_test.py | 39 ++++++++++++++++++++ workflowai/core/domain/tool.py | 11 ++++++ workflowai/core/domain/version_properties.py | 8 +++- 3 files changed, 57 insertions(+), 1 deletion(-) create mode 100644 tests/e2e/tools_test.py create mode 100644 workflowai/core/domain/tool.py diff --git a/tests/e2e/tools_test.py b/tests/e2e/tools_test.py new file mode 100644 index 0000000..7217d08 --- /dev/null +++ b/tests/e2e/tools_test.py @@ -0,0 +1,39 @@ +from pydantic import BaseModel + +from workflowai import Run, agent +from workflowai.core.domain.model import Model +from workflowai.core.domain.tool import Tool +from workflowai.core.domain.version_properties import VersionProperties + + +class AnswerQuestionInput(BaseModel): + question: str + + +class AnswerQuestionOutput(BaseModel): + answer: str = "" + + +_GET_CURRENT_TIME_TOOL = Tool( + name="get_current_time", + description="Get the current time", + input_schema={}, + output_schema={ + "properties": { + "time": {"type": "string", "description": "The current time"}, + }, + }, +) + + +@agent(id="answer-question") +async def answer_question(_: AnswerQuestionInput) -> Run[AnswerQuestionOutput]: ... + + +async def test_tools(): + run = await answer_question( + AnswerQuestionInput(question="What is the current time?"), + version=VersionProperties(model=Model.GPT_4O_LATEST, enabled_tools=[_GET_CURRENT_TIME_TOOL]), + ) + assert not run.output.answer + assert run.tool_call_requests diff --git a/workflowai/core/domain/tool.py b/workflowai/core/domain/tool.py new file mode 100644 index 0000000..13b02bc --- /dev/null +++ b/workflowai/core/domain/tool.py @@ -0,0 +1,11 @@ +from typing import Any + +from pydantic import BaseModel, Field + + +class Tool(BaseModel): + name: str = Field(description="The name of the tool") + description: str = Field(default="", description="The description of the tool") + + input_schema: dict[str, Any] = Field(description="The input class of the tool") + output_schema: dict[str, Any] = Field(description="The output class of the tool") diff --git a/workflowai/core/domain/version_properties.py b/workflowai/core/domain/version_properties.py index afebe8b..8ba1736 100644 --- a/workflowai/core/domain/version_properties.py +++ b/workflowai/core/domain/version_properties.py @@ -1,8 +1,9 @@ -from typing import Optional +from typing import Optional, Union from pydantic import BaseModel, ConfigDict, Field from workflowai.core.domain.model import ModelOrStr +from workflowai.core.domain.tool import Tool class VersionProperties(BaseModel): @@ -42,3 +43,8 @@ class VersionProperties(BaseModel): default=None, description="The version of the runner used", ) + + enabled_tools: Optional[list[Union[str, Tool]]] = Field( + default=None, + description="The tools enabled for the run. A string can be used to refer to a tool hosted by WorkflowAI", + ) From ffbf8ee41f9d3f9bc9d1595b328a8e385700617f Mon Sep 17 00:00:00 2001 From: Guillaume Aquilina Date: Wed, 29 Jan 2025 21:32:00 -0500 Subject: [PATCH 3/7] feat: add reply run function --- tests/e2e/tools_test.py | 17 ++++-- workflowai/core/_common_types.py | 31 ++++++++++ workflowai/core/client/_models.py | 28 ++++++++- workflowai/core/client/_models_test.py | 4 +- workflowai/core/client/_types.py | 23 +------- workflowai/core/client/_utils.py | 19 +++++- workflowai/core/client/agent.py | 82 +++++++++++++++++++++++--- workflowai/core/domain/run.py | 52 +++++++++++++++- workflowai/core/domain/run_test.py | 47 +++++++++++++++ workflowai/core/domain/tool_call.py | 5 +- workflowai/core/utils/_vars.py | 4 ++ 11 files changed, 268 insertions(+), 44 deletions(-) create mode 100644 workflowai/core/_common_types.py create mode 100644 workflowai/core/domain/run_test.py diff --git a/tests/e2e/tools_test.py b/tests/e2e/tools_test.py index 7217d08..7619eaf 100644 --- a/tests/e2e/tools_test.py +++ b/tests/e2e/tools_test.py @@ -3,6 +3,7 @@ from workflowai import Run, agent from workflowai.core.domain.model import Model from workflowai.core.domain.tool import Tool +from workflowai.core.domain.tool_call import ToolCallResult from workflowai.core.domain.version_properties import VersionProperties @@ -26,14 +27,20 @@ class AnswerQuestionOutput(BaseModel): ) -@agent(id="answer-question") +@agent( + id="answer-question", + version=VersionProperties(model=Model.GPT_4O_LATEST, enabled_tools=[_GET_CURRENT_TIME_TOOL]), +) async def answer_question(_: AnswerQuestionInput) -> Run[AnswerQuestionOutput]: ... async def test_tools(): - run = await answer_question( - AnswerQuestionInput(question="What is the current time?"), - version=VersionProperties(model=Model.GPT_4O_LATEST, enabled_tools=[_GET_CURRENT_TIME_TOOL]), - ) + run = await answer_question(AnswerQuestionInput(question="What is the current time spelled out in French?")) assert not run.output.answer + assert run.tool_call_requests + assert len(run.tool_call_requests) == 1 + assert run.tool_call_requests[0].name == "get_current_time" + + replied = await run.reply(tool_results=[ToolCallResult(id=run.tool_call_requests[0].id, output={"time": "12:00"})]) + assert replied.output.answer diff --git a/workflowai/core/_common_types.py b/workflowai/core/_common_types.py new file mode 100644 index 0000000..bf00d1c --- /dev/null +++ b/workflowai/core/_common_types.py @@ -0,0 +1,31 @@ +from typing import ( + Any, + Generic, + Optional, + Protocol, + TypeVar, +) + +from pydantic import BaseModel +from typing_extensions import NotRequired, TypedDict + +from workflowai.core.domain.cache_usage import CacheUsage +from workflowai.core.domain.task import AgentOutput +from workflowai.core.domain.version_reference import VersionReference + +AgentInputContra = TypeVar("AgentInputContra", bound=BaseModel, contravariant=True) +AgentOutputCov = TypeVar("AgentOutputCov", bound=BaseModel, covariant=True) + + +class OutputValidator(Protocol, Generic[AgentOutputCov]): + def __call__(self, data: dict[str, Any], has_tool_call_requests: bool) -> AgentOutputCov: ... + + +class RunParams(TypedDict, Generic[AgentOutput]): + version: NotRequired[Optional["VersionReference"]] + use_cache: NotRequired["CacheUsage"] + metadata: NotRequired[Optional[dict[str, Any]]] + labels: NotRequired[Optional[set[str]]] + max_retry_delay: NotRequired[float] + max_retry_count: NotRequired[float] + validator: NotRequired[OutputValidator["AgentOutput"]] diff --git a/workflowai/core/client/_models.py b/workflowai/core/client/_models.py index 1d9dea7..369b9fe 100644 --- a/workflowai/core/client/_models.py +++ b/workflowai/core/client/_models.py @@ -3,12 +3,13 @@ from pydantic import BaseModel, Field from typing_extensions import NotRequired, TypedDict -from workflowai.core.client._types import OutputValidator +from workflowai.core._common_types import OutputValidator from workflowai.core.domain.cache_usage import CacheUsage from workflowai.core.domain.run import Run from workflowai.core.domain.task import AgentOutput from workflowai.core.domain.tool_call import ToolCall as DToolCall from workflowai.core.domain.tool_call import ToolCallRequest as DToolCallRequest +from workflowai.core.domain.tool_call import ToolCallResult as DToolCallResult from workflowai.core.domain.version import Version as DVersion from workflowai.core.domain.version_properties import VersionProperties as DVersionProperties from workflowai.core.utils._iter import safe_map_list @@ -33,6 +34,29 @@ class RunRequest(BaseModel): stream: Optional[bool] = None +class ReplyRequest(BaseModel): + user_response: Optional[str] = None + version: Union[str, int, dict[str, Any]] + metadata: Optional[dict[str, Any]] = None + + class ToolResult(BaseModel): + id: str + output: Optional[Any] + error: Optional[str] + + @classmethod + def from_domain(cls, tool_result: DToolCallResult): + return cls( + id=tool_result.id, + output=tool_result.output, + error=tool_result.error, + ) + + tool_results: Optional[list[ToolResult]] = None + + stream: Optional[bool] = None + + class VersionProperties(TypedDict): model: NotRequired[Optional[str]] provider: NotRequired[Optional[str]] @@ -107,7 +131,7 @@ def to_domain(self, task_id: str, task_schema_id: int, validator: OutputValidato id=self.id, agent_id=task_id, schema_id=task_schema_id, - output=validator(self.task_output), + output=validator(self.task_output, self.tool_call_requests is not None), version=self.version and self.version.to_domain(), duration_seconds=self.duration_seconds, cost_usd=self.cost_usd, diff --git a/workflowai/core/client/_models_test.py b/workflowai/core/client/_models_test.py index 800366d..e12bdbf 100644 --- a/workflowai/core/client/_models_test.py +++ b/workflowai/core/client/_models_test.py @@ -5,7 +5,7 @@ from tests.utils import fixture_text from workflowai.core.client._models import RunResponse -from workflowai.core.client._utils import tolerant_validator +from workflowai.core.client._utils import intolerant_validator, tolerant_validator from workflowai.core.domain.run import Run from workflowai.core.domain.tool_call import ToolCallRequest @@ -76,7 +76,7 @@ def test_with_version_validation_fails(self): '{"id": "1", "task_output": {"a": 1}, "version": {"properties": {"a": 1, "b": "test"}}}', ) with pytest.raises(ValidationError): - chunk.to_domain(task_id="1", task_schema_id=1, validator=_TaskOutput.model_validate) + chunk.to_domain(task_id="1", task_schema_id=1, validator=intolerant_validator(_TaskOutput)) def test_with_tool_calls(self): chunk = RunResponse.model_validate_json( diff --git a/workflowai/core/client/_types.py b/workflowai/core/client/_types.py index 7dd3bdc..96cf75a 100644 --- a/workflowai/core/client/_types.py +++ b/workflowai/core/client/_types.py @@ -1,37 +1,18 @@ -from collections.abc import Callable from typing import ( Any, AsyncIterator, Generic, Optional, Protocol, - TypeVar, Union, overload, ) -from pydantic import BaseModel -from typing_extensions import NotRequired, TypedDict, Unpack +from typing_extensions import Unpack -from workflowai.core.domain.cache_usage import CacheUsage +from workflowai.core._common_types import AgentInputContra, AgentOutputCov, RunParams from workflowai.core.domain.run import Run from workflowai.core.domain.task import AgentInput, AgentOutput -from workflowai.core.domain.version_reference import VersionReference - -AgentInputContra = TypeVar("AgentInputContra", bound=BaseModel, contravariant=True) -AgentOutputCov = TypeVar("AgentOutputCov", bound=BaseModel, covariant=True) - -OutputValidator = Callable[[dict[str, Any]], AgentOutput] - - -class RunParams(TypedDict, Generic[AgentOutput]): - version: NotRequired[Optional[VersionReference]] - use_cache: NotRequired[CacheUsage] - metadata: NotRequired[Optional[dict[str, Any]]] - labels: NotRequired[Optional[set[str]]] - max_retry_delay: NotRequired[float] - max_retry_count: NotRequired[float] - validator: NotRequired[OutputValidator[AgentOutput]] class RunFn(Protocol, Generic[AgentInputContra, AgentOutput]): diff --git a/workflowai/core/client/_utils.py b/workflowai/core/client/_utils.py index 770014d..c846ef8 100644 --- a/workflowai/core/client/_utils.py +++ b/workflowai/core/client/_utils.py @@ -6,9 +6,10 @@ import re from json import JSONDecodeError from time import time +from typing import Any +from workflowai.core._common_types import OutputValidator from workflowai.core._logger import logger -from workflowai.core.client._types import OutputValidator from workflowai.core.domain.errors import BaseError, WorkflowAIError from workflowai.core.domain.task import AgentOutput from workflowai.core.domain.version_reference import VersionReference @@ -86,7 +87,21 @@ async def _wait_for_exception(e: WorkflowAIError): def tolerant_validator(m: type[AgentOutput]) -> OutputValidator[AgentOutput]: - return lambda payload: m.model_construct(None, **payload) + def _validator(data: dict[str, Any], has_tool_call_requests: bool) -> AgentOutput: # noqa: ARG001 + return m.model_construct(None, **data) + + return _validator + + +def intolerant_validator(m: type[AgentOutput]) -> OutputValidator[AgentOutput]: + def _validator(data: dict[str, Any], has_tool_call_requests: bool) -> AgentOutput: + # When we have tool call requests, the output can be empty + if has_tool_call_requests: + return m.model_construct(None, **data) + + return m.model_validate(data) + + return _validator def global_default_version_reference() -> VersionReference: diff --git a/workflowai/core/client/agent.py b/workflowai/core/client/agent.py index 9793163..c3c76ff 100644 --- a/workflowai/core/client/agent.py +++ b/workflowai/core/client/agent.py @@ -1,17 +1,31 @@ -from collections.abc import Awaitable, Callable +from collections.abc import Awaitable, Callable, Iterable from typing import Any, Generic, NamedTuple, Optional, Union from typing_extensions import Unpack +from workflowai.core._common_types import OutputValidator from workflowai.core.client._api import APIClient -from workflowai.core.client._models import CreateAgentRequest, CreateAgentResponse, RunRequest, RunResponse +from workflowai.core.client._models import ( + CreateAgentRequest, + CreateAgentResponse, + ReplyRequest, + RunRequest, + RunResponse, +) from workflowai.core.client._types import RunParams -from workflowai.core.client._utils import build_retryable_wait, global_default_version_reference, tolerant_validator +from workflowai.core.client._utils import ( + build_retryable_wait, + global_default_version_reference, + intolerant_validator, + tolerant_validator, +) from workflowai.core.domain.errors import BaseError, WorkflowAIError from workflowai.core.domain.run import Run from workflowai.core.domain.task import AgentInput, AgentOutput +from workflowai.core.domain.tool_call import ToolCallResult from workflowai.core.domain.version_properties import VersionProperties from workflowai.core.domain.version_reference import VersionReference +from workflowai.core.utils._vars import BM class Agent(Generic[AgentInput, AgentOutput]): @@ -35,8 +49,8 @@ def __init__( def api(self) -> APIClient: return self._api() - class _PreparedRun(NamedTuple): - request: RunRequest + class _PreparedRun(NamedTuple, Generic[BM]): + request: BM route: str should_retry: Callable[[], bool] wait_for_exception: Callable[[WorkflowAIError], Awaitable[None]] @@ -78,6 +92,35 @@ async def _prepare_run(self, task_input: AgentInput, stream: bool, **kwargs: Unp ) return self._PreparedRun(request, route, should_retry, wait_for_exception, schema_id) + async def _prepare_reply( + self, + run_id: str, + user_response: Optional[str], + tool_results: Optional[Iterable[ToolCallResult]], + stream: bool, + **kwargs: Unpack[RunParams[AgentOutput]], + ): + if not self.schema_id: + raise ValueError("schema_id is required") + version = self._sanitize_version(kwargs.get("version")) + + request = ReplyRequest( + user_response=user_response, + version=version, + stream=stream, + metadata=kwargs.get("metadata"), + tool_results=[ReplyRequest.ToolResult.from_domain(tool_result) for tool_result in tool_results] + if tool_results + else None, + ) + route = f"/v1/_/agents/{self.agent_id}/runs/{run_id}/reply" + should_retry, wait_for_exception = build_retryable_wait( + kwargs.get("max_retry_delay", 60), + kwargs.get("max_retry_count", 1), + ) + + return self._PreparedRun(request, route, should_retry, wait_for_exception, self.schema_id) + async def register(self): """Registers the agent and returns the schema id""" res = await self.api.post( @@ -92,6 +135,16 @@ async def register(self): self.schema_id = res.schema_id return res.schema_id + def _build_run( + self, + chunk: RunResponse, + schema_id: int, + validator: OutputValidator[AgentOutput], + ) -> Run[AgentOutput]: + run = chunk.to_domain(self.agent_id, schema_id, validator) + run._agent = self # pyright: ignore [reportPrivateUsage] + return run + async def run( self, task_input: AgentInput, @@ -122,13 +175,13 @@ async def run( or an async iterator of output objects """ prepared_run = await self._prepare_run(task_input, stream=False, **kwargs) - validator = kwargs.get("validator") or self.output_cls.model_validate + validator = kwargs.get("validator") or intolerant_validator(self.output_cls) last_error = None while prepared_run.should_retry(): try: res = await self.api.post(prepared_run.route, prepared_run.request, returns=RunResponse, run=True) - return res.to_domain(self.agent_id, prepared_run.schema_id, validator) + return self._build_run(res, prepared_run.schema_id, validator) except WorkflowAIError as e: # noqa: PERF203 last_error = e await prepared_run.wait_for_exception(e) @@ -176,7 +229,20 @@ async def stream( returns=RunResponse, run=True, ): - yield chunk.to_domain(self.agent_id, prepared_run.schema_id, validator) + yield self._build_run(chunk, prepared_run.schema_id, validator) return except WorkflowAIError as e: # noqa: PERF203 await prepared_run.wait_for_exception(e) + + async def reply( + self, + run_id: str, + user_response: Optional[str] = None, + tool_results: Optional[Iterable[ToolCallResult]] = None, + **kwargs: Unpack[RunParams[AgentOutput]], + ): + prepared_run = await self._prepare_reply(run_id, user_response, tool_results, stream=False, **kwargs) + validator = kwargs.get("validator") or intolerant_validator(self.output_cls) + + res = await self.api.post(prepared_run.route, prepared_run.request, returns=RunResponse, run=True) + return self._build_run(res, prepared_run.schema_id, validator) diff --git a/workflowai/core/domain/run.py b/workflowai/core/domain/run.py index 6046328..9affd7e 100644 --- a/workflowai/core/domain/run.py +++ b/workflowai/core/domain/run.py @@ -1,10 +1,14 @@ import uuid -from typing import Any, Generic, Optional +from collections.abc import Iterable +from typing import Any, Generic, Optional, Protocol from pydantic import BaseModel, Field # pyright: ignore [reportUnknownVariableType] +from typing_extensions import Unpack +from workflowai.core import _common_types +from workflowai.core.client import _types from workflowai.core.domain.task import AgentOutput -from workflowai.core.domain.tool_call import ToolCall, ToolCallRequest +from workflowai.core.domain.tool_call import ToolCall, ToolCallRequest, ToolCallResult from workflowai.core.domain.version import Version @@ -37,3 +41,47 @@ class Run(BaseModel, Generic[AgentOutput]): tool_calls: Optional[list[ToolCall]] = None tool_call_requests: Optional[list[ToolCallRequest]] = None + + _agent: Optional["_AgentBase[AgentOutput]"] = None + + def __eq__(self, other: object) -> bool: + # Probably over simplistic but the object is not crazy complicated + # We just need a way to ignore the agent object + if not isinstance(other, Run): + return False + if self.__dict__ == other.__dict__: + return True + # Otherwise we check without the agent + for field, value in self.__dict__.items(): + if field == "_agent": + continue + if not value == other.__dict__.get(field): + return False + return True + + async def reply( + self, + user_response: Optional[str] = None, + tool_results: Optional[Iterable[ToolCallResult]] = None, + **kwargs: Unpack["_common_types.RunParams[AgentOutput]"], + ): + if not self._agent: + raise ValueError("Agent is not set") + return await self._agent.reply( + run_id=self.id, + user_response=user_response, + tool_results=tool_results, + **kwargs, + ) + + +class _AgentBase(Protocol, Generic[AgentOutput]): + async def reply( + self, + run_id: str, + user_response: Optional[str] = None, + tool_results: Optional[Iterable[ToolCallResult]] = None, + **kwargs: Unpack["_types.RunParams[AgentOutput]"], + ) -> "Run[AgentOutput]": + """Reply to a run. Either a user_response or tool_results must be provided.""" + ... diff --git a/workflowai/core/domain/run_test.py b/workflowai/core/domain/run_test.py new file mode 100644 index 0000000..31aa347 --- /dev/null +++ b/workflowai/core/domain/run_test.py @@ -0,0 +1,47 @@ +from unittest.mock import Mock + +import pytest +from pydantic import BaseModel + +from workflowai.core.domain.run import Run +from workflowai.core.domain.version import Version +from workflowai.core.domain.version_properties import VersionProperties + + +class _TestOutput(BaseModel): + message: str + + +@pytest.fixture +def run1() -> Run[_TestOutput]: + return Run[_TestOutput]( + id="test-id", + agent_id="agent-1", + schema_id=1, + output=_TestOutput(message="test output"), + duration_seconds=1.0, + cost_usd=0.1, + version=Version(properties=VersionProperties()), + metadata={"test": "data"}, + tool_calls=[], + tool_call_requests=[], + ) + + +@pytest.fixture +def run2(run1: Run[_TestOutput]) -> Run[_TestOutput]: + return run1.model_copy(deep=True) + + +class TestRunEquality: + def test_identical(self, run1: Run[_TestOutput], run2: Run[_TestOutput]): + assert run1 == run2 + + def test_different_output(self, run1: Run[_TestOutput], run2: Run[_TestOutput]): + run2.output.message = "different output" + assert run1 != run2 + + def test_different_agents(self, run1: Run[_TestOutput], run2: Run[_TestOutput]): + run2._agent = Mock() # pyright: ignore [reportPrivateUsage] + assert run1._agent != run2._agent, "sanity check" # pyright: ignore [reportPrivateUsage] + assert run1 == run2 diff --git a/workflowai/core/domain/tool_call.py b/workflowai/core/domain/tool_call.py index 7e260ee..0ca1565 100644 --- a/workflowai/core/domain/tool_call.py +++ b/workflowai/core/domain/tool_call.py @@ -22,8 +22,9 @@ class ToolCallRequest(BaseModel): input: dict[str, Any] -class ToolCallOutput(BaseModel): +class ToolCallResult(BaseModel): """The output of a tool call""" id: str - output: dict[str, Any] + output: Optional[Any] = None + error: Optional[str] = None diff --git a/workflowai/core/utils/_vars.py b/workflowai/core/utils/_vars.py index 3fb8d7a..0db5692 100644 --- a/workflowai/core/utils/_vars.py +++ b/workflowai/core/utils/_vars.py @@ -1,4 +1,8 @@ from typing import TypeVar +from pydantic import BaseModel + T = TypeVar("T") U = TypeVar("U") + +BM = TypeVar("BM", bound=BaseModel) From ff6615e5cbc85c9aa15e4dd05f3388911f8d3cd9 Mon Sep 17 00:00:00 2001 From: Guillaume Aquilina Date: Wed, 29 Jan 2025 22:20:39 -0500 Subject: [PATCH 4/7] feat: first implementation of an auto tool call --- tests/e2e/tools_test.py | 52 +++++++---- workflowai/__init__.py | 5 +- workflowai/core/_common_types.py | 7 +- workflowai/core/client/_fn_utils.py | 7 +- workflowai/core/client/agent.py | 129 ++++++++++++++++++++++++--- workflowai/core/utils/_tools.py | 99 ++++++++++++++++++++ workflowai/core/utils/_tools_test.py | 72 +++++++++++++++ 7 files changed, 340 insertions(+), 31 deletions(-) create mode 100644 workflowai/core/utils/_tools.py create mode 100644 workflowai/core/utils/_tools_test.py diff --git a/tests/e2e/tools_test.py b/tests/e2e/tools_test.py index 7619eaf..82869a5 100644 --- a/tests/e2e/tools_test.py +++ b/tests/e2e/tools_test.py @@ -1,4 +1,8 @@ +from datetime import datetime +from typing import Annotated + from pydantic import BaseModel +from zoneinfo import ZoneInfo from workflowai import Run, agent from workflowai.core.domain.model import Model @@ -15,26 +19,24 @@ class AnswerQuestionOutput(BaseModel): answer: str = "" -_GET_CURRENT_TIME_TOOL = Tool( - name="get_current_time", - description="Get the current time", - input_schema={}, - output_schema={ - "properties": { - "time": {"type": "string", "description": "The current time"}, +async def test_manual_tool(): + get_current_time_tool = Tool( + name="get_current_time", + description="Get the current time", + input_schema={}, + output_schema={ + "properties": { + "time": {"type": "string", "description": "The current time"}, + }, }, - }, -) - + ) -@agent( - id="answer-question", - version=VersionProperties(model=Model.GPT_4O_LATEST, enabled_tools=[_GET_CURRENT_TIME_TOOL]), -) -async def answer_question(_: AnswerQuestionInput) -> Run[AnswerQuestionOutput]: ... + @agent( + id="answer-question", + version=VersionProperties(model=Model.GPT_4O_LATEST, enabled_tools=[get_current_time_tool]), + ) + async def answer_question(_: AnswerQuestionInput) -> Run[AnswerQuestionOutput]: ... - -async def test_tools(): run = await answer_question(AnswerQuestionInput(question="What is the current time spelled out in French?")) assert not run.output.answer @@ -44,3 +46,19 @@ async def test_tools(): replied = await run.reply(tool_results=[ToolCallResult(id=run.tool_call_requests[0].id, output={"time": "12:00"})]) assert replied.output.answer + + +async def test_auto_tool(): + def get_current_time(timezone: Annotated[str, "The timezone to get the current time in. e-g Europe/Paris"]) -> str: + """Return the current time in the given timezone in iso format""" + return datetime.now(ZoneInfo(timezone)).isoformat() + + @agent( + id="answer-question", + tools=[get_current_time], + version=VersionProperties(model=Model.GPT_4O_LATEST), + ) + async def answer_question(_: AnswerQuestionInput) -> Run[AnswerQuestionOutput]: ... + + run = await answer_question(AnswerQuestionInput(question="What is the current time in Paris?")) + assert run.output.answer diff --git a/workflowai/__init__.py b/workflowai/__init__.py index 37e4255..2906edb 100644 --- a/workflowai/__init__.py +++ b/workflowai/__init__.py @@ -1,5 +1,6 @@ import os -from typing import Optional +from collections.abc import Callable, Iterable +from typing import Any, Optional from typing_extensions import deprecated @@ -68,6 +69,7 @@ def agent( schema_id: Optional[int] = None, version: Optional[VersionReference] = None, model: Optional["model.ModelOrStr"] = None, + tools: Optional[Iterable[Callable[..., Any]]] = None, ) -> AgentDecorator: from workflowai.core.client._fn_utils import agent_wrapper @@ -77,4 +79,5 @@ def agent( agent_id=id, version=version, model=model, + tools=tools, ) diff --git a/workflowai/core/_common_types.py b/workflowai/core/_common_types.py index bf00d1c..af3f595 100644 --- a/workflowai/core/_common_types.py +++ b/workflowai/core/_common_types.py @@ -21,11 +21,16 @@ class OutputValidator(Protocol, Generic[AgentOutputCov]): def __call__(self, data: dict[str, Any], has_tool_call_requests: bool) -> AgentOutputCov: ... -class RunParams(TypedDict, Generic[AgentOutput]): +class BaseRunParams(TypedDict): version: NotRequired[Optional["VersionReference"]] use_cache: NotRequired["CacheUsage"] metadata: NotRequired[Optional[dict[str, Any]]] labels: NotRequired[Optional[set[str]]] max_retry_delay: NotRequired[float] max_retry_count: NotRequired[float] + + max_tool_iterations: NotRequired[int] # 10 by default + + +class RunParams(BaseRunParams, Generic[AgentOutput]): validator: NotRequired[OutputValidator["AgentOutput"]] diff --git a/workflowai/core/client/_fn_utils.py b/workflowai/core/client/_fn_utils.py index 6204dca..0f097dd 100644 --- a/workflowai/core/client/_fn_utils.py +++ b/workflowai/core/client/_fn_utils.py @@ -1,6 +1,6 @@ import functools import inspect -from collections.abc import Callable +from collections.abc import Callable, Iterable from typing import ( Any, AsyncIterator, @@ -130,6 +130,7 @@ def wrap_run_template( version: Optional[VersionReference], model: Optional[ModelOrStr], fn: RunTemplate[AgentInput, AgentOutput], + tools: Optional[Iterable[Callable[..., Any]]] = None, ) -> Union[ _RunnableAgent[AgentInput, AgentOutput], _RunnableOutputOnlyAgent[AgentInput, AgentOutput], @@ -155,6 +156,7 @@ def wrap_run_template( api=client, schema_id=schema_id, version=version, + tools=tools, ) @@ -168,10 +170,11 @@ def agent_wrapper( agent_id: Optional[str] = None, version: Optional[VersionReference] = None, model: Optional[ModelOrStr] = None, + tools: Optional[Iterable[Callable[..., Any]]] = None, ) -> AgentDecorator: def wrap(fn: RunTemplate[AgentInput, AgentOutput]) -> FinalRunTemplate[AgentInput, AgentOutput]: tid = agent_id or agent_id_from_fn_name(fn) - return functools.wraps(fn)(wrap_run_template(client, tid, schema_id, version, model, fn)) # pyright: ignore [reportReturnType] + return functools.wraps(fn)(wrap_run_template(client, tid, schema_id, version, model, fn, tools)) # pyright: ignore [reportReturnType] # pyright is unhappy with generics return wrap # pyright: ignore [reportReturnType] diff --git a/workflowai/core/client/agent.py b/workflowai/core/client/agent.py index c3c76ff..002beaa 100644 --- a/workflowai/core/client/agent.py +++ b/workflowai/core/client/agent.py @@ -1,9 +1,10 @@ +import asyncio from collections.abc import Awaitable, Callable, Iterable -from typing import Any, Generic, NamedTuple, Optional, Union +from typing import Any, Generic, NamedTuple, Optional, Union, cast from typing_extensions import Unpack -from workflowai.core._common_types import OutputValidator +from workflowai.core._common_types import BaseRunParams, OutputValidator from workflowai.core.client._api import APIClient from workflowai.core.client._models import ( CreateAgentRequest, @@ -22,9 +23,11 @@ from workflowai.core.domain.errors import BaseError, WorkflowAIError from workflowai.core.domain.run import Run from workflowai.core.domain.task import AgentInput, AgentOutput -from workflowai.core.domain.tool_call import ToolCallResult +from workflowai.core.domain.tool import Tool +from workflowai.core.domain.tool_call import ToolCallRequest, ToolCallResult from workflowai.core.domain.version_properties import VersionProperties from workflowai.core.domain.version_reference import VersionReference +from workflowai.core.utils._tools import tool_schema from workflowai.core.utils._vars import BM @@ -37,6 +40,7 @@ def __init__( api: Union[APIClient, Callable[[], APIClient]], schema_id: Optional[int] = None, version: Optional[VersionReference] = None, + tools: Optional[Iterable[Callable[..., Any]]] = None, ): self.agent_id = agent_id self.schema_id = schema_id @@ -44,6 +48,12 @@ def __init__( self.output_cls = output_cls self.version: VersionReference = version or global_default_version_reference() self._api = (lambda: api) if isinstance(api, APIClient) else api + self._tools = self.build_tools(tools) if tools else None + + @classmethod + def build_tools(cls, tools: Iterable[Callable[..., Any]]) -> dict[str, tuple[Tool, Callable[..., Any]]]: + # TODO: we should be more tolerant with errors ? + return {tool.__name__: (tool_schema(tool), tool) for tool in tools} @property def api(self) -> APIClient: @@ -67,6 +77,16 @@ def _sanitize_version(self, version: Optional[VersionReference]) -> Union[str, i import workflowai dumped["model"] = workflowai.DEFAULT_MODEL + if self._tools: + dumped["enabled_tools"] = [ + { + "name": tool.name, + "description": tool.description, + "input_schema": tool.input_schema, + "output_schema": tool.output_schema, + } + for tool, _ in self._tools.values() + ] return dumped async def _prepare_run(self, task_input: AgentInput, stream: bool, **kwargs: Unpack[RunParams[AgentOutput]]): @@ -135,14 +155,77 @@ async def register(self): self.schema_id = res.schema_id return res.schema_id - def _build_run( + @classmethod + async def _safe_execute_tool(cls, tool_call_request: ToolCallRequest, tool_func: Callable[..., Any]): + try: + output: Any = tool_func(**tool_call_request.input) + if isinstance(output, Awaitable): + output = await output + return ToolCallResult( + id=tool_call_request.id, + output=output, + ) + except Exception as e: # noqa: BLE001 + return ToolCallResult( + id=tool_call_request.id, + error=str(e), + ) + + async def _execute_tools( + self, + run_id: str, + tool_call_requests: Iterable[ToolCallRequest], + current_iteration: int, + **kwargs: Unpack[RunParams[AgentOutput]], + ): + if not self._tools: + return None + + executions: list[tuple[ToolCallRequest, Callable[..., Any]]] = [] + for tool_call_request in tool_call_requests: + if tool_call_request.name not in self._tools: + continue + + _, tool_func = self._tools[tool_call_request.name] + executions.append((tool_call_request, tool_func)) + + if not executions: + return None + + # Executing all tools in parallel + results = await asyncio.gather( + *[self._safe_execute_tool(tool_call_request, tool_func) for tool_call_request, tool_func in executions], + ) + return await self.reply( + run_id=run_id, + tool_results=results, + current_iteration=current_iteration + 1, + **kwargs, + ) + + async def _build_run( self, chunk: RunResponse, schema_id: int, validator: OutputValidator[AgentOutput], + current_iteration: int, + **kwargs: Unpack[BaseRunParams], ) -> Run[AgentOutput]: run = chunk.to_domain(self.agent_id, schema_id, validator) run._agent = self # pyright: ignore [reportPrivateUsage] + + if run.tool_call_requests: + with_reply = await self._execute_tools( + run_id=run.id, + tool_call_requests=run.tool_call_requests, + current_iteration=current_iteration, + validator=validator, + **kwargs, + ) + # Execute tools return None if there are actually no available tools to execute + if with_reply: + return with_reply + return run async def run( @@ -175,13 +258,21 @@ async def run( or an async iterator of output objects """ prepared_run = await self._prepare_run(task_input, stream=False, **kwargs) - validator = kwargs.get("validator") or intolerant_validator(self.output_cls) + validator, new_kwargs = self._sanitize_validator(kwargs, intolerant_validator(self.output_cls)) last_error = None while prepared_run.should_retry(): try: res = await self.api.post(prepared_run.route, prepared_run.request, returns=RunResponse, run=True) - return self._build_run(res, prepared_run.schema_id, validator) + return await self._build_run( + res, + prepared_run.schema_id, + validator, + current_iteration=0, + # TODO[test]: add test with custom validator + # We popped validator above + **new_kwargs, + ) except WorkflowAIError as e: # noqa: PERF203 last_error = e await prepared_run.wait_for_exception(e) @@ -218,7 +309,7 @@ async def stream( or an async iterator of output objects """ prepared_run = await self._prepare_run(task_input, stream=True, **kwargs) - validator = kwargs.get("validator") or tolerant_validator(self.output_cls) + validator, new_kwargs = self._sanitize_validator(kwargs, tolerant_validator(self.output_cls)) while prepared_run.should_retry(): try: @@ -229,7 +320,13 @@ async def stream( returns=RunResponse, run=True, ): - yield self._build_run(chunk, prepared_run.schema_id, validator) + yield await self._build_run( + chunk, + prepared_run.schema_id, + validator, + current_iteration=0, + **new_kwargs, + ) return except WorkflowAIError as e: # noqa: PERF203 await prepared_run.wait_for_exception(e) @@ -239,10 +336,22 @@ async def reply( run_id: str, user_response: Optional[str] = None, tool_results: Optional[Iterable[ToolCallResult]] = None, + current_iteration: int = 0, **kwargs: Unpack[RunParams[AgentOutput]], ): prepared_run = await self._prepare_reply(run_id, user_response, tool_results, stream=False, **kwargs) - validator = kwargs.get("validator") or intolerant_validator(self.output_cls) + validator, new_kwargs = self._sanitize_validator(kwargs, intolerant_validator(self.output_cls)) res = await self.api.post(prepared_run.route, prepared_run.request, returns=RunResponse, run=True) - return self._build_run(res, prepared_run.schema_id, validator) + return await self._build_run( + res, + prepared_run.schema_id, + validator, + current_iteration=current_iteration, + **new_kwargs, + ) + + @classmethod + def _sanitize_validator(cls, kwargs: RunParams[AgentOutput], default: OutputValidator[AgentOutput]): + validator = kwargs.pop("validator", default) + return validator, cast(BaseRunParams, kwargs) diff --git a/workflowai/core/utils/_tools.py b/workflowai/core/utils/_tools.py new file mode 100644 index 0000000..e7be95d --- /dev/null +++ b/workflowai/core/utils/_tools.py @@ -0,0 +1,99 @@ +import inspect +from enum import Enum +from typing import Any, Callable, get_type_hints + +from pydantic import BaseModel + +ToolFunction = Callable[..., Any] + + +def tool_schema(func: ToolFunction): + """Creates JSON schemas for function input parameters and return type. + + Args: + func (Callable[[Any], Any]): a Python callable with annotated types + + Returns: + FunctionJsonSchema: a FunctionJsonSchema object containing the function input/output JSON schemas + """ + from workflowai.core.domain.tool import Tool + + sig = inspect.signature(func) + type_hints = get_type_hints(func, include_extras=True) + + input_schema = _build_input_schema(sig, type_hints) + output_schema = _build_output_schema(type_hints) + + tool_description = inspect.getdoc(func) + + return Tool( + name=func.__name__, + description=tool_description or "", + input_schema=input_schema, + output_schema=output_schema, + ) + + +def _get_type_schema(param_type: type) -> dict[str, Any]: + """Convert a Python type to its corresponding JSON schema type. + + Args: + param_type: The Python type to convert + + Returns: + A dictionary containing the JSON schema type definition + """ + if issubclass(param_type, Enum): + if not issubclass(param_type, str): + raise ValueError(f"Non string enums are not supported: {param_type}") + return {"type": "string", "enum": [e.value for e in param_type]} + + if param_type is str: + return {"type": "string"} + + if param_type in (int, float): + return {"type": "number"} + + if param_type is bool: + return {"type": "boolean"} + + if isinstance(param_type, BaseModel): + return param_type.model_json_schema() + + raise ValueError(f"Unsupported type: {param_type}") + + +def _build_input_schema(sig: inspect.Signature, type_hints: dict[str, Any]) -> dict[str, Any]: + input_schema: dict[str, Any] = {"type": "object", "properties": {}, "required": []} + + for param_name, param in sig.parameters.items(): + if param_name == "self": + continue + + param_type_hint = type_hints[param_name] + param_type = param_type_hint.__origin__ if hasattr(param_type_hint, "__origin__") else param_type_hint + param_description = param_type_hint.__metadata__[0] if hasattr(param_type_hint, "__metadata__") else None + + param_schema = _get_type_schema(param_type) if isinstance(param_type, type) else {"type": "string"} + if param_description is not None: + param_schema["description"] = param_description + + if param.default is inspect.Parameter.empty: + input_schema["required"].append(param_name) + + input_schema["properties"][param_name] = param_schema + + return input_schema + + +def _build_output_schema(type_hints: dict[str, Any]) -> dict[str, Any]: + return_type = type_hints.get("return") + if not return_type: + raise ValueError("Return type annotation is required") + + return_type_base = return_type.__origin__ if hasattr(return_type, "__origin__") else return_type + + if not isinstance(return_type_base, type): + raise ValueError(f"Unsupported return type: {return_type_base}") + + return _get_type_schema(return_type_base) diff --git a/workflowai/core/utils/_tools_test.py b/workflowai/core/utils/_tools_test.py new file mode 100644 index 0000000..6dc7628 --- /dev/null +++ b/workflowai/core/utils/_tools_test.py @@ -0,0 +1,72 @@ +from enum import Enum +from typing import Annotated + +from workflowai.core.utils._tools import tool_schema + + +class TestToolSchema: + def test_function_with_basic_types(self): + class TestMode(str, Enum): + FAST = "fast" + SLOW = "slow" + + def sample_func( + name: Annotated[str, "The name parameter"], + age: int, + height: float, + is_active: bool, + mode: TestMode = TestMode.FAST, + ) -> bool: + """Sample function for testing""" + ... + + schema = tool_schema(sample_func) + + assert schema.name == "sample_func" + assert schema.input_schema == { + "type": "object", + "properties": { + "name": { + "type": "string", + "description": "The name parameter", + }, + "age": { + "type": "number", + }, + "height": { + "type": "number", + }, + "is_active": { + "type": "boolean", + }, + "mode": { + "type": "string", + "enum": ["fast", "slow"], + }, + }, + "required": ["name", "age", "height", "is_active"], # 'mode' is not required + } + assert schema.output_schema == { + "type": "boolean", + } + assert schema.description == "Sample function for testing" + + def test_method_with_self(self): + class TestClass: + def sample_method(self, value: int) -> str: + return str(value) + + schema = tool_schema(TestClass.sample_method) + + assert schema.input_schema == { + "type": "object", + "properties": { + "value": { + "type": "number", + }, + }, + "required": ["value"], + } + assert schema.output_schema == { + "type": "string", + } From 0f7f631daca4a430a9cebe9a6764f619c7c0ae8d Mon Sep 17 00:00:00 2001 From: Guillaume Aquilina Date: Wed, 29 Jan 2025 23:31:38 -0500 Subject: [PATCH 5/7] doc: add tools to readme --- README.md | 162 +++++++++++++++++++++++++++++++++++++++++++++++++++++- 1 file changed, 160 insertions(+), 2 deletions(-) diff --git a/README.md b/README.md index dc1a4d3..6f7671f 100644 --- a/README.md +++ b/README.md @@ -116,7 +116,7 @@ run will be created. By default: #### Using different models -WorkflowAI supports a long list of models. The source of truth for models we support is on [workflowai.com](https://workflowai.com). The [Model](./workflowai/core/domain/model.py) type is a good indication of what models are supported at the time of the sdk release, although it may be missing some models since new ones are added all the time. +WorkflowAI supports a long list of models. The source of truth for models we support is on [workflowai.com](https://workflowai.com). The [Model enum](./workflowai/core/domain/model.py) is a good indication of what models are supported at the time of the sdk release, although it may be missing some models since new ones are added all the time. You can set the model explicitly in the agent decorator: @@ -174,11 +174,20 @@ def say_hello(input: Input) -> AsyncIterator[Run[Output]]: ### Tools -WorkflowAI has a few tools that can be used to enhance the agent's capabilities: +Tools allow enhancing an agent's capabilities by allowing it to call external functions. + +#### WorkflowAI Hosted tools + +WorkflowAI hosts a few tools: - `@browser-text` allows fetching the content of a web page - `@search` allows performing a web search +Hosted tools tend to be faster because there is no back and forth between the client and the WorkflowAI API. Instead, +if a tool call is needed, the WorkflowAI API will call it within a single request. + +A single run will be created for all tool iterations. + To use a tool, simply add it's handles to the instructions (the function docstring): ```python @@ -190,6 +199,47 @@ def say_hello(input: Input) -> Output: ... ``` +#### Custom tools + +Custom tools allow using most functions within a single agent call. If an agent has custom tools, and the model +deems that tools are needed for a particular run, the agent will: + +- call all tools in parallel +- wait for all tools to complete +- reply to the run with the tool outputs +- continue with the next step of the run, and re-execute tools if needed +- ... +- until either no tool calls are requested, the max iteration (10 by default) or the agent has run to completion + +Tools are defined as regular python functions, and can be async or sync. Examples for tools are available in the [tools end 2 end test file](./tests/e2e/tools_test.py). + +> **Important**: It must be possible to determine the schema of a tool from the function signature. This means that +> the function must have type annotations and use standard types or `BaseModel` only for now. + +```python +# Annotations for parameters are passed as property descriptions in the tool schema +def get_current_time(timezone: Annotated[str, "The timezone to get the current time in. e-g Europe/Paris"]) -> str: + """Return the current time in the given timezone in iso format""" + return datetime.now(ZoneInfo(timezone)).isoformat() + +@agent( + id="answer-question", + tools=[get_current_time], + version=VersionProperties(model=Model.GPT_4O_LATEST), +) +async def answer_question(_: AnswerQuestionInput) -> Run[AnswerQuestionOutput]: ... + +run = await answer_question(AnswerQuestionInput(question="What is the current time in Paris?")) +assert run.output.answer +``` + +> It's important to understand that there are actually two runs in a single agent call: +> +> - the first run returns an empty output with a tool call request with a timezone +> - the second run returns the current time in the given timezone +> +> Only the last run is returned to the caller. + ### Error handling Agents can raise errors, for example when the underlying model fails to generate a response or when @@ -210,3 +260,111 @@ except WorkflowAIError as e: print(e.code) print(e.message) ``` + +### Definining input and output types + +There are some important subtleties when defining input and output types. + +#### Descriptions and examples + +Field description and examples are passed to the model and can help stir the output in the right direction. A good +use case is to describe a format or style for a string field + +```python +# summary has no examples or description so the model will likely return a block of text +class SummaryOutput(BaseModel): + summary: str + +# passing the description will help the model return a summary formatted as bullet points +class SummaryOutput(BaseModel): + summary: str = Field(description="A summary, formatted as bullet points") + +# passing examples can help as well +class SummaryOutput(BaseModel): + summary: str = Field(examples=["- Paris is a city in France\n- London is a city in England"]) +``` + +Some notes: + +- there are very little use cases for descriptions and examples in the **input** type. The model will most of the + infer from the value that is passed. An example use case is to use the description for fields that can be missing. +- adding examples that are too numerous or too specific can push the model to restrict the output value + +#### Required versus optional fields + +In short, we recommend using default values for most output fields. + +Pydantic is by default rather strict on model validation. If there is no default value, the field must be provided. +Although the fact that a field is required is passed to the model, the generation can sometimes omit null or empty +values. + +```python +class Input(BaseModel): + name: str + +class OutputStrict(BaseModel): + greeting: str + +@workflowai.agent() +async def say_hello_strict(_: Input) -> OutputStrict: + ... + +try: + run = await say_hello(Input(name="John")) + print(run.output.greeting) # "Hello, John!" +except WorkflowAIError as e: + print(e.code) # "invalid_generation" error code means that the generation did not match the schema + +class OutputTolerant(BaseModel): + greeting: str = "" + +@workflowai.agent() +async def say_hello_tolerant(_: Input) -> OutputTolerant: + ... + +# The invalid_generation is less likely +run = await say_hello_tolerant(Input(name="John")) +if not run.output.greeting: + print("No greeting was generated !") +print(run.output.greeting) # "Hello, John!" + +``` + +> WorkflowAI automatically retries invalid generations once. If a model outputs an object that does not match the +> schema, a new generation is triggered with the previous response and the error message as context. + +Another reason to prefer optional fields in the output is for streaming. Partial outputs are constructed using +`BaseModel.model_construct` when streaming. If a default value is not provided for a field, fields that are +absent will cause `AttributeError` when queried. + +```python +class Input(BaseModel): + name: str + +class OutputStrict(BaseModel): + greeting1: str + greeting2: str + +@workflowai.agent() +def say_hello_strict(_: Input) -> AsyncIterator[Output]: + ... + +async for run in say_hello(Input(name="John")): + try: + print(run.output.greeting1) + except AttributeError: + # run.output.greeting1 has not been generated yet + + +class OutputTolerant(BaseModel): + greeting1: str = "" + greeting2: str = "" + +@workflowai.agent() +def say_hello_tolerant(_: Input) -> AsyncIterator[OutputTolerant]: + ... + +async for run in say_hello(Input(name="John")): + print(run.output.greeting1) # will be empty if the model has not generated it yet + +``` From 890254600cb70a1e6f9ca7d04ccae9ede5c941a3 Mon Sep 17 00:00:00 2001 From: Guillaume Aquilina Date: Wed, 29 Jan 2025 23:39:28 -0500 Subject: [PATCH 6/7] fix: generic named tuple --- workflowai/core/client/agent.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/workflowai/core/client/agent.py b/workflowai/core/client/agent.py index 002beaa..cbe06f5 100644 --- a/workflowai/core/client/agent.py +++ b/workflowai/core/client/agent.py @@ -2,6 +2,7 @@ from collections.abc import Awaitable, Callable, Iterable from typing import Any, Generic, NamedTuple, Optional, Union, cast +from pydantic import BaseModel from typing_extensions import Unpack from workflowai.core._common_types import BaseRunParams, OutputValidator @@ -28,7 +29,6 @@ from workflowai.core.domain.version_properties import VersionProperties from workflowai.core.domain.version_reference import VersionReference from workflowai.core.utils._tools import tool_schema -from workflowai.core.utils._vars import BM class Agent(Generic[AgentInput, AgentOutput]): @@ -59,8 +59,9 @@ def build_tools(cls, tools: Iterable[Callable[..., Any]]) -> dict[str, tuple[Too def api(self) -> APIClient: return self._api() - class _PreparedRun(NamedTuple, Generic[BM]): - request: BM + class _PreparedRun(NamedTuple): + # would be nice to use a generic here, but python 3.9 does not support generic NamedTuple + request: BaseModel route: str should_retry: Callable[[], bool] wait_for_exception: Callable[[WorkflowAIError], Awaitable[None]] From 8bb0da5f8b66da9bd76d781bb478d8c8f64472e7 Mon Sep 17 00:00:00 2001 From: Guillaume Aquilina Date: Wed, 29 Jan 2025 23:39:59 -0500 Subject: [PATCH 7/7] chore: bump version dev3 --- pyproject.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index 1cce866..77ac7c5 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [tool.poetry] name = "workflowai" -version = "0.6.0.dev2" +version = "0.6.0.dev3" description = "" authors = ["Guillaume Aquilina "] readme = "README.md"