Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Binary file added tests/e2e/assets/call.mp3
Binary file not shown.
111 changes: 111 additions & 0 deletions tests/e2e/audio_models_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,111 @@
"""
This test verifies model availability for audio processing tasks.
It checks which models support audio processing and which don't,
ensuring proper handling of unsupported models.
"""

import base64
import os

import pytest
from pydantic import BaseModel, Field # pyright: ignore [reportUnknownVariableType]

import workflowai
from workflowai import Model, Run
from workflowai.fields import Audio


class AudioInput(BaseModel):
"""Input containing the audio file to analyze."""
audio: Audio = Field(
description="The audio recording to analyze for spam/robocall detection",
)


class SpamIndicator(BaseModel):
"""A specific indicator that suggests the call might be spam."""
description: str = Field(
description="Description of the spam indicator found in the audio",
examples=[
"Uses urgency to pressure the listener",
"Mentions winning a prize without entering a contest",
"Automated/robotic voice detected",
],
)
quote: str = Field(
description="The exact quote or timestamp where this indicator appears",
examples=[
"'You must act now before it's too late'",
"'You've been selected as our prize winner'",
"0:05-0:15 - Synthetic voice pattern detected",
],
)


class AudioClassification(BaseModel):
"""Output containing the spam classification results."""
is_spam: bool = Field(
description="Whether the audio is classified as spam/robocall",
)
confidence_score: float = Field(
description="Confidence score for the classification (0.0 to 1.0)",
ge=0.0,
le=1.0,
)
spam_indicators: list[SpamIndicator] = Field(
default_factory=list,
description="List of specific indicators that suggest this is spam",
)
reasoning: str = Field(
description="Detailed explanation of why this was classified as spam or legitimate",
)


@workflowai.agent(
id="audio-spam-detector",
model=Model.GEMINI_1_5_FLASH_LATEST,
)
async def classify_audio(audio_input: AudioInput) -> Run[AudioClassification]:
"""
Analyze the audio recording to determine if it's a spam/robocall.

Guidelines:
1. Listen for common spam/robocall indicators:
- Use of urgency or pressure tactics
- Unsolicited offers or prizes
- Automated/synthetic voices
- Requests for personal/financial information
- Impersonation of legitimate organizations

2. Consider both content and delivery:
- What is being said (transcribe key parts)
- How it's being said (tone, pacing, naturalness)
- Background noise and call quality

3. Provide clear reasoning:
- Cite specific examples from the audio
- Explain confidence level
- Note any uncertainty
"""
...


@pytest.fixture
def audio_file() -> Audio:
"""Load the test audio file."""
current_dir = os.path.dirname(os.path.abspath(__file__))
audio_path = os.path.join(current_dir, "assets", "call.mp3")

if not os.path.exists(audio_path):
raise FileNotFoundError(
f"Audio file not found at {audio_path}. "
"Please make sure you have the example audio file in the correct location.",
)

with open(audio_path, "rb") as f:
audio_data = f.read()

return Audio(
content_type="audio/mp3",
data=base64.b64encode(audio_data).decode(),
)
41 changes: 40 additions & 1 deletion workflowai/core/client/_models.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Any, Literal, Optional, Union
from typing import Any, Generic, Literal, Optional, TypeVar, Union

from pydantic import BaseModel, Field # pyright: ignore [reportUnknownVariableType]
from typing_extensions import NotRequired, TypedDict
Expand Down Expand Up @@ -156,3 +156,42 @@ class CreateAgentRequest(BaseModel):
class CreateAgentResponse(BaseModel):
id: str
schema_id: int


class ModelMetadata(BaseModel):
"""Metadata for a model."""
provider_name: str = Field(description="Name of the model provider")
price_per_input_token_usd: Optional[float] = Field(None, description="Cost per input token in USD")
price_per_output_token_usd: Optional[float] = Field(None, description="Cost per output token in USD")
release_date: Optional[str] = Field(None, description="Release date of the model")
context_window_tokens: Optional[int] = Field(None, description="Size of the context window in tokens")
quality_index: Optional[float] = Field(None, description="Quality index of the model")


class ModelInfo(BaseModel):
"""Information about a model."""
id: str = Field(description="Unique identifier for the model")
name: str = Field(description="Display name of the model")
icon_url: Optional[str] = Field(None, description="URL for the model's icon")
modes: list[str] = Field(default_factory=list, description="Supported modes for this model")
is_not_supported_reason: Optional[str] = Field(
None,
description="Reason why the model is not supported, if applicable",
)
average_cost_per_run_usd: Optional[float] = Field(None, description="Average cost per run in USD")
is_latest: bool = Field(default=False, description="Whether this is the latest version of the model")
metadata: Optional[ModelMetadata] = Field(None, description="Additional metadata about the model")
is_default: bool = Field(default=False, description="Whether this is the default model")
providers: list[str] = Field(default_factory=list, description="List of providers that offer this model")


T = TypeVar("T")

class Page(BaseModel, Generic[T]):
"""A generic paginated response."""
items: list[T] = Field(description="List of items in this page")
count: Optional[int] = Field(None, description="Total number of items available")


class ListModelsResponse(Page[ModelInfo]):
"""Response from the list models API endpoint."""
21 changes: 21 additions & 0 deletions workflowai/core/client/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@
from workflowai.core.client._models import (
CreateAgentRequest,
CreateAgentResponse,
ListModelsResponse,
ModelInfo,
ReplyRequest,
RunRequest,
RunResponse,
Expand Down Expand Up @@ -469,3 +471,22 @@ async def reply(
def _sanitize_validator(cls, kwargs: RunParams[AgentOutput], default: OutputValidator[AgentOutput]):
validator = kwargs.pop("validator", default)
return validator, cast(BaseRunParams, kwargs)

async def list_models(self) -> list[ModelInfo]:
"""Fetch the list of available models from the API for this agent.

Returns:
list[ModelInfo]: List of available models with their full information.

Raises:
ValueError: If the agent has not been registered (schema_id is None).
"""
if not self.schema_id:
self.schema_id = await self.register()

response = await self.api.get(
# The "_" refers to the currently authenticated tenant's namespace
f"/v1/_/agents/{self.agent_id}/schemas/{self.schema_id}/models",
returns=ListModelsResponse,
)
return response.items
137 changes: 137 additions & 0 deletions workflowai/core/client/agent_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
)
from tests.utils import fixtures_json
from workflowai.core.client._api import APIClient
from workflowai.core.client._models import ModelInfo
from workflowai.core.client.agent import Agent
from workflowai.core.client.client import (
WorkflowAI,
Expand Down Expand Up @@ -367,3 +368,139 @@ def test_version_properties_with_model(self, agent: Agent[HelloTaskInput, HelloT
def test_version_with_models_and_version(self, agent: Agent[HelloTaskInput, HelloTaskOutput]):
# If version is explcitly provided then it takes priority and we log a warning
assert agent._sanitize_version({"version": "staging", "model": "gemini-1.5-pro-latest"}) == "staging" # pyright: ignore [reportPrivateUsage]


@pytest.mark.asyncio
async def test_list_models(agent: Agent[HelloTaskInput, HelloTaskOutput], httpx_mock: HTTPXMock):
"""Test that list_models correctly fetches and returns available models."""
# Mock the HTTP response instead of the API client method
httpx_mock.add_response(
url="http://localhost:8000/v1/_/agents/123/schemas/1/models",
json={
"items": [
{
"id": "gpt-4",
"name": "GPT-4",
"icon_url": "https://example.com/gpt4.png",
"modes": ["chat"],
"is_not_supported_reason": None,
"average_cost_per_run_usd": 0.01,
"is_latest": True,
"metadata": {
"provider_name": "OpenAI",
"price_per_input_token_usd": 0.0001,
"price_per_output_token_usd": 0.0002,
"release_date": "2024-01-01",
"context_window_tokens": 128000,
"quality_index": 0.95,
},
"is_default": True,
"providers": ["openai"],
},
{
"id": "claude-3",
"name": "Claude 3",
"icon_url": "https://example.com/claude3.png",
"modes": ["chat"],
"is_not_supported_reason": None,
"average_cost_per_run_usd": 0.015,
"is_latest": True,
"metadata": {
"provider_name": "Anthropic",
"price_per_input_token_usd": 0.00015,
"price_per_output_token_usd": 0.00025,
"release_date": "2024-03-01",
"context_window_tokens": 200000,
"quality_index": 0.98,
},
"is_default": False,
"providers": ["anthropic"],
},
],
"count": 2,
},
)

# Call the method
models = await agent.list_models()

# Verify the HTTP request was made correctly
request = httpx_mock.get_request()
assert request is not None, "Expected an HTTP request to be made"
assert request.url == "http://localhost:8000/v1/_/agents/123/schemas/1/models"

# Verify we get back the full ModelInfo objects
assert len(models) == 2
assert isinstance(models[0], ModelInfo)
assert models[0].id == "gpt-4"
assert models[0].name == "GPT-4"
assert models[0].modes == ["chat"]
assert models[0].metadata is not None
assert models[0].metadata.provider_name == "OpenAI"

assert isinstance(models[1], ModelInfo)
assert models[1].id == "claude-3"
assert models[1].name == "Claude 3"
assert models[1].modes == ["chat"]
assert models[1].metadata is not None
assert models[1].metadata.provider_name == "Anthropic"


@pytest.mark.asyncio
async def test_list_models_registers_if_needed(
agent_no_schema: Agent[HelloTaskInput, HelloTaskOutput],
httpx_mock: HTTPXMock,
):
"""Test that list_models registers the agent if it hasn't been registered yet."""
# Mock the registration response
httpx_mock.add_response(
url="http://localhost:8000/v1/_/agents",
json={"id": "123", "schema_id": 2},
)

# Mock the models response with the new structure
httpx_mock.add_response(
url="http://localhost:8000/v1/_/agents/123/schemas/2/models",
json={
"items": [
{
"id": "gpt-4",
"name": "GPT-4",
"icon_url": "https://example.com/gpt4.png",
"modes": ["chat"],
"is_not_supported_reason": None,
"average_cost_per_run_usd": 0.01,
"is_latest": True,
"metadata": {
"provider_name": "OpenAI",
"price_per_input_token_usd": 0.0001,
"price_per_output_token_usd": 0.0002,
"release_date": "2024-01-01",
"context_window_tokens": 128000,
"quality_index": 0.95,
},
"is_default": True,
"providers": ["openai"],
},
],
"count": 1,
},
)

# Call the method
models = await agent_no_schema.list_models()

# Verify both API calls were made
reqs = httpx_mock.get_requests()
assert len(reqs) == 2
assert reqs[0].url == "http://localhost:8000/v1/_/agents"
assert reqs[1].url == "http://localhost:8000/v1/_/agents/123/schemas/2/models"

# Verify we get back the full ModelInfo object
assert len(models) == 1
assert isinstance(models[0], ModelInfo)
assert models[0].id == "gpt-4"
assert models[0].name == "GPT-4"
assert models[0].modes == ["chat"]
assert models[0].metadata is not None
assert models[0].metadata.provider_name == "OpenAI"
12 changes: 12 additions & 0 deletions workflowai/core/fields/audio.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
"""Audio field for handling audio file inputs."""

from workflowai.core.fields.file import File


class Audio(File):
"""A field representing an audio file.

This field is used to handle audio inputs in various formats (MP3, WAV, etc.).
The audio can be provided either as base64-encoded data or as a URL.
"""
pass
1 change: 1 addition & 0 deletions workflowai/fields.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from workflowai.core.fields.audio import Audio as Audio
from workflowai.core.fields.chat_message import ChatMessage as ChatMessage
from workflowai.core.fields.email_address import EmailAddressStr as EmailAddressStr
from workflowai.core.fields.file import File as File
Expand Down