diff --git a/pyproject.toml b/pyproject.toml index f6b5f26..94def17 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [tool.poetry] name = "workflowai" -version = "0.6.0.dev17" +version = "0.6.0.dev18" description = "" authors = ["Guillaume Aquilina "] readme = "README.md" diff --git a/workflowai/core/client/_models.py b/workflowai/core/client/_models.py index 3f060c1..8cf1bf1 100644 --- a/workflowai/core/client/_models.py +++ b/workflowai/core/client/_models.py @@ -202,6 +202,11 @@ class ListModelsResponse(Page[ModelInfo]): """Response from the list models API endpoint.""" +class ListModelsRequest(BaseModel): + instructions: Optional[str] = Field(default=None, description="Used to detect internal tools") + requires_tools: Optional[bool] = Field(default=None, description="Whether the agent uses external tools") + + class CompletionsResponse(BaseModel): """Response from the completions API endpoint.""" diff --git a/workflowai/core/client/agent.py b/workflowai/core/client/agent.py index 5198d6e..dcfefb2 100644 --- a/workflowai/core/client/agent.py +++ b/workflowai/core/client/agent.py @@ -12,6 +12,7 @@ CompletionsResponse, CreateAgentRequest, CreateAgentResponse, + ListModelsRequest, ListModelsResponse, ModelInfo, ReplyRequest, @@ -477,7 +478,11 @@ def _sanitize_validator(cls, kwargs: RunParams[AgentOutput], default: OutputVali validator = kwargs.pop("validator", default) return validator, cast(BaseRunParams, kwargs) - async def list_models(self) -> list[ModelInfo]: + async def list_models( + self, + instructions: Optional[str] = None, + requires_tools: Optional[bool] = None, + ) -> list[ModelInfo]: """Fetch the list of available models from the API for this agent. Returns: @@ -486,12 +491,22 @@ async def list_models(self) -> list[ModelInfo]: Raises: ValueError: If the agent has not been registered (schema_id is None). """ + if not self.schema_id: self.schema_id = await self.register() - response = await self.api.get( + request_data = ListModelsRequest(instructions=instructions, requires_tools=requires_tools) + + if instructions is None and self.version and isinstance(self.version, VersionProperties): + request_data.instructions = self.version.instructions + + if requires_tools is None and self._tools: + request_data.requires_tools = True + + response = await self.api.post( # The "_" refers to the currently authenticated tenant's namespace f"/v1/_/agents/{self.agent_id}/schemas/{self.schema_id}/models", + data=request_data, returns=ListModelsResponse, ) return response.items diff --git a/workflowai/core/client/agent_test.py b/workflowai/core/client/agent_test.py index f62bf59..8b4f4ab 100644 --- a/workflowai/core/client/agent_test.py +++ b/workflowai/core/client/agent_test.py @@ -35,6 +35,49 @@ def agent(api_client: APIClient): return Agent(agent_id="123", schema_id=1, input_cls=HelloTaskInput, output_cls=HelloTaskOutput, api=api_client) +@pytest.fixture +def agent_with_instructions(api_client: APIClient): + return Agent( + agent_id="123", + schema_id=1, + input_cls=HelloTaskInput, + output_cls=HelloTaskOutput, + api=api_client, + version=VersionProperties(instructions="Some instructions"), + ) + + +@pytest.fixture +def agent_with_tools(api_client: APIClient): + def some_tool() -> str: + return "Hello, world!" + + return Agent( + agent_id="123", + schema_id=1, + input_cls=HelloTaskInput, + output_cls=HelloTaskOutput, + api=api_client, + tools=[some_tool], + ) + + +@pytest.fixture +def agent_with_tools_and_instructions(api_client: APIClient): + def some_tool() -> str: + return "Hello, world!" + + return Agent( + agent_id="123", + schema_id=1, + input_cls=HelloTaskInput, + output_cls=HelloTaskOutput, + api=api_client, + version=VersionProperties(instructions="Some instructions"), + tools=[some_tool], + ) + + @pytest.fixture def agent_not_optional(api_client: APIClient): return Agent( @@ -463,7 +506,177 @@ async def test_list_models(agent: Agent[HelloTaskInput, HelloTaskOutput], httpx_ # Verify the HTTP request was made correctly request = httpx_mock.get_request() assert request is not None, "Expected an HTTP request to be made" + assert request.method == "POST" + assert request.url == "http://localhost:8000/v1/_/agents/123/schemas/1/models" + assert json.loads(request.content) == {} + + # Verify we get back the full ModelInfo objects + assert len(models) == 2 + assert isinstance(models[0], ModelInfo) + assert models[0].id == "gpt-4" + assert models[0].name == "GPT-4" + assert models[0].modes == ["chat"] + assert models[0].metadata is not None + assert models[0].metadata.provider_name == "OpenAI" + + assert isinstance(models[1], ModelInfo) + assert models[1].id == "claude-3" + assert models[1].name == "Claude 3" + assert models[1].modes == ["chat"] + assert models[1].metadata is not None + assert models[1].metadata.provider_name == "Anthropic" + + +@pytest.mark.asyncio +async def test_list_models_with_params_override(agent: Agent[HelloTaskInput, HelloTaskOutput], httpx_mock: HTTPXMock): + """Test that list_models correctly fetches and returns available models.""" + # Mock the HTTP response instead of the API client method + httpx_mock.add_response( + url="http://localhost:8000/v1/_/agents/123/schemas/1/models", + json={ + "items": [ + { + "id": "gpt-4", + "name": "GPT-4", + "icon_url": "https://example.com/gpt4.png", + "modes": ["chat"], + "is_not_supported_reason": None, + "average_cost_per_run_usd": 0.01, + "is_latest": True, + "metadata": { + "provider_name": "OpenAI", + "price_per_input_token_usd": 0.0001, + "price_per_output_token_usd": 0.0002, + "release_date": "2024-01-01", + "context_window_tokens": 128000, + "quality_index": 0.95, + }, + "is_default": True, + "providers": ["openai"], + }, + { + "id": "claude-3", + "name": "Claude 3", + "icon_url": "https://example.com/claude3.png", + "modes": ["chat"], + "is_not_supported_reason": None, + "average_cost_per_run_usd": 0.015, + "is_latest": True, + "metadata": { + "provider_name": "Anthropic", + "price_per_input_token_usd": 0.00015, + "price_per_output_token_usd": 0.00025, + "release_date": "2024-03-01", + "context_window_tokens": 200000, + "quality_index": 0.98, + }, + "is_default": False, + "providers": ["anthropic"], + }, + ], + "count": 2, + }, + ) + + # Call the method + models = await agent.list_models(instructions="Some override instructions", requires_tools=True) + + # Verify the HTTP request was made correctly + request = httpx_mock.get_request() + assert request is not None, "Expected an HTTP request to be made" + assert request.method == "POST" assert request.url == "http://localhost:8000/v1/_/agents/123/schemas/1/models" + assert json.loads(request.content) == { + "instructions": "Some override instructions", + "requires_tools": True, + } + + # Verify we get back the full ModelInfo objects + assert len(models) == 2 + assert isinstance(models[0], ModelInfo) + assert models[0].id == "gpt-4" + assert models[0].name == "GPT-4" + assert models[0].modes == ["chat"] + assert models[0].metadata is not None + assert models[0].metadata.provider_name == "OpenAI" + + assert isinstance(models[1], ModelInfo) + assert models[1].id == "claude-3" + assert models[1].name == "Claude 3" + assert models[1].modes == ["chat"] + assert models[1].metadata is not None + assert models[1].metadata.provider_name == "Anthropic" + + +@pytest.mark.asyncio +async def test_list_models_with_params_override_and_agent_with_tools_and_instructions( + agent_with_tools_and_instructions: Agent[HelloTaskInput, HelloTaskOutput], + httpx_mock: HTTPXMock, +): + """Test that list_models correctly fetches and returns available models.""" + # Mock the HTTP response instead of the API client method + httpx_mock.add_response( + url="http://localhost:8000/v1/_/agents/123/schemas/1/models", + json={ + "items": [ + { + "id": "gpt-4", + "name": "GPT-4", + "icon_url": "https://example.com/gpt4.png", + "modes": ["chat"], + "is_not_supported_reason": None, + "average_cost_per_run_usd": 0.01, + "is_latest": True, + "metadata": { + "provider_name": "OpenAI", + "price_per_input_token_usd": 0.0001, + "price_per_output_token_usd": 0.0002, + "release_date": "2024-01-01", + "context_window_tokens": 128000, + "quality_index": 0.95, + }, + "is_default": True, + "providers": ["openai"], + }, + { + "id": "claude-3", + "name": "Claude 3", + "icon_url": "https://example.com/claude3.png", + "modes": ["chat"], + "is_not_supported_reason": None, + "average_cost_per_run_usd": 0.015, + "is_latest": True, + "metadata": { + "provider_name": "Anthropic", + "price_per_input_token_usd": 0.00015, + "price_per_output_token_usd": 0.00025, + "release_date": "2024-03-01", + "context_window_tokens": 200000, + "quality_index": 0.98, + }, + "is_default": False, + "providers": ["anthropic"], + }, + ], + "count": 2, + }, + ) + + # Call the method + models = await agent_with_tools_and_instructions.list_models( + instructions="Some override instructions", + requires_tools=False, + ) + + # Verify the HTTP request was made correctly + request = httpx_mock.get_request() + assert request is not None, "Expected an HTTP request to be made" + assert request.method == "POST" + assert request.url == "http://localhost:8000/v1/_/agents/123/schemas/1/models" + assert json.loads(request.content) == { + "instructions": "Some override instructions", + "requires_tools": False, + } # Verify we get back the full ModelInfo objects assert len(models) == 2 @@ -530,7 +743,9 @@ async def test_list_models_registers_if_needed( reqs = httpx_mock.get_requests() assert len(reqs) == 2 assert reqs[0].url == "http://localhost:8000/v1/_/agents" + assert reqs[1].method == "POST" assert reqs[1].url == "http://localhost:8000/v1/_/agents/123/schemas/2/models" + assert json.loads(reqs[1].content) == {} # Verify we get back the full ModelInfo object assert len(models) == 1 @@ -542,6 +757,252 @@ async def test_list_models_registers_if_needed( assert models[0].metadata.provider_name == "OpenAI" +@pytest.mark.asyncio +async def test_list_models_with_instructions( + agent_with_instructions: Agent[HelloTaskInput, HelloTaskOutput], + httpx_mock: HTTPXMock, +): + """Test that list_models correctly fetches and returns available models.""" + # Mock the HTTP response instead of the API client method + httpx_mock.add_response( + method="POST", + url="http://localhost:8000/v1/_/agents/123/schemas/1/models", + json={ + "items": [ + { + "id": "gpt-4", + "name": "GPT-4", + "icon_url": "https://example.com/gpt4.png", + "modes": ["chat"], + "is_not_supported_reason": None, + "average_cost_per_run_usd": 0.01, + "is_latest": True, + "metadata": { + "provider_name": "OpenAI", + "price_per_input_token_usd": 0.0001, + "price_per_output_token_usd": 0.0002, + "release_date": "2024-01-01", + "context_window_tokens": 128000, + "quality_index": 0.95, + }, + "is_default": True, + "providers": ["openai"], + }, + { + "id": "claude-3", + "name": "Claude 3", + "icon_url": "https://example.com/claude3.png", + "modes": ["chat"], + "is_not_supported_reason": None, + "average_cost_per_run_usd": 0.015, + "is_latest": True, + "metadata": { + "provider_name": "Anthropic", + "price_per_input_token_usd": 0.00015, + "price_per_output_token_usd": 0.00025, + "release_date": "2024-03-01", + "context_window_tokens": 200000, + "quality_index": 0.98, + }, + "is_default": False, + "providers": ["anthropic"], + }, + ], + "count": 2, + }, + ) + + # Call the method + models = await agent_with_instructions.list_models() + + # Verify the HTTP request was made correctly + request = httpx_mock.get_request() + assert request is not None, "Expected an HTTP request to be made" + assert request.method == "POST" + assert request.url == "http://localhost:8000/v1/_/agents/123/schemas/1/models" + assert json.loads(request.content) == {"instructions": "Some instructions"} + + # Verify we get back the full ModelInfo objects + assert len(models) == 2 + assert isinstance(models[0], ModelInfo) + assert models[0].id == "gpt-4" + assert models[0].name == "GPT-4" + assert models[0].modes == ["chat"] + assert models[0].metadata is not None + assert models[0].metadata.provider_name == "OpenAI" + + assert isinstance(models[1], ModelInfo) + assert models[1].id == "claude-3" + assert models[1].name == "Claude 3" + assert models[1].modes == ["chat"] + assert models[1].metadata is not None + assert models[1].metadata.provider_name == "Anthropic" + + +@pytest.mark.asyncio +async def test_list_models_with_tools( + agent_with_tools: Agent[HelloTaskInput, HelloTaskOutput], + httpx_mock: HTTPXMock, +): + """Test that list_models correctly fetches and returns available models.""" + # Mock the HTTP response instead of the API client method + httpx_mock.add_response( + method="POST", + url="http://localhost:8000/v1/_/agents/123/schemas/1/models", + json={ + "items": [ + { + "id": "gpt-4", + "name": "GPT-4", + "icon_url": "https://example.com/gpt4.png", + "modes": ["chat"], + "is_not_supported_reason": None, + "average_cost_per_run_usd": 0.01, + "is_latest": True, + "metadata": { + "provider_name": "OpenAI", + "price_per_input_token_usd": 0.0001, + "price_per_output_token_usd": 0.0002, + "release_date": "2024-01-01", + "context_window_tokens": 128000, + "quality_index": 0.95, + }, + "is_default": True, + "providers": ["openai"], + }, + { + "id": "claude-3", + "name": "Claude 3", + "icon_url": "https://example.com/claude3.png", + "modes": ["chat"], + "is_not_supported_reason": None, + "average_cost_per_run_usd": 0.015, + "is_latest": True, + "metadata": { + "provider_name": "Anthropic", + "price_per_input_token_usd": 0.00015, + "price_per_output_token_usd": 0.00025, + "release_date": "2024-03-01", + "context_window_tokens": 200000, + "quality_index": 0.98, + }, + "is_default": False, + "providers": ["anthropic"], + }, + ], + "count": 2, + }, + ) + + # Call the method + models = await agent_with_tools.list_models() + + # Verify the HTTP request was made correctly + request = httpx_mock.get_request() + assert request is not None, "Expected an HTTP request to be made" + assert request.method == "POST" + assert request.url == "http://localhost:8000/v1/_/agents/123/schemas/1/models" + assert json.loads(request.content) == {"requires_tools": True} + + # Verify we get back the full ModelInfo objects + assert len(models) == 2 + assert isinstance(models[0], ModelInfo) + assert models[0].id == "gpt-4" + assert models[0].name == "GPT-4" + assert models[0].modes == ["chat"] + assert models[0].metadata is not None + assert models[0].metadata.provider_name == "OpenAI" + + assert isinstance(models[1], ModelInfo) + assert models[1].id == "claude-3" + assert models[1].name == "Claude 3" + assert models[1].modes == ["chat"] + assert models[1].metadata is not None + assert models[1].metadata.provider_name == "Anthropic" + + +@pytest.mark.asyncio +async def test_list_models_with_instructions_and_tools( + agent_with_tools_and_instructions: Agent[HelloTaskInput, HelloTaskOutput], + httpx_mock: HTTPXMock, +): + """Test that list_models correctly fetches and returns available models.""" + # Mock the HTTP response instead of the API client method + httpx_mock.add_response( + method="POST", + url="http://localhost:8000/v1/_/agents/123/schemas/1/models", + json={ + "items": [ + { + "id": "gpt-4", + "name": "GPT-4", + "icon_url": "https://example.com/gpt4.png", + "modes": ["chat"], + "is_not_supported_reason": None, + "average_cost_per_run_usd": 0.01, + "is_latest": True, + "metadata": { + "provider_name": "OpenAI", + "price_per_input_token_usd": 0.0001, + "price_per_output_token_usd": 0.0002, + "release_date": "2024-01-01", + "context_window_tokens": 128000, + "quality_index": 0.95, + }, + "is_default": True, + "providers": ["openai"], + }, + { + "id": "claude-3", + "name": "Claude 3", + "icon_url": "https://example.com/claude3.png", + "modes": ["chat"], + "is_not_supported_reason": None, + "average_cost_per_run_usd": 0.015, + "is_latest": True, + "metadata": { + "provider_name": "Anthropic", + "price_per_input_token_usd": 0.00015, + "price_per_output_token_usd": 0.00025, + "release_date": "2024-03-01", + "context_window_tokens": 200000, + "quality_index": 0.98, + }, + "is_default": False, + "providers": ["anthropic"], + }, + ], + "count": 2, + }, + ) + + # Call the method + models = await agent_with_tools_and_instructions.list_models() + + # Verify the HTTP request was made correctly + request = httpx_mock.get_request() + assert request is not None, "Expected an HTTP request to be made" + assert request.method == "POST" + assert request.url == "http://localhost:8000/v1/_/agents/123/schemas/1/models" + assert json.loads(request.content) == {"instructions": "Some instructions", "requires_tools": True} + + # Verify we get back the full ModelInfo objects + assert len(models) == 2 + assert isinstance(models[0], ModelInfo) + assert models[0].id == "gpt-4" + assert models[0].name == "GPT-4" + assert models[0].modes == ["chat"] + assert models[0].metadata is not None + assert models[0].metadata.provider_name == "OpenAI" + + assert isinstance(models[1], ModelInfo) + assert models[1].id == "claude-3" + assert models[1].name == "Claude 3" + assert models[1].modes == ["chat"] + assert models[1].metadata is not None + assert models[1].metadata.provider_name == "Anthropic" + + class TestFetchCompletions: async def test_fetch_completions(self, agent: Agent[HelloTaskInput, HelloTaskOutput], httpx_mock: HTTPXMock): """Test that fetch_completions correctly fetches and returns completions.""" diff --git a/workflowai/core/domain/run.py b/workflowai/core/domain/run.py index 759fe4a..05e32c9 100644 --- a/workflowai/core/domain/run.py +++ b/workflowai/core/domain/run.py @@ -150,6 +150,9 @@ async def fetch_completions(self) -> list[Completion]: class _AgentBase(Protocol, Generic[AgentOutput]): + # TODO: fix circular dep + from workflowai.core.client._models import ModelInfo + async def reply( self, run_id: str, @@ -161,3 +164,5 @@ async def reply( ... async def fetch_completions(self, run_id: str) -> list[Completion]: ... + + async def list_models(self) -> list[ModelInfo]: ...