From 39b47192933d524dda34f0727fe7d8314db1b7b6 Mon Sep 17 00:00:00 2001 From: Guillaume Aquilina Date: Thu, 23 Jan 2025 15:19:39 -0500 Subject: [PATCH 1/4] chore: rename task output -> output --- examples/city_to_capital_task.py | 4 +- tests/e2e/run_test.py | 2 +- tests/integration/run_test.py | 8 +-- workflowai/core/client/_fn_utils.py | 36 ++++++------- workflowai/core/client/_fn_utils_test.py | 4 +- workflowai/core/client/_models.py | 10 ++-- workflowai/core/client/_models_test.py | 12 ++--- workflowai/core/client/_types.py | 67 +++++++++++++----------- workflowai/core/client/_utils.py | 4 +- workflowai/core/client/agent.py | 28 +++++----- workflowai/core/client/agent_test.py | 12 ++--- workflowai/core/client/client_test.py | 22 ++++---- workflowai/core/domain/run.py | 18 +++---- workflowai/core/domain/task.py | 4 +- 14 files changed, 117 insertions(+), 114 deletions(-) diff --git a/examples/city_to_capital_task.py b/examples/city_to_capital_task.py index e0045dd..e095e38 100644 --- a/examples/city_to_capital_task.py +++ b/examples/city_to_capital_task.py @@ -28,9 +28,9 @@ async def city_to_capital(task_input: CityToCapitalTaskInput) -> CityToCapitalTa def main(city: str) -> None: async def _inner() -> None: task_input = CityToCapitalTaskInput(city=city) - task_output = await city_to_capital(task_input) + output = await city_to_capital(task_input) - rprint(task_output) + rprint(output) aiorun(_inner()) diff --git a/tests/e2e/run_test.py b/tests/e2e/run_test.py index 7c098d6..fbb2cdd 100644 --- a/tests/e2e/run_test.py +++ b/tests/e2e/run_test.py @@ -52,7 +52,7 @@ async def test_run_task( ): task_input = ExtractProductReviewSentimentTaskInput(review_text="This product is amazing!") run = await extract_product_review_sentiment_agent.run(task_input=task_input, use_cache="never") - assert run.task_output.sentiment == Sentiment.POSITIVE + assert run.output.sentiment == Sentiment.POSITIVE async def test_stream_task( diff --git a/tests/integration/run_test.py b/tests/integration/run_test.py index 9e8bd05..6645225 100644 --- a/tests/integration/run_test.py +++ b/tests/integration/run_test.py @@ -73,9 +73,9 @@ async def city_to_capital(task_input: CityToCapitalTaskInput) -> CityToCapitalTa _mock_response(httpx_mock) task_input = CityToCapitalTaskInput(city="Hello") - task_output = await city_to_capital(task_input) + output = await city_to_capital(task_input) - assert task_output.capital == "Tokyo" + assert output.capital == "Tokyo" _check_request(httpx_mock.get_request()) @@ -90,7 +90,7 @@ async def city_to_capital(task_input: CityToCapitalTaskInput) -> Run[CityToCapit with_run = await city_to_capital(task_input) assert with_run.id == "123" - assert with_run.task_output.capital == "Tokyo" + assert with_run.output.capital == "Tokyo" _check_request(httpx_mock.get_request()) @@ -105,7 +105,7 @@ async def city_to_capital(task_input: CityToCapitalTaskInput) -> Run[CityToCapit with_run = await city_to_capital(task_input) assert with_run.id == "123" - assert with_run.task_output.capital == "Tokyo" + assert with_run.output.capital == "Tokyo" _check_request(httpx_mock.get_request(), version="staging") diff --git a/workflowai/core/client/_fn_utils.py b/workflowai/core/client/_fn_utils.py index 13accda..b340ab8 100644 --- a/workflowai/core/client/_fn_utils.py +++ b/workflowai/core/client/_fn_utils.py @@ -27,7 +27,7 @@ from workflowai.core.client.agent import Agent from workflowai.core.domain.model import Model from workflowai.core.domain.run import Run -from workflowai.core.domain.task import TaskInput, TaskOutput +from workflowai.core.domain.task import AgentInput, AgentOutput from workflowai.core.domain.version_properties import VersionProperties from workflowai.core.domain.version_reference import VersionReference @@ -66,7 +66,7 @@ def is_async_iterator(t: type[Any]) -> bool: return issubclass(ori, AsyncIterator) -def extract_fn_spec(fn: RunTemplate[TaskInput, TaskOutput]) -> RunFunctionSpec: +def extract_fn_spec(fn: RunTemplate[AgentInput, AgentOutput]) -> RunFunctionSpec: hints = get_type_hints(fn) if "return" not in hints: raise ValueError("Function must have a return type hint") @@ -90,25 +90,25 @@ def extract_fn_spec(fn: RunTemplate[TaskInput, TaskOutput]) -> RunFunctionSpec: return RunFunctionSpec(stream, output_only, input_cls, output_cls) -class _RunnableAgent(Agent[TaskInput, TaskOutput], Generic[TaskInput, TaskOutput]): - async def __call__(self, task_input: TaskInput, **kwargs: Unpack[RunParams[TaskOutput]]): +class _RunnableAgent(Agent[AgentInput, AgentOutput], Generic[AgentInput, AgentOutput]): + async def __call__(self, task_input: AgentInput, **kwargs: Unpack[RunParams[AgentOutput]]): return await self.run(task_input, **kwargs) -class _RunnableOutputOnlyAgent(Agent[TaskInput, TaskOutput], Generic[TaskInput, TaskOutput]): - async def __call__(self, task_input: TaskInput, **kwargs: Unpack[RunParams[TaskOutput]]): - return (await self.run(task_input, **kwargs)).task_output +class _RunnableOutputOnlyAgent(Agent[AgentInput, AgentOutput], Generic[AgentInput, AgentOutput]): + async def __call__(self, task_input: AgentInput, **kwargs: Unpack[RunParams[AgentOutput]]): + return (await self.run(task_input, **kwargs)).output -class _RunnableStreamAgent(Agent[TaskInput, TaskOutput], Generic[TaskInput, TaskOutput]): - def __call__(self, task_input: TaskInput, **kwargs: Unpack[RunParams[TaskOutput]]): +class _RunnableStreamAgent(Agent[AgentInput, AgentOutput], Generic[AgentInput, AgentOutput]): + def __call__(self, task_input: AgentInput, **kwargs: Unpack[RunParams[AgentOutput]]): return self.stream(task_input, **kwargs) -class _RunnableStreamOutputOnlyAgent(Agent[TaskInput, TaskOutput], Generic[TaskInput, TaskOutput]): - async def __call__(self, task_input: TaskInput, **kwargs: Unpack[RunParams[TaskOutput]]): +class _RunnableStreamOutputOnlyAgent(Agent[AgentInput, AgentOutput], Generic[AgentInput, AgentOutput]): + async def __call__(self, task_input: AgentInput, **kwargs: Unpack[RunParams[AgentOutput]]): async for chunk in self.stream(task_input, **kwargs): - yield chunk.task_output + yield chunk.output def wrap_run_template( @@ -117,12 +117,12 @@ def wrap_run_template( schema_id: Optional[int], version: Optional[VersionReference], model: Optional[Model], - fn: RunTemplate[TaskInput, TaskOutput], + fn: RunTemplate[AgentInput, AgentOutput], ) -> Union[ - _RunnableAgent[TaskInput, TaskOutput], - _RunnableOutputOnlyAgent[TaskInput, TaskOutput], - _RunnableStreamAgent[TaskInput, TaskOutput], - _RunnableStreamOutputOnlyAgent[TaskInput, TaskOutput], + _RunnableAgent[AgentInput, AgentOutput], + _RunnableOutputOnlyAgent[AgentInput, AgentOutput], + _RunnableStreamAgent[AgentInput, AgentOutput], + _RunnableStreamOutputOnlyAgent[AgentInput, AgentOutput], ]: stream, output_only, input_cls, output_cls = extract_fn_spec(fn) @@ -157,7 +157,7 @@ def agent_wrapper( version: Optional[VersionReference] = None, model: Optional[Model] = None, ) -> TaskDecorator: - def wrap(fn: RunTemplate[TaskInput, TaskOutput]) -> FinalRunTemplate[TaskInput, TaskOutput]: + def wrap(fn: RunTemplate[AgentInput, AgentOutput]) -> FinalRunTemplate[AgentInput, AgentOutput]: tid = agent_id or agent_id_from_fn_name(fn) return functools.wraps(fn)(wrap_run_template(client, tid, schema_id, version, model, fn)) # pyright: ignore [reportReturnType] diff --git a/workflowai/core/client/_fn_utils_test.py b/workflowai/core/client/_fn_utils_test.py index a5a7534..0d955fd 100644 --- a/workflowai/core/client/_fn_utils_test.py +++ b/workflowai/core/client/_fn_utils_test.py @@ -76,7 +76,7 @@ async def test_fn_run(self, mock_api_client: Mock): run = await wrapped(HelloTaskInput(name="World")) assert isinstance(run, Run) assert run.id == "1" - assert run.task_output == HelloTaskOutput(message="Hello, World!") + assert run.output == HelloTaskOutput(message="Hello, World!") def fn_stream(self, task_input: HelloTaskInput) -> AsyncIterator[Run[HelloTaskOutput]]: ... @@ -89,7 +89,7 @@ async def test_fn_stream(self, mock_api_client: Mock): assert len(chunks) == 1 assert isinstance(chunks[0], Run) assert chunks[0].id == "1" - assert chunks[0].task_output == HelloTaskOutput(message="Hello, World!") + assert chunks[0].output == HelloTaskOutput(message="Hello, World!") async def fn_run_output_only(self, task_input: HelloTaskInput) -> HelloTaskOutput: ... diff --git a/workflowai/core/client/_models.py b/workflowai/core/client/_models.py index d179163..fac6ca0 100644 --- a/workflowai/core/client/_models.py +++ b/workflowai/core/client/_models.py @@ -6,7 +6,7 @@ from workflowai.core.client._types import OutputValidator from workflowai.core.domain.cache_usage import CacheUsage from workflowai.core.domain.run import Run -from workflowai.core.domain.task import TaskOutput +from workflowai.core.domain.task import AgentOutput from workflowai.core.domain.task_version import TaskVersion from workflowai.core.domain.version_properties import VersionProperties as DVersionProperties @@ -47,12 +47,12 @@ class RunResponse(BaseModel): duration_seconds: Optional[float] = None cost_usd: Optional[float] = None - def to_domain(self, task_id: str, task_schema_id: int, validator: OutputValidator[TaskOutput]) -> Run[TaskOutput]: + def to_domain(self, task_id: str, task_schema_id: int, validator: OutputValidator[AgentOutput]) -> Run[AgentOutput]: return Run( id=self.id, - task_id=task_id, - task_schema_id=task_schema_id, - task_output=validator(self.task_output), + agent_id=task_id, + schema_id=task_schema_id, + output=validator(self.task_output), version=self.version and TaskVersion( properties=DVersionProperties.model_construct( diff --git a/workflowai/core/client/_models_test.py b/workflowai/core/client/_models_test.py index 4db04b5..46b8c72 100644 --- a/workflowai/core/client/_models_test.py +++ b/workflowai/core/client/_models_test.py @@ -42,10 +42,10 @@ def test_no_version_not_optional(self): parsed = chunk.to_domain(task_id="1", task_schema_id=1, validator=tolerant_validator(_TaskOutput)) assert isinstance(parsed, Run) - assert parsed.task_output.a == 1 + assert parsed.output.a == 1 # b is not defined with pytest.raises(AttributeError): - assert parsed.task_output.b + assert parsed.output.b def test_no_version_optional(self): chunk = RunResponse.model_validate_json('{"id": "1", "task_output": {"a": 1}}') @@ -53,8 +53,8 @@ def test_no_version_optional(self): parsed = chunk.to_domain(task_id="1", task_schema_id=1, validator=tolerant_validator(_TaskOutputOpt)) assert isinstance(parsed, Run) - assert parsed.task_output.a == 1 - assert parsed.task_output.b is None + assert parsed.output.a == 1 + assert parsed.output.b is None def test_with_version(self): chunk = RunResponse.model_validate_json( @@ -64,8 +64,8 @@ def test_with_version(self): parsed = chunk.to_domain(task_id="1", task_schema_id=1, validator=tolerant_validator(_TaskOutput)) assert isinstance(parsed, Run) - assert parsed.task_output.a == 1 - assert parsed.task_output.b == "test" + assert parsed.output.a == 1 + assert parsed.output.b == "test" assert parsed.cost_usd == 0.1 assert parsed.duration_seconds == 1 diff --git a/workflowai/core/client/_types.py b/workflowai/core/client/_types.py index 0bb07a8..fddef7e 100644 --- a/workflowai/core/client/_types.py +++ b/workflowai/core/client/_types.py @@ -15,38 +15,38 @@ from workflowai.core.domain.cache_usage import CacheUsage from workflowai.core.domain.run import Run -from workflowai.core.domain.task import TaskInput, TaskOutput +from workflowai.core.domain.task import AgentInput, AgentOutput from workflowai.core.domain.version_reference import VersionReference TaskInputContra = TypeVar("TaskInputContra", bound=BaseModel, contravariant=True) TaskOutputCov = TypeVar("TaskOutputCov", bound=BaseModel, covariant=True) -OutputValidator = Callable[[dict[str, Any]], TaskOutput] +OutputValidator = Callable[[dict[str, Any]], AgentOutput] -class RunParams(TypedDict, Generic[TaskOutput]): +class RunParams(TypedDict, Generic[AgentOutput]): version: NotRequired[Optional[VersionReference]] use_cache: NotRequired[CacheUsage] metadata: NotRequired[Optional[dict[str, Any]]] labels: NotRequired[Optional[set[str]]] max_retry_delay: NotRequired[float] max_retry_count: NotRequired[float] - validator: NotRequired[OutputValidator[TaskOutput]] + validator: NotRequired[OutputValidator[AgentOutput]] -class RunFn(Protocol, Generic[TaskInputContra, TaskOutput]): - async def __call__(self, task_input: TaskInputContra) -> Run[TaskOutput]: ... +class RunFn(Protocol, Generic[TaskInputContra, AgentOutput]): + async def __call__(self, task_input: TaskInputContra) -> Run[AgentOutput]: ... class RunFnOutputOnly(Protocol, Generic[TaskInputContra, TaskOutputCov]): async def __call__(self, task_input: TaskInputContra) -> TaskOutputCov: ... -class StreamRunFn(Protocol, Generic[TaskInputContra, TaskOutput]): +class StreamRunFn(Protocol, Generic[TaskInputContra, AgentOutput]): def __call__( self, task_input: TaskInputContra, - ) -> AsyncIterator[Run[TaskOutput]]: ... + ) -> AsyncIterator[Run[AgentOutput]]: ... class StreamRunFnOutputOnly(Protocol, Generic[TaskInputContra, TaskOutputCov]): @@ -57,10 +57,10 @@ def __call__( RunTemplate = Union[ - RunFn[TaskInput, TaskOutput], - RunFnOutputOnly[TaskInput, TaskOutput], - StreamRunFn[TaskInput, TaskOutput], - StreamRunFnOutputOnly[TaskInput, TaskOutput], + RunFn[AgentInput, AgentOutput], + RunFnOutputOnly[AgentInput, AgentOutput], + StreamRunFn[AgentInput, AgentOutput], + StreamRunFnOutputOnly[AgentInput, AgentOutput], ] @@ -75,60 +75,63 @@ class _BaseProtocol(Protocol): __code__: Any -class FinalRunFn(_BaseProtocol, Protocol, Generic[TaskInputContra, TaskOutput]): +class FinalRunFn(_BaseProtocol, Protocol, Generic[TaskInputContra, AgentOutput]): async def __call__( self, task_input: TaskInputContra, - **kwargs: Unpack[RunParams[TaskOutput]], - ) -> Run[TaskOutput]: ... + **kwargs: Unpack[RunParams[AgentOutput]], + ) -> Run[AgentOutput]: ... -class FinalRunFnOutputOnly(_BaseProtocol, Protocol, Generic[TaskInputContra, TaskOutput]): +class FinalRunFnOutputOnly(_BaseProtocol, Protocol, Generic[TaskInputContra, AgentOutput]): async def __call__( self, task_input: TaskInputContra, - **kwargs: Unpack[RunParams[TaskOutput]], - ) -> TaskOutput: ... + **kwargs: Unpack[RunParams[AgentOutput]], + ) -> AgentOutput: ... -class FinalStreamRunFn(_BaseProtocol, Protocol, Generic[TaskInputContra, TaskOutput]): +class FinalStreamRunFn(_BaseProtocol, Protocol, Generic[TaskInputContra, AgentOutput]): def __call__( self, task_input: TaskInputContra, - **kwargs: Unpack[RunParams[TaskOutput]], - ) -> AsyncIterator[Run[TaskOutput]]: ... + **kwargs: Unpack[RunParams[AgentOutput]], + ) -> AsyncIterator[Run[AgentOutput]]: ... class FinalStreamRunFnOutputOnly(_BaseProtocol, Protocol, Generic[TaskInputContra, TaskOutputCov]): def __call__( self, task_input: TaskInputContra, - **kwargs: Unpack[RunParams[TaskOutput]], + **kwargs: Unpack[RunParams[AgentOutput]], ) -> AsyncIterator[TaskOutputCov]: ... FinalRunTemplate = Union[ - FinalRunFn[TaskInput, TaskOutput], - FinalRunFnOutputOnly[TaskInput, TaskOutput], - FinalStreamRunFn[TaskInput, TaskOutput], - FinalStreamRunFnOutputOnly[TaskInput, TaskOutput], + FinalRunFn[AgentInput, AgentOutput], + FinalRunFnOutputOnly[AgentInput, AgentOutput], + FinalStreamRunFn[AgentInput, AgentOutput], + FinalStreamRunFnOutputOnly[AgentInput, AgentOutput], ] class TaskDecorator(Protocol): @overload - def __call__(self, fn: RunFn[TaskInput, TaskOutput]) -> FinalRunFn[TaskInput, TaskOutput]: ... + def __call__(self, fn: RunFn[AgentInput, AgentOutput]) -> FinalRunFn[AgentInput, AgentOutput]: ... @overload - def __call__(self, fn: RunFnOutputOnly[TaskInput, TaskOutput]) -> FinalRunFnOutputOnly[TaskInput, TaskOutput]: ... + def __call__( + self, + fn: RunFnOutputOnly[AgentInput, AgentOutput], + ) -> FinalRunFnOutputOnly[AgentInput, AgentOutput]: ... @overload - def __call__(self, fn: StreamRunFn[TaskInput, TaskOutput]) -> FinalStreamRunFn[TaskInput, TaskOutput]: ... + def __call__(self, fn: StreamRunFn[AgentInput, AgentOutput]) -> FinalStreamRunFn[AgentInput, AgentOutput]: ... @overload def __call__( self, - fn: StreamRunFnOutputOnly[TaskInput, TaskOutput], - ) -> FinalStreamRunFnOutputOnly[TaskInput, TaskOutput]: ... + fn: StreamRunFnOutputOnly[AgentInput, AgentOutput], + ) -> FinalStreamRunFnOutputOnly[AgentInput, AgentOutput]: ... - def __call__(self, fn: RunTemplate[TaskInput, TaskOutput]) -> FinalRunTemplate[TaskInput, TaskOutput]: ... + def __call__(self, fn: RunTemplate[AgentInput, AgentOutput]) -> FinalRunTemplate[AgentInput, AgentOutput]: ... diff --git a/workflowai/core/client/_utils.py b/workflowai/core/client/_utils.py index 44dc6d2..3d3d181 100644 --- a/workflowai/core/client/_utils.py +++ b/workflowai/core/client/_utils.py @@ -9,7 +9,7 @@ from workflowai.core.client._types import OutputValidator from workflowai.core.domain.errors import BaseError, WorkflowAIError -from workflowai.core.domain.task import TaskOutput +from workflowai.core.domain.task import AgentOutput from workflowai.core.domain.version_reference import VersionReference from workflowai.core.logger import logger @@ -85,7 +85,7 @@ async def _wait_for_exception(e: WorkflowAIError): return _should_retry, _wait_for_exception -def tolerant_validator(m: type[TaskOutput]) -> OutputValidator[TaskOutput]: +def tolerant_validator(m: type[AgentOutput]) -> OutputValidator[AgentOutput]: return lambda payload: m.model_construct(None, **payload) diff --git a/workflowai/core/client/agent.py b/workflowai/core/client/agent.py index 4386526..9793163 100644 --- a/workflowai/core/client/agent.py +++ b/workflowai/core/client/agent.py @@ -9,17 +9,17 @@ from workflowai.core.client._utils import build_retryable_wait, global_default_version_reference, tolerant_validator from workflowai.core.domain.errors import BaseError, WorkflowAIError from workflowai.core.domain.run import Run -from workflowai.core.domain.task import TaskInput, TaskOutput +from workflowai.core.domain.task import AgentInput, AgentOutput from workflowai.core.domain.version_properties import VersionProperties from workflowai.core.domain.version_reference import VersionReference -class Agent(Generic[TaskInput, TaskOutput]): +class Agent(Generic[AgentInput, AgentOutput]): def __init__( self, agent_id: str, - input_cls: type[TaskInput], - output_cls: type[TaskOutput], + input_cls: type[AgentInput], + output_cls: type[AgentOutput], api: Union[APIClient, Callable[[], APIClient]], schema_id: Optional[int] = None, version: Optional[VersionReference] = None, @@ -55,7 +55,7 @@ def _sanitize_version(self, version: Optional[VersionReference]) -> Union[str, i dumped["model"] = workflowai.DEFAULT_MODEL return dumped - async def _prepare_run(self, task_input: TaskInput, stream: bool, **kwargs: Unpack[RunParams[TaskOutput]]): + async def _prepare_run(self, task_input: AgentInput, stream: bool, **kwargs: Unpack[RunParams[AgentOutput]]): schema_id = self.schema_id if not schema_id: schema_id = await self.register() @@ -94,13 +94,13 @@ async def register(self): async def run( self, - task_input: TaskInput, - **kwargs: Unpack[RunParams[TaskOutput]], - ) -> Run[TaskOutput]: + task_input: AgentInput, + **kwargs: Unpack[RunParams[AgentOutput]], + ) -> Run[AgentOutput]: """Run the agent Args: - task_input (TaskInput): the input to the task + task_input (AgentInput): the input to the task version (Optional[TaskVersionReference], optional): the version of the task to run. If not provided, the version defined in the task is used. Defaults to None. use_cache (CacheUsage, optional): how to use the cache. Defaults to "auto". @@ -118,7 +118,7 @@ async def run( max_retry_count (int, optional): The maximum number of retry attempts. Defaults to 1. Returns: - Union[TaskRun[TaskInput, TaskOutput], AsyncIterator[TaskOutput]]: the task run object + Union[TaskRun[AgentInput, AgentOutput], AsyncIterator[AgentOutput]]: the task run object or an async iterator of output objects """ prepared_run = await self._prepare_run(task_input, stream=False, **kwargs) @@ -137,13 +137,13 @@ async def run( async def stream( self, - task_input: TaskInput, - **kwargs: Unpack[RunParams[TaskOutput]], + task_input: AgentInput, + **kwargs: Unpack[RunParams[AgentOutput]], ): """Stream the output of the agent Args: - task_input (TaskInput): the input to the task + task_input (AgentInput): the input to the task version (Optional[TaskVersionReference], optional): the version of the task to run. If not provided, the version defined in the task is used. Defaults to None. use_cache (CacheUsage, optional): how to use the cache. Defaults to "auto". @@ -161,7 +161,7 @@ async def stream( max_retry_count (int, optional): The maximum number of retry attempts. Defaults to 1. Returns: - Union[TaskRun[TaskInput, TaskOutput], AsyncIterator[TaskOutput]]: the task run object + Union[TaskRun[AgentInput, AgentOutput], AsyncIterator[AgentOutput]]: the task run object or an async iterator of output objects """ prepared_run = await self._prepare_run(task_input, stream=True, **kwargs) diff --git a/workflowai/core/client/agent_test.py b/workflowai/core/client/agent_test.py index 228ba3a..ec4d881 100644 --- a/workflowai/core/client/agent_test.py +++ b/workflowai/core/client/agent_test.py @@ -58,8 +58,8 @@ async def test_success(self, httpx_mock: HTTPXMock, agent: Agent[HelloTaskInput, task_run = await agent.run(task_input=HelloTaskInput(name="Alice")) assert task_run.id == "8f635b73-f403-47ee-bff9-18320616c6cc" - assert task_run.task_id == "123" - assert task_run.task_schema_id == 1 + assert task_run.agent_id == "123" + assert task_run.schema_id == 1 reqs = httpx_mock.get_requests() assert len(reqs) == 1 @@ -85,7 +85,7 @@ async def test_stream(self, httpx_mock: HTTPXMock, agent: Agent[HelloTaskInput, chunks = [chunk async for chunk in agent.stream(task_input=HelloTaskInput(name="Alice"))] - outputs = [chunk.task_output for chunk in chunks] + outputs = [chunk.output for chunk in chunks] assert outputs == [ HelloTaskOutput(message=""), HelloTaskOutput(message="hel"), @@ -119,14 +119,14 @@ async def test_stream_not_optional( chunks = [chunk async for chunk in agent_not_optional.stream(task_input=HelloTaskInput(name="Alice"))] - messages = [chunk.task_output.message for chunk in chunks] + messages = [chunk.output.message for chunk in chunks] assert messages == ["", "hel", "hello", "hello"] for chunk in chunks[:-1]: with pytest.raises(AttributeError): # Since the field is not optional, it will raise an attribute error - assert chunk.task_output.another_field - assert chunks[-1].task_output.another_field == "test" + assert chunk.output.another_field + assert chunks[-1].output.another_field == "test" last_message = chunks[-1] assert isinstance(last_message, Run) diff --git a/workflowai/core/client/client_test.py b/workflowai/core/client/client_test.py index f268340..9f81c16 100644 --- a/workflowai/core/client/client_test.py +++ b/workflowai/core/client/client_test.py @@ -39,9 +39,9 @@ async def test_run_output_only(self, workflowai: WorkflowAI, mock_run_fn: Mock): async def fn(task_input: HelloTaskInput) -> HelloTaskOutput: ... mock_run_fn.return_value = Run( - task_output=HelloTaskOutput(message="hello"), - task_id="123", - task_schema_id=1, + output=HelloTaskOutput(message="hello"), + agent_id="123", + schema_id=1, ) output = await fn(HelloTaskInput(name="Alice")) @@ -54,16 +54,16 @@ async def fn(task_input: HelloTaskInput) -> Run[HelloTaskOutput]: ... mock_run_fn.return_value = Run( id="1", - task_output=HelloTaskOutput(message="hello"), - task_id="123", - task_schema_id=1, + output=HelloTaskOutput(message="hello"), + agent_id="123", + schema_id=1, ) - output = await fn(HelloTaskInput(name="Alice")) + run = await fn(HelloTaskInput(name="Alice")) - assert output.id == "1" - assert output.task_output == HelloTaskOutput(message="hello") - assert isinstance(output, Run) + assert run.id == "1" + assert run.output == HelloTaskOutput(message="hello") + assert isinstance(run, Run) async def test_stream(self, workflowai: WorkflowAI, httpx_mock: HTTPXMock): # We avoid mocking the run fn directly here, python does weird things with @@ -86,7 +86,7 @@ def fn(task_input: HelloTaskInput) -> AsyncIterator[Run[HelloTaskOutput]]: ... chunks = [chunk async for chunk in fn(HelloTaskInput(name="Alice"))] def _run(output: HelloTaskOutput, **kwargs: Any) -> Run[HelloTaskOutput]: - return Run(id="1", task_id="123", task_schema_id=1, task_output=output, **kwargs) + return Run(id="1", agent_id="123", schema_id=1, output=output, **kwargs) assert chunks == [ _run(HelloTaskOutput(message="")), diff --git a/workflowai/core/domain/run.py b/workflowai/core/domain/run.py index 9314f9a..b35845f 100644 --- a/workflowai/core/domain/run.py +++ b/workflowai/core/domain/run.py @@ -3,32 +3,32 @@ from pydantic import BaseModel, Field # pyright: ignore [reportUnknownVariableType] -from workflowai.core.domain.task import TaskOutput +from workflowai.core.domain.task import AgentOutput from workflowai.core.domain.task_version import TaskVersion -class Run(BaseModel, Generic[TaskOutput]): +class Run(BaseModel, Generic[AgentOutput]): """ - A task run is an instance of a task with a specific input and output. + A run is an instance of a agent with a specific input and output. - This class represent a task run that already has been recorded and possibly + This class represent a run that already has been recorded and possibly been evaluated """ id: str = Field( default_factory=lambda: str(uuid.uuid4()), - description="The unique identifier of the task run. This is a UUIDv7.", + description="The unique identifier of the run. This is a UUIDv7.", ) - task_id: str - task_schema_id: int - task_output: TaskOutput + agent_id: str + schema_id: int + output: AgentOutput duration_seconds: Optional[float] = None cost_usd: Optional[float] = None version: Optional[TaskVersion] = Field( default=None, - description="The version of the task that was run. Only provided if the version differs from the version" + description="The version of the agent that was run. Only provided if the version differs from the version" " specified in the request, for example in case of a model fallback", ) diff --git a/workflowai/core/domain/task.py b/workflowai/core/domain/task.py index 285cba9..ec792bd 100644 --- a/workflowai/core/domain/task.py +++ b/workflowai/core/domain/task.py @@ -2,5 +2,5 @@ from pydantic import BaseModel -TaskInput = TypeVar("TaskInput", bound=BaseModel) -TaskOutput = TypeVar("TaskOutput", bound=BaseModel) +AgentInput = TypeVar("AgentInput", bound=BaseModel) +AgentOutput = TypeVar("AgentOutput", bound=BaseModel) From d07d3698733ff96be124a244b4dd442f839c6c78 Mon Sep 17 00:00:00 2001 From: Guillaume Aquilina Date: Thu, 23 Jan 2025 15:21:12 -0500 Subject: [PATCH 2/4] chore: task version -> version --- workflowai/__init__.py | 2 +- workflowai/core/client/_models.py | 4 ++-- workflowai/core/domain/run.py | 4 ++-- workflowai/core/domain/{task_version.py => version.py} | 2 +- 4 files changed, 6 insertions(+), 6 deletions(-) rename workflowai/core/domain/{task_version.py => version.py} (90%) diff --git a/workflowai/__init__.py b/workflowai/__init__.py index adf66ec..ec9eba4 100644 --- a/workflowai/__init__.py +++ b/workflowai/__init__.py @@ -9,7 +9,7 @@ from workflowai.core.domain.errors import WorkflowAIError as WorkflowAIError from workflowai.core.domain.model import Model as Model from workflowai.core.domain.run import Run as Run -from workflowai.core.domain.task_version import TaskVersion as TaskVersion +from workflowai.core.domain.version import Version as Version from workflowai.core.domain.version_reference import ( VersionReference as VersionReference, ) diff --git a/workflowai/core/client/_models.py b/workflowai/core/client/_models.py index fac6ca0..37f44e9 100644 --- a/workflowai/core/client/_models.py +++ b/workflowai/core/client/_models.py @@ -7,7 +7,7 @@ from workflowai.core.domain.cache_usage import CacheUsage from workflowai.core.domain.run import Run from workflowai.core.domain.task import AgentOutput -from workflowai.core.domain.task_version import TaskVersion +from workflowai.core.domain.version import Version as DVersion from workflowai.core.domain.version_properties import VersionProperties as DVersionProperties @@ -54,7 +54,7 @@ def to_domain(self, task_id: str, task_schema_id: int, validator: OutputValidato schema_id=task_schema_id, output=validator(self.task_output), version=self.version - and TaskVersion( + and DVersion( properties=DVersionProperties.model_construct( None, **self.version.properties, diff --git a/workflowai/core/domain/run.py b/workflowai/core/domain/run.py index b35845f..08ce2fb 100644 --- a/workflowai/core/domain/run.py +++ b/workflowai/core/domain/run.py @@ -4,7 +4,7 @@ from pydantic import BaseModel, Field # pyright: ignore [reportUnknownVariableType] from workflowai.core.domain.task import AgentOutput -from workflowai.core.domain.task_version import TaskVersion +from workflowai.core.domain.version import Version class Run(BaseModel, Generic[AgentOutput]): @@ -26,7 +26,7 @@ class Run(BaseModel, Generic[AgentOutput]): duration_seconds: Optional[float] = None cost_usd: Optional[float] = None - version: Optional[TaskVersion] = Field( + version: Optional[Version] = Field( default=None, description="The version of the agent that was run. Only provided if the version differs from the version" " specified in the request, for example in case of a model fallback", diff --git a/workflowai/core/domain/task_version.py b/workflowai/core/domain/version.py similarity index 90% rename from workflowai/core/domain/task_version.py rename to workflowai/core/domain/version.py index a930923..5002d1a 100644 --- a/workflowai/core/domain/task_version.py +++ b/workflowai/core/domain/version.py @@ -3,7 +3,7 @@ from workflowai.core.domain.version_properties import VersionProperties -class TaskVersion(BaseModel): +class Version(BaseModel): properties: VersionProperties = Field( default_factory=VersionProperties, description="The properties used for executing the run.", From c4df8ea7e384a08da6d7bdd9e3ddc89df45b8701 Mon Sep 17 00:00:00 2001 From: Guillaume Aquilina Date: Thu, 23 Jan 2025 16:42:28 -0500 Subject: [PATCH 3/4] chore: rename more task -> agent --- pyproject.toml | 2 +- workflowai/__init__.py | 4 +-- workflowai/core/client/_fn_utils.py | 20 +++++++------- workflowai/core/client/_types.py | 42 ++++++++++++++--------------- 4 files changed, 34 insertions(+), 34 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index ffd8e82..1cce866 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [tool.poetry] name = "workflowai" -version = "0.6.0.dev1" +version = "0.6.0.dev2" description = "" authors = ["Guillaume Aquilina "] readme = "README.md" diff --git a/workflowai/__init__.py b/workflowai/__init__.py index ec9eba4..eac4b5a 100644 --- a/workflowai/__init__.py +++ b/workflowai/__init__.py @@ -3,7 +3,7 @@ from typing_extensions import deprecated -from workflowai.core.client._types import TaskDecorator +from workflowai.core.client._types import AgentDecorator from workflowai.core.client.client import WorkflowAI as WorkflowAI from workflowai.core.domain.cache_usage import CacheUsage as CacheUsage from workflowai.core.domain.errors import WorkflowAIError as WorkflowAIError @@ -67,7 +67,7 @@ def agent( schema_id: Optional[int] = None, version: Optional[VersionReference] = None, model: Optional[Model] = None, -) -> TaskDecorator: +) -> AgentDecorator: from workflowai.core.client._fn_utils import agent_wrapper return agent_wrapper( diff --git a/workflowai/core/client/_fn_utils.py b/workflowai/core/client/_fn_utils.py index b340ab8..5421999 100644 --- a/workflowai/core/client/_fn_utils.py +++ b/workflowai/core/client/_fn_utils.py @@ -19,10 +19,10 @@ from workflowai.core.client._api import APIClient from workflowai.core.client._types import ( + AgentDecorator, FinalRunTemplate, RunParams, RunTemplate, - TaskDecorator, ) from workflowai.core.client.agent import Agent from workflowai.core.domain.model import Model @@ -91,23 +91,23 @@ def extract_fn_spec(fn: RunTemplate[AgentInput, AgentOutput]) -> RunFunctionSpec class _RunnableAgent(Agent[AgentInput, AgentOutput], Generic[AgentInput, AgentOutput]): - async def __call__(self, task_input: AgentInput, **kwargs: Unpack[RunParams[AgentOutput]]): - return await self.run(task_input, **kwargs) + async def __call__(self, input: AgentInput, **kwargs: Unpack[RunParams[AgentOutput]]): # noqa: A002 + return await self.run(input, **kwargs) class _RunnableOutputOnlyAgent(Agent[AgentInput, AgentOutput], Generic[AgentInput, AgentOutput]): - async def __call__(self, task_input: AgentInput, **kwargs: Unpack[RunParams[AgentOutput]]): - return (await self.run(task_input, **kwargs)).output + async def __call__(self, input: AgentInput, **kwargs: Unpack[RunParams[AgentOutput]]): # noqa: A002 + return (await self.run(input, **kwargs)).output class _RunnableStreamAgent(Agent[AgentInput, AgentOutput], Generic[AgentInput, AgentOutput]): - def __call__(self, task_input: AgentInput, **kwargs: Unpack[RunParams[AgentOutput]]): - return self.stream(task_input, **kwargs) + def __call__(self, input: AgentInput, **kwargs: Unpack[RunParams[AgentOutput]]): # noqa: A002 + return self.stream(input, **kwargs) class _RunnableStreamOutputOnlyAgent(Agent[AgentInput, AgentOutput], Generic[AgentInput, AgentOutput]): - async def __call__(self, task_input: AgentInput, **kwargs: Unpack[RunParams[AgentOutput]]): - async for chunk in self.stream(task_input, **kwargs): + async def __call__(self, input: AgentInput, **kwargs: Unpack[RunParams[AgentOutput]]): # noqa: A002 + async for chunk in self.stream(input, **kwargs): yield chunk.output @@ -156,7 +156,7 @@ def agent_wrapper( agent_id: Optional[str] = None, version: Optional[VersionReference] = None, model: Optional[Model] = None, -) -> TaskDecorator: +) -> AgentDecorator: def wrap(fn: RunTemplate[AgentInput, AgentOutput]) -> FinalRunTemplate[AgentInput, AgentOutput]: tid = agent_id or agent_id_from_fn_name(fn) return functools.wraps(fn)(wrap_run_template(client, tid, schema_id, version, model, fn)) # pyright: ignore [reportReturnType] diff --git a/workflowai/core/client/_types.py b/workflowai/core/client/_types.py index fddef7e..9fad07c 100644 --- a/workflowai/core/client/_types.py +++ b/workflowai/core/client/_types.py @@ -18,8 +18,8 @@ from workflowai.core.domain.task import AgentInput, AgentOutput from workflowai.core.domain.version_reference import VersionReference -TaskInputContra = TypeVar("TaskInputContra", bound=BaseModel, contravariant=True) -TaskOutputCov = TypeVar("TaskOutputCov", bound=BaseModel, covariant=True) +AgentInputContra = TypeVar("AgentInputContra", bound=BaseModel, contravariant=True) +AgentOutputCov = TypeVar("AgentOutputCov", bound=BaseModel, covariant=True) OutputValidator = Callable[[dict[str, Any]], AgentOutput] @@ -34,26 +34,26 @@ class RunParams(TypedDict, Generic[AgentOutput]): validator: NotRequired[OutputValidator[AgentOutput]] -class RunFn(Protocol, Generic[TaskInputContra, AgentOutput]): - async def __call__(self, task_input: TaskInputContra) -> Run[AgentOutput]: ... +class RunFn(Protocol, Generic[AgentInputContra, AgentOutput]): + async def __call__(self, task_input: AgentInputContra) -> Run[AgentOutput]: ... -class RunFnOutputOnly(Protocol, Generic[TaskInputContra, TaskOutputCov]): - async def __call__(self, task_input: TaskInputContra) -> TaskOutputCov: ... +class RunFnOutputOnly(Protocol, Generic[AgentInputContra, AgentOutputCov]): + async def __call__(self, task_input: AgentInputContra) -> AgentOutputCov: ... -class StreamRunFn(Protocol, Generic[TaskInputContra, AgentOutput]): +class StreamRunFn(Protocol, Generic[AgentInputContra, AgentOutput]): def __call__( self, - task_input: TaskInputContra, + task_input: AgentInputContra, ) -> AsyncIterator[Run[AgentOutput]]: ... -class StreamRunFnOutputOnly(Protocol, Generic[TaskInputContra, TaskOutputCov]): +class StreamRunFnOutputOnly(Protocol, Generic[AgentInputContra, AgentOutputCov]): def __call__( self, - task_input: TaskInputContra, - ) -> AsyncIterator[TaskOutputCov]: ... + task_input: AgentInputContra, + ) -> AsyncIterator[AgentOutputCov]: ... RunTemplate = Union[ @@ -75,36 +75,36 @@ class _BaseProtocol(Protocol): __code__: Any -class FinalRunFn(_BaseProtocol, Protocol, Generic[TaskInputContra, AgentOutput]): +class FinalRunFn(_BaseProtocol, Protocol, Generic[AgentInputContra, AgentOutput]): async def __call__( self, - task_input: TaskInputContra, + input: AgentInputContra, # noqa: A002 **kwargs: Unpack[RunParams[AgentOutput]], ) -> Run[AgentOutput]: ... -class FinalRunFnOutputOnly(_BaseProtocol, Protocol, Generic[TaskInputContra, AgentOutput]): +class FinalRunFnOutputOnly(_BaseProtocol, Protocol, Generic[AgentInputContra, AgentOutput]): async def __call__( self, - task_input: TaskInputContra, + input: AgentInputContra, # noqa: A002 **kwargs: Unpack[RunParams[AgentOutput]], ) -> AgentOutput: ... -class FinalStreamRunFn(_BaseProtocol, Protocol, Generic[TaskInputContra, AgentOutput]): +class FinalStreamRunFn(_BaseProtocol, Protocol, Generic[AgentInputContra, AgentOutput]): def __call__( self, - task_input: TaskInputContra, + input: AgentInputContra, # noqa: A002 **kwargs: Unpack[RunParams[AgentOutput]], ) -> AsyncIterator[Run[AgentOutput]]: ... -class FinalStreamRunFnOutputOnly(_BaseProtocol, Protocol, Generic[TaskInputContra, TaskOutputCov]): +class FinalStreamRunFnOutputOnly(_BaseProtocol, Protocol, Generic[AgentInputContra, AgentOutputCov]): def __call__( self, - task_input: TaskInputContra, + input: AgentInputContra, # noqa: A002 **kwargs: Unpack[RunParams[AgentOutput]], - ) -> AsyncIterator[TaskOutputCov]: ... + ) -> AsyncIterator[AgentOutputCov]: ... FinalRunTemplate = Union[ @@ -115,7 +115,7 @@ def __call__( ] -class TaskDecorator(Protocol): +class AgentDecorator(Protocol): @overload def __call__(self, fn: RunFn[AgentInput, AgentOutput]) -> FinalRunFn[AgentInput, AgentOutput]: ... From 65059f33203b2b9742e4165fc4f54d89fc6725ca Mon Sep 17 00:00:00 2001 From: Guillaume Aquilina Date: Thu, 23 Jan 2025 16:53:57 -0500 Subject: [PATCH 4/4] feat: allow any naming for first positional param --- tests/e2e/no_schema_test.py | 2 +- tests/e2e/run_test.py | 2 +- workflowai/core/client/_fn_utils.py | 20 ++++++++++++++---- workflowai/core/client/_fn_utils_test.py | 8 ++++---- workflowai/core/client/_types.py | 26 +++++++++++------------- 5 files changed, 34 insertions(+), 24 deletions(-) diff --git a/tests/e2e/no_schema_test.py b/tests/e2e/no_schema_test.py index 6449c86..dc09a62 100644 --- a/tests/e2e/no_schema_test.py +++ b/tests/e2e/no_schema_test.py @@ -14,7 +14,7 @@ class SummarizeTaskOutput(BaseModel): @workflowai.agent(id="summarize", model="gemini-1.5-flash-latest") -async def summarize(task_input: SummarizeTaskInput) -> SummarizeTaskOutput: ... +async def summarize(_: SummarizeTaskInput) -> SummarizeTaskOutput: ... async def test_summarize(): diff --git a/tests/e2e/run_test.py b/tests/e2e/run_test.py index fbb2cdd..4dc8a5a 100644 --- a/tests/e2e/run_test.py +++ b/tests/e2e/run_test.py @@ -27,7 +27,7 @@ class ExtractProductReviewSentimentTaskOutput(BaseModel): @workflowai.agent(id="extract-product-review-sentiment", schema_id=1) def extract_product_review_sentiment( - task_input: ExtractProductReviewSentimentTaskInput, + _: ExtractProductReviewSentimentTaskInput, ) -> AsyncIterator[ExtractProductReviewSentimentTaskOutput]: ... diff --git a/workflowai/core/client/_fn_utils.py b/workflowai/core/client/_fn_utils.py index 5421999..57aee97 100644 --- a/workflowai/core/client/_fn_utils.py +++ b/workflowai/core/client/_fn_utils.py @@ -1,4 +1,5 @@ import functools +import inspect from collections.abc import Callable from typing import ( Any, @@ -66,17 +67,28 @@ def is_async_iterator(t: type[Any]) -> bool: return issubclass(ori, AsyncIterator) +def _first_arg_name(fn: Callable[..., Any]) -> Optional[str]: + sig = inspect.signature(fn) + for param in sig.parameters.values(): + if param.kind == param.POSITIONAL_OR_KEYWORD: + return param.name + return None + + def extract_fn_spec(fn: RunTemplate[AgentInput, AgentOutput]) -> RunFunctionSpec: + first_arg_name = _first_arg_name(fn) + if not first_arg_name: + raise ValueError("Function must have a first positional argument") hints = get_type_hints(fn) if "return" not in hints: raise ValueError("Function must have a return type hint") - if "task_input" not in hints: - raise ValueError("Function must have a task_input parameter") + if first_arg_name not in hints: + raise ValueError("Function must have a first positional parameter") return_type_hint = hints["return"] - input_cls = hints["task_input"] + input_cls = hints[first_arg_name] if not issubclass(input_cls, BaseModel): - raise ValueError("task_input must be a subclass of BaseModel") + raise ValueError("First positional parameter must be a subclass of BaseModel") output_cls = None diff --git a/workflowai/core/client/_fn_utils_test.py b/workflowai/core/client/_fn_utils_test.py index 0d955fd..c5cd37f 100644 --- a/workflowai/core/client/_fn_utils_test.py +++ b/workflowai/core/client/_fn_utils_test.py @@ -21,16 +21,16 @@ from workflowai.core.domain.run import Run -async def say_hello(task_input: HelloTaskInput) -> HelloTaskOutput: ... +async def say_hello(_: HelloTaskInput) -> HelloTaskOutput: ... -async def say_hello_run(task_input: HelloTaskInput) -> Run[HelloTaskOutput]: ... +async def say_hello_run(bla: HelloTaskInput) -> Run[HelloTaskOutput]: ... -def stream_hello(task_input: HelloTaskInput) -> AsyncIterator[HelloTaskOutput]: ... +def stream_hello(_: HelloTaskInput) -> AsyncIterator[HelloTaskOutput]: ... -def stream_hello_run(task_input: HelloTaskInput) -> AsyncIterator[Run[HelloTaskOutput]]: ... +def stream_hello_run(_: HelloTaskInput) -> AsyncIterator[Run[HelloTaskOutput]]: ... class TestGetGenericArgs: diff --git a/workflowai/core/client/_types.py b/workflowai/core/client/_types.py index 9fad07c..7dd3bdc 100644 --- a/workflowai/core/client/_types.py +++ b/workflowai/core/client/_types.py @@ -35,25 +35,19 @@ class RunParams(TypedDict, Generic[AgentOutput]): class RunFn(Protocol, Generic[AgentInputContra, AgentOutput]): - async def __call__(self, task_input: AgentInputContra) -> Run[AgentOutput]: ... + async def __call__(self, _: AgentInputContra, /) -> Run[AgentOutput]: ... class RunFnOutputOnly(Protocol, Generic[AgentInputContra, AgentOutputCov]): - async def __call__(self, task_input: AgentInputContra) -> AgentOutputCov: ... + async def __call__(self, _: AgentInputContra, /) -> AgentOutputCov: ... class StreamRunFn(Protocol, Generic[AgentInputContra, AgentOutput]): - def __call__( - self, - task_input: AgentInputContra, - ) -> AsyncIterator[Run[AgentOutput]]: ... + def __call__(self, _: AgentInputContra, /) -> AsyncIterator[Run[AgentOutput]]: ... class StreamRunFnOutputOnly(Protocol, Generic[AgentInputContra, AgentOutputCov]): - def __call__( - self, - task_input: AgentInputContra, - ) -> AsyncIterator[AgentOutputCov]: ... + def __call__(self, _: AgentInputContra, /) -> AsyncIterator[AgentOutputCov]: ... RunTemplate = Union[ @@ -78,7 +72,8 @@ class _BaseProtocol(Protocol): class FinalRunFn(_BaseProtocol, Protocol, Generic[AgentInputContra, AgentOutput]): async def __call__( self, - input: AgentInputContra, # noqa: A002 + _: AgentInputContra, + /, **kwargs: Unpack[RunParams[AgentOutput]], ) -> Run[AgentOutput]: ... @@ -86,7 +81,8 @@ async def __call__( class FinalRunFnOutputOnly(_BaseProtocol, Protocol, Generic[AgentInputContra, AgentOutput]): async def __call__( self, - input: AgentInputContra, # noqa: A002 + _: AgentInputContra, + /, **kwargs: Unpack[RunParams[AgentOutput]], ) -> AgentOutput: ... @@ -94,7 +90,8 @@ async def __call__( class FinalStreamRunFn(_BaseProtocol, Protocol, Generic[AgentInputContra, AgentOutput]): def __call__( self, - input: AgentInputContra, # noqa: A002 + _: AgentInputContra, + /, **kwargs: Unpack[RunParams[AgentOutput]], ) -> AsyncIterator[Run[AgentOutput]]: ... @@ -102,7 +99,8 @@ def __call__( class FinalStreamRunFnOutputOnly(_BaseProtocol, Protocol, Generic[AgentInputContra, AgentOutputCov]): def __call__( self, - input: AgentInputContra, # noqa: A002 + _: AgentInputContra, + /, **kwargs: Unpack[RunParams[AgentOutput]], ) -> AsyncIterator[AgentOutputCov]: ...