diff --git a/.github/workflows/release.yml b/.github/workflows/release.yml index 8165de0f..2b0cf254 100644 --- a/.github/workflows/release.yml +++ b/.github/workflows/release.yml @@ -189,10 +189,10 @@ jobs: pip install -r backend/requirements.txt pip install --no-deps chatterbox-tts - - name: Install PyTorch with CUDA 12.1 + - name: Install PyTorch with CUDA 12.6 run: | - pip install torch --index-url https://download.pytorch.org/whl/cu121 --force-reinstall --no-deps - pip install torchaudio --index-url https://download.pytorch.org/whl/cu121 + pip install torch --index-url https://download.pytorch.org/whl/cu126 --force-reinstall --no-deps + pip install torchaudio --index-url https://download.pytorch.org/whl/cu126 --force-reinstall --no-deps - name: Verify CUDA support in torch run: | diff --git a/.gitignore b/.gitignore index 5f87d802..515c059b 100644 --- a/.gitignore +++ b/.gitignore @@ -50,6 +50,7 @@ logs/ app/openapi.json tauri/src-tauri/binaries/* tauri/src-tauri/gen/Assets.car +tauri/src-tauri/gen/voicebox.icns # Temporary tmp/ diff --git a/Dockerfile b/Dockerfile index 4705f98c..20da9e1a 100644 --- a/Dockerfile +++ b/Dockerfile @@ -31,6 +31,8 @@ RUN apt-get update && apt-get install -y --no-install-recommends \ build-essential \ && rm -rf /var/lib/apt/lists/* +RUN pip install --no-cache-dir --upgrade pip + COPY backend/requirements.txt . RUN pip install --no-cache-dir --prefix=/install -r requirements.txt RUN pip install --no-cache-dir --prefix=/install \ diff --git a/app/src/components/History/HistoryTable.tsx b/app/src/components/History/HistoryTable.tsx index 65f201d7..e88c7701 100644 --- a/app/src/components/History/HistoryTable.tsx +++ b/app/src/components/History/HistoryTable.tsx @@ -153,7 +153,9 @@ export function HistoryTable() { } }, [historyData, page]); - // Reset to page 0 when deletions or imports occur + // Reset to page 0 when deletions, imports, or generation completions occur + const pendingCount = useGenerationStore((state) => state.pendingGenerationIds.size); + const prevPendingCountRef = useRef(pendingCount); useEffect(() => { if (deleteGeneration.isSuccess || importGeneration.isSuccess) { setPage(0); @@ -161,6 +163,19 @@ export function HistoryTable() { } }, [deleteGeneration.isSuccess, importGeneration.isSuccess]); + useEffect(() => { + // A generation finished (pending count decreased) — scroll back to show it + if ( + prevPendingCountRef.current > 0 && + pendingCount < prevPendingCountRef.current && + page !== 0 + ) { + setPage(0); + setAllHistory([]); + } + prevPendingCountRef.current = pendingCount; + }, [pendingCount, page]); + // Intersection Observer for infinite scroll useEffect(() => { const loadMoreEl = loadMoreRef.current; diff --git a/app/src/lib/api/client.ts b/app/src/lib/api/client.ts index c6691ab5..98f98182 100644 --- a/app/src/lib/api/client.ts +++ b/app/src/lib/api/client.ts @@ -32,8 +32,24 @@ import type { TranscriptionResponse, VoiceProfileCreate, VoiceProfileResponse, + WhisperModelSize, } from './types'; +function formatErrorDetail(detail: unknown, fallback: string): string { + if (typeof detail === 'string') return detail; + if (Array.isArray(detail)) { + return detail + .map((e: Record) => e.msg || e.message || JSON.stringify(e)) + .join('; '); + } + if (detail && typeof detail === 'object') { + const obj = detail as Record; + if (typeof obj.message === 'string') return obj.message; + return JSON.stringify(detail); + } + return fallback; +} + class ApiClient { private getBaseUrl(): string { const serverUrl = useServerStore.getState().serverUrl; @@ -54,7 +70,7 @@ class ApiClient { const error = await response.json().catch(() => ({ detail: response.statusText, })); - throw new Error(error.detail || `HTTP error! status: ${response.status}`); + throw new Error(formatErrorDetail(error.detail, `HTTP error! status: ${response.status}`)); } return response.json(); @@ -113,7 +129,7 @@ class ApiClient { const error = await response.json().catch(() => ({ detail: response.statusText, })); - throw new Error(error.detail || `HTTP error! status: ${response.status}`); + throw new Error(formatErrorDetail(error.detail, `HTTP error! status: ${response.status}`)); } return response.json(); @@ -147,7 +163,7 @@ class ApiClient { const error = await response.json().catch(() => ({ detail: response.statusText, })); - throw new Error(error.detail || `HTTP error! status: ${response.status}`); + throw new Error(formatErrorDetail(error.detail, `HTTP error! status: ${response.status}`)); } return response.blob(); @@ -167,7 +183,7 @@ class ApiClient { const error = await response.json().catch(() => ({ detail: response.statusText, })); - throw new Error(error.detail || `HTTP error! status: ${response.status}`); + throw new Error(formatErrorDetail(error.detail, `HTTP error! status: ${response.status}`)); } return response.json(); @@ -187,7 +203,7 @@ class ApiClient { const error = await response.json().catch(() => ({ detail: response.statusText, })); - throw new Error(error.detail || `HTTP error! status: ${response.status}`); + throw new Error(formatErrorDetail(error.detail, `HTTP error! status: ${response.status}`)); } return response.json(); @@ -257,7 +273,7 @@ class ApiClient { const error = await response.json().catch(() => ({ detail: response.statusText, })); - throw new Error(error.detail || `HTTP error! status: ${response.status}`); + throw new Error(formatErrorDetail(error.detail, `HTTP error! status: ${response.status}`)); } return response.blob(); @@ -271,7 +287,7 @@ class ApiClient { const error = await response.json().catch(() => ({ detail: response.statusText, })); - throw new Error(error.detail || `HTTP error! status: ${response.status}`); + throw new Error(formatErrorDetail(error.detail, `HTTP error! status: ${response.status}`)); } return response.blob(); @@ -297,7 +313,7 @@ class ApiClient { const error = await response.json().catch(() => ({ detail: response.statusText, })); - throw new Error(error.detail || `HTTP error! status: ${response.status}`); + throw new Error(formatErrorDetail(error.detail, `HTTP error! status: ${response.status}`)); } return response.json(); @@ -318,12 +334,19 @@ class ApiClient { } // Transcription - async transcribeAudio(file: File, language?: LanguageCode): Promise { + async transcribeAudio( + file: File, + language?: LanguageCode, + model?: WhisperModelSize, + ): Promise { const formData = new FormData(); formData.append('file', file); if (language) { formData.append('language', language); } + if (model) { + formData.append('model', model); + } const url = `${this.getBaseUrl()}/transcribe`; const response = await fetch(url, { @@ -335,7 +358,7 @@ class ApiClient { const error = await response.json().catch(() => ({ detail: response.statusText, })); - throw new Error(error.detail || `HTTP error! status: ${response.status}`); + throw new Error(formatErrorDetail(error.detail, `HTTP error! status: ${response.status}`)); } return response.json(); @@ -608,7 +631,7 @@ class ApiClient { const error = await response.json().catch(() => ({ detail: response.statusText, })); - throw new Error(error.detail || `HTTP error! status: ${response.status}`); + throw new Error(formatErrorDetail(error.detail, `HTTP error! status: ${response.status}`)); } return response.blob(); @@ -705,7 +728,7 @@ class ApiClient { const error = await response.json().catch(() => ({ detail: response.statusText, })); - throw new Error(error.detail || `HTTP error! status: ${response.status}`); + throw new Error(formatErrorDetail(error.detail, `HTTP error! status: ${response.status}`)); } return response.blob(); diff --git a/app/src/lib/api/types.ts b/app/src/lib/api/types.ts index 49e90918..daae2a95 100644 --- a/app/src/lib/api/types.ts +++ b/app/src/lib/api/types.ts @@ -99,8 +99,11 @@ export interface HistoryListResponse { total: number; } +export type WhisperModelSize = 'base' | 'small' | 'medium' | 'large' | 'turbo'; + export interface TranscriptionRequest { language?: LanguageCode; + model?: WhisperModelSize; } export interface TranscriptionResponse { diff --git a/app/src/lib/hooks/useGenerationProgress.ts b/app/src/lib/hooks/useGenerationProgress.ts index 17d9e0cd..4c6e9143 100644 --- a/app/src/lib/hooks/useGenerationProgress.ts +++ b/app/src/lib/hooks/useGenerationProgress.ts @@ -75,8 +75,8 @@ export function useGenerationProgress() { currentSources.delete(id); removePendingGeneration(id); - // Refresh history to pick up the completed generation - queryClient.invalidateQueries({ queryKey: ['history'] }); + // Refetch history to pick up the completed generation + queryClient.refetchQueries({ queryKey: ['history'] }); // If this generation was queued for a story, add it now const storyId = removePendingStoryAdd(id); @@ -120,7 +120,7 @@ export function useGenerationProgress() { removePendingGeneration(id); removePendingStoryAdd(id); - queryClient.invalidateQueries({ queryKey: ['history'] }); + queryClient.refetchQueries({ queryKey: ['history'] }); toast({ title: data.status === 'not_found' ? 'Generation not found' : 'Generation failed', @@ -134,11 +134,12 @@ export function useGenerationProgress() { }; source.onerror = () => { - // EventSource auto-reconnects, but if we get repeated errors - // just clean up + // SSE connection dropped — clean up and refresh history so any + // completed/failed generation still appears in the list source.close(); currentSources.delete(id); removePendingGeneration(id); + queryClient.refetchQueries({ queryKey: ['history'] }); }; currentSources.set(id, source); diff --git a/app/src/lib/hooks/useTranscription.ts b/app/src/lib/hooks/useTranscription.ts index 0b80722f..641b02df 100644 --- a/app/src/lib/hooks/useTranscription.ts +++ b/app/src/lib/hooks/useTranscription.ts @@ -1,10 +1,18 @@ import { useMutation } from '@tanstack/react-query'; import { apiClient } from '@/lib/api/client'; +import type { WhisperModelSize } from '@/lib/api/types'; import type { LanguageCode } from '@/lib/constants/languages'; export function useTranscription() { return useMutation({ - mutationFn: ({ file, language }: { file: File; language?: LanguageCode }) => - apiClient.transcribeAudio(file, language), + mutationFn: ({ + file, + language, + model, + }: { + file: File; + language?: LanguageCode; + model?: WhisperModelSize; + }) => apiClient.transcribeAudio(file, language, model), }); } diff --git a/backend/backends/__init__.py b/backend/backends/__init__.py index cc35eabe..6f20f3de 100644 --- a/backend/backends/__init__.py +++ b/backend/backends/__init__.py @@ -134,6 +134,7 @@ async def transcribe( self, audio_path: str, language: Optional[str] = None, + model_size: Optional[str] = None, ) -> str: """ Transcribe audio to text. diff --git a/backend/backends/mlx_backend.py b/backend/backends/mlx_backend.py index e4a1ea97..92405dbf 100644 --- a/backend/backends/mlx_backend.py +++ b/backend/backends/mlx_backend.py @@ -345,18 +345,20 @@ async def transcribe( self, audio_path: str, language: Optional[str] = None, + model_size: Optional[str] = None, ) -> str: """ Transcribe audio to text. Args: audio_path: Path to audio file - language: Optional language hint (en or zh) + language: Optional language hint + model_size: Optional model size override Returns: Transcribed text """ - await self.load_model_async(None) + await self.load_model_async(model_size) def _transcribe_sync(): """Run synchronous transcription in thread pool.""" diff --git a/backend/backends/pytorch_backend.py b/backend/backends/pytorch_backend.py index 8f4a7a58..8ed4ab9c 100644 --- a/backend/backends/pytorch_backend.py +++ b/backend/backends/pytorch_backend.py @@ -306,18 +306,20 @@ async def transcribe( self, audio_path: str, language: Optional[str] = None, + model_size: Optional[str] = None, ) -> str: """ Transcribe audio to text. Args: audio_path: Path to audio file - language: Optional language hint (en or zh) + language: Optional language hint + model_size: Optional model size override Returns: Transcribed text """ - await self.load_model_async(None) + await self.load_model_async(model_size) def _transcribe_sync(): """Run synchronous transcription in thread pool.""" diff --git a/backend/models.py b/backend/models.py index ef8e196d..3308b3bc 100644 --- a/backend/models.py +++ b/backend/models.py @@ -149,7 +149,8 @@ class HistoryListResponse(BaseModel): class TranscriptionRequest(BaseModel): """Request model for audio transcription.""" - language: Optional[str] = Field(None, pattern="^(en|zh)$") + language: Optional[str] = Field(None, pattern="^(en|zh|ja|ko|de|fr|ru|pt|es|it)$") + model: Optional[str] = Field(None, pattern="^(base|small|medium|large|turbo)$") class TranscriptionResponse(BaseModel): diff --git a/backend/routes/generations.py b/backend/routes/generations.py index 160b0aa1..8541b659 100644 --- a/backend/routes/generations.py +++ b/backend/routes/generations.py @@ -1,12 +1,15 @@ """TTS generation endpoints.""" import asyncio +import logging import uuid from fastapi import APIRouter, Depends, HTTPException from fastapi.responses import StreamingResponse from sqlalchemy.orm import Session +logger = logging.getLogger(__name__) + from .. import models from ..services import history, profiles, tts from ..database import Generation as DBGeneration, VoiceProfile as DBVoiceProfile, get_db @@ -181,25 +184,28 @@ async def get_generation_status(generation_id: str, db: Session = Depends(get_db import json async def event_stream(): - while True: - db.expire_all() - gen = db.query(DBGeneration).filter_by(id=generation_id).first() - if not gen: - yield f"data: {json.dumps({'status': 'not_found', 'id': generation_id})}\n\n" - return - - payload = { - "id": gen.id, - "status": gen.status or "completed", - "duration": gen.duration, - "error": gen.error, - } - yield f"data: {json.dumps(payload)}\n\n" - - if (gen.status or "completed") in ("completed", "failed"): - return - - await asyncio.sleep(1) + try: + while True: + db.expire_all() + gen = db.query(DBGeneration).filter_by(id=generation_id).first() + if not gen: + yield f"data: {json.dumps({'status': 'not_found', 'id': generation_id})}\n\n" + return + + payload = { + "id": gen.id, + "status": gen.status or "completed", + "duration": gen.duration, + "error": gen.error, + } + yield f"data: {json.dumps(payload)}\n\n" + + if (gen.status or "completed") in ("completed", "failed"): + return + + await asyncio.sleep(1) + except (BrokenPipeError, ConnectionResetError, asyncio.CancelledError): + logger.debug("SSE client disconnected for generation %s", generation_id) return StreamingResponse( event_stream(), @@ -265,9 +271,12 @@ async def stream_speech( wav_bytes = tts.audio_to_wav_bytes(audio, sample_rate) async def _wav_stream(): - chunk_size = 64 * 1024 - for i in range(0, len(wav_bytes), chunk_size): - yield wav_bytes[i : i + chunk_size] + try: + chunk_size = 64 * 1024 + for i in range(0, len(wav_bytes), chunk_size): + yield wav_bytes[i : i + chunk_size] + except (BrokenPipeError, ConnectionResetError, asyncio.CancelledError): + logger.debug("Client disconnected during audio stream") return StreamingResponse( _wav_stream(), diff --git a/backend/routes/profiles.py b/backend/routes/profiles.py index 6b7b4509..5b2257e0 100644 --- a/backend/routes/profiles.py +++ b/backend/routes/profiles.py @@ -102,6 +102,10 @@ async def delete_profile( return {"message": "Profile deleted successfully"} +SAMPLE_MAX_FILE_SIZE = 50 * 1024 * 1024 # 50 MB +SAMPLE_UPLOAD_CHUNK_SIZE = 1024 * 1024 # 1 MB + + @router.post("/profiles/{profile_id}/samples", response_model=models.ProfileSampleResponse) async def add_profile_sample( profile_id: str, @@ -115,8 +119,16 @@ async def add_profile_sample( file_suffix = _uploaded_ext if _uploaded_ext in _allowed_audio_exts else ".wav" with tempfile.NamedTemporaryFile(suffix=file_suffix, delete=False) as tmp: - content = await file.read() - tmp.write(content) + total_size = 0 + while chunk := await file.read(SAMPLE_UPLOAD_CHUNK_SIZE): + total_size += len(chunk) + if total_size > SAMPLE_MAX_FILE_SIZE: + Path(tmp.name).unlink(missing_ok=True) + raise HTTPException( + status_code=413, + detail=f"File too large (max {SAMPLE_MAX_FILE_SIZE // (1024 * 1024)} MB)", + ) + tmp.write(chunk) tmp_path = tmp.name try: diff --git a/backend/routes/transcription.py b/backend/routes/transcription.py index 90cb1c95..dc949132 100644 --- a/backend/routes/transcription.py +++ b/backend/routes/transcription.py @@ -20,6 +20,7 @@ async def transcribe_audio( file: UploadFile = File(...), language: str | None = Form(None), + model: str | None = Form(None), ): """Transcribe audio file to text.""" with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as tmp: @@ -29,14 +30,23 @@ async def transcribe_audio( try: from ..utils.audio import load_audio + from ..backends import WHISPER_HF_REPOS audio, sr = await asyncio.to_thread(load_audio, tmp_path) duration = len(audio) / sr whisper_model = transcribe.get_whisper_model() - model_size = whisper_model.model_size + model_size = model if model else whisper_model.model_size - if not whisper_model.is_loaded() and not whisper_model._is_model_cached(model_size): + valid_sizes = list(WHISPER_HF_REPOS.keys()) + if model_size not in valid_sizes: + raise HTTPException( + status_code=400, + detail=f"Invalid model size '{model_size}'. Must be one of: {', '.join(valid_sizes)}", + ) + + already_loaded = whisper_model.is_loaded() and whisper_model.model_size == model_size + if not already_loaded and not whisper_model._is_model_cached(model_size): progress_model_name = f"whisper-{model_size}" task_manager = get_task_manager() @@ -59,7 +69,7 @@ async def download_whisper_background(): }, ) - text = await whisper_model.transcribe(tmp_path, language) + text = await whisper_model.transcribe(tmp_path, language, model_size) return models.TranscriptionResponse( text=text, diff --git a/backend/services/profiles.py b/backend/services/profiles.py index 46180f3c..d20c142d 100644 --- a/backend/services/profiles.py +++ b/backend/services/profiles.py @@ -22,7 +22,7 @@ Generation as DBGeneration, ) from ..models import EffectConfig -from ..utils.audio import validate_reference_audio, load_audio, save_audio +from ..utils.audio import validate_reference_audio, validate_and_load_reference_audio, load_audio, save_audio from ..utils.images import validate_image, process_avatar from ..utils.cache import _get_cache_dir, clear_profile_cache from .tts import get_tts_model @@ -117,11 +117,16 @@ async def add_profile_sample( Returns: Created sample """ + import asyncio + profile = db.query(DBVoiceProfile).filter_by(id=profile_id).first() if not profile: raise ValueError(f"Profile {profile_id} not found") - is_valid, error_msg = validate_reference_audio(audio_path) + # Validate and load audio in a single pass, off the event loop + is_valid, error_msg, audio, sr = await asyncio.to_thread( + validate_and_load_reference_audio, audio_path + ) if not is_valid: raise ValueError(f"Invalid reference audio: {error_msg}") @@ -130,8 +135,7 @@ async def add_profile_sample( profile_dir.mkdir(parents=True, exist_ok=True) dest_path = profile_dir / f"{sample_id}.wav" - audio, sr = load_audio(audio_path) - save_audio(audio, str(dest_path), sr) + await asyncio.to_thread(save_audio, audio, str(dest_path), sr) db_sample = DBProfileSample( id=sample_id, diff --git a/backend/utils/audio.py b/backend/utils/audio.py index 2c6de760..2016992c 100644 --- a/backend/utils/audio.py +++ b/backend/utils/audio.py @@ -217,22 +217,40 @@ def validate_reference_audio( Returns: Tuple of (is_valid, error_message) """ + result = validate_and_load_reference_audio( + audio_path, min_duration, max_duration, min_rms + ) + return (result[0], result[1]) + + +def validate_and_load_reference_audio( + audio_path: str, + min_duration: float = 2.0, + max_duration: float = 30.0, + min_rms: float = 0.01, +) -> Tuple[bool, Optional[str], Optional[np.ndarray], Optional[int]]: + """ + Validate and load reference audio in a single pass. + + Returns: + Tuple of (is_valid, error_message, audio_array, sample_rate) + """ try: audio, sr = load_audio(audio_path) duration = len(audio) / sr if duration < min_duration: - return False, f"Audio too short (minimum {min_duration} seconds)" + return False, f"Audio too short (minimum {min_duration} seconds)", None, None if duration > max_duration: - return False, f"Audio too long (maximum {max_duration} seconds)" + return False, f"Audio too long (maximum {max_duration} seconds)", None, None rms = np.sqrt(np.mean(audio**2)) if rms < min_rms: - return False, "Audio is too quiet or silent" + return False, "Audio is too quiet or silent", None, None if np.abs(audio).max() > 0.99: - return False, "Audio is clipping (reduce input gain)" + return False, "Audio is clipping (reduce input gain)", None, None - return True, None + return True, None, audio, sr except Exception as e: - return False, f"Error validating audio: {str(e)}" + return False, f"Error validating audio: {str(e)}", None, None diff --git a/backend/utils/progress.py b/backend/utils/progress.py index 418a88c7..56977ebd 100644 --- a/backend/utils/progress.py +++ b/backend/utils/progress.py @@ -246,6 +246,8 @@ async def subscribe(self, model_name: str): # Send heartbeat yield ": heartbeat\n\n" continue + except (BrokenPipeError, ConnectionResetError, asyncio.CancelledError): + logger.debug(f"SSE client disconnected from {model_name}") finally: # Remove from listeners if model_name in self._listeners: diff --git a/tauri/src-tauri/gen/voicebox.icns b/tauri/src-tauri/gen/voicebox.icns deleted file mode 100644 index e4492f52..00000000 Binary files a/tauri/src-tauri/gen/voicebox.icns and /dev/null differ