From 606f66bb4c152b84c4d33c9763b684e699e82371 Mon Sep 17 00:00:00 2001 From: YBubu Date: Wed, 26 Feb 2025 10:51:13 +0000 Subject: [PATCH 1/3] pass instructions and requires_tools to /models --- pyproject.toml | 2 +- workflowai/core/client/_models.py | 5 + workflowai/core/client/agent.py | 13 +- workflowai/core/client/agent_test.py | 297 +++++++++++++++++++++++++++ 4 files changed, 315 insertions(+), 2 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index f6b5f26..94def17 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [tool.poetry] name = "workflowai" -version = "0.6.0.dev17" +version = "0.6.0.dev18" description = "" authors = ["Guillaume Aquilina "] readme = "README.md" diff --git a/workflowai/core/client/_models.py b/workflowai/core/client/_models.py index 3f060c1..8ae8b76 100644 --- a/workflowai/core/client/_models.py +++ b/workflowai/core/client/_models.py @@ -202,6 +202,11 @@ class ListModelsResponse(Page[ModelInfo]): """Response from the list models API endpoint.""" +class ListModelsRequest(BaseModel): + instructions: Optional[str] = Field(default=None, description="Used to detect internal tools") + requires_tools: Optional[bool] = Field(default=False, description="Whether the agent uses external tools") + + class CompletionsResponse(BaseModel): """Response from the completions API endpoint.""" diff --git a/workflowai/core/client/agent.py b/workflowai/core/client/agent.py index 5198d6e..de3279d 100644 --- a/workflowai/core/client/agent.py +++ b/workflowai/core/client/agent.py @@ -12,6 +12,7 @@ CompletionsResponse, CreateAgentRequest, CreateAgentResponse, + ListModelsRequest, ListModelsResponse, ModelInfo, ReplyRequest, @@ -486,12 +487,22 @@ async def list_models(self) -> list[ModelInfo]: Raises: ValueError: If the agent has not been registered (schema_id is None). """ + if not self.schema_id: self.schema_id = await self.register() - response = await self.api.get( + data = ListModelsRequest() + if self.version and isinstance(self.version, VersionProperties): + data.instructions = self.version.instructions + + if self._tools: + for _ in self._tools.values(): + data.requires_tools = True + + response = await self.api.post( # The "_" refers to the currently authenticated tenant's namespace f"/v1/_/agents/{self.agent_id}/schemas/{self.schema_id}/models", + data=data, returns=ListModelsResponse, ) return response.items diff --git a/workflowai/core/client/agent_test.py b/workflowai/core/client/agent_test.py index f62bf59..99dc244 100644 --- a/workflowai/core/client/agent_test.py +++ b/workflowai/core/client/agent_test.py @@ -35,6 +35,49 @@ def agent(api_client: APIClient): return Agent(agent_id="123", schema_id=1, input_cls=HelloTaskInput, output_cls=HelloTaskOutput, api=api_client) +@pytest.fixture +def agent_with_instructions(api_client: APIClient): + return Agent( + agent_id="123", + schema_id=1, + input_cls=HelloTaskInput, + output_cls=HelloTaskOutput, + api=api_client, + version=VersionProperties(instructions="Some instructions"), + ) + + +@pytest.fixture +def agent_with_tools(api_client: APIClient): + def some_tool() -> str: + return "Hello, world!" + + return Agent( + agent_id="123", + schema_id=1, + input_cls=HelloTaskInput, + output_cls=HelloTaskOutput, + api=api_client, + tools=[some_tool], + ) + + +@pytest.fixture +def agent_with_tools_and_instructions(api_client: APIClient): + def some_tool() -> str: + return "Hello, world!" + + return Agent( + agent_id="123", + schema_id=1, + input_cls=HelloTaskInput, + output_cls=HelloTaskOutput, + api=api_client, + version=VersionProperties(instructions="Some instructions"), + tools=[some_tool], + ) + + @pytest.fixture def agent_not_optional(api_client: APIClient): return Agent( @@ -463,7 +506,11 @@ async def test_list_models(agent: Agent[HelloTaskInput, HelloTaskOutput], httpx_ # Verify the HTTP request was made correctly request = httpx_mock.get_request() assert request is not None, "Expected an HTTP request to be made" + assert request.method == "POST" assert request.url == "http://localhost:8000/v1/_/agents/123/schemas/1/models" + assert json.loads(request.content) == { + "requires_tools": False, + } # Verify we get back the full ModelInfo objects assert len(models) == 2 @@ -530,7 +577,11 @@ async def test_list_models_registers_if_needed( reqs = httpx_mock.get_requests() assert len(reqs) == 2 assert reqs[0].url == "http://localhost:8000/v1/_/agents" + assert reqs[1].method == "POST" assert reqs[1].url == "http://localhost:8000/v1/_/agents/123/schemas/2/models" + assert json.loads(reqs[1].content) == { + "requires_tools": False, + } # Verify we get back the full ModelInfo object assert len(models) == 1 @@ -542,6 +593,252 @@ async def test_list_models_registers_if_needed( assert models[0].metadata.provider_name == "OpenAI" +@pytest.mark.asyncio +async def test_list_models_with_instructions( + agent_with_instructions: Agent[HelloTaskInput, HelloTaskOutput], + httpx_mock: HTTPXMock, +): + """Test that list_models correctly fetches and returns available models.""" + # Mock the HTTP response instead of the API client method + httpx_mock.add_response( + method="POST", + url="http://localhost:8000/v1/_/agents/123/schemas/1/models", + json={ + "items": [ + { + "id": "gpt-4", + "name": "GPT-4", + "icon_url": "https://example.com/gpt4.png", + "modes": ["chat"], + "is_not_supported_reason": None, + "average_cost_per_run_usd": 0.01, + "is_latest": True, + "metadata": { + "provider_name": "OpenAI", + "price_per_input_token_usd": 0.0001, + "price_per_output_token_usd": 0.0002, + "release_date": "2024-01-01", + "context_window_tokens": 128000, + "quality_index": 0.95, + }, + "is_default": True, + "providers": ["openai"], + }, + { + "id": "claude-3", + "name": "Claude 3", + "icon_url": "https://example.com/claude3.png", + "modes": ["chat"], + "is_not_supported_reason": None, + "average_cost_per_run_usd": 0.015, + "is_latest": True, + "metadata": { + "provider_name": "Anthropic", + "price_per_input_token_usd": 0.00015, + "price_per_output_token_usd": 0.00025, + "release_date": "2024-03-01", + "context_window_tokens": 200000, + "quality_index": 0.98, + }, + "is_default": False, + "providers": ["anthropic"], + }, + ], + "count": 2, + }, + ) + + # Call the method + models = await agent_with_instructions.list_models() + + # Verify the HTTP request was made correctly + request = httpx_mock.get_request() + assert request is not None, "Expected an HTTP request to be made" + assert request.method == "POST" + assert request.url == "http://localhost:8000/v1/_/agents/123/schemas/1/models" + assert json.loads(request.content) == {"instructions": "Some instructions", "requires_tools": False} + + # Verify we get back the full ModelInfo objects + assert len(models) == 2 + assert isinstance(models[0], ModelInfo) + assert models[0].id == "gpt-4" + assert models[0].name == "GPT-4" + assert models[0].modes == ["chat"] + assert models[0].metadata is not None + assert models[0].metadata.provider_name == "OpenAI" + + assert isinstance(models[1], ModelInfo) + assert models[1].id == "claude-3" + assert models[1].name == "Claude 3" + assert models[1].modes == ["chat"] + assert models[1].metadata is not None + assert models[1].metadata.provider_name == "Anthropic" + + +@pytest.mark.asyncio +async def test_list_models_with_tools( + agent_with_tools: Agent[HelloTaskInput, HelloTaskOutput], + httpx_mock: HTTPXMock, +): + """Test that list_models correctly fetches and returns available models.""" + # Mock the HTTP response instead of the API client method + httpx_mock.add_response( + method="POST", + url="http://localhost:8000/v1/_/agents/123/schemas/1/models", + json={ + "items": [ + { + "id": "gpt-4", + "name": "GPT-4", + "icon_url": "https://example.com/gpt4.png", + "modes": ["chat"], + "is_not_supported_reason": None, + "average_cost_per_run_usd": 0.01, + "is_latest": True, + "metadata": { + "provider_name": "OpenAI", + "price_per_input_token_usd": 0.0001, + "price_per_output_token_usd": 0.0002, + "release_date": "2024-01-01", + "context_window_tokens": 128000, + "quality_index": 0.95, + }, + "is_default": True, + "providers": ["openai"], + }, + { + "id": "claude-3", + "name": "Claude 3", + "icon_url": "https://example.com/claude3.png", + "modes": ["chat"], + "is_not_supported_reason": None, + "average_cost_per_run_usd": 0.015, + "is_latest": True, + "metadata": { + "provider_name": "Anthropic", + "price_per_input_token_usd": 0.00015, + "price_per_output_token_usd": 0.00025, + "release_date": "2024-03-01", + "context_window_tokens": 200000, + "quality_index": 0.98, + }, + "is_default": False, + "providers": ["anthropic"], + }, + ], + "count": 2, + }, + ) + + # Call the method + models = await agent_with_tools.list_models() + + # Verify the HTTP request was made correctly + request = httpx_mock.get_request() + assert request is not None, "Expected an HTTP request to be made" + assert request.method == "POST" + assert request.url == "http://localhost:8000/v1/_/agents/123/schemas/1/models" + assert json.loads(request.content) == {"requires_tools": True} + + # Verify we get back the full ModelInfo objects + assert len(models) == 2 + assert isinstance(models[0], ModelInfo) + assert models[0].id == "gpt-4" + assert models[0].name == "GPT-4" + assert models[0].modes == ["chat"] + assert models[0].metadata is not None + assert models[0].metadata.provider_name == "OpenAI" + + assert isinstance(models[1], ModelInfo) + assert models[1].id == "claude-3" + assert models[1].name == "Claude 3" + assert models[1].modes == ["chat"] + assert models[1].metadata is not None + assert models[1].metadata.provider_name == "Anthropic" + + +@pytest.mark.asyncio +async def test_list_models_with_instructions_and_tools( + agent_with_tools_and_instructions: Agent[HelloTaskInput, HelloTaskOutput], + httpx_mock: HTTPXMock, +): + """Test that list_models correctly fetches and returns available models.""" + # Mock the HTTP response instead of the API client method + httpx_mock.add_response( + method="POST", + url="http://localhost:8000/v1/_/agents/123/schemas/1/models", + json={ + "items": [ + { + "id": "gpt-4", + "name": "GPT-4", + "icon_url": "https://example.com/gpt4.png", + "modes": ["chat"], + "is_not_supported_reason": None, + "average_cost_per_run_usd": 0.01, + "is_latest": True, + "metadata": { + "provider_name": "OpenAI", + "price_per_input_token_usd": 0.0001, + "price_per_output_token_usd": 0.0002, + "release_date": "2024-01-01", + "context_window_tokens": 128000, + "quality_index": 0.95, + }, + "is_default": True, + "providers": ["openai"], + }, + { + "id": "claude-3", + "name": "Claude 3", + "icon_url": "https://example.com/claude3.png", + "modes": ["chat"], + "is_not_supported_reason": None, + "average_cost_per_run_usd": 0.015, + "is_latest": True, + "metadata": { + "provider_name": "Anthropic", + "price_per_input_token_usd": 0.00015, + "price_per_output_token_usd": 0.00025, + "release_date": "2024-03-01", + "context_window_tokens": 200000, + "quality_index": 0.98, + }, + "is_default": False, + "providers": ["anthropic"], + }, + ], + "count": 2, + }, + ) + + # Call the method + models = await agent_with_tools_and_instructions.list_models() + + # Verify the HTTP request was made correctly + request = httpx_mock.get_request() + assert request is not None, "Expected an HTTP request to be made" + assert request.method == "POST" + assert request.url == "http://localhost:8000/v1/_/agents/123/schemas/1/models" + assert json.loads(request.content) == {"instructions": "Some instructions", "requires_tools": True} + + # Verify we get back the full ModelInfo objects + assert len(models) == 2 + assert isinstance(models[0], ModelInfo) + assert models[0].id == "gpt-4" + assert models[0].name == "GPT-4" + assert models[0].modes == ["chat"] + assert models[0].metadata is not None + assert models[0].metadata.provider_name == "OpenAI" + + assert isinstance(models[1], ModelInfo) + assert models[1].id == "claude-3" + assert models[1].name == "Claude 3" + assert models[1].modes == ["chat"] + assert models[1].metadata is not None + assert models[1].metadata.provider_name == "Anthropic" + + class TestFetchCompletions: async def test_fetch_completions(self, agent: Agent[HelloTaskInput, HelloTaskOutput], httpx_mock: HTTPXMock): """Test that fetch_completions correctly fetches and returns completions.""" From 30232d8fa287353d606b2d820c9869efe04134f9 Mon Sep 17 00:00:00 2001 From: Yann BURY Date: Thu, 27 Feb 2025 13:18:24 +0400 Subject: [PATCH 2/3] Update workflowai/core/client/_models.py Co-authored-by: guillaq --- workflowai/core/client/_models.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/workflowai/core/client/_models.py b/workflowai/core/client/_models.py index 8ae8b76..8cf1bf1 100644 --- a/workflowai/core/client/_models.py +++ b/workflowai/core/client/_models.py @@ -204,7 +204,7 @@ class ListModelsResponse(Page[ModelInfo]): class ListModelsRequest(BaseModel): instructions: Optional[str] = Field(default=None, description="Used to detect internal tools") - requires_tools: Optional[bool] = Field(default=False, description="Whether the agent uses external tools") + requires_tools: Optional[bool] = Field(default=None, description="Whether the agent uses external tools") class CompletionsResponse(BaseModel): From fdbe09a041b055fac3fe119e1d8bd1412a5feda6 Mon Sep 17 00:00:00 2001 From: YBubu Date: Thu, 27 Feb 2025 09:49:45 +0000 Subject: [PATCH 3/3] add override params to list_models --- workflowai/core/client/agent.py | 20 ++-- workflowai/core/client/agent_test.py | 172 ++++++++++++++++++++++++++- workflowai/core/domain/run.py | 5 + 3 files changed, 185 insertions(+), 12 deletions(-) diff --git a/workflowai/core/client/agent.py b/workflowai/core/client/agent.py index de3279d..dcfefb2 100644 --- a/workflowai/core/client/agent.py +++ b/workflowai/core/client/agent.py @@ -478,7 +478,11 @@ def _sanitize_validator(cls, kwargs: RunParams[AgentOutput], default: OutputVali validator = kwargs.pop("validator", default) return validator, cast(BaseRunParams, kwargs) - async def list_models(self) -> list[ModelInfo]: + async def list_models( + self, + instructions: Optional[str] = None, + requires_tools: Optional[bool] = None, + ) -> list[ModelInfo]: """Fetch the list of available models from the API for this agent. Returns: @@ -491,18 +495,18 @@ async def list_models(self) -> list[ModelInfo]: if not self.schema_id: self.schema_id = await self.register() - data = ListModelsRequest() - if self.version and isinstance(self.version, VersionProperties): - data.instructions = self.version.instructions + request_data = ListModelsRequest(instructions=instructions, requires_tools=requires_tools) - if self._tools: - for _ in self._tools.values(): - data.requires_tools = True + if instructions is None and self.version and isinstance(self.version, VersionProperties): + request_data.instructions = self.version.instructions + + if requires_tools is None and self._tools: + request_data.requires_tools = True response = await self.api.post( # The "_" refers to the currently authenticated tenant's namespace f"/v1/_/agents/{self.agent_id}/schemas/{self.schema_id}/models", - data=data, + data=request_data, returns=ListModelsResponse, ) return response.items diff --git a/workflowai/core/client/agent_test.py b/workflowai/core/client/agent_test.py index 99dc244..8b4f4ab 100644 --- a/workflowai/core/client/agent_test.py +++ b/workflowai/core/client/agent_test.py @@ -503,12 +503,178 @@ async def test_list_models(agent: Agent[HelloTaskInput, HelloTaskOutput], httpx_ # Call the method models = await agent.list_models() + # Verify the HTTP request was made correctly + request = httpx_mock.get_request() + assert request is not None, "Expected an HTTP request to be made" + assert request.method == "POST" + assert request.url == "http://localhost:8000/v1/_/agents/123/schemas/1/models" + assert json.loads(request.content) == {} + + # Verify we get back the full ModelInfo objects + assert len(models) == 2 + assert isinstance(models[0], ModelInfo) + assert models[0].id == "gpt-4" + assert models[0].name == "GPT-4" + assert models[0].modes == ["chat"] + assert models[0].metadata is not None + assert models[0].metadata.provider_name == "OpenAI" + + assert isinstance(models[1], ModelInfo) + assert models[1].id == "claude-3" + assert models[1].name == "Claude 3" + assert models[1].modes == ["chat"] + assert models[1].metadata is not None + assert models[1].metadata.provider_name == "Anthropic" + + +@pytest.mark.asyncio +async def test_list_models_with_params_override(agent: Agent[HelloTaskInput, HelloTaskOutput], httpx_mock: HTTPXMock): + """Test that list_models correctly fetches and returns available models.""" + # Mock the HTTP response instead of the API client method + httpx_mock.add_response( + url="http://localhost:8000/v1/_/agents/123/schemas/1/models", + json={ + "items": [ + { + "id": "gpt-4", + "name": "GPT-4", + "icon_url": "https://example.com/gpt4.png", + "modes": ["chat"], + "is_not_supported_reason": None, + "average_cost_per_run_usd": 0.01, + "is_latest": True, + "metadata": { + "provider_name": "OpenAI", + "price_per_input_token_usd": 0.0001, + "price_per_output_token_usd": 0.0002, + "release_date": "2024-01-01", + "context_window_tokens": 128000, + "quality_index": 0.95, + }, + "is_default": True, + "providers": ["openai"], + }, + { + "id": "claude-3", + "name": "Claude 3", + "icon_url": "https://example.com/claude3.png", + "modes": ["chat"], + "is_not_supported_reason": None, + "average_cost_per_run_usd": 0.015, + "is_latest": True, + "metadata": { + "provider_name": "Anthropic", + "price_per_input_token_usd": 0.00015, + "price_per_output_token_usd": 0.00025, + "release_date": "2024-03-01", + "context_window_tokens": 200000, + "quality_index": 0.98, + }, + "is_default": False, + "providers": ["anthropic"], + }, + ], + "count": 2, + }, + ) + + # Call the method + models = await agent.list_models(instructions="Some override instructions", requires_tools=True) + + # Verify the HTTP request was made correctly + request = httpx_mock.get_request() + assert request is not None, "Expected an HTTP request to be made" + assert request.method == "POST" + assert request.url == "http://localhost:8000/v1/_/agents/123/schemas/1/models" + assert json.loads(request.content) == { + "instructions": "Some override instructions", + "requires_tools": True, + } + + # Verify we get back the full ModelInfo objects + assert len(models) == 2 + assert isinstance(models[0], ModelInfo) + assert models[0].id == "gpt-4" + assert models[0].name == "GPT-4" + assert models[0].modes == ["chat"] + assert models[0].metadata is not None + assert models[0].metadata.provider_name == "OpenAI" + + assert isinstance(models[1], ModelInfo) + assert models[1].id == "claude-3" + assert models[1].name == "Claude 3" + assert models[1].modes == ["chat"] + assert models[1].metadata is not None + assert models[1].metadata.provider_name == "Anthropic" + + +@pytest.mark.asyncio +async def test_list_models_with_params_override_and_agent_with_tools_and_instructions( + agent_with_tools_and_instructions: Agent[HelloTaskInput, HelloTaskOutput], + httpx_mock: HTTPXMock, +): + """Test that list_models correctly fetches and returns available models.""" + # Mock the HTTP response instead of the API client method + httpx_mock.add_response( + url="http://localhost:8000/v1/_/agents/123/schemas/1/models", + json={ + "items": [ + { + "id": "gpt-4", + "name": "GPT-4", + "icon_url": "https://example.com/gpt4.png", + "modes": ["chat"], + "is_not_supported_reason": None, + "average_cost_per_run_usd": 0.01, + "is_latest": True, + "metadata": { + "provider_name": "OpenAI", + "price_per_input_token_usd": 0.0001, + "price_per_output_token_usd": 0.0002, + "release_date": "2024-01-01", + "context_window_tokens": 128000, + "quality_index": 0.95, + }, + "is_default": True, + "providers": ["openai"], + }, + { + "id": "claude-3", + "name": "Claude 3", + "icon_url": "https://example.com/claude3.png", + "modes": ["chat"], + "is_not_supported_reason": None, + "average_cost_per_run_usd": 0.015, + "is_latest": True, + "metadata": { + "provider_name": "Anthropic", + "price_per_input_token_usd": 0.00015, + "price_per_output_token_usd": 0.00025, + "release_date": "2024-03-01", + "context_window_tokens": 200000, + "quality_index": 0.98, + }, + "is_default": False, + "providers": ["anthropic"], + }, + ], + "count": 2, + }, + ) + + # Call the method + models = await agent_with_tools_and_instructions.list_models( + instructions="Some override instructions", + requires_tools=False, + ) + # Verify the HTTP request was made correctly request = httpx_mock.get_request() assert request is not None, "Expected an HTTP request to be made" assert request.method == "POST" assert request.url == "http://localhost:8000/v1/_/agents/123/schemas/1/models" assert json.loads(request.content) == { + "instructions": "Some override instructions", "requires_tools": False, } @@ -579,9 +745,7 @@ async def test_list_models_registers_if_needed( assert reqs[0].url == "http://localhost:8000/v1/_/agents" assert reqs[1].method == "POST" assert reqs[1].url == "http://localhost:8000/v1/_/agents/123/schemas/2/models" - assert json.loads(reqs[1].content) == { - "requires_tools": False, - } + assert json.loads(reqs[1].content) == {} # Verify we get back the full ModelInfo object assert len(models) == 1 @@ -656,7 +820,7 @@ async def test_list_models_with_instructions( assert request is not None, "Expected an HTTP request to be made" assert request.method == "POST" assert request.url == "http://localhost:8000/v1/_/agents/123/schemas/1/models" - assert json.loads(request.content) == {"instructions": "Some instructions", "requires_tools": False} + assert json.loads(request.content) == {"instructions": "Some instructions"} # Verify we get back the full ModelInfo objects assert len(models) == 2 diff --git a/workflowai/core/domain/run.py b/workflowai/core/domain/run.py index 759fe4a..05e32c9 100644 --- a/workflowai/core/domain/run.py +++ b/workflowai/core/domain/run.py @@ -150,6 +150,9 @@ async def fetch_completions(self) -> list[Completion]: class _AgentBase(Protocol, Generic[AgentOutput]): + # TODO: fix circular dep + from workflowai.core.client._models import ModelInfo + async def reply( self, run_id: str, @@ -161,3 +164,5 @@ async def reply( ... async def fetch_completions(self, run_id: str) -> list[Completion]: ... + + async def list_models(self) -> list[ModelInfo]: ...