diff --git a/tests/e2e/assets/call.mp3 b/tests/e2e/assets/call.mp3 new file mode 100644 index 0000000..2093286 Binary files /dev/null and b/tests/e2e/assets/call.mp3 differ diff --git a/tests/e2e/audio_models_test.py b/tests/e2e/audio_models_test.py new file mode 100644 index 0000000..64a608c --- /dev/null +++ b/tests/e2e/audio_models_test.py @@ -0,0 +1,111 @@ +""" +This test verifies model availability for audio processing tasks. +It checks which models support audio processing and which don't, +ensuring proper handling of unsupported models. +""" + +import base64 +import os + +import pytest +from pydantic import BaseModel, Field # pyright: ignore [reportUnknownVariableType] + +import workflowai +from workflowai import Model, Run +from workflowai.fields import Audio + + +class AudioInput(BaseModel): + """Input containing the audio file to analyze.""" + audio: Audio = Field( + description="The audio recording to analyze for spam/robocall detection", + ) + + +class SpamIndicator(BaseModel): + """A specific indicator that suggests the call might be spam.""" + description: str = Field( + description="Description of the spam indicator found in the audio", + examples=[ + "Uses urgency to pressure the listener", + "Mentions winning a prize without entering a contest", + "Automated/robotic voice detected", + ], + ) + quote: str = Field( + description="The exact quote or timestamp where this indicator appears", + examples=[ + "'You must act now before it's too late'", + "'You've been selected as our prize winner'", + "0:05-0:15 - Synthetic voice pattern detected", + ], + ) + + +class AudioClassification(BaseModel): + """Output containing the spam classification results.""" + is_spam: bool = Field( + description="Whether the audio is classified as spam/robocall", + ) + confidence_score: float = Field( + description="Confidence score for the classification (0.0 to 1.0)", + ge=0.0, + le=1.0, + ) + spam_indicators: list[SpamIndicator] = Field( + default_factory=list, + description="List of specific indicators that suggest this is spam", + ) + reasoning: str = Field( + description="Detailed explanation of why this was classified as spam or legitimate", + ) + + +@workflowai.agent( + id="audio-spam-detector", + model=Model.GEMINI_1_5_FLASH_LATEST, +) +async def classify_audio(audio_input: AudioInput) -> Run[AudioClassification]: + """ + Analyze the audio recording to determine if it's a spam/robocall. + + Guidelines: + 1. Listen for common spam/robocall indicators: + - Use of urgency or pressure tactics + - Unsolicited offers or prizes + - Automated/synthetic voices + - Requests for personal/financial information + - Impersonation of legitimate organizations + + 2. Consider both content and delivery: + - What is being said (transcribe key parts) + - How it's being said (tone, pacing, naturalness) + - Background noise and call quality + + 3. Provide clear reasoning: + - Cite specific examples from the audio + - Explain confidence level + - Note any uncertainty + """ + ... + + +@pytest.fixture +def audio_file() -> Audio: + """Load the test audio file.""" + current_dir = os.path.dirname(os.path.abspath(__file__)) + audio_path = os.path.join(current_dir, "assets", "call.mp3") + + if not os.path.exists(audio_path): + raise FileNotFoundError( + f"Audio file not found at {audio_path}. " + "Please make sure you have the example audio file in the correct location.", + ) + + with open(audio_path, "rb") as f: + audio_data = f.read() + + return Audio( + content_type="audio/mp3", + data=base64.b64encode(audio_data).decode(), + ) diff --git a/workflowai/core/client/_models.py b/workflowai/core/client/_models.py index 2f4e429..7dd55e7 100644 --- a/workflowai/core/client/_models.py +++ b/workflowai/core/client/_models.py @@ -1,4 +1,4 @@ -from typing import Any, Literal, Optional, Union +from typing import Any, Generic, Literal, Optional, TypeVar, Union from pydantic import BaseModel, Field # pyright: ignore [reportUnknownVariableType] from typing_extensions import NotRequired, TypedDict @@ -156,3 +156,42 @@ class CreateAgentRequest(BaseModel): class CreateAgentResponse(BaseModel): id: str schema_id: int + + +class ModelMetadata(BaseModel): + """Metadata for a model.""" + provider_name: str = Field(description="Name of the model provider") + price_per_input_token_usd: Optional[float] = Field(None, description="Cost per input token in USD") + price_per_output_token_usd: Optional[float] = Field(None, description="Cost per output token in USD") + release_date: Optional[str] = Field(None, description="Release date of the model") + context_window_tokens: Optional[int] = Field(None, description="Size of the context window in tokens") + quality_index: Optional[float] = Field(None, description="Quality index of the model") + + +class ModelInfo(BaseModel): + """Information about a model.""" + id: str = Field(description="Unique identifier for the model") + name: str = Field(description="Display name of the model") + icon_url: Optional[str] = Field(None, description="URL for the model's icon") + modes: list[str] = Field(default_factory=list, description="Supported modes for this model") + is_not_supported_reason: Optional[str] = Field( + None, + description="Reason why the model is not supported, if applicable", + ) + average_cost_per_run_usd: Optional[float] = Field(None, description="Average cost per run in USD") + is_latest: bool = Field(default=False, description="Whether this is the latest version of the model") + metadata: Optional[ModelMetadata] = Field(None, description="Additional metadata about the model") + is_default: bool = Field(default=False, description="Whether this is the default model") + providers: list[str] = Field(default_factory=list, description="List of providers that offer this model") + + +T = TypeVar("T") + +class Page(BaseModel, Generic[T]): + """A generic paginated response.""" + items: list[T] = Field(description="List of items in this page") + count: Optional[int] = Field(None, description="Total number of items available") + + +class ListModelsResponse(Page[ModelInfo]): + """Response from the list models API endpoint.""" diff --git a/workflowai/core/client/agent.py b/workflowai/core/client/agent.py index 837139e..2cf9e9c 100644 --- a/workflowai/core/client/agent.py +++ b/workflowai/core/client/agent.py @@ -11,6 +11,8 @@ from workflowai.core.client._models import ( CreateAgentRequest, CreateAgentResponse, + ListModelsResponse, + ModelInfo, ReplyRequest, RunRequest, RunResponse, @@ -469,3 +471,22 @@ async def reply( def _sanitize_validator(cls, kwargs: RunParams[AgentOutput], default: OutputValidator[AgentOutput]): validator = kwargs.pop("validator", default) return validator, cast(BaseRunParams, kwargs) + + async def list_models(self) -> list[ModelInfo]: + """Fetch the list of available models from the API for this agent. + + Returns: + list[ModelInfo]: List of available models with their full information. + + Raises: + ValueError: If the agent has not been registered (schema_id is None). + """ + if not self.schema_id: + self.schema_id = await self.register() + + response = await self.api.get( + # The "_" refers to the currently authenticated tenant's namespace + f"/v1/_/agents/{self.agent_id}/schemas/{self.schema_id}/models", + returns=ListModelsResponse, + ) + return response.items diff --git a/workflowai/core/client/agent_test.py b/workflowai/core/client/agent_test.py index 559f41d..2614305 100644 --- a/workflowai/core/client/agent_test.py +++ b/workflowai/core/client/agent_test.py @@ -13,6 +13,7 @@ ) from tests.utils import fixtures_json from workflowai.core.client._api import APIClient +from workflowai.core.client._models import ModelInfo from workflowai.core.client.agent import Agent from workflowai.core.client.client import ( WorkflowAI, @@ -367,3 +368,139 @@ def test_version_properties_with_model(self, agent: Agent[HelloTaskInput, HelloT def test_version_with_models_and_version(self, agent: Agent[HelloTaskInput, HelloTaskOutput]): # If version is explcitly provided then it takes priority and we log a warning assert agent._sanitize_version({"version": "staging", "model": "gemini-1.5-pro-latest"}) == "staging" # pyright: ignore [reportPrivateUsage] + + +@pytest.mark.asyncio +async def test_list_models(agent: Agent[HelloTaskInput, HelloTaskOutput], httpx_mock: HTTPXMock): + """Test that list_models correctly fetches and returns available models.""" + # Mock the HTTP response instead of the API client method + httpx_mock.add_response( + url="http://localhost:8000/v1/_/agents/123/schemas/1/models", + json={ + "items": [ + { + "id": "gpt-4", + "name": "GPT-4", + "icon_url": "https://example.com/gpt4.png", + "modes": ["chat"], + "is_not_supported_reason": None, + "average_cost_per_run_usd": 0.01, + "is_latest": True, + "metadata": { + "provider_name": "OpenAI", + "price_per_input_token_usd": 0.0001, + "price_per_output_token_usd": 0.0002, + "release_date": "2024-01-01", + "context_window_tokens": 128000, + "quality_index": 0.95, + }, + "is_default": True, + "providers": ["openai"], + }, + { + "id": "claude-3", + "name": "Claude 3", + "icon_url": "https://example.com/claude3.png", + "modes": ["chat"], + "is_not_supported_reason": None, + "average_cost_per_run_usd": 0.015, + "is_latest": True, + "metadata": { + "provider_name": "Anthropic", + "price_per_input_token_usd": 0.00015, + "price_per_output_token_usd": 0.00025, + "release_date": "2024-03-01", + "context_window_tokens": 200000, + "quality_index": 0.98, + }, + "is_default": False, + "providers": ["anthropic"], + }, + ], + "count": 2, + }, + ) + + # Call the method + models = await agent.list_models() + + # Verify the HTTP request was made correctly + request = httpx_mock.get_request() + assert request is not None, "Expected an HTTP request to be made" + assert request.url == "http://localhost:8000/v1/_/agents/123/schemas/1/models" + + # Verify we get back the full ModelInfo objects + assert len(models) == 2 + assert isinstance(models[0], ModelInfo) + assert models[0].id == "gpt-4" + assert models[0].name == "GPT-4" + assert models[0].modes == ["chat"] + assert models[0].metadata is not None + assert models[0].metadata.provider_name == "OpenAI" + + assert isinstance(models[1], ModelInfo) + assert models[1].id == "claude-3" + assert models[1].name == "Claude 3" + assert models[1].modes == ["chat"] + assert models[1].metadata is not None + assert models[1].metadata.provider_name == "Anthropic" + + +@pytest.mark.asyncio +async def test_list_models_registers_if_needed( + agent_no_schema: Agent[HelloTaskInput, HelloTaskOutput], + httpx_mock: HTTPXMock, +): + """Test that list_models registers the agent if it hasn't been registered yet.""" + # Mock the registration response + httpx_mock.add_response( + url="http://localhost:8000/v1/_/agents", + json={"id": "123", "schema_id": 2}, + ) + + # Mock the models response with the new structure + httpx_mock.add_response( + url="http://localhost:8000/v1/_/agents/123/schemas/2/models", + json={ + "items": [ + { + "id": "gpt-4", + "name": "GPT-4", + "icon_url": "https://example.com/gpt4.png", + "modes": ["chat"], + "is_not_supported_reason": None, + "average_cost_per_run_usd": 0.01, + "is_latest": True, + "metadata": { + "provider_name": "OpenAI", + "price_per_input_token_usd": 0.0001, + "price_per_output_token_usd": 0.0002, + "release_date": "2024-01-01", + "context_window_tokens": 128000, + "quality_index": 0.95, + }, + "is_default": True, + "providers": ["openai"], + }, + ], + "count": 1, + }, + ) + + # Call the method + models = await agent_no_schema.list_models() + + # Verify both API calls were made + reqs = httpx_mock.get_requests() + assert len(reqs) == 2 + assert reqs[0].url == "http://localhost:8000/v1/_/agents" + assert reqs[1].url == "http://localhost:8000/v1/_/agents/123/schemas/2/models" + + # Verify we get back the full ModelInfo object + assert len(models) == 1 + assert isinstance(models[0], ModelInfo) + assert models[0].id == "gpt-4" + assert models[0].name == "GPT-4" + assert models[0].modes == ["chat"] + assert models[0].metadata is not None + assert models[0].metadata.provider_name == "OpenAI" diff --git a/workflowai/core/fields/audio.py b/workflowai/core/fields/audio.py new file mode 100644 index 0000000..394fbed --- /dev/null +++ b/workflowai/core/fields/audio.py @@ -0,0 +1,12 @@ +"""Audio field for handling audio file inputs.""" + +from workflowai.core.fields.file import File + + +class Audio(File): + """A field representing an audio file. + + This field is used to handle audio inputs in various formats (MP3, WAV, etc.). + The audio can be provided either as base64-encoded data or as a URL. + """ + pass diff --git a/workflowai/fields.py b/workflowai/fields.py index ddbb632..df32c14 100644 --- a/workflowai/fields.py +++ b/workflowai/fields.py @@ -1,3 +1,4 @@ +from workflowai.core.fields.audio import Audio as Audio from workflowai.core.fields.chat_message import ChatMessage as ChatMessage from workflowai.core.fields.email_address import EmailAddressStr as EmailAddressStr from workflowai.core.fields.file import File as File