From 0e266f82d95675e8a0f00e544dc9bbdf2b1c92b7 Mon Sep 17 00:00:00 2001 From: Pierre Date: Fri, 21 Feb 2025 12:28:24 -0700 Subject: [PATCH 1/3] feat: Add fetch_completions method to Run class Added ability to fetch completions for a run via the /v1/_/agents/{agent_id}/runs/{run_id}/completions endpoint. Changes:\n- Added CompletionUsage, Message, Completion, and CompletionsResponse models\n- Added fetch_completions() method to Run class\n- Added comprehensive test suite with success and error cases\n- Added detailed test documentation explaining test rationale\n\nThe fetch_completions method allows retrieving the full conversation history and token usage information for a completed run. --- workflowai/core/domain/run.py | 56 +++++++++++++++++++ workflowai/core/domain/run_test.py | 86 +++++++++++++++++++++++++++++- 2 files changed, 141 insertions(+), 1 deletion(-) diff --git a/workflowai/core/domain/run.py b/workflowai/core/domain/run.py index f7f7950..348cb44 100644 --- a/workflowai/core/domain/run.py +++ b/workflowai/core/domain/run.py @@ -7,12 +7,45 @@ from workflowai import env from workflowai.core import _common_types from workflowai.core.client import _types +from workflowai.core.client._api import APIClient from workflowai.core.domain.errors import BaseError from workflowai.core.domain.task import AgentOutput from workflowai.core.domain.tool_call import ToolCall, ToolCallRequest, ToolCallResult from workflowai.core.domain.version import Version +class CompletionUsage(BaseModel): + """Usage information for a completion.""" + completion_token_count: int + completion_cost_usd: float + reasoning_token_count: int + prompt_token_count: int + prompt_token_count_cached: int + prompt_cost_usd: float + prompt_audio_token_count: int + prompt_audio_duration_seconds: float + prompt_image_count: int + model_context_window_size: int + + +class Message(BaseModel): + """A message in a completion.""" + role: str + content: str + + +class Completion(BaseModel): + """A completion from the model.""" + messages: list[Message] + response: str + usage: CompletionUsage + + +class CompletionsResponse(BaseModel): + """Response from the completions API endpoint.""" + completions: list[Completion] + + class Run(BaseModel, Generic[AgentOutput]): """ A run is an instance of a agent with a specific input and output. @@ -125,8 +158,31 @@ def __str__(self) -> str: def run_url(self): return f"{env.WORKFLOWAI_APP_URL}/agents/{self.agent_id}/runs/{self.id}" + async def fetch_completions(self) -> CompletionsResponse: + """Fetch the completions for this run. + + Returns: + CompletionsResponse: The completions response containing a list of completions + with their messages, responses and usage information. + + Raises: + ValueError: If the agent is not set or if the run id is not set. + """ + if not self._agent: + raise ValueError("Agent is not set") + if not self.id: + raise ValueError("Run id is not set") + + # The "_" refers to the currently authenticated tenant's namespace + return await self._agent.api.get( + f"/v1/_/agents/{self.agent_id}/runs/{self.id}/completions", + returns=CompletionsResponse, + ) + class _AgentBase(Protocol, Generic[AgentOutput]): + api: APIClient + async def reply( self, run_id: str, diff --git a/workflowai/core/domain/run_test.py b/workflowai/core/domain/run_test.py index 230576e..d2fc46d 100644 --- a/workflowai/core/domain/run_test.py +++ b/workflowai/core/domain/run_test.py @@ -3,7 +3,8 @@ import pytest from pydantic import BaseModel -from workflowai.core.domain.run import Run +from workflowai.core.client._api import APIClient +from workflowai.core.domain.run import Completion, CompletionsResponse, CompletionUsage, Message, Run from workflowai.core.domain.version import Version from workflowai.core.domain.version_properties import VersionProperties @@ -120,3 +121,86 @@ class TestRunURL: @patch("workflowai.env.WORKFLOWAI_APP_URL", "https://workflowai.hello") def test_run_url(self, run1: Run[_TestOutput]): assert run1.run_url == "https://workflowai.hello/agents/agent-1/runs/test-id" + + +class TestFetchCompletions: + """Tests for the fetch_completions method of the Run class.""" + + # Test the successful case of fetching completions: + # 1. Verifies that the API client is called with the correct URL and parameters + # 2. Verifies that the response is properly parsed into CompletionsResponse + # 3. Checks that all fields (messages, response, usage) are correctly populated + # 4. Ensures the completion contains the expected conversation history (system, user, assistant) + async def test_fetch_completions_success(self, run1: Run[_TestOutput]): + # Create a mock API client + mock_api = Mock(spec=APIClient) + mock_api.get.return_value = CompletionsResponse( + completions=[ + Completion( + messages=[ + Message(role="system", content="You are a helpful assistant"), + Message(role="user", content="Hello"), + Message(role="assistant", content="Hi there!"), + ], + response="Hi there!", + usage=CompletionUsage( + completion_token_count=3, + completion_cost_usd=0.001, + reasoning_token_count=10, + prompt_token_count=20, + prompt_token_count_cached=0, + prompt_cost_usd=0.002, + prompt_audio_token_count=0, + prompt_audio_duration_seconds=0, + prompt_image_count=0, + model_context_window_size=32000, + ), + ), + ], + ) + + # Create a mock agent with the mock API client + mock_agent = Mock() + mock_agent.api = mock_api + run1._agent = mock_agent # pyright: ignore [reportPrivateUsage] + + # Call fetch_completions + completions = await run1.fetch_completions() + + # Verify the API was called correctly + mock_api.get.assert_called_once_with( + "/v1/_/agents/agent-1/runs/test-id/completions", + returns=CompletionsResponse, + ) + + # Verify the response + assert len(completions.completions) == 1 + completion = completions.completions[0] + assert len(completion.messages) == 3 + assert completion.messages[0].role == "system" + assert completion.messages[0].content == "You are a helpful assistant" + assert completion.response == "Hi there!" + assert completion.usage.completion_token_count == 3 + assert completion.usage.completion_cost_usd == 0.001 + + # Test that fetch_completions fails appropriately when the agent is not set: + # 1. This is a common error case that occurs when a Run object is created without an agent + # 2. The method should fail fast with a clear error message before attempting any API calls + # 3. This protects users from confusing errors that would occur if we tried to use the API client + async def test_fetch_completions_no_agent(self, run1: Run[_TestOutput]): + run1._agent = None # pyright: ignore [reportPrivateUsage] + with pytest.raises(ValueError, match="Agent is not set"): + await run1.fetch_completions() + + # Test that fetch_completions fails appropriately when the run ID is not set: + # 1. The run ID is required to construct the API endpoint URL + # 2. Without it, we can't make a valid API request + # 3. This validates that we fail fast with a clear error message + # 4. This should never happen in practice (as Run objects always have an ID), + # but we test it for completeness and to ensure robust error handling + async def test_fetch_completions_no_id(self, run1: Run[_TestOutput]): + mock_agent = Mock() + run1._agent = mock_agent # pyright: ignore [reportPrivateUsage] + run1.id = "" # Empty ID + with pytest.raises(ValueError, match="Run id is not set"): + await run1.fetch_completions() From c2cc02bf6ef4b6ce7aab135ad7c5189457256815 Mon Sep 17 00:00:00 2001 From: Pierre Date: Fri, 21 Feb 2025 12:30:33 -0700 Subject: [PATCH 2/3] fix: Make api a property in _AgentBase protocol Fixed type error in Run._agent assignment by making the api field a property in the _AgentBase protocol to match the Agent class implementation. This ensures proper type compatibility between Agent and _AgentBase. --- workflowai/core/domain/run.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/workflowai/core/domain/run.py b/workflowai/core/domain/run.py index 348cb44..ee3c037 100644 --- a/workflowai/core/domain/run.py +++ b/workflowai/core/domain/run.py @@ -181,7 +181,8 @@ async def fetch_completions(self) -> CompletionsResponse: class _AgentBase(Protocol, Generic[AgentOutput]): - api: APIClient + @property + def api(self) -> APIClient: ... async def reply( self, From 37c97aa6d5090046b38dc5b678d8a1ef6f076448 Mon Sep 17 00:00:00 2001 From: Guillaume Aquilina Date: Fri, 21 Feb 2025 15:19:16 -0700 Subject: [PATCH 3/3] feat: move fetch completion to agent --- tests/fixtures/completions.json | 24 ++++++++ tests/fixtures/task_example.json | 16 ----- workflowai/core/client/_models.py | 11 ++++ workflowai/core/client/agent.py | 17 ++++++ workflowai/core/client/agent_test.py | 29 ++++++++++ workflowai/core/domain/completion.py | 39 +++++++++++++ workflowai/core/domain/run.py | 47 ++------------- workflowai/core/domain/run_test.py | 87 +++++++++++++--------------- 8 files changed, 166 insertions(+), 104 deletions(-) create mode 100644 tests/fixtures/completions.json delete mode 100644 tests/fixtures/task_example.json create mode 100644 workflowai/core/domain/completion.py diff --git a/tests/fixtures/completions.json b/tests/fixtures/completions.json new file mode 100644 index 0000000..7bba405 --- /dev/null +++ b/tests/fixtures/completions.json @@ -0,0 +1,24 @@ +{ + "completions": [ + { + "messages": [ + { + "role": "system", + "content": "I am instructions" + }, + { + "role": "user", + "content": "I am user message" + } + ], + "response": "This is a test response", + "usage": { + "completion_token_count": 222, + "completion_cost_usd": 0.00013319999999999999, + "prompt_token_count": 1230, + "prompt_cost_usd": 0.00018449999999999999, + "model_context_window_size": 1048576 + } + } + ] +} diff --git a/tests/fixtures/task_example.json b/tests/fixtures/task_example.json deleted file mode 100644 index 94718f2..0000000 --- a/tests/fixtures/task_example.json +++ /dev/null @@ -1,16 +0,0 @@ -{ - "id": "8f635b73-f403-47ee-bff9-18320616c6cc", - "task_id": "citytocapital", - "task_schema_id": 1, - "task_input": { - "name": "Houston" - }, - "task_input_hash": "403c3739ab1c20643336dde3ad2950bb", - "task_input_preview": "city: \"Houston\"", - "task_output": { - "message": "Austin" - }, - "task_output_hash": "41a00820c2d20f738de5cc6bcb02b550", - "task_output_preview": "capital: \"Austin\"", - "created_at": "2024-05-31T01:38:48.688000Z" -} diff --git a/workflowai/core/client/_models.py b/workflowai/core/client/_models.py index 7dd55e7..3f060c1 100644 --- a/workflowai/core/client/_models.py +++ b/workflowai/core/client/_models.py @@ -5,6 +5,7 @@ from workflowai.core._common_types import OutputValidator from workflowai.core.domain.cache_usage import CacheUsage +from workflowai.core.domain.completion import Completion from workflowai.core.domain.run import Run from workflowai.core.domain.task import AgentOutput from workflowai.core.domain.tool_call import ToolCall as DToolCall @@ -160,6 +161,7 @@ class CreateAgentResponse(BaseModel): class ModelMetadata(BaseModel): """Metadata for a model.""" + provider_name: str = Field(description="Name of the model provider") price_per_input_token_usd: Optional[float] = Field(None, description="Cost per input token in USD") price_per_output_token_usd: Optional[float] = Field(None, description="Cost per output token in USD") @@ -170,6 +172,7 @@ class ModelMetadata(BaseModel): class ModelInfo(BaseModel): """Information about a model.""" + id: str = Field(description="Unique identifier for the model") name: str = Field(description="Display name of the model") icon_url: Optional[str] = Field(None, description="URL for the model's icon") @@ -187,11 +190,19 @@ class ModelInfo(BaseModel): T = TypeVar("T") + class Page(BaseModel, Generic[T]): """A generic paginated response.""" + items: list[T] = Field(description="List of items in this page") count: Optional[int] = Field(None, description="Total number of items available") class ListModelsResponse(Page[ModelInfo]): """Response from the list models API endpoint.""" + + +class CompletionsResponse(BaseModel): + """Response from the completions API endpoint.""" + + completions: list[Completion] diff --git a/workflowai/core/client/agent.py b/workflowai/core/client/agent.py index 81004f0..5198d6e 100644 --- a/workflowai/core/client/agent.py +++ b/workflowai/core/client/agent.py @@ -9,6 +9,7 @@ from workflowai.core._common_types import BaseRunParams, OutputValidator, VersionRunParams from workflowai.core.client._api import APIClient from workflowai.core.client._models import ( + CompletionsResponse, CreateAgentRequest, CreateAgentResponse, ListModelsResponse, @@ -24,6 +25,7 @@ intolerant_validator, tolerant_validator, ) +from workflowai.core.domain.completion import Completion from workflowai.core.domain.errors import BaseError, WorkflowAIError from workflowai.core.domain.run import Run from workflowai.core.domain.task import AgentInput, AgentOutput @@ -493,3 +495,18 @@ async def list_models(self) -> list[ModelInfo]: returns=ListModelsResponse, ) return response.items + + async def fetch_completions(self, run_id: str) -> list[Completion]: + """Fetch the completions for a run. + + Args: + run_id (str): The id of the run to fetch completions for. + + Returns: + CompletionsResponse: The completions for the run. + """ + raw = await self.api.get( + f"/v1/_/agents/{self.agent_id}/runs/{run_id}/completions", + returns=CompletionsResponse, + ) + return raw.completions diff --git a/workflowai/core/client/agent_test.py b/workflowai/core/client/agent_test.py index 7168b43..f62bf59 100644 --- a/workflowai/core/client/agent_test.py +++ b/workflowai/core/client/agent_test.py @@ -19,6 +19,7 @@ from workflowai.core.client.client import ( WorkflowAI, ) +from workflowai.core.domain.completion import Completion, CompletionUsage, Message from workflowai.core.domain.errors import WorkflowAIError from workflowai.core.domain.run import Run from workflowai.core.domain.version_properties import VersionProperties @@ -539,3 +540,31 @@ async def test_list_models_registers_if_needed( assert models[0].modes == ["chat"] assert models[0].metadata is not None assert models[0].metadata.provider_name == "OpenAI" + + +class TestFetchCompletions: + async def test_fetch_completions(self, agent: Agent[HelloTaskInput, HelloTaskOutput], httpx_mock: HTTPXMock): + """Test that fetch_completions correctly fetches and returns completions.""" + # Mock the HTTP response instead of the API client method + httpx_mock.add_response( + url="http://localhost:8000/v1/_/agents/123/runs/1/completions", + json=fixtures_json("completions.json"), + ) + + completions = await agent.fetch_completions("1") + assert completions == [ + Completion( + messages=[ + Message(role="system", content="I am instructions"), + Message(role="user", content="I am user message"), + ], + response="This is a test response", + usage=CompletionUsage( + completion_token_count=222, + completion_cost_usd=0.00013319999999999999, + prompt_token_count=1230, + prompt_cost_usd=0.00018449999999999999, + model_context_window_size=1048576, + ), + ), + ] diff --git a/workflowai/core/domain/completion.py b/workflowai/core/domain/completion.py new file mode 100644 index 0000000..ca20342 --- /dev/null +++ b/workflowai/core/domain/completion.py @@ -0,0 +1,39 @@ +from typing import Optional + +from pydantic import BaseModel, Field + + +class CompletionUsage(BaseModel): + """Usage information for a completion.""" + + completion_token_count: Optional[int] = None + completion_cost_usd: Optional[float] = None + reasoning_token_count: Optional[int] = None + prompt_token_count: Optional[int] = None + prompt_token_count_cached: Optional[int] = None + prompt_cost_usd: Optional[float] = None + prompt_audio_token_count: Optional[int] = None + prompt_audio_duration_seconds: Optional[float] = None + prompt_image_count: Optional[int] = None + model_context_window_size: Optional[int] = None + + +class Message(BaseModel): + """A message in a completion.""" + + role: str = "" + content: str = "" + + +class Completion(BaseModel): + """A completion from the model.""" + + messages: list[Message] = Field(default_factory=list) + response: Optional[str] = None + usage: CompletionUsage = Field(default_factory=CompletionUsage) + + +class CompletionsResponse(BaseModel): + """Response from the completions API endpoint.""" + + completions: list[Completion] diff --git a/workflowai/core/domain/run.py b/workflowai/core/domain/run.py index 4493c7f..759fe4a 100644 --- a/workflowai/core/domain/run.py +++ b/workflowai/core/domain/run.py @@ -7,45 +7,13 @@ from workflowai import env from workflowai.core import _common_types from workflowai.core.client import _types -from workflowai.core.client._api import APIClient +from workflowai.core.domain.completion import Completion from workflowai.core.domain.errors import BaseError from workflowai.core.domain.task import AgentOutput from workflowai.core.domain.tool_call import ToolCall, ToolCallRequest, ToolCallResult from workflowai.core.domain.version import Version -class CompletionUsage(BaseModel): - """Usage information for a completion.""" - completion_token_count: int - completion_cost_usd: float - reasoning_token_count: int - prompt_token_count: int - prompt_token_count_cached: int - prompt_cost_usd: float - prompt_audio_token_count: int - prompt_audio_duration_seconds: float - prompt_image_count: int - model_context_window_size: int - - -class Message(BaseModel): - """A message in a completion.""" - role: str - content: str - - -class Completion(BaseModel): - """A completion from the model.""" - messages: list[Message] - response: str - usage: CompletionUsage - - -class CompletionsResponse(BaseModel): - """Response from the completions API endpoint.""" - completions: list[Completion] - - class Run(BaseModel, Generic[AgentOutput]): """ A run is an instance of a agent with a specific input and output. @@ -163,7 +131,7 @@ def __str__(self) -> str: def run_url(self): return f"{env.WORKFLOWAI_APP_URL}/_/agents/{self.agent_id}/runs/{self.id}" - async def fetch_completions(self) -> CompletionsResponse: + async def fetch_completions(self) -> list[Completion]: """Fetch the completions for this run. Returns: @@ -178,17 +146,10 @@ async def fetch_completions(self) -> CompletionsResponse: if not self.id: raise ValueError("Run id is not set") - # The "_" refers to the currently authenticated tenant's namespace - return await self._agent.api.get( - f"/v1/_/agents/{self.agent_id}/runs/{self.id}/completions", - returns=CompletionsResponse, - ) + return await self._agent.fetch_completions(self.id) class _AgentBase(Protocol, Generic[AgentOutput]): - @property - def api(self) -> APIClient: ... - async def reply( self, run_id: str, @@ -198,3 +159,5 @@ async def reply( ) -> "Run[AgentOutput]": """Reply to a run. Either a user_message or tool_results must be provided.""" ... + + async def fetch_completions(self, run_id: str) -> list[Completion]: ... diff --git a/workflowai/core/domain/run_test.py b/workflowai/core/domain/run_test.py index 8202c17..ac0aae7 100644 --- a/workflowai/core/domain/run_test.py +++ b/workflowai/core/domain/run_test.py @@ -3,8 +3,11 @@ import pytest from pydantic import BaseModel -from workflowai.core.client._api import APIClient -from workflowai.core.domain.run import Completion, CompletionsResponse, CompletionUsage, Message, Run +from workflowai.core.domain.completion import Completion, CompletionUsage, Message +from workflowai.core.domain.run import ( + Run, + _AgentBase, # pyright: ignore [reportPrivateUsage] +) from workflowai.core.domain.version import Version from workflowai.core.domain.version_properties import VersionProperties @@ -14,8 +17,14 @@ class _TestOutput(BaseModel): @pytest.fixture -def run1() -> Run[_TestOutput]: - return Run[_TestOutput]( +def mock_agent() -> Mock: + mock = Mock(spec=_AgentBase) + return mock + + +@pytest.fixture +def run1(mock_agent: Mock) -> Run[_TestOutput]: + run = Run[_TestOutput]( id="run-id", agent_id="agent-id", schema_id=1, @@ -27,6 +36,8 @@ def run1() -> Run[_TestOutput]: tool_calls=[], tool_call_requests=[], ) + run._agent = mock_agent # pyright: ignore [reportPrivateUsage] + return run @pytest.fixture @@ -134,56 +145,40 @@ def test_run_url(self, run1: Run[_TestOutput]): class TestFetchCompletions: """Tests for the fetch_completions method of the Run class.""" - # Test the successful case of fetching completions: - # 1. Verifies that the API client is called with the correct URL and parameters - # 2. Verifies that the response is properly parsed into CompletionsResponse - # 3. Checks that all fields (messages, response, usage) are correctly populated - # 4. Ensures the completion contains the expected conversation history (system, user, assistant) - async def test_fetch_completions_success(self, run1: Run[_TestOutput]): - # Create a mock API client - mock_api = Mock(spec=APIClient) - mock_api.get.return_value = CompletionsResponse( - completions=[ - Completion( - messages=[ - Message(role="system", content="You are a helpful assistant"), - Message(role="user", content="Hello"), - Message(role="assistant", content="Hi there!"), - ], - response="Hi there!", - usage=CompletionUsage( - completion_token_count=3, - completion_cost_usd=0.001, - reasoning_token_count=10, - prompt_token_count=20, - prompt_token_count_cached=0, - prompt_cost_usd=0.002, - prompt_audio_token_count=0, - prompt_audio_duration_seconds=0, - prompt_image_count=0, - model_context_window_size=32000, - ), + # Test that the underlying agent is called with the proper run id + async def test_fetch_completions_success(self, run1: Run[_TestOutput], mock_agent: Mock): + mock_agent.fetch_completions.return_value = [ + Completion( + messages=[ + Message(role="system", content="You are a helpful assistant"), + Message(role="user", content="Hello"), + Message(role="assistant", content="Hi there!"), + ], + response="Hi there!", + usage=CompletionUsage( + completion_token_count=3, + completion_cost_usd=0.001, + reasoning_token_count=10, + prompt_token_count=20, + prompt_token_count_cached=0, + prompt_cost_usd=0.002, + prompt_audio_token_count=0, + prompt_audio_duration_seconds=0, + prompt_image_count=0, + model_context_window_size=32000, ), - ], - ) - - # Create a mock agent with the mock API client - mock_agent = Mock() - mock_agent.api = mock_api - run1._agent = mock_agent # pyright: ignore [reportPrivateUsage] + ), + ] # Call fetch_completions completions = await run1.fetch_completions() # Verify the API was called correctly - mock_api.get.assert_called_once_with( - "/v1/_/agents/agent-id/runs/run-id/completions", - returns=CompletionsResponse, - ) + mock_agent.fetch_completions.assert_called_once_with("run-id") # Verify the response - assert len(completions.completions) == 1 - completion = completions.completions[0] + assert len(completions) == 1 + completion = completions[0] assert len(completion.messages) == 3 assert completion.messages[0].role == "system" assert completion.messages[0].content == "You are a helpful assistant"