From 38123a5ec0ffbb0b021d5c48955f8bff31b31987 Mon Sep 17 00:00:00 2001 From: Sameer Pashikanti Date: Tue, 16 Sep 2025 12:50:14 +0000 Subject: [PATCH] updates - add assembly ai provider - actually use the ENABLE_RUNPOD_DIARIZATION variable --- echo/server/dembrane/config.py | 83 ++++---- echo/server/dembrane/conversation_health.py | 61 +++--- echo/server/dembrane/s3.py | 3 + echo/server/dembrane/service/__init__.py | 33 +++- echo/server/dembrane/service/conversation.py | 11 +- echo/server/dembrane/service/event.py | 34 ---- echo/server/dembrane/transcribe.py | 185 ++++++++++++------ echo/server/tests/test_transcribe_assembly.py | 151 ++++++++++++++ ...ranscribe.py => test_transcribe_runpod.py} | 0 9 files changed, 398 insertions(+), 163 deletions(-) delete mode 100644 echo/server/dembrane/service/event.py create mode 100644 echo/server/tests/test_transcribe_assembly.py rename echo/server/tests/{test_transcribe.py => test_transcribe_runpod.py} (100%) diff --git a/echo/server/dembrane/config.py b/echo/server/dembrane/config.py index 4e886eb2..d9f29140 100644 --- a/echo/server/dembrane/config.py +++ b/echo/server/dembrane/config.py @@ -6,6 +6,13 @@ # patterns are inconsistent - ENABLE_LITELLM_WHISPER_TRANSCRIPTION needs to be set # better yet modularize it and have modules manage their own config? +## ENABLE_ASSEMBLYAI_TRANSCRIPTION = os.environ.get( +# "ENABLE_ASSEMBLYAI_TRANSCRIPTION", "false" +# ).lower() in ["true", "1"] +# This is a bad pattern for hygiene because it allows for multiple values to be set if you want it to be true/false + +# This file inits twice for some reason... + import os import sys import logging @@ -177,13 +184,20 @@ DISABLE_CORS = os.environ.get("DISABLE_CORS", "false").lower() in ["true", "1"] logger.debug(f"DISABLE_CORS: {DISABLE_CORS}") -ENABLE_ENGLISH_TRANSCRIPTION_WITH_LITELLM = os.environ.get( - "ENABLE_ENGLISH_TRANSCRIPTION_WITH_LITELLM", "false" +### Transcription + +ENABLE_ASSEMBLYAI_TRANSCRIPTION = os.environ.get( + "ENABLE_ASSEMBLYAI_TRANSCRIPTION", "false" ).lower() in ["true", "1"] -# ENABLE_ENGLISH_TRANSCRIPTION_WITH_LITELLM is optional and defaults to false -logger.debug( - "ENABLE_ENGLISH_TRANSCRIPTION_WITH_LITELLM: %s", ENABLE_ENGLISH_TRANSCRIPTION_WITH_LITELLM -) +logger.debug(f"ENABLE_ASSEMBLYAI_TRANSCRIPTION: {ENABLE_ASSEMBLYAI_TRANSCRIPTION}") + +ASSEMBLYAI_API_KEY = os.environ.get("ASSEMBLYAI_API_KEY") +if ENABLE_ASSEMBLYAI_TRANSCRIPTION: + assert ASSEMBLYAI_API_KEY, "ASSEMBLYAI_API_KEY environment variable is not set" + logger.debug("ASSEMBLYAI_API_KEY: set") + +ASSEMBLYAI_BASE_URL = os.environ.get("ASSEMBLYAI_BASE_URL", "https://api.eu.assemblyai.com") +logger.debug(f"ASSEMBLYAI_BASE_URL: {ASSEMBLYAI_BASE_URL}") ENABLE_RUNPOD_WHISPER_TRANSCRIPTION = os.environ.get( "ENABLE_RUNPOD_WHISPER_TRANSCRIPTION", "false" @@ -195,7 +209,6 @@ assert RUNPOD_WHISPER_API_KEY, "RUNPOD_WHISPER_API_KEY environment variable is not set" logger.debug("RUNPOD_WHISPER_API_KEY: set") - RUNPOD_WHISPER_BASE_URL = os.environ.get("RUNPOD_WHISPER_BASE_URL") if ENABLE_RUNPOD_WHISPER_TRANSCRIPTION: assert RUNPOD_WHISPER_BASE_URL, "RUNPOD_WHISPER_BASE_URL environment variable is not set" @@ -212,6 +225,35 @@ str(os.environ.get("RUNPOD_WHISPER_MAX_REQUEST_THRESHOLD")) ) +ENABLE_LITELLM_WHISPER_TRANSCRIPTION = os.environ.get( + "ENABLE_LITELLM_WHISPER_TRANSCRIPTION", "false" +).lower() in ["true", "1"] +logger.debug(f"ENABLE_LITELLM_WHISPER_TRANSCRIPTION: {ENABLE_LITELLM_WHISPER_TRANSCRIPTION}") + +LITELLM_WHISPER_API_KEY = os.environ.get("LITELLM_WHISPER_API_KEY") +if ENABLE_LITELLM_WHISPER_TRANSCRIPTION: + assert LITELLM_WHISPER_API_KEY, "LITELLM_WHISPER_API_KEY environment variable is not set" + logger.debug("LITELLM_WHISPER_API_KEY: set") + +LITELLM_WHISPER_API_VERSION = os.environ.get("LITELLM_WHISPER_API_VERSION", "2024-06-01") +if ENABLE_LITELLM_WHISPER_TRANSCRIPTION: + assert LITELLM_WHISPER_API_VERSION, ( + "LITELLM_WHISPER_API_VERSION environment variable is not set" + ) + logger.debug(f"LITELLM_WHISPER_API_VERSION: {LITELLM_WHISPER_API_VERSION}") + +LITELLM_WHISPER_MODEL = os.environ.get("LITELLM_WHISPER_MODEL") +if ENABLE_LITELLM_WHISPER_TRANSCRIPTION: + assert LITELLM_WHISPER_MODEL, "LITELLM_WHISPER_MODEL environment variable is not set" + logger.debug(f"LITELLM_WHISPER_MODEL: {LITELLM_WHISPER_MODEL}") + +LITELLM_WHISPER_URL = os.environ.get("LITELLM_WHISPER_URL") +if ENABLE_LITELLM_WHISPER_TRANSCRIPTION: + assert LITELLM_WHISPER_URL, "LITELLM_WHISPER_URL environment variable is not set" + logger.debug(f"LITELLM_WHISPER_URL: {LITELLM_WHISPER_URL}") + +### END Transcription + RUNPOD_TOPIC_MODELER_URL = os.environ.get("RUNPOD_TOPIC_MODELER_URL") logger.debug(f"RUNPOD_TOPIC_MODELER_URL: {RUNPOD_TOPIC_MODELER_URL}") @@ -276,33 +318,6 @@ assert LARGE_LITELLM_API_BASE, "LARGE_LITELLM_API_BASE environment variable is not set" logger.debug(f"LARGE_LITELLM_API_BASE: {LARGE_LITELLM_API_BASE}") -ENABLE_LITELLM_WHISPER_TRANSCRIPTION = os.environ.get( - "ENABLE_LITELLM_WHISPER_TRANSCRIPTION", "false" -).lower() in ["true", "1"] -logger.debug(f"ENABLE_LITELLM_WHISPER_TRANSCRIPTION: {ENABLE_LITELLM_WHISPER_TRANSCRIPTION}") - -LITELLM_WHISPER_API_KEY = os.environ.get("LITELLM_WHISPER_API_KEY") -if ENABLE_LITELLM_WHISPER_TRANSCRIPTION: - assert LITELLM_WHISPER_API_KEY, "LITELLM_WHISPER_API_KEY environment variable is not set" - logger.debug("LITELLM_WHISPER_API_KEY: set") - -LITELLM_WHISPER_API_VERSION = os.environ.get("LITELLM_WHISPER_API_VERSION", "2024-06-01") -if ENABLE_LITELLM_WHISPER_TRANSCRIPTION: - assert LITELLM_WHISPER_API_VERSION, ( - "LITELLM_WHISPER_API_VERSION environment variable is not set" - ) - logger.debug(f"LITELLM_WHISPER_API_VERSION: {LITELLM_WHISPER_API_VERSION}") - -LITELLM_WHISPER_MODEL = os.environ.get("LITELLM_WHISPER_MODEL") -if ENABLE_LITELLM_WHISPER_TRANSCRIPTION: - assert LITELLM_WHISPER_MODEL, "LITELLM_WHISPER_MODEL environment variable is not set" - logger.debug(f"LITELLM_WHISPER_MODEL: {LITELLM_WHISPER_MODEL}") - -LITELLM_WHISPER_URL = os.environ.get("LITELLM_WHISPER_URL") -if ENABLE_LITELLM_WHISPER_TRANSCRIPTION: - assert LITELLM_WHISPER_URL, "LITELLM_WHISPER_URL environment variable is not set" - logger.debug(f"LITELLM_WHISPER_URL: {LITELLM_WHISPER_URL}") - # *****************LIGHTRAG CONFIGURATIONS***************** # Lightrag LLM model: Makes nodes and answers queries diff --git a/echo/server/dembrane/conversation_health.py b/echo/server/dembrane/conversation_health.py index d899ae87..8c594582 100644 --- a/echo/server/dembrane/conversation_health.py +++ b/echo/server/dembrane/conversation_health.py @@ -10,6 +10,7 @@ from dembrane.s3 import get_signed_url from dembrane.utils import get_utc_timestamp from dembrane.config import ( + ENABLE_RUNPOD_DIARIZATION, RUNPOD_DIARIZATION_API_KEY, RUNPOD_DIARIZATION_TIMEOUT, RUNPOD_DIARIZATION_BASE_URL, @@ -23,7 +24,7 @@ def _fetch_chunk_data(chunk_id: str) -> tuple[str, str] | None: """ Retrieves the audio file URI and project language for a given chunk ID from Directus. - + Returns: A tuple containing (audio_file_uri, project_language) if successful, or None if retrieval fails. """ @@ -33,13 +34,15 @@ def _fetch_chunk_data(chunk_id: str) -> tuple[str, str] | None: { "query": { "filter": {"id": {"_eq": chunk_id}}, - "fields": ["path", "conversation_id.project_id.language"] + "fields": ["path", "conversation_id.project_id.language"], } - } + }, )[0] audio_file_uri = directus_item["path"] project_language = directus_item["conversation_id"]["project_id"]["language"] - logger.debug(f"Starting diarization for chunk_id: {chunk_id}, path: {audio_file_uri}, project_language: {project_language}") + logger.debug( + f"Starting diarization for chunk_id: {chunk_id}, path: {audio_file_uri}, project_language: {project_language}" + ) return audio_file_uri, project_language except Exception as e: logger.error(f"Failed to fetch audio_file_uri for chunk_id {chunk_id}: {e}") @@ -49,10 +52,10 @@ def _fetch_chunk_data(chunk_id: str) -> tuple[str, str] | None: def _generate_audio_url(audio_file_uri: str) -> str | None: """ Generates a signed URL for the specified audio file. - + Args: audio_file_uri: The URI of the audio file to sign. - + Returns: The signed URL as a string if successful, or None if signing fails. """ @@ -68,7 +71,7 @@ def _generate_audio_url(audio_file_uri: str) -> str | None: def _should_skip_diarization(project_language: str) -> bool: """ Determines whether diarization should be skipped for a given project language. - + Returns True if diarization is disabled for non-English languages based on configuration; otherwise, returns False. """ if DISABLE_MULTILINGUAL_DIARIZATION and project_language != "en": @@ -80,11 +83,11 @@ def _should_skip_diarization(project_language: str) -> bool: def _submit_diarization_job(audio_url: str, project_language: str) -> tuple[str, str] | None: """ Submits an audio diarization job to RunPod using the provided audio URL and project language. - + Args: audio_url: The signed URL of the audio file to be processed. project_language: The language code associated with the project. - + Returns: A tuple containing the job ID and the job status link if submission is successful, or None if the request fails. """ @@ -98,7 +101,7 @@ def _submit_diarization_job(audio_url: str, project_language: str) -> tuple[str, "Authorization": f"Bearer {api_key}", } data = {"input": {"audio": audio_url, "language": project_language}} - + try: logger.debug(f"Sending POST to {base_url}/run with data: {data}") response = requests.post(f"{base_url}/run", headers=headers, json=data, timeout=timeout) @@ -115,11 +118,11 @@ def _submit_diarization_job(audio_url: str, project_language: str) -> tuple[str, def _poll_job_status(job_status_link: str, headers: dict) -> dict | None: """ Retrieves the current status of a diarization job from the provided status link. - + Args: job_status_link: The URL to poll for job status. headers: HTTP headers to include in the request. - + Returns: The JSON response containing job status information, or None if the request fails. """ @@ -136,7 +139,7 @@ def _poll_job_status(job_status_link: str, headers: dict) -> dict | None: def _update_chunk_with_results(chunk_id: str, dirz_response_data: dict) -> None: """ Updates a conversation chunk in Directus with diarization analysis results. - + Args: chunk_id: The ID of the conversation chunk to update. dirz_response_data: Dictionary containing diarization metrics and results to store. @@ -145,7 +148,7 @@ def _update_chunk_with_results(chunk_id: str, dirz_response_data: dict) -> None: cross_talk_instances = dirz_response_data.get("cross_talk_instances") silence_ratio = dirz_response_data.get("silence_ratio") joined_diarization = dirz_response_data.get("joined_diarization") - + directus.update_item( "conversation_chunk", chunk_id, @@ -162,7 +165,7 @@ def _update_chunk_with_results(chunk_id: str, dirz_response_data: dict) -> None: def _cancel_job_on_timeout(job_id: str) -> None: """ Cancels a diarization job on RunPod if it has exceeded the allowed processing time. - + Logs a warning before attempting cancellation and logs an error if the cancellation fails. """ base_url = RUNPOD_DIARIZATION_BASE_URL @@ -171,7 +174,7 @@ def _cancel_job_on_timeout(job_id: str) -> None: "Content-Type": "application/json", "Authorization": f"Bearer {api_key}", } - + try: cancel_endpoint = f"{base_url}/cancel/{job_id}" logger.warning(f"Timeout reached. Cancelling diarization job {job_id} at {cancel_endpoint}") @@ -187,10 +190,14 @@ def get_runpod_diarization( ) -> None: """ Orchestrates the diarization process for a given chunk by submitting an audio diarization job to RunPod, polling for completion within a timeout, and updating Directus with the results or canceling the job if it times out. - + Args: chunk_id: The identifier of the audio chunk to process. """ + if not ENABLE_RUNPOD_DIARIZATION: + logger.debug("Skipping diarization because ENABLE_RUNPOD_DIARIZATION is disabled") + return None + # Fetch chunk data chunk_data = _fetch_chunk_data(chunk_id) if not chunk_data: @@ -219,24 +226,28 @@ def get_runpod_diarization( "Content-Type": "application/json", "Authorization": f"Bearer {api_key}", } - + start_time = time.time() while time.time() - start_time < timeout: response_data = _poll_job_status(job_status_link, headers) if response_data: status = response_data.get("status") logger.debug(f"Job {job_id} status: {status}") - + if status == "COMPLETED": dirz_response_data = response_data.get("output") if dirz_response_data: - logger.info(f"Diarization job {job_id} completed. Updating chunk {chunk_id} with results.") + logger.info( + f"Diarization job {job_id} completed. Updating chunk {chunk_id} with results." + ) _update_chunk_with_results(chunk_id, dirz_response_data) return else: - logger.warning(f"Diarization job {job_id} completed but no output data received.") + logger.warning( + f"Diarization job {job_id} completed but no output data received." + ) return - + time.sleep(3) # Timeout: cancel the job @@ -254,6 +265,10 @@ def get_health_status( """ Get the health status of conversations. """ + if not ENABLE_RUNPOD_DIARIZATION: + logger.debug("Skipping diarization because ENABLE_RUNPOD_DIARIZATION is disabled") + return {} + if not project_ids and not conversation_ids: raise ValueError("Either project_ids or conversation_ids must be provided") @@ -325,7 +340,7 @@ def _get_timebound_conversation_chunks( }, }, ) - try: + try: response = response[:max_chunks_for_conversation] aggregated_response.extend(_flatten_response(response)) except Exception as e: diff --git a/echo/server/dembrane/s3.py b/echo/server/dembrane/s3.py index 21b20b3c..d34e0f00 100644 --- a/echo/server/dembrane/s3.py +++ b/echo/server/dembrane/s3.py @@ -141,6 +141,9 @@ def save_to_s3_from_file_like( def get_signed_url(file_name: str, expires_in_seconds: int = 3600) -> str: + """ + WARNING: this will also "get fake signed urls" for files that don't exist + """ return s3_client.generate_presigned_url( "get_object", Params={"Bucket": STORAGE_S3_BUCKET, "Key": get_sanitized_s3_key(file_name)}, diff --git a/echo/server/dembrane/service/__init__.py b/echo/server/dembrane/service/__init__.py index d6443deb..626bca6d 100644 --- a/echo/server/dembrane/service/__init__.py +++ b/echo/server/dembrane/service/__init__.py @@ -15,18 +15,35 @@ project = project_service.get_by_id_or_raise(project_id) """ -# Import service classes -from .file import get_file_service -from .event import EventService -from .project import ProjectService -from .conversation import ConversationService +from .file import FileServiceException, get_file_service +from .project import ProjectService, ProjectServiceException, ProjectNotFoundException +from .conversation import ( + ConversationService, + ConversationServiceException, + ConversationNotFoundException, + ConversationChunkNotFoundException, + ConversationNotOpenForParticipationException, +) -# Create service instances without circular dependencies file_service = get_file_service() -event_service = EventService() project_service = ProjectService() conversation_service = ConversationService( file_service=file_service, - event_service=event_service, project_service=project_service, ) + +exceptions = { + "file": { + "FileServiceException": FileServiceException, + }, + "conversation": { + "ConversationChunkNotFoundException": ConversationChunkNotFoundException, + "ConversationNotFoundException": ConversationNotFoundException, + "ConversationNotOpenForParticipationException": ConversationNotOpenForParticipationException, + "ConversationServiceException": ConversationServiceException, + }, + "project": { + "ProjectNotFoundException": ProjectNotFoundException, + "ProjectServiceException": ProjectServiceException, + }, +} diff --git a/echo/server/dembrane/service/conversation.py b/echo/server/dembrane/service/conversation.py index cce0f0c1..1c6a7321 100644 --- a/echo/server/dembrane/service/conversation.py +++ b/echo/server/dembrane/service/conversation.py @@ -9,7 +9,6 @@ if TYPE_CHECKING: from dembrane.service.file import FileService - from dembrane.service.event import EventService from dembrane.service.project import ProjectService # allows for None to be a sentinel value @@ -36,11 +35,9 @@ class ConversationService: def __init__( self, file_service: "FileService", - event_service: "EventService", project_service: "ProjectService", ): self.file_service = file_service - self.event_service = event_service self.project_service = project_service def get_by_id_or_raise( @@ -280,9 +277,10 @@ def create_chunk( def update_chunk( self, chunk_id: str, - raw_transcript: Any = _UNSET, - transcript: Any = _UNSET, path: Any = _UNSET, + diarization: Any = _UNSET, + transcript: Any = _UNSET, + raw_transcript: Any = _UNSET, runpod_job_status_link: Any = _UNSET, error: Any = _UNSET, hallucination_reason: Any = _UNSET, @@ -296,6 +294,9 @@ def update_chunk( if raw_transcript is not _UNSET: update["raw_transcript"] = raw_transcript + if diarization is not _UNSET: + update["diarization"] = diarization + if transcript is not _UNSET: update["transcript"] = transcript diff --git a/echo/server/dembrane/service/event.py b/echo/server/dembrane/service/event.py deleted file mode 100644 index c97517b3..00000000 --- a/echo/server/dembrane/service/event.py +++ /dev/null @@ -1,34 +0,0 @@ -from uuid import uuid4 -from datetime import datetime, timezone - -from pydantic import Field, BaseModel - - -class BaseEvent(BaseModel): - """Base class for all events""" - - event_id: str = Field(default_factory=lambda: str(uuid4())) - timestamp: datetime = Field(default_factory=lambda: datetime.now(timezone.utc)) - - # Set by the event service - event_name: str - - class Config: - json_encoders = {datetime: lambda v: v.isoformat()} - - -class ChunkCreatedEvent(BaseEvent): - event_name: str = "system.chunk.created" - - chunk_id: str - conversation_id: str - - -class EventService: - """Service for publishing events""" - - def publish( - self, - event: BaseEvent, - ) -> None: - pass diff --git a/echo/server/dembrane/transcribe.py b/echo/server/dembrane/transcribe.py index e91e8fda..1cd75b16 100644 --- a/echo/server/dembrane/transcribe.py +++ b/echo/server/dembrane/transcribe.py @@ -1,9 +1,17 @@ +""" +File is messy. Need to split implementations of different transcription providers into different classes perhaps. +Add interface for a generic transcription provider. (Which can be sync or async.) +But it is probably not needed. +Can provide selfhost options through "litellm" and api use through "assembly" +""" + # transcribe.py import io import os +import time import logging import mimetypes -from typing import List, Optional +from typing import Any, List, Literal, Optional import requests from litellm import transcription @@ -11,18 +19,22 @@ from dembrane.s3 import get_signed_url, get_stream_from_s3 from dembrane.config import ( API_BASE_URL, + ASSEMBLYAI_API_KEY, + ASSEMBLYAI_BASE_URL, LITELLM_WHISPER_URL, LITELLM_WHISPER_MODEL, RUNPOD_WHISPER_API_KEY, LITELLM_WHISPER_API_KEY, RUNPOD_WHISPER_BASE_URL, LITELLM_WHISPER_API_VERSION, + ENABLE_ASSEMBLYAI_TRANSCRIPTION, RUNPOD_WHISPER_PRIORITY_BASE_URL, ENABLE_RUNPOD_WHISPER_TRANSCRIPTION, ENABLE_LITELLM_WHISPER_TRANSCRIPTION, RUNPOD_WHISPER_MAX_REQUEST_THRESHOLD, ) from dembrane.prompts import render_prompt +from dembrane.service import conversation_service from dembrane.directus import directus logger = logging.getLogger("transcribe") @@ -114,6 +126,60 @@ def transcribe_audio_litellm( raise TranscriptionError(f"LiteLLM transcription failed: {e}") from e +def transcribe_audio_assemblyai( + audio_file_uri: str, + language: Optional[str], # pyright: ignore[reportUnusedParameter] + hotwords: Optional[List[str]], +) -> tuple[str, dict[str, Any]]: + """Transcribe audio through AssemblyAI""" + logger = logging.getLogger("transcribe.transcribe_audio_assemblyai") + + headers = { + "Content-Type": "application/json", + "Authorization": f"Bearer {ASSEMBLYAI_API_KEY}", + } + + data: dict[str, Any] = { + "audio_url": audio_file_uri, + "speech_model": "universal", + "language_detection_options": { + "expected_languages": ["nl", "en", "fr", "es", "de", "it", "pt"], + }, + } + + if language: + if language == "auto": + data["language_detection"] = True + else: + data["language_code"] = language + + if hotwords: + data["keyterms_prompt"] = hotwords + + try: + response = requests.post(f"{ASSEMBLYAI_BASE_URL}/v2/transcript", headers=headers, json=data) + response.raise_for_status() + + transcript_id = response.json()["id"] + polling_endpoint = f"{ASSEMBLYAI_BASE_URL}/v2/transcript/{transcript_id}" + + # TODO: using webhooks will be ideal, but this is easy to impl and test for ;) + # we will be blocking some of our cheap "workers" here with time.sleep + while True: + transcript = requests.get(polling_endpoint, headers=headers).json() + if transcript["status"] == "completed": + # return both to add the diarization response later... + return transcript["text"], transcript + elif transcript["status"] == "error": + raise RuntimeError(f"Transcription failed: {transcript['error']}") + else: + time.sleep(3) + + except Exception as e: + logger.error(f"AssemblyAI transcription failed: {e}") + raise TranscriptionError(f"AssemblyAI transcription failed: {e}") from e + + # Helper functions extracted to simplify `transcribe_conversation_chunk` # NOTE: These are internal helpers ‑ they should **not** be considered part of the public API. @@ -156,6 +222,14 @@ def _fetch_conversation(conversation_id: str) -> dict: return conversation_rows[0] +def _save_transcript( + conversation_chunk_id: str, transcript: str, diarization: Optional[dict] = None +) -> None: + conversation_service.update_chunk( + conversation_chunk_id, transcript=transcript, diarization=diarization + ) + + def _build_whisper_prompt(conversation: dict, language: str) -> str: """Compose the whisper prompt from defaults and project-specific overrides.""" default_prompt = render_prompt("default_whisper_prompt", language, {}) @@ -171,24 +245,26 @@ def _build_whisper_prompt(conversation: dict, language: str) -> str: return " ".join(prompt_parts) -def _should_use_runpod(language: str) -> bool: - """Decide whether RunPod should be used for the given language.""" - logger.debug(f"the language str is unused: {language}") - if not ENABLE_RUNPOD_WHISPER_TRANSCRIPTION: - return False - # Removed English + override logic - now use RunPod for all languages - return True +def _build_hotwords(conversation: dict) -> Optional[List[str]]: + """Build the hotwords from the conversation""" + hotwords_str = conversation["project_id"].get("default_conversation_transcript_prompt") + if hotwords_str: + return [str(word.strip()) for word in hotwords_str.split(",")] + return None -def _should_use_litellm() -> bool: - """Decide whether LiteLLM should be used for the given language.""" - if not ENABLE_LITELLM_WHISPER_TRANSCRIPTION: - return False - return True +def _get_transcript_provider() -> Literal["Runpod", "LiteLLM", "AssemblyAI"]: + if ENABLE_ASSEMBLYAI_TRANSCRIPTION: + return "AssemblyAI" + elif ENABLE_RUNPOD_WHISPER_TRANSCRIPTION: + return "Runpod" + elif ENABLE_LITELLM_WHISPER_TRANSCRIPTION: + return "LiteLLM" + else: + raise TranscriptionError("No valid transcription configuration found.") def _get_status_runpod(runpod_job_status_link: str) -> tuple[str, dict]: - """Get the status of a RunPod job.""" headers = { "Content-Type": "application/json", "Authorization": f"Bearer {RUNPOD_WHISPER_API_KEY}", @@ -264,10 +340,9 @@ def _process_runpod_transcription( def transcribe_conversation_chunk(conversation_chunk_id: str) -> str: - """Process conversation chunk for transcription - - Note: RunPod is now used for all transcription when enabled. - Falls back to LiteLLM only if RunPod is disabled. + """ + Process conversation chunk for transcription + matches on _get_transcript_provider() Returns: str: The conversation chunk ID if successful @@ -281,48 +356,40 @@ def transcribe_conversation_chunk(conversation_chunk_id: str) -> str: chunk = _fetch_chunk(conversation_chunk_id) conversation = _fetch_conversation(chunk["conversation_id"]) language = conversation["project_id"]["language"] or "en" - logger.debug(f"using language: {language}") - - whisper_prompt = _build_whisper_prompt(conversation, language) - - logger.debug(f"whisper_prompt: {whisper_prompt}") - - if _should_use_runpod(language): - logger.info("Using RunPod for transcription") - - hotwords_str = conversation["project_id"].get( - "default_conversation_transcript_prompt", None - ) - - hotwords = hotwords_str.split(",") if hotwords_str else None - return _process_runpod_transcription(chunk, conversation_chunk_id, language, hotwords) - - elif _should_use_litellm(): - logger.info("Using LITELLM for transcription") - - transcript = transcribe_audio_litellm( - chunk["path"], language=language, whisper_prompt=whisper_prompt - ) - logger.debug(f"transcript: {transcript}") - - directus.update_item( - "conversation_chunk", - conversation_chunk_id, - { - "transcript": transcript, - }, - ) - - logger.info(f"Processed chunk for transcription: {conversation_chunk_id}") - return conversation_chunk_id - - else: - raise TranscriptionError( - "No valid transcription configuration found." - "If `ENABLE_ENGLISH_TRANSCRIPTION_WITH_LITELLM` is enabled, " - "then `ENABLE_LITELLM_WHISPER_TRANSCRIPTION` must be enabled." - ) + transcript_provider = _get_transcript_provider() + + match transcript_provider: + case "AssemblyAI": + logger.info("Using AssemblyAI for transcription") + hotwords = _build_hotwords(conversation) + signed_url = get_signed_url(chunk["path"], expires_in_seconds=3 * 24 * 60 * 60) + transcript, assemblyai_response = transcribe_audio_assemblyai( + signed_url, language=language, hotwords=hotwords + ) + _save_transcript( + conversation_chunk_id, + transcript, + diarization={ + "schema": "ASSEMBLYAI", + "data": assemblyai_response.get("words", {}), + }, + ) + return conversation_chunk_id + case "Runpod": + logger.info("Using RunPod for transcription") + hotwords = _build_hotwords(conversation) + return _process_runpod_transcription( + chunk, conversation_chunk_id, language, hotwords + ) + case "LiteLLM": + logger.info("Using LITELLM for transcription") + whisper_prompt = _build_whisper_prompt(conversation, language) + transcript = transcribe_audio_litellm( + chunk["path"], language=language, whisper_prompt=whisper_prompt + ) + _save_transcript(conversation_chunk_id, transcript, diarization=None) + return conversation_chunk_id except Exception as e: logger.error(f"Failed to process conversation chunk {conversation_chunk_id}: {e}") diff --git a/echo/server/tests/test_transcribe_assembly.py b/echo/server/tests/test_transcribe_assembly.py new file mode 100644 index 00000000..9eab1bec --- /dev/null +++ b/echo/server/tests/test_transcribe_assembly.py @@ -0,0 +1,151 @@ +import os +import logging + +import pytest + +from dembrane.s3 import delete_from_s3, save_to_s3_from_url +from dembrane.utils import get_utc_timestamp +from dembrane.directus import directus +from dembrane.transcribe import transcribe_audio_assemblyai, transcribe_conversation_chunk + +logger = logging.getLogger("test_transcribe_assembly") + + +def _require_assemblyai(): + """Ensure AssemblyAI is enabled and credentials are present or skip.""" + if not os.environ.get("ASSEMBLYAI_API_KEY"): + pytest.skip("ASSEMBLYAI_API_KEY not set; skipping AssemblyAI tests") + # Force provider selection to AssemblyAI in config by env flags + os.environ["ENABLE_ASSEMBLYAI_TRANSCRIPTION"] = "true" + os.environ["ENABLE_RUNPOD_WHISPER_TRANSCRIPTION"] = "false" + os.environ["ENABLE_LITELLM_WHISPER_TRANSCRIPTION"] = "false" + + +@pytest.fixture +def fixture_chunk_en(): + _require_assemblyai() + logger.info("setup") + + p = directus.create_item( + "project", + { + "name": "test", + "language": "en", + "is_conversation_allowed": True, + }, + )["data"] + + c = directus.create_item( + "conversation", + {"project_id": p["id"], "participant_name": "test_assembly_en", "language": "en"}, + )["data"] + + path = save_to_s3_from_url( + "https://github.com/runpod-workers/sample-inputs/raw/refs/heads/main/audio/Arthur.mp3", + public=True, + ) + + cc = directus.create_item( + "conversation_chunk", + { + "conversation_id": c["id"], + "path": path, + "timestamp": str(get_utc_timestamp()), + }, + )["data"] + + yield { + "project_id": p["id"], + "conversation_id": c["id"], + "chunk_id": cc["id"], + "path": path, + } + + logger.info("teardown") + directus.delete_item("conversation_chunk", cc["id"]) + directus.delete_item("conversation", c["id"]) + directus.delete_item("project", p["id"]) + delete_from_s3(path) + + +@pytest.fixture +def fixture_chunk_nl(): + _require_assemblyai() + logger.info("setup") + + p = directus.create_item( + "project", + { + "name": "test", + "language": "nl", + "is_conversation_allowed": True, + }, + )["data"] + + c = directus.create_item( + "conversation", + {"project_id": p["id"], "participant_name": "test_assembly_nl", "language": "nl"}, + )["data"] + + path = save_to_s3_from_url( + "https://github.com/runpod-workers/sample-inputs/raw/refs/heads/main/audio/Arthur.mp3", + public=True, + ) + + cc = directus.create_item( + "conversation_chunk", + { + "conversation_id": c["id"], + "path": path, + "timestamp": str(get_utc_timestamp()), + }, + )["data"] + + yield { + "project_id": p["id"], + "conversation_id": c["id"], + "chunk_id": cc["id"], + "path": path, + } + + logger.info("teardown") + directus.delete_item("conversation_chunk", cc["id"]) + directus.delete_item("conversation", c["id"]) + directus.delete_item("project", p["id"]) + delete_from_s3(path) + + +class TestTranscribeAssemblyAI: + def test_transcribe_conversation_chunk_en(self, fixture_chunk_en): + chunk_id = fixture_chunk_en["chunk_id"] + result_id = transcribe_conversation_chunk(chunk_id) + assert result_id == chunk_id + + # fetch chunk and validate transcript saved (API is synchronous) + cc = dict(directus.get_item("conversation_chunk", result_id)) + assert cc.get("transcript") is not None + assert isinstance(cc.get("transcript"), str) + assert len(cc.get("transcript")) > 0 + + def test_transcribe_conversation_chunk_nl(self, fixture_chunk_nl): + chunk_id = fixture_chunk_nl["chunk_id"] + result_id = transcribe_conversation_chunk(chunk_id) + assert result_id == chunk_id + + cc = dict(directus.get_item("conversation_chunk", result_id)) + assert cc.get("transcript") is not None + assert isinstance(cc.get("transcript"), str) + assert len(cc.get("transcript")) > 0 + + +def test_transcribe_audio_assemblyai(): + transcript, response = transcribe_audio_assemblyai( + audio_file_uri="https://github.com/runpod-workers/sample-inputs/raw/refs/heads/main/audio/Arthur.mp3", + language="en", + hotwords=["Arther"], + ) + + assert transcript is not None + assert response is not None + assert response.get("words") is not None + assert response.get("words") is not None diff --git a/echo/server/tests/test_transcribe.py b/echo/server/tests/test_transcribe_runpod.py similarity index 100% rename from echo/server/tests/test_transcribe.py rename to echo/server/tests/test_transcribe_runpod.py