Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions examples/city_to_capital_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,9 +28,9 @@ async def city_to_capital(task_input: CityToCapitalTaskInput) -> CityToCapitalTa
def main(city: str) -> None:
async def _inner() -> None:
task_input = CityToCapitalTaskInput(city=city)
task_output = await city_to_capital(task_input)
output = await city_to_capital(task_input)

rprint(task_output)
rprint(output)

aiorun(_inner())

Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[tool.poetry]
name = "workflowai"
version = "0.6.0.dev1"
version = "0.6.0.dev2"
description = ""
authors = ["Guillaume Aquilina <guillaume@workflowai.com>"]
readme = "README.md"
Expand Down
2 changes: 1 addition & 1 deletion tests/e2e/no_schema_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ class SummarizeTaskOutput(BaseModel):


@workflowai.agent(id="summarize", model="gemini-1.5-flash-latest")
async def summarize(task_input: SummarizeTaskInput) -> SummarizeTaskOutput: ...
async def summarize(_: SummarizeTaskInput) -> SummarizeTaskOutput: ...


async def test_summarize():
Expand Down
4 changes: 2 additions & 2 deletions tests/e2e/run_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ class ExtractProductReviewSentimentTaskOutput(BaseModel):

@workflowai.agent(id="extract-product-review-sentiment", schema_id=1)
def extract_product_review_sentiment(
task_input: ExtractProductReviewSentimentTaskInput,
_: ExtractProductReviewSentimentTaskInput,
) -> AsyncIterator[ExtractProductReviewSentimentTaskOutput]: ...


Expand All @@ -52,7 +52,7 @@ async def test_run_task(
):
task_input = ExtractProductReviewSentimentTaskInput(review_text="This product is amazing!")
run = await extract_product_review_sentiment_agent.run(task_input=task_input, use_cache="never")
assert run.task_output.sentiment == Sentiment.POSITIVE
assert run.output.sentiment == Sentiment.POSITIVE


async def test_stream_task(
Expand Down
8 changes: 4 additions & 4 deletions tests/integration/run_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,9 +73,9 @@ async def city_to_capital(task_input: CityToCapitalTaskInput) -> CityToCapitalTa
_mock_response(httpx_mock)

task_input = CityToCapitalTaskInput(city="Hello")
task_output = await city_to_capital(task_input)
output = await city_to_capital(task_input)

assert task_output.capital == "Tokyo"
assert output.capital == "Tokyo"

_check_request(httpx_mock.get_request())

Expand All @@ -90,7 +90,7 @@ async def city_to_capital(task_input: CityToCapitalTaskInput) -> Run[CityToCapit
with_run = await city_to_capital(task_input)

assert with_run.id == "123"
assert with_run.task_output.capital == "Tokyo"
assert with_run.output.capital == "Tokyo"

_check_request(httpx_mock.get_request())

Expand All @@ -105,7 +105,7 @@ async def city_to_capital(task_input: CityToCapitalTaskInput) -> Run[CityToCapit
with_run = await city_to_capital(task_input)

assert with_run.id == "123"
assert with_run.task_output.capital == "Tokyo"
assert with_run.output.capital == "Tokyo"

_check_request(httpx_mock.get_request(), version="staging")

Expand Down
6 changes: 3 additions & 3 deletions workflowai/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,13 @@

from typing_extensions import deprecated

from workflowai.core.client._types import TaskDecorator
from workflowai.core.client._types import AgentDecorator
from workflowai.core.client.client import WorkflowAI as WorkflowAI
from workflowai.core.domain.cache_usage import CacheUsage as CacheUsage
from workflowai.core.domain.errors import WorkflowAIError as WorkflowAIError
from workflowai.core.domain.model import Model as Model
from workflowai.core.domain.run import Run as Run
from workflowai.core.domain.task_version import TaskVersion as TaskVersion
from workflowai.core.domain.version import Version as Version
from workflowai.core.domain.version_reference import (
VersionReference as VersionReference,
)
Expand Down Expand Up @@ -67,7 +67,7 @@ def agent(
schema_id: Optional[int] = None,
version: Optional[VersionReference] = None,
model: Optional[Model] = None,
) -> TaskDecorator:
) -> AgentDecorator:
from workflowai.core.client._fn_utils import agent_wrapper

return agent_wrapper(
Expand Down
66 changes: 39 additions & 27 deletions workflowai/core/client/_fn_utils.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import functools
import inspect
from collections.abc import Callable
from typing import (
Any,
Expand All @@ -19,15 +20,15 @@

from workflowai.core.client._api import APIClient
from workflowai.core.client._types import (
AgentDecorator,
FinalRunTemplate,
RunParams,
RunTemplate,
TaskDecorator,
)
from workflowai.core.client.agent import Agent
from workflowai.core.domain.model import Model
from workflowai.core.domain.run import Run
from workflowai.core.domain.task import TaskInput, TaskOutput
from workflowai.core.domain.task import AgentInput, AgentOutput
from workflowai.core.domain.version_properties import VersionProperties
from workflowai.core.domain.version_reference import VersionReference

Expand Down Expand Up @@ -66,17 +67,28 @@ def is_async_iterator(t: type[Any]) -> bool:
return issubclass(ori, AsyncIterator)


def extract_fn_spec(fn: RunTemplate[TaskInput, TaskOutput]) -> RunFunctionSpec:
def _first_arg_name(fn: Callable[..., Any]) -> Optional[str]:
sig = inspect.signature(fn)
for param in sig.parameters.values():
if param.kind == param.POSITIONAL_OR_KEYWORD:
return param.name
return None


def extract_fn_spec(fn: RunTemplate[AgentInput, AgentOutput]) -> RunFunctionSpec:
first_arg_name = _first_arg_name(fn)
if not first_arg_name:
raise ValueError("Function must have a first positional argument")
hints = get_type_hints(fn)
if "return" not in hints:
raise ValueError("Function must have a return type hint")
if "task_input" not in hints:
raise ValueError("Function must have a task_input parameter")
if first_arg_name not in hints:
raise ValueError("Function must have a first positional parameter")

return_type_hint = hints["return"]
input_cls = hints["task_input"]
input_cls = hints[first_arg_name]
if not issubclass(input_cls, BaseModel):
raise ValueError("task_input must be a subclass of BaseModel")
raise ValueError("First positional parameter must be a subclass of BaseModel")

output_cls = None

Expand All @@ -90,25 +102,25 @@ def extract_fn_spec(fn: RunTemplate[TaskInput, TaskOutput]) -> RunFunctionSpec:
return RunFunctionSpec(stream, output_only, input_cls, output_cls)


class _RunnableAgent(Agent[TaskInput, TaskOutput], Generic[TaskInput, TaskOutput]):
async def __call__(self, task_input: TaskInput, **kwargs: Unpack[RunParams[TaskOutput]]):
return await self.run(task_input, **kwargs)
class _RunnableAgent(Agent[AgentInput, AgentOutput], Generic[AgentInput, AgentOutput]):
async def __call__(self, input: AgentInput, **kwargs: Unpack[RunParams[AgentOutput]]): # noqa: A002
return await self.run(input, **kwargs)


class _RunnableOutputOnlyAgent(Agent[TaskInput, TaskOutput], Generic[TaskInput, TaskOutput]):
async def __call__(self, task_input: TaskInput, **kwargs: Unpack[RunParams[TaskOutput]]):
return (await self.run(task_input, **kwargs)).task_output
class _RunnableOutputOnlyAgent(Agent[AgentInput, AgentOutput], Generic[AgentInput, AgentOutput]):
async def __call__(self, input: AgentInput, **kwargs: Unpack[RunParams[AgentOutput]]): # noqa: A002
return (await self.run(input, **kwargs)).output


class _RunnableStreamAgent(Agent[TaskInput, TaskOutput], Generic[TaskInput, TaskOutput]):
def __call__(self, task_input: TaskInput, **kwargs: Unpack[RunParams[TaskOutput]]):
return self.stream(task_input, **kwargs)
class _RunnableStreamAgent(Agent[AgentInput, AgentOutput], Generic[AgentInput, AgentOutput]):
def __call__(self, input: AgentInput, **kwargs: Unpack[RunParams[AgentOutput]]): # noqa: A002
return self.stream(input, **kwargs)


class _RunnableStreamOutputOnlyAgent(Agent[TaskInput, TaskOutput], Generic[TaskInput, TaskOutput]):
async def __call__(self, task_input: TaskInput, **kwargs: Unpack[RunParams[TaskOutput]]):
async for chunk in self.stream(task_input, **kwargs):
yield chunk.task_output
class _RunnableStreamOutputOnlyAgent(Agent[AgentInput, AgentOutput], Generic[AgentInput, AgentOutput]):
async def __call__(self, input: AgentInput, **kwargs: Unpack[RunParams[AgentOutput]]): # noqa: A002
async for chunk in self.stream(input, **kwargs):
yield chunk.output


def wrap_run_template(
Expand All @@ -117,12 +129,12 @@ def wrap_run_template(
schema_id: Optional[int],
version: Optional[VersionReference],
model: Optional[Model],
fn: RunTemplate[TaskInput, TaskOutput],
fn: RunTemplate[AgentInput, AgentOutput],
) -> Union[
_RunnableAgent[TaskInput, TaskOutput],
_RunnableOutputOnlyAgent[TaskInput, TaskOutput],
_RunnableStreamAgent[TaskInput, TaskOutput],
_RunnableStreamOutputOnlyAgent[TaskInput, TaskOutput],
_RunnableAgent[AgentInput, AgentOutput],
_RunnableOutputOnlyAgent[AgentInput, AgentOutput],
_RunnableStreamAgent[AgentInput, AgentOutput],
_RunnableStreamOutputOnlyAgent[AgentInput, AgentOutput],
]:
stream, output_only, input_cls, output_cls = extract_fn_spec(fn)

Expand Down Expand Up @@ -156,8 +168,8 @@ def agent_wrapper(
agent_id: Optional[str] = None,
version: Optional[VersionReference] = None,
model: Optional[Model] = None,
) -> TaskDecorator:
def wrap(fn: RunTemplate[TaskInput, TaskOutput]) -> FinalRunTemplate[TaskInput, TaskOutput]:
) -> AgentDecorator:
def wrap(fn: RunTemplate[AgentInput, AgentOutput]) -> FinalRunTemplate[AgentInput, AgentOutput]:
tid = agent_id or agent_id_from_fn_name(fn)
return functools.wraps(fn)(wrap_run_template(client, tid, schema_id, version, model, fn)) # pyright: ignore [reportReturnType]

Expand Down
12 changes: 6 additions & 6 deletions workflowai/core/client/_fn_utils_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,16 +21,16 @@
from workflowai.core.domain.run import Run


async def say_hello(task_input: HelloTaskInput) -> HelloTaskOutput: ...
async def say_hello(_: HelloTaskInput) -> HelloTaskOutput: ...


async def say_hello_run(task_input: HelloTaskInput) -> Run[HelloTaskOutput]: ...
async def say_hello_run(bla: HelloTaskInput) -> Run[HelloTaskOutput]: ...


def stream_hello(task_input: HelloTaskInput) -> AsyncIterator[HelloTaskOutput]: ...
def stream_hello(_: HelloTaskInput) -> AsyncIterator[HelloTaskOutput]: ...


def stream_hello_run(task_input: HelloTaskInput) -> AsyncIterator[Run[HelloTaskOutput]]: ...
def stream_hello_run(_: HelloTaskInput) -> AsyncIterator[Run[HelloTaskOutput]]: ...


class TestGetGenericArgs:
Expand Down Expand Up @@ -76,7 +76,7 @@ async def test_fn_run(self, mock_api_client: Mock):
run = await wrapped(HelloTaskInput(name="World"))
assert isinstance(run, Run)
assert run.id == "1"
assert run.task_output == HelloTaskOutput(message="Hello, World!")
assert run.output == HelloTaskOutput(message="Hello, World!")

def fn_stream(self, task_input: HelloTaskInput) -> AsyncIterator[Run[HelloTaskOutput]]: ...

Expand All @@ -89,7 +89,7 @@ async def test_fn_stream(self, mock_api_client: Mock):
assert len(chunks) == 1
assert isinstance(chunks[0], Run)
assert chunks[0].id == "1"
assert chunks[0].task_output == HelloTaskOutput(message="Hello, World!")
assert chunks[0].output == HelloTaskOutput(message="Hello, World!")

async def fn_run_output_only(self, task_input: HelloTaskInput) -> HelloTaskOutput: ...

Expand Down
14 changes: 7 additions & 7 deletions workflowai/core/client/_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,8 @@
from workflowai.core.client._types import OutputValidator
from workflowai.core.domain.cache_usage import CacheUsage
from workflowai.core.domain.run import Run
from workflowai.core.domain.task import TaskOutput
from workflowai.core.domain.task_version import TaskVersion
from workflowai.core.domain.task import AgentOutput
from workflowai.core.domain.version import Version as DVersion
from workflowai.core.domain.version_properties import VersionProperties as DVersionProperties


Expand Down Expand Up @@ -47,14 +47,14 @@ class RunResponse(BaseModel):
duration_seconds: Optional[float] = None
cost_usd: Optional[float] = None

def to_domain(self, task_id: str, task_schema_id: int, validator: OutputValidator[TaskOutput]) -> Run[TaskOutput]:
def to_domain(self, task_id: str, task_schema_id: int, validator: OutputValidator[AgentOutput]) -> Run[AgentOutput]:
return Run(
id=self.id,
task_id=task_id,
task_schema_id=task_schema_id,
task_output=validator(self.task_output),
agent_id=task_id,
schema_id=task_schema_id,
output=validator(self.task_output),
version=self.version
and TaskVersion(
and DVersion(
properties=DVersionProperties.model_construct(
None,
**self.version.properties,
Expand Down
12 changes: 6 additions & 6 deletions workflowai/core/client/_models_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,19 +42,19 @@ def test_no_version_not_optional(self):

parsed = chunk.to_domain(task_id="1", task_schema_id=1, validator=tolerant_validator(_TaskOutput))
assert isinstance(parsed, Run)
assert parsed.task_output.a == 1
assert parsed.output.a == 1
# b is not defined
with pytest.raises(AttributeError):
assert parsed.task_output.b
assert parsed.output.b

def test_no_version_optional(self):
chunk = RunResponse.model_validate_json('{"id": "1", "task_output": {"a": 1}}')
assert chunk

parsed = chunk.to_domain(task_id="1", task_schema_id=1, validator=tolerant_validator(_TaskOutputOpt))
assert isinstance(parsed, Run)
assert parsed.task_output.a == 1
assert parsed.task_output.b is None
assert parsed.output.a == 1
assert parsed.output.b is None

def test_with_version(self):
chunk = RunResponse.model_validate_json(
Expand All @@ -64,8 +64,8 @@ def test_with_version(self):

parsed = chunk.to_domain(task_id="1", task_schema_id=1, validator=tolerant_validator(_TaskOutput))
assert isinstance(parsed, Run)
assert parsed.task_output.a == 1
assert parsed.task_output.b == "test"
assert parsed.output.a == 1
assert parsed.output.b == "test"

assert parsed.cost_usd == 0.1
assert parsed.duration_seconds == 1
Expand Down
Loading
Loading