diff --git a/tests/fixtures/completions.json b/tests/fixtures/completions.json new file mode 100644 index 0000000..7bba405 --- /dev/null +++ b/tests/fixtures/completions.json @@ -0,0 +1,24 @@ +{ + "completions": [ + { + "messages": [ + { + "role": "system", + "content": "I am instructions" + }, + { + "role": "user", + "content": "I am user message" + } + ], + "response": "This is a test response", + "usage": { + "completion_token_count": 222, + "completion_cost_usd": 0.00013319999999999999, + "prompt_token_count": 1230, + "prompt_cost_usd": 0.00018449999999999999, + "model_context_window_size": 1048576 + } + } + ] +} diff --git a/tests/fixtures/task_example.json b/tests/fixtures/task_example.json deleted file mode 100644 index 94718f2..0000000 --- a/tests/fixtures/task_example.json +++ /dev/null @@ -1,16 +0,0 @@ -{ - "id": "8f635b73-f403-47ee-bff9-18320616c6cc", - "task_id": "citytocapital", - "task_schema_id": 1, - "task_input": { - "name": "Houston" - }, - "task_input_hash": "403c3739ab1c20643336dde3ad2950bb", - "task_input_preview": "city: \"Houston\"", - "task_output": { - "message": "Austin" - }, - "task_output_hash": "41a00820c2d20f738de5cc6bcb02b550", - "task_output_preview": "capital: \"Austin\"", - "created_at": "2024-05-31T01:38:48.688000Z" -} diff --git a/workflowai/core/client/_models.py b/workflowai/core/client/_models.py index 7dd55e7..3f060c1 100644 --- a/workflowai/core/client/_models.py +++ b/workflowai/core/client/_models.py @@ -5,6 +5,7 @@ from workflowai.core._common_types import OutputValidator from workflowai.core.domain.cache_usage import CacheUsage +from workflowai.core.domain.completion import Completion from workflowai.core.domain.run import Run from workflowai.core.domain.task import AgentOutput from workflowai.core.domain.tool_call import ToolCall as DToolCall @@ -160,6 +161,7 @@ class CreateAgentResponse(BaseModel): class ModelMetadata(BaseModel): """Metadata for a model.""" + provider_name: str = Field(description="Name of the model provider") price_per_input_token_usd: Optional[float] = Field(None, description="Cost per input token in USD") price_per_output_token_usd: Optional[float] = Field(None, description="Cost per output token in USD") @@ -170,6 +172,7 @@ class ModelMetadata(BaseModel): class ModelInfo(BaseModel): """Information about a model.""" + id: str = Field(description="Unique identifier for the model") name: str = Field(description="Display name of the model") icon_url: Optional[str] = Field(None, description="URL for the model's icon") @@ -187,11 +190,19 @@ class ModelInfo(BaseModel): T = TypeVar("T") + class Page(BaseModel, Generic[T]): """A generic paginated response.""" + items: list[T] = Field(description="List of items in this page") count: Optional[int] = Field(None, description="Total number of items available") class ListModelsResponse(Page[ModelInfo]): """Response from the list models API endpoint.""" + + +class CompletionsResponse(BaseModel): + """Response from the completions API endpoint.""" + + completions: list[Completion] diff --git a/workflowai/core/client/agent.py b/workflowai/core/client/agent.py index 81004f0..5198d6e 100644 --- a/workflowai/core/client/agent.py +++ b/workflowai/core/client/agent.py @@ -9,6 +9,7 @@ from workflowai.core._common_types import BaseRunParams, OutputValidator, VersionRunParams from workflowai.core.client._api import APIClient from workflowai.core.client._models import ( + CompletionsResponse, CreateAgentRequest, CreateAgentResponse, ListModelsResponse, @@ -24,6 +25,7 @@ intolerant_validator, tolerant_validator, ) +from workflowai.core.domain.completion import Completion from workflowai.core.domain.errors import BaseError, WorkflowAIError from workflowai.core.domain.run import Run from workflowai.core.domain.task import AgentInput, AgentOutput @@ -493,3 +495,18 @@ async def list_models(self) -> list[ModelInfo]: returns=ListModelsResponse, ) return response.items + + async def fetch_completions(self, run_id: str) -> list[Completion]: + """Fetch the completions for a run. + + Args: + run_id (str): The id of the run to fetch completions for. + + Returns: + CompletionsResponse: The completions for the run. + """ + raw = await self.api.get( + f"/v1/_/agents/{self.agent_id}/runs/{run_id}/completions", + returns=CompletionsResponse, + ) + return raw.completions diff --git a/workflowai/core/client/agent_test.py b/workflowai/core/client/agent_test.py index 7168b43..f62bf59 100644 --- a/workflowai/core/client/agent_test.py +++ b/workflowai/core/client/agent_test.py @@ -19,6 +19,7 @@ from workflowai.core.client.client import ( WorkflowAI, ) +from workflowai.core.domain.completion import Completion, CompletionUsage, Message from workflowai.core.domain.errors import WorkflowAIError from workflowai.core.domain.run import Run from workflowai.core.domain.version_properties import VersionProperties @@ -539,3 +540,31 @@ async def test_list_models_registers_if_needed( assert models[0].modes == ["chat"] assert models[0].metadata is not None assert models[0].metadata.provider_name == "OpenAI" + + +class TestFetchCompletions: + async def test_fetch_completions(self, agent: Agent[HelloTaskInput, HelloTaskOutput], httpx_mock: HTTPXMock): + """Test that fetch_completions correctly fetches and returns completions.""" + # Mock the HTTP response instead of the API client method + httpx_mock.add_response( + url="http://localhost:8000/v1/_/agents/123/runs/1/completions", + json=fixtures_json("completions.json"), + ) + + completions = await agent.fetch_completions("1") + assert completions == [ + Completion( + messages=[ + Message(role="system", content="I am instructions"), + Message(role="user", content="I am user message"), + ], + response="This is a test response", + usage=CompletionUsage( + completion_token_count=222, + completion_cost_usd=0.00013319999999999999, + prompt_token_count=1230, + prompt_cost_usd=0.00018449999999999999, + model_context_window_size=1048576, + ), + ), + ] diff --git a/workflowai/core/domain/completion.py b/workflowai/core/domain/completion.py new file mode 100644 index 0000000..ca20342 --- /dev/null +++ b/workflowai/core/domain/completion.py @@ -0,0 +1,39 @@ +from typing import Optional + +from pydantic import BaseModel, Field + + +class CompletionUsage(BaseModel): + """Usage information for a completion.""" + + completion_token_count: Optional[int] = None + completion_cost_usd: Optional[float] = None + reasoning_token_count: Optional[int] = None + prompt_token_count: Optional[int] = None + prompt_token_count_cached: Optional[int] = None + prompt_cost_usd: Optional[float] = None + prompt_audio_token_count: Optional[int] = None + prompt_audio_duration_seconds: Optional[float] = None + prompt_image_count: Optional[int] = None + model_context_window_size: Optional[int] = None + + +class Message(BaseModel): + """A message in a completion.""" + + role: str = "" + content: str = "" + + +class Completion(BaseModel): + """A completion from the model.""" + + messages: list[Message] = Field(default_factory=list) + response: Optional[str] = None + usage: CompletionUsage = Field(default_factory=CompletionUsage) + + +class CompletionsResponse(BaseModel): + """Response from the completions API endpoint.""" + + completions: list[Completion] diff --git a/workflowai/core/domain/run.py b/workflowai/core/domain/run.py index 686321c..759fe4a 100644 --- a/workflowai/core/domain/run.py +++ b/workflowai/core/domain/run.py @@ -7,6 +7,7 @@ from workflowai import env from workflowai.core import _common_types from workflowai.core.client import _types +from workflowai.core.domain.completion import Completion from workflowai.core.domain.errors import BaseError from workflowai.core.domain.task import AgentOutput from workflowai.core.domain.tool_call import ToolCall, ToolCallRequest, ToolCallResult @@ -130,6 +131,23 @@ def __str__(self) -> str: def run_url(self): return f"{env.WORKFLOWAI_APP_URL}/_/agents/{self.agent_id}/runs/{self.id}" + async def fetch_completions(self) -> list[Completion]: + """Fetch the completions for this run. + + Returns: + CompletionsResponse: The completions response containing a list of completions + with their messages, responses and usage information. + + Raises: + ValueError: If the agent is not set or if the run id is not set. + """ + if not self._agent: + raise ValueError("Agent is not set") + if not self.id: + raise ValueError("Run id is not set") + + return await self._agent.fetch_completions(self.id) + class _AgentBase(Protocol, Generic[AgentOutput]): async def reply( @@ -141,3 +159,5 @@ async def reply( ) -> "Run[AgentOutput]": """Reply to a run. Either a user_message or tool_results must be provided.""" ... + + async def fetch_completions(self, run_id: str) -> list[Completion]: ... diff --git a/workflowai/core/domain/run_test.py b/workflowai/core/domain/run_test.py index 886f5d5..ac0aae7 100644 --- a/workflowai/core/domain/run_test.py +++ b/workflowai/core/domain/run_test.py @@ -3,7 +3,11 @@ import pytest from pydantic import BaseModel -from workflowai.core.domain.run import Run +from workflowai.core.domain.completion import Completion, CompletionUsage, Message +from workflowai.core.domain.run import ( + Run, + _AgentBase, # pyright: ignore [reportPrivateUsage] +) from workflowai.core.domain.version import Version from workflowai.core.domain.version_properties import VersionProperties @@ -13,8 +17,14 @@ class _TestOutput(BaseModel): @pytest.fixture -def run1() -> Run[_TestOutput]: - return Run[_TestOutput]( +def mock_agent() -> Mock: + mock = Mock(spec=_AgentBase) + return mock + + +@pytest.fixture +def run1(mock_agent: Mock) -> Run[_TestOutput]: + run = Run[_TestOutput]( id="run-id", agent_id="agent-id", schema_id=1, @@ -26,6 +36,8 @@ def run1() -> Run[_TestOutput]: tool_calls=[], tool_call_requests=[], ) + run._agent = mock_agent # pyright: ignore [reportPrivateUsage] + return run @pytest.fixture @@ -128,3 +140,70 @@ class TestRunURL: @patch("workflowai.env.WORKFLOWAI_APP_URL", "https://workflowai.hello") def test_run_url(self, run1: Run[_TestOutput]): assert run1.run_url == "https://workflowai.hello/_/agents/agent-id/runs/run-id" + + +class TestFetchCompletions: + """Tests for the fetch_completions method of the Run class.""" + + # Test that the underlying agent is called with the proper run id + async def test_fetch_completions_success(self, run1: Run[_TestOutput], mock_agent: Mock): + mock_agent.fetch_completions.return_value = [ + Completion( + messages=[ + Message(role="system", content="You are a helpful assistant"), + Message(role="user", content="Hello"), + Message(role="assistant", content="Hi there!"), + ], + response="Hi there!", + usage=CompletionUsage( + completion_token_count=3, + completion_cost_usd=0.001, + reasoning_token_count=10, + prompt_token_count=20, + prompt_token_count_cached=0, + prompt_cost_usd=0.002, + prompt_audio_token_count=0, + prompt_audio_duration_seconds=0, + prompt_image_count=0, + model_context_window_size=32000, + ), + ), + ] + + # Call fetch_completions + completions = await run1.fetch_completions() + + # Verify the API was called correctly + mock_agent.fetch_completions.assert_called_once_with("run-id") + + # Verify the response + assert len(completions) == 1 + completion = completions[0] + assert len(completion.messages) == 3 + assert completion.messages[0].role == "system" + assert completion.messages[0].content == "You are a helpful assistant" + assert completion.response == "Hi there!" + assert completion.usage.completion_token_count == 3 + assert completion.usage.completion_cost_usd == 0.001 + + # Test that fetch_completions fails appropriately when the agent is not set: + # 1. This is a common error case that occurs when a Run object is created without an agent + # 2. The method should fail fast with a clear error message before attempting any API calls + # 3. This protects users from confusing errors that would occur if we tried to use the API client + async def test_fetch_completions_no_agent(self, run1: Run[_TestOutput]): + run1._agent = None # pyright: ignore [reportPrivateUsage] + with pytest.raises(ValueError, match="Agent is not set"): + await run1.fetch_completions() + + # Test that fetch_completions fails appropriately when the run ID is not set: + # 1. The run ID is required to construct the API endpoint URL + # 2. Without it, we can't make a valid API request + # 3. This validates that we fail fast with a clear error message + # 4. This should never happen in practice (as Run objects always have an ID), + # but we test it for completeness and to ensure robust error handling + async def test_fetch_completions_no_id(self, run1: Run[_TestOutput]): + mock_agent = Mock() + run1._agent = mock_agent # pyright: ignore [reportPrivateUsage] + run1.id = "" # Empty ID + with pytest.raises(ValueError, match="Run id is not set"): + await run1.fetch_completions()