Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[tool.poetry]
name = "workflowai"
version = "0.6.0.dev17"
version = "0.6.0.dev18"
description = ""
authors = ["Guillaume Aquilina <guillaume@workflowai.com>"]
readme = "README.md"
Expand Down
5 changes: 5 additions & 0 deletions workflowai/core/client/_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -202,6 +202,11 @@ class ListModelsResponse(Page[ModelInfo]):
"""Response from the list models API endpoint."""


class ListModelsRequest(BaseModel):
instructions: Optional[str] = Field(default=None, description="Used to detect internal tools")
requires_tools: Optional[bool] = Field(default=None, description="Whether the agent uses external tools")


class CompletionsResponse(BaseModel):
"""Response from the completions API endpoint."""

Expand Down
19 changes: 17 additions & 2 deletions workflowai/core/client/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
CompletionsResponse,
CreateAgentRequest,
CreateAgentResponse,
ListModelsRequest,
ListModelsResponse,
ModelInfo,
ReplyRequest,
Expand Down Expand Up @@ -477,7 +478,11 @@ def _sanitize_validator(cls, kwargs: RunParams[AgentOutput], default: OutputVali
validator = kwargs.pop("validator", default)
return validator, cast(BaseRunParams, kwargs)

async def list_models(self) -> list[ModelInfo]:
async def list_models(
self,
instructions: Optional[str] = None,
requires_tools: Optional[bool] = None,
) -> list[ModelInfo]:
"""Fetch the list of available models from the API for this agent.

Returns:
Expand All @@ -486,12 +491,22 @@ async def list_models(self) -> list[ModelInfo]:
Raises:
ValueError: If the agent has not been registered (schema_id is None).
"""

if not self.schema_id:
self.schema_id = await self.register()

response = await self.api.get(
request_data = ListModelsRequest(instructions=instructions, requires_tools=requires_tools)

if instructions is None and self.version and isinstance(self.version, VersionProperties):
request_data.instructions = self.version.instructions

if requires_tools is None and self._tools:
request_data.requires_tools = True

response = await self.api.post(
# The "_" refers to the currently authenticated tenant's namespace
f"/v1/_/agents/{self.agent_id}/schemas/{self.schema_id}/models",
data=request_data,
returns=ListModelsResponse,
)
return response.items
Expand Down
Loading