diff --git a/examples/city_to_capital_task.py b/examples/city_to_capital_task.py index e0045dd..e095e38 100644 --- a/examples/city_to_capital_task.py +++ b/examples/city_to_capital_task.py @@ -28,9 +28,9 @@ async def city_to_capital(task_input: CityToCapitalTaskInput) -> CityToCapitalTa def main(city: str) -> None: async def _inner() -> None: task_input = CityToCapitalTaskInput(city=city) - task_output = await city_to_capital(task_input) + output = await city_to_capital(task_input) - rprint(task_output) + rprint(output) aiorun(_inner()) diff --git a/pyproject.toml b/pyproject.toml index ffd8e82..1cce866 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [tool.poetry] name = "workflowai" -version = "0.6.0.dev1" +version = "0.6.0.dev2" description = "" authors = ["Guillaume Aquilina "] readme = "README.md" diff --git a/tests/e2e/no_schema_test.py b/tests/e2e/no_schema_test.py index 6449c86..dc09a62 100644 --- a/tests/e2e/no_schema_test.py +++ b/tests/e2e/no_schema_test.py @@ -14,7 +14,7 @@ class SummarizeTaskOutput(BaseModel): @workflowai.agent(id="summarize", model="gemini-1.5-flash-latest") -async def summarize(task_input: SummarizeTaskInput) -> SummarizeTaskOutput: ... +async def summarize(_: SummarizeTaskInput) -> SummarizeTaskOutput: ... async def test_summarize(): diff --git a/tests/e2e/run_test.py b/tests/e2e/run_test.py index 7c098d6..4dc8a5a 100644 --- a/tests/e2e/run_test.py +++ b/tests/e2e/run_test.py @@ -27,7 +27,7 @@ class ExtractProductReviewSentimentTaskOutput(BaseModel): @workflowai.agent(id="extract-product-review-sentiment", schema_id=1) def extract_product_review_sentiment( - task_input: ExtractProductReviewSentimentTaskInput, + _: ExtractProductReviewSentimentTaskInput, ) -> AsyncIterator[ExtractProductReviewSentimentTaskOutput]: ... @@ -52,7 +52,7 @@ async def test_run_task( ): task_input = ExtractProductReviewSentimentTaskInput(review_text="This product is amazing!") run = await extract_product_review_sentiment_agent.run(task_input=task_input, use_cache="never") - assert run.task_output.sentiment == Sentiment.POSITIVE + assert run.output.sentiment == Sentiment.POSITIVE async def test_stream_task( diff --git a/tests/integration/run_test.py b/tests/integration/run_test.py index 9e8bd05..6645225 100644 --- a/tests/integration/run_test.py +++ b/tests/integration/run_test.py @@ -73,9 +73,9 @@ async def city_to_capital(task_input: CityToCapitalTaskInput) -> CityToCapitalTa _mock_response(httpx_mock) task_input = CityToCapitalTaskInput(city="Hello") - task_output = await city_to_capital(task_input) + output = await city_to_capital(task_input) - assert task_output.capital == "Tokyo" + assert output.capital == "Tokyo" _check_request(httpx_mock.get_request()) @@ -90,7 +90,7 @@ async def city_to_capital(task_input: CityToCapitalTaskInput) -> Run[CityToCapit with_run = await city_to_capital(task_input) assert with_run.id == "123" - assert with_run.task_output.capital == "Tokyo" + assert with_run.output.capital == "Tokyo" _check_request(httpx_mock.get_request()) @@ -105,7 +105,7 @@ async def city_to_capital(task_input: CityToCapitalTaskInput) -> Run[CityToCapit with_run = await city_to_capital(task_input) assert with_run.id == "123" - assert with_run.task_output.capital == "Tokyo" + assert with_run.output.capital == "Tokyo" _check_request(httpx_mock.get_request(), version="staging") diff --git a/workflowai/__init__.py b/workflowai/__init__.py index adf66ec..eac4b5a 100644 --- a/workflowai/__init__.py +++ b/workflowai/__init__.py @@ -3,13 +3,13 @@ from typing_extensions import deprecated -from workflowai.core.client._types import TaskDecorator +from workflowai.core.client._types import AgentDecorator from workflowai.core.client.client import WorkflowAI as WorkflowAI from workflowai.core.domain.cache_usage import CacheUsage as CacheUsage from workflowai.core.domain.errors import WorkflowAIError as WorkflowAIError from workflowai.core.domain.model import Model as Model from workflowai.core.domain.run import Run as Run -from workflowai.core.domain.task_version import TaskVersion as TaskVersion +from workflowai.core.domain.version import Version as Version from workflowai.core.domain.version_reference import ( VersionReference as VersionReference, ) @@ -67,7 +67,7 @@ def agent( schema_id: Optional[int] = None, version: Optional[VersionReference] = None, model: Optional[Model] = None, -) -> TaskDecorator: +) -> AgentDecorator: from workflowai.core.client._fn_utils import agent_wrapper return agent_wrapper( diff --git a/workflowai/core/client/_fn_utils.py b/workflowai/core/client/_fn_utils.py index 13accda..57aee97 100644 --- a/workflowai/core/client/_fn_utils.py +++ b/workflowai/core/client/_fn_utils.py @@ -1,4 +1,5 @@ import functools +import inspect from collections.abc import Callable from typing import ( Any, @@ -19,15 +20,15 @@ from workflowai.core.client._api import APIClient from workflowai.core.client._types import ( + AgentDecorator, FinalRunTemplate, RunParams, RunTemplate, - TaskDecorator, ) from workflowai.core.client.agent import Agent from workflowai.core.domain.model import Model from workflowai.core.domain.run import Run -from workflowai.core.domain.task import TaskInput, TaskOutput +from workflowai.core.domain.task import AgentInput, AgentOutput from workflowai.core.domain.version_properties import VersionProperties from workflowai.core.domain.version_reference import VersionReference @@ -66,17 +67,28 @@ def is_async_iterator(t: type[Any]) -> bool: return issubclass(ori, AsyncIterator) -def extract_fn_spec(fn: RunTemplate[TaskInput, TaskOutput]) -> RunFunctionSpec: +def _first_arg_name(fn: Callable[..., Any]) -> Optional[str]: + sig = inspect.signature(fn) + for param in sig.parameters.values(): + if param.kind == param.POSITIONAL_OR_KEYWORD: + return param.name + return None + + +def extract_fn_spec(fn: RunTemplate[AgentInput, AgentOutput]) -> RunFunctionSpec: + first_arg_name = _first_arg_name(fn) + if not first_arg_name: + raise ValueError("Function must have a first positional argument") hints = get_type_hints(fn) if "return" not in hints: raise ValueError("Function must have a return type hint") - if "task_input" not in hints: - raise ValueError("Function must have a task_input parameter") + if first_arg_name not in hints: + raise ValueError("Function must have a first positional parameter") return_type_hint = hints["return"] - input_cls = hints["task_input"] + input_cls = hints[first_arg_name] if not issubclass(input_cls, BaseModel): - raise ValueError("task_input must be a subclass of BaseModel") + raise ValueError("First positional parameter must be a subclass of BaseModel") output_cls = None @@ -90,25 +102,25 @@ def extract_fn_spec(fn: RunTemplate[TaskInput, TaskOutput]) -> RunFunctionSpec: return RunFunctionSpec(stream, output_only, input_cls, output_cls) -class _RunnableAgent(Agent[TaskInput, TaskOutput], Generic[TaskInput, TaskOutput]): - async def __call__(self, task_input: TaskInput, **kwargs: Unpack[RunParams[TaskOutput]]): - return await self.run(task_input, **kwargs) +class _RunnableAgent(Agent[AgentInput, AgentOutput], Generic[AgentInput, AgentOutput]): + async def __call__(self, input: AgentInput, **kwargs: Unpack[RunParams[AgentOutput]]): # noqa: A002 + return await self.run(input, **kwargs) -class _RunnableOutputOnlyAgent(Agent[TaskInput, TaskOutput], Generic[TaskInput, TaskOutput]): - async def __call__(self, task_input: TaskInput, **kwargs: Unpack[RunParams[TaskOutput]]): - return (await self.run(task_input, **kwargs)).task_output +class _RunnableOutputOnlyAgent(Agent[AgentInput, AgentOutput], Generic[AgentInput, AgentOutput]): + async def __call__(self, input: AgentInput, **kwargs: Unpack[RunParams[AgentOutput]]): # noqa: A002 + return (await self.run(input, **kwargs)).output -class _RunnableStreamAgent(Agent[TaskInput, TaskOutput], Generic[TaskInput, TaskOutput]): - def __call__(self, task_input: TaskInput, **kwargs: Unpack[RunParams[TaskOutput]]): - return self.stream(task_input, **kwargs) +class _RunnableStreamAgent(Agent[AgentInput, AgentOutput], Generic[AgentInput, AgentOutput]): + def __call__(self, input: AgentInput, **kwargs: Unpack[RunParams[AgentOutput]]): # noqa: A002 + return self.stream(input, **kwargs) -class _RunnableStreamOutputOnlyAgent(Agent[TaskInput, TaskOutput], Generic[TaskInput, TaskOutput]): - async def __call__(self, task_input: TaskInput, **kwargs: Unpack[RunParams[TaskOutput]]): - async for chunk in self.stream(task_input, **kwargs): - yield chunk.task_output +class _RunnableStreamOutputOnlyAgent(Agent[AgentInput, AgentOutput], Generic[AgentInput, AgentOutput]): + async def __call__(self, input: AgentInput, **kwargs: Unpack[RunParams[AgentOutput]]): # noqa: A002 + async for chunk in self.stream(input, **kwargs): + yield chunk.output def wrap_run_template( @@ -117,12 +129,12 @@ def wrap_run_template( schema_id: Optional[int], version: Optional[VersionReference], model: Optional[Model], - fn: RunTemplate[TaskInput, TaskOutput], + fn: RunTemplate[AgentInput, AgentOutput], ) -> Union[ - _RunnableAgent[TaskInput, TaskOutput], - _RunnableOutputOnlyAgent[TaskInput, TaskOutput], - _RunnableStreamAgent[TaskInput, TaskOutput], - _RunnableStreamOutputOnlyAgent[TaskInput, TaskOutput], + _RunnableAgent[AgentInput, AgentOutput], + _RunnableOutputOnlyAgent[AgentInput, AgentOutput], + _RunnableStreamAgent[AgentInput, AgentOutput], + _RunnableStreamOutputOnlyAgent[AgentInput, AgentOutput], ]: stream, output_only, input_cls, output_cls = extract_fn_spec(fn) @@ -156,8 +168,8 @@ def agent_wrapper( agent_id: Optional[str] = None, version: Optional[VersionReference] = None, model: Optional[Model] = None, -) -> TaskDecorator: - def wrap(fn: RunTemplate[TaskInput, TaskOutput]) -> FinalRunTemplate[TaskInput, TaskOutput]: +) -> AgentDecorator: + def wrap(fn: RunTemplate[AgentInput, AgentOutput]) -> FinalRunTemplate[AgentInput, AgentOutput]: tid = agent_id or agent_id_from_fn_name(fn) return functools.wraps(fn)(wrap_run_template(client, tid, schema_id, version, model, fn)) # pyright: ignore [reportReturnType] diff --git a/workflowai/core/client/_fn_utils_test.py b/workflowai/core/client/_fn_utils_test.py index a5a7534..c5cd37f 100644 --- a/workflowai/core/client/_fn_utils_test.py +++ b/workflowai/core/client/_fn_utils_test.py @@ -21,16 +21,16 @@ from workflowai.core.domain.run import Run -async def say_hello(task_input: HelloTaskInput) -> HelloTaskOutput: ... +async def say_hello(_: HelloTaskInput) -> HelloTaskOutput: ... -async def say_hello_run(task_input: HelloTaskInput) -> Run[HelloTaskOutput]: ... +async def say_hello_run(bla: HelloTaskInput) -> Run[HelloTaskOutput]: ... -def stream_hello(task_input: HelloTaskInput) -> AsyncIterator[HelloTaskOutput]: ... +def stream_hello(_: HelloTaskInput) -> AsyncIterator[HelloTaskOutput]: ... -def stream_hello_run(task_input: HelloTaskInput) -> AsyncIterator[Run[HelloTaskOutput]]: ... +def stream_hello_run(_: HelloTaskInput) -> AsyncIterator[Run[HelloTaskOutput]]: ... class TestGetGenericArgs: @@ -76,7 +76,7 @@ async def test_fn_run(self, mock_api_client: Mock): run = await wrapped(HelloTaskInput(name="World")) assert isinstance(run, Run) assert run.id == "1" - assert run.task_output == HelloTaskOutput(message="Hello, World!") + assert run.output == HelloTaskOutput(message="Hello, World!") def fn_stream(self, task_input: HelloTaskInput) -> AsyncIterator[Run[HelloTaskOutput]]: ... @@ -89,7 +89,7 @@ async def test_fn_stream(self, mock_api_client: Mock): assert len(chunks) == 1 assert isinstance(chunks[0], Run) assert chunks[0].id == "1" - assert chunks[0].task_output == HelloTaskOutput(message="Hello, World!") + assert chunks[0].output == HelloTaskOutput(message="Hello, World!") async def fn_run_output_only(self, task_input: HelloTaskInput) -> HelloTaskOutput: ... diff --git a/workflowai/core/client/_models.py b/workflowai/core/client/_models.py index d179163..37f44e9 100644 --- a/workflowai/core/client/_models.py +++ b/workflowai/core/client/_models.py @@ -6,8 +6,8 @@ from workflowai.core.client._types import OutputValidator from workflowai.core.domain.cache_usage import CacheUsage from workflowai.core.domain.run import Run -from workflowai.core.domain.task import TaskOutput -from workflowai.core.domain.task_version import TaskVersion +from workflowai.core.domain.task import AgentOutput +from workflowai.core.domain.version import Version as DVersion from workflowai.core.domain.version_properties import VersionProperties as DVersionProperties @@ -47,14 +47,14 @@ class RunResponse(BaseModel): duration_seconds: Optional[float] = None cost_usd: Optional[float] = None - def to_domain(self, task_id: str, task_schema_id: int, validator: OutputValidator[TaskOutput]) -> Run[TaskOutput]: + def to_domain(self, task_id: str, task_schema_id: int, validator: OutputValidator[AgentOutput]) -> Run[AgentOutput]: return Run( id=self.id, - task_id=task_id, - task_schema_id=task_schema_id, - task_output=validator(self.task_output), + agent_id=task_id, + schema_id=task_schema_id, + output=validator(self.task_output), version=self.version - and TaskVersion( + and DVersion( properties=DVersionProperties.model_construct( None, **self.version.properties, diff --git a/workflowai/core/client/_models_test.py b/workflowai/core/client/_models_test.py index 4db04b5..46b8c72 100644 --- a/workflowai/core/client/_models_test.py +++ b/workflowai/core/client/_models_test.py @@ -42,10 +42,10 @@ def test_no_version_not_optional(self): parsed = chunk.to_domain(task_id="1", task_schema_id=1, validator=tolerant_validator(_TaskOutput)) assert isinstance(parsed, Run) - assert parsed.task_output.a == 1 + assert parsed.output.a == 1 # b is not defined with pytest.raises(AttributeError): - assert parsed.task_output.b + assert parsed.output.b def test_no_version_optional(self): chunk = RunResponse.model_validate_json('{"id": "1", "task_output": {"a": 1}}') @@ -53,8 +53,8 @@ def test_no_version_optional(self): parsed = chunk.to_domain(task_id="1", task_schema_id=1, validator=tolerant_validator(_TaskOutputOpt)) assert isinstance(parsed, Run) - assert parsed.task_output.a == 1 - assert parsed.task_output.b is None + assert parsed.output.a == 1 + assert parsed.output.b is None def test_with_version(self): chunk = RunResponse.model_validate_json( @@ -64,8 +64,8 @@ def test_with_version(self): parsed = chunk.to_domain(task_id="1", task_schema_id=1, validator=tolerant_validator(_TaskOutput)) assert isinstance(parsed, Run) - assert parsed.task_output.a == 1 - assert parsed.task_output.b == "test" + assert parsed.output.a == 1 + assert parsed.output.b == "test" assert parsed.cost_usd == 0.1 assert parsed.duration_seconds == 1 diff --git a/workflowai/core/client/_types.py b/workflowai/core/client/_types.py index 0bb07a8..7dd3bdc 100644 --- a/workflowai/core/client/_types.py +++ b/workflowai/core/client/_types.py @@ -15,52 +15,46 @@ from workflowai.core.domain.cache_usage import CacheUsage from workflowai.core.domain.run import Run -from workflowai.core.domain.task import TaskInput, TaskOutput +from workflowai.core.domain.task import AgentInput, AgentOutput from workflowai.core.domain.version_reference import VersionReference -TaskInputContra = TypeVar("TaskInputContra", bound=BaseModel, contravariant=True) -TaskOutputCov = TypeVar("TaskOutputCov", bound=BaseModel, covariant=True) +AgentInputContra = TypeVar("AgentInputContra", bound=BaseModel, contravariant=True) +AgentOutputCov = TypeVar("AgentOutputCov", bound=BaseModel, covariant=True) -OutputValidator = Callable[[dict[str, Any]], TaskOutput] +OutputValidator = Callable[[dict[str, Any]], AgentOutput] -class RunParams(TypedDict, Generic[TaskOutput]): +class RunParams(TypedDict, Generic[AgentOutput]): version: NotRequired[Optional[VersionReference]] use_cache: NotRequired[CacheUsage] metadata: NotRequired[Optional[dict[str, Any]]] labels: NotRequired[Optional[set[str]]] max_retry_delay: NotRequired[float] max_retry_count: NotRequired[float] - validator: NotRequired[OutputValidator[TaskOutput]] + validator: NotRequired[OutputValidator[AgentOutput]] -class RunFn(Protocol, Generic[TaskInputContra, TaskOutput]): - async def __call__(self, task_input: TaskInputContra) -> Run[TaskOutput]: ... +class RunFn(Protocol, Generic[AgentInputContra, AgentOutput]): + async def __call__(self, _: AgentInputContra, /) -> Run[AgentOutput]: ... -class RunFnOutputOnly(Protocol, Generic[TaskInputContra, TaskOutputCov]): - async def __call__(self, task_input: TaskInputContra) -> TaskOutputCov: ... +class RunFnOutputOnly(Protocol, Generic[AgentInputContra, AgentOutputCov]): + async def __call__(self, _: AgentInputContra, /) -> AgentOutputCov: ... -class StreamRunFn(Protocol, Generic[TaskInputContra, TaskOutput]): - def __call__( - self, - task_input: TaskInputContra, - ) -> AsyncIterator[Run[TaskOutput]]: ... +class StreamRunFn(Protocol, Generic[AgentInputContra, AgentOutput]): + def __call__(self, _: AgentInputContra, /) -> AsyncIterator[Run[AgentOutput]]: ... -class StreamRunFnOutputOnly(Protocol, Generic[TaskInputContra, TaskOutputCov]): - def __call__( - self, - task_input: TaskInputContra, - ) -> AsyncIterator[TaskOutputCov]: ... +class StreamRunFnOutputOnly(Protocol, Generic[AgentInputContra, AgentOutputCov]): + def __call__(self, _: AgentInputContra, /) -> AsyncIterator[AgentOutputCov]: ... RunTemplate = Union[ - RunFn[TaskInput, TaskOutput], - RunFnOutputOnly[TaskInput, TaskOutput], - StreamRunFn[TaskInput, TaskOutput], - StreamRunFnOutputOnly[TaskInput, TaskOutput], + RunFn[AgentInput, AgentOutput], + RunFnOutputOnly[AgentInput, AgentOutput], + StreamRunFn[AgentInput, AgentOutput], + StreamRunFnOutputOnly[AgentInput, AgentOutput], ] @@ -75,60 +69,67 @@ class _BaseProtocol(Protocol): __code__: Any -class FinalRunFn(_BaseProtocol, Protocol, Generic[TaskInputContra, TaskOutput]): +class FinalRunFn(_BaseProtocol, Protocol, Generic[AgentInputContra, AgentOutput]): async def __call__( self, - task_input: TaskInputContra, - **kwargs: Unpack[RunParams[TaskOutput]], - ) -> Run[TaskOutput]: ... + _: AgentInputContra, + /, + **kwargs: Unpack[RunParams[AgentOutput]], + ) -> Run[AgentOutput]: ... -class FinalRunFnOutputOnly(_BaseProtocol, Protocol, Generic[TaskInputContra, TaskOutput]): +class FinalRunFnOutputOnly(_BaseProtocol, Protocol, Generic[AgentInputContra, AgentOutput]): async def __call__( self, - task_input: TaskInputContra, - **kwargs: Unpack[RunParams[TaskOutput]], - ) -> TaskOutput: ... + _: AgentInputContra, + /, + **kwargs: Unpack[RunParams[AgentOutput]], + ) -> AgentOutput: ... -class FinalStreamRunFn(_BaseProtocol, Protocol, Generic[TaskInputContra, TaskOutput]): +class FinalStreamRunFn(_BaseProtocol, Protocol, Generic[AgentInputContra, AgentOutput]): def __call__( self, - task_input: TaskInputContra, - **kwargs: Unpack[RunParams[TaskOutput]], - ) -> AsyncIterator[Run[TaskOutput]]: ... + _: AgentInputContra, + /, + **kwargs: Unpack[RunParams[AgentOutput]], + ) -> AsyncIterator[Run[AgentOutput]]: ... -class FinalStreamRunFnOutputOnly(_BaseProtocol, Protocol, Generic[TaskInputContra, TaskOutputCov]): +class FinalStreamRunFnOutputOnly(_BaseProtocol, Protocol, Generic[AgentInputContra, AgentOutputCov]): def __call__( self, - task_input: TaskInputContra, - **kwargs: Unpack[RunParams[TaskOutput]], - ) -> AsyncIterator[TaskOutputCov]: ... + _: AgentInputContra, + /, + **kwargs: Unpack[RunParams[AgentOutput]], + ) -> AsyncIterator[AgentOutputCov]: ... FinalRunTemplate = Union[ - FinalRunFn[TaskInput, TaskOutput], - FinalRunFnOutputOnly[TaskInput, TaskOutput], - FinalStreamRunFn[TaskInput, TaskOutput], - FinalStreamRunFnOutputOnly[TaskInput, TaskOutput], + FinalRunFn[AgentInput, AgentOutput], + FinalRunFnOutputOnly[AgentInput, AgentOutput], + FinalStreamRunFn[AgentInput, AgentOutput], + FinalStreamRunFnOutputOnly[AgentInput, AgentOutput], ] -class TaskDecorator(Protocol): +class AgentDecorator(Protocol): @overload - def __call__(self, fn: RunFn[TaskInput, TaskOutput]) -> FinalRunFn[TaskInput, TaskOutput]: ... + def __call__(self, fn: RunFn[AgentInput, AgentOutput]) -> FinalRunFn[AgentInput, AgentOutput]: ... @overload - def __call__(self, fn: RunFnOutputOnly[TaskInput, TaskOutput]) -> FinalRunFnOutputOnly[TaskInput, TaskOutput]: ... + def __call__( + self, + fn: RunFnOutputOnly[AgentInput, AgentOutput], + ) -> FinalRunFnOutputOnly[AgentInput, AgentOutput]: ... @overload - def __call__(self, fn: StreamRunFn[TaskInput, TaskOutput]) -> FinalStreamRunFn[TaskInput, TaskOutput]: ... + def __call__(self, fn: StreamRunFn[AgentInput, AgentOutput]) -> FinalStreamRunFn[AgentInput, AgentOutput]: ... @overload def __call__( self, - fn: StreamRunFnOutputOnly[TaskInput, TaskOutput], - ) -> FinalStreamRunFnOutputOnly[TaskInput, TaskOutput]: ... + fn: StreamRunFnOutputOnly[AgentInput, AgentOutput], + ) -> FinalStreamRunFnOutputOnly[AgentInput, AgentOutput]: ... - def __call__(self, fn: RunTemplate[TaskInput, TaskOutput]) -> FinalRunTemplate[TaskInput, TaskOutput]: ... + def __call__(self, fn: RunTemplate[AgentInput, AgentOutput]) -> FinalRunTemplate[AgentInput, AgentOutput]: ... diff --git a/workflowai/core/client/_utils.py b/workflowai/core/client/_utils.py index 44dc6d2..3d3d181 100644 --- a/workflowai/core/client/_utils.py +++ b/workflowai/core/client/_utils.py @@ -9,7 +9,7 @@ from workflowai.core.client._types import OutputValidator from workflowai.core.domain.errors import BaseError, WorkflowAIError -from workflowai.core.domain.task import TaskOutput +from workflowai.core.domain.task import AgentOutput from workflowai.core.domain.version_reference import VersionReference from workflowai.core.logger import logger @@ -85,7 +85,7 @@ async def _wait_for_exception(e: WorkflowAIError): return _should_retry, _wait_for_exception -def tolerant_validator(m: type[TaskOutput]) -> OutputValidator[TaskOutput]: +def tolerant_validator(m: type[AgentOutput]) -> OutputValidator[AgentOutput]: return lambda payload: m.model_construct(None, **payload) diff --git a/workflowai/core/client/agent.py b/workflowai/core/client/agent.py index 4386526..9793163 100644 --- a/workflowai/core/client/agent.py +++ b/workflowai/core/client/agent.py @@ -9,17 +9,17 @@ from workflowai.core.client._utils import build_retryable_wait, global_default_version_reference, tolerant_validator from workflowai.core.domain.errors import BaseError, WorkflowAIError from workflowai.core.domain.run import Run -from workflowai.core.domain.task import TaskInput, TaskOutput +from workflowai.core.domain.task import AgentInput, AgentOutput from workflowai.core.domain.version_properties import VersionProperties from workflowai.core.domain.version_reference import VersionReference -class Agent(Generic[TaskInput, TaskOutput]): +class Agent(Generic[AgentInput, AgentOutput]): def __init__( self, agent_id: str, - input_cls: type[TaskInput], - output_cls: type[TaskOutput], + input_cls: type[AgentInput], + output_cls: type[AgentOutput], api: Union[APIClient, Callable[[], APIClient]], schema_id: Optional[int] = None, version: Optional[VersionReference] = None, @@ -55,7 +55,7 @@ def _sanitize_version(self, version: Optional[VersionReference]) -> Union[str, i dumped["model"] = workflowai.DEFAULT_MODEL return dumped - async def _prepare_run(self, task_input: TaskInput, stream: bool, **kwargs: Unpack[RunParams[TaskOutput]]): + async def _prepare_run(self, task_input: AgentInput, stream: bool, **kwargs: Unpack[RunParams[AgentOutput]]): schema_id = self.schema_id if not schema_id: schema_id = await self.register() @@ -94,13 +94,13 @@ async def register(self): async def run( self, - task_input: TaskInput, - **kwargs: Unpack[RunParams[TaskOutput]], - ) -> Run[TaskOutput]: + task_input: AgentInput, + **kwargs: Unpack[RunParams[AgentOutput]], + ) -> Run[AgentOutput]: """Run the agent Args: - task_input (TaskInput): the input to the task + task_input (AgentInput): the input to the task version (Optional[TaskVersionReference], optional): the version of the task to run. If not provided, the version defined in the task is used. Defaults to None. use_cache (CacheUsage, optional): how to use the cache. Defaults to "auto". @@ -118,7 +118,7 @@ async def run( max_retry_count (int, optional): The maximum number of retry attempts. Defaults to 1. Returns: - Union[TaskRun[TaskInput, TaskOutput], AsyncIterator[TaskOutput]]: the task run object + Union[TaskRun[AgentInput, AgentOutput], AsyncIterator[AgentOutput]]: the task run object or an async iterator of output objects """ prepared_run = await self._prepare_run(task_input, stream=False, **kwargs) @@ -137,13 +137,13 @@ async def run( async def stream( self, - task_input: TaskInput, - **kwargs: Unpack[RunParams[TaskOutput]], + task_input: AgentInput, + **kwargs: Unpack[RunParams[AgentOutput]], ): """Stream the output of the agent Args: - task_input (TaskInput): the input to the task + task_input (AgentInput): the input to the task version (Optional[TaskVersionReference], optional): the version of the task to run. If not provided, the version defined in the task is used. Defaults to None. use_cache (CacheUsage, optional): how to use the cache. Defaults to "auto". @@ -161,7 +161,7 @@ async def stream( max_retry_count (int, optional): The maximum number of retry attempts. Defaults to 1. Returns: - Union[TaskRun[TaskInput, TaskOutput], AsyncIterator[TaskOutput]]: the task run object + Union[TaskRun[AgentInput, AgentOutput], AsyncIterator[AgentOutput]]: the task run object or an async iterator of output objects """ prepared_run = await self._prepare_run(task_input, stream=True, **kwargs) diff --git a/workflowai/core/client/agent_test.py b/workflowai/core/client/agent_test.py index 228ba3a..ec4d881 100644 --- a/workflowai/core/client/agent_test.py +++ b/workflowai/core/client/agent_test.py @@ -58,8 +58,8 @@ async def test_success(self, httpx_mock: HTTPXMock, agent: Agent[HelloTaskInput, task_run = await agent.run(task_input=HelloTaskInput(name="Alice")) assert task_run.id == "8f635b73-f403-47ee-bff9-18320616c6cc" - assert task_run.task_id == "123" - assert task_run.task_schema_id == 1 + assert task_run.agent_id == "123" + assert task_run.schema_id == 1 reqs = httpx_mock.get_requests() assert len(reqs) == 1 @@ -85,7 +85,7 @@ async def test_stream(self, httpx_mock: HTTPXMock, agent: Agent[HelloTaskInput, chunks = [chunk async for chunk in agent.stream(task_input=HelloTaskInput(name="Alice"))] - outputs = [chunk.task_output for chunk in chunks] + outputs = [chunk.output for chunk in chunks] assert outputs == [ HelloTaskOutput(message=""), HelloTaskOutput(message="hel"), @@ -119,14 +119,14 @@ async def test_stream_not_optional( chunks = [chunk async for chunk in agent_not_optional.stream(task_input=HelloTaskInput(name="Alice"))] - messages = [chunk.task_output.message for chunk in chunks] + messages = [chunk.output.message for chunk in chunks] assert messages == ["", "hel", "hello", "hello"] for chunk in chunks[:-1]: with pytest.raises(AttributeError): # Since the field is not optional, it will raise an attribute error - assert chunk.task_output.another_field - assert chunks[-1].task_output.another_field == "test" + assert chunk.output.another_field + assert chunks[-1].output.another_field == "test" last_message = chunks[-1] assert isinstance(last_message, Run) diff --git a/workflowai/core/client/client_test.py b/workflowai/core/client/client_test.py index f268340..9f81c16 100644 --- a/workflowai/core/client/client_test.py +++ b/workflowai/core/client/client_test.py @@ -39,9 +39,9 @@ async def test_run_output_only(self, workflowai: WorkflowAI, mock_run_fn: Mock): async def fn(task_input: HelloTaskInput) -> HelloTaskOutput: ... mock_run_fn.return_value = Run( - task_output=HelloTaskOutput(message="hello"), - task_id="123", - task_schema_id=1, + output=HelloTaskOutput(message="hello"), + agent_id="123", + schema_id=1, ) output = await fn(HelloTaskInput(name="Alice")) @@ -54,16 +54,16 @@ async def fn(task_input: HelloTaskInput) -> Run[HelloTaskOutput]: ... mock_run_fn.return_value = Run( id="1", - task_output=HelloTaskOutput(message="hello"), - task_id="123", - task_schema_id=1, + output=HelloTaskOutput(message="hello"), + agent_id="123", + schema_id=1, ) - output = await fn(HelloTaskInput(name="Alice")) + run = await fn(HelloTaskInput(name="Alice")) - assert output.id == "1" - assert output.task_output == HelloTaskOutput(message="hello") - assert isinstance(output, Run) + assert run.id == "1" + assert run.output == HelloTaskOutput(message="hello") + assert isinstance(run, Run) async def test_stream(self, workflowai: WorkflowAI, httpx_mock: HTTPXMock): # We avoid mocking the run fn directly here, python does weird things with @@ -86,7 +86,7 @@ def fn(task_input: HelloTaskInput) -> AsyncIterator[Run[HelloTaskOutput]]: ... chunks = [chunk async for chunk in fn(HelloTaskInput(name="Alice"))] def _run(output: HelloTaskOutput, **kwargs: Any) -> Run[HelloTaskOutput]: - return Run(id="1", task_id="123", task_schema_id=1, task_output=output, **kwargs) + return Run(id="1", agent_id="123", schema_id=1, output=output, **kwargs) assert chunks == [ _run(HelloTaskOutput(message="")), diff --git a/workflowai/core/domain/run.py b/workflowai/core/domain/run.py index 9314f9a..08ce2fb 100644 --- a/workflowai/core/domain/run.py +++ b/workflowai/core/domain/run.py @@ -3,32 +3,32 @@ from pydantic import BaseModel, Field # pyright: ignore [reportUnknownVariableType] -from workflowai.core.domain.task import TaskOutput -from workflowai.core.domain.task_version import TaskVersion +from workflowai.core.domain.task import AgentOutput +from workflowai.core.domain.version import Version -class Run(BaseModel, Generic[TaskOutput]): +class Run(BaseModel, Generic[AgentOutput]): """ - A task run is an instance of a task with a specific input and output. + A run is an instance of a agent with a specific input and output. - This class represent a task run that already has been recorded and possibly + This class represent a run that already has been recorded and possibly been evaluated """ id: str = Field( default_factory=lambda: str(uuid.uuid4()), - description="The unique identifier of the task run. This is a UUIDv7.", + description="The unique identifier of the run. This is a UUIDv7.", ) - task_id: str - task_schema_id: int - task_output: TaskOutput + agent_id: str + schema_id: int + output: AgentOutput duration_seconds: Optional[float] = None cost_usd: Optional[float] = None - version: Optional[TaskVersion] = Field( + version: Optional[Version] = Field( default=None, - description="The version of the task that was run. Only provided if the version differs from the version" + description="The version of the agent that was run. Only provided if the version differs from the version" " specified in the request, for example in case of a model fallback", ) diff --git a/workflowai/core/domain/task.py b/workflowai/core/domain/task.py index 285cba9..ec792bd 100644 --- a/workflowai/core/domain/task.py +++ b/workflowai/core/domain/task.py @@ -2,5 +2,5 @@ from pydantic import BaseModel -TaskInput = TypeVar("TaskInput", bound=BaseModel) -TaskOutput = TypeVar("TaskOutput", bound=BaseModel) +AgentInput = TypeVar("AgentInput", bound=BaseModel) +AgentOutput = TypeVar("AgentOutput", bound=BaseModel) diff --git a/workflowai/core/domain/task_version.py b/workflowai/core/domain/version.py similarity index 90% rename from workflowai/core/domain/task_version.py rename to workflowai/core/domain/version.py index a930923..5002d1a 100644 --- a/workflowai/core/domain/task_version.py +++ b/workflowai/core/domain/version.py @@ -3,7 +3,7 @@ from workflowai.core.domain.version_properties import VersionProperties -class TaskVersion(BaseModel): +class Version(BaseModel): properties: VersionProperties = Field( default_factory=VersionProperties, description="The properties used for executing the run.",