From 81133bad23b671d1f51e88260051f2b72f493df6 Mon Sep 17 00:00:00 2001 From: Juan Carlos M P Date: Wed, 22 Apr 2026 12:34:49 -0500 Subject: [PATCH] feat(windows): add native ROCm support for AMD GPUs Implements native ROCm architecture for Windows. - Adds backend build pipeline for voicebox-server-rocm.exe - Detects AMD GPUs dynamically and routes PyTorch allocations - Adds automatic download and update logic for ROCm dependencies - Refactors UI in GpuPage.tsx and GpuAcceleration.tsx to add AMD flows - Fixes 'Switch to CPU' lock on Windows via Tauri backend_override state - Resolves PyInstaller/rocm_sdk UnboundLocalError silent crashes - Resolves Numba/NumPy 2.x incompatibilities during Qwen3-TTS load - Resolves HF_HUB_OFFLINE Catch-22 for CustomVoice processor caching --- .gitignore | Bin 738 -> 948 bytes app/package.json | 8 +- .../ServerSettings/GpuAcceleration.test.tsx | 337 +++++++++++++ .../ServerSettings/GpuAcceleration.tsx | 406 +++++++++++---- app/src/components/ServerTab/GpuPage.tsx | 412 ++++++++++++---- app/src/i18n/locales/en/translation.json | 46 +- app/src/lib/api/client.ts | 18 + app/src/lib/api/models/ModelStatus.ts | 2 +- app/src/lib/api/types.ts | 20 + app/src/platform/types.ts | 1 + app/src/test/setup.ts | 1 + app/vite.config.ts | 7 +- backend/app.py | 25 +- backend/backends/base.py | 5 + backend/backends/hume_backend.py | 10 +- backend/backends/qwen_custom_voice_backend.py | 13 +- backend/build_binary.py | 280 +++++++++-- backend/pyi_rth_rocm_sdk.py | 85 ++++ backend/requirements-rocm.txt | 4 + backend/routes/__init__.py | 2 + backend/routes/health.py | 19 +- backend/routes/rocm.py | 79 +++ backend/server.py | 23 +- backend/services/rocm.py | 465 ++++++++++++++++++ backend/tests/test_amd_gpu_detect.py | 87 ++++ backend/tests/test_rocm_backends.py | 68 +++ backend/tests/test_rocm_build.py | 129 +++++ backend/tests/test_rocm_download.py | 203 ++++++++ backend/tests/test_rocm_requirements.py | 130 +++++ backend/utils/platform_detect.py | 50 +- tauri/src-tauri/src/main.rs | 176 +++++-- tauri/src/platform/lifecycle.ts | 9 + 32 files changed, 2823 insertions(+), 297 deletions(-) create mode 100644 app/src/components/ServerSettings/GpuAcceleration.test.tsx create mode 100644 app/src/test/setup.ts create mode 100644 backend/pyi_rth_rocm_sdk.py create mode 100644 backend/requirements-rocm.txt create mode 100644 backend/routes/rocm.py create mode 100644 backend/services/rocm.py create mode 100644 backend/tests/test_amd_gpu_detect.py create mode 100644 backend/tests/test_rocm_backends.py create mode 100644 backend/tests/test_rocm_build.py create mode 100644 backend/tests/test_rocm_download.py create mode 100644 backend/tests/test_rocm_requirements.py diff --git a/.gitignore b/.gitignore index bcc1927cfbb9acb3a74f2690d850495089f27f89..853c5060975fbeadd1fbd442d6dfb9c9a648813d 100644 GIT binary patch literal 948 zcma)4O^e$w5Y^e>|KO4yT0>;$KS)U^fu&u@(jH1NvNUnDktHF?Nqp&V?~}cOw%bE7 zo{{Fg`PQo1k|(RkN=>mWtW(quPK}0QQx@H5xpDH`l||e7NeX$QwgpggKYf=@{lM|9 zpSUz4!oB9vl?8vC(#hGfxRAYoyvW_>uZv@FgHL6#sy>d|sLGAWj|t97#{@=~tuvGQ zey#1%-7jU4MCd7#YA(FbN)3Hhbfc_>sAnUg;F@o|-w&b(lC$l%JCt^bsG1OgYeiy? z6t8oncy)04xsbmcz}OzzLvjKBPp5I{B3B5TLv2M8)w?lLSodTph(zi=8h{xQ-^`l# zI-Q9SI(t00ejbl;C>J6RH`{miqJy&oSxkguP>ak%7iOV+x@V}48e1s~blx~DO?b_p zW1e#oW6(_ua=n~7ZAHL7oBB0|f}2f@lp>cR{2RYGEva)iuBAo7zr*JcUWBBD;oe|t zOQDa`-o_-1A%w+C@FW-Do3_ebW0hTwLgroD@uj;b8oUc4Oh_|$OeMhpRdZZlX7piK zC5q{H|4zs=o^6xuu?ZAM1C7V?hyAE;X#@zukpkX0wR3CyDlq8(*>%1r% z43U78Ks_ydP7!8|vxYV7A3v=bB~6*(u6tj7a9ygLD-$?R-0l&qsc;;(KAzIAJ+LVw zieRHm&Jx`;Hq{L!PUIVcgT(us50e=Yc(7(6oHIwtTuK*nDrQk4I3u zV}ImS5T4I<8c4aJF+0Pids7L^cEpIqFuJ$k1l90rUDW9Lf|#57A&sr~kVcSCJQHjm zXJ1mBFfu4HVGYel6B_;D_g}k{7n?y^(@ADvvRB_o++$OV*HrBG=Wg)P6q%0R{6k5F z`kn}xlr`l=dQRTrkZ#J`NG~d~DeO&9Imk3xg-^>N5}OdPG}AKLhJNUCsQzQob?#080_};Q#;t diff --git a/app/package.json b/app/package.json index f149fb1a..0455e8ea 100644 --- a/app/package.json +++ b/app/package.json @@ -8,6 +8,7 @@ "build": "vite build", "typecheck": "tsc -p tsconfig.json --noEmit", "preview": "vite preview", + "test": "vitest", "lint": "biome lint src", "lint:fix": "biome lint --write src", "format": "biome format --write src", @@ -60,11 +61,16 @@ }, "devDependencies": { "@tailwindcss/vite": "^4.1.18", + "@testing-library/dom": "^10.4.0", + "@testing-library/jest-dom": "^6.5.0", + "@testing-library/react": "^16.0.0", "@types/react": "^18.3.0", "@types/react-dom": "^18.3.0", "@vitejs/plugin-react": "^4.3.0", + "jsdom": "^25.0.0", "tailwindcss": "^4.1.0", "typescript": "^5.6.0", - "vite": "^5.4.0" + "vite": "^5.4.0", + "vitest": "^2.1.0" } } diff --git a/app/src/components/ServerSettings/GpuAcceleration.test.tsx b/app/src/components/ServerSettings/GpuAcceleration.test.tsx new file mode 100644 index 00000000..25fb3412 --- /dev/null +++ b/app/src/components/ServerSettings/GpuAcceleration.test.tsx @@ -0,0 +1,337 @@ +import { QueryClient, QueryClientProvider } from '@tanstack/react-query'; +import { fireEvent, render, screen, waitFor } from '@testing-library/react'; +import { afterEach, beforeEach, describe, expect, it, vi } from 'vitest'; +import { GpuAcceleration } from './GpuAcceleration'; + +// Mock dependencies +vi.mock('@/lib/api/client', () => ({ + apiClient: { + getHealth: vi.fn(), + getCudaStatus: vi.fn(), + getRocmStatus: vi.fn(), + downloadCudaBackend: vi.fn(), + downloadRocmBackend: vi.fn(), + deleteCudaBackend: vi.fn(), + deleteRocmBackend: vi.fn(), + }, +})); + +vi.mock('@/lib/hooks/useServer', () => ({ + useServerHealth: vi.fn(), +})); + +vi.mock('@/platform/PlatformContext', () => ({ + usePlatform: vi.fn(), +})); + +vi.mock('@/stores/serverStore', () => ({ + useServerStore: vi.fn((selector) => selector({ serverUrl: 'http://localhost:8000' })), +})); + +import { apiClient } from '@/lib/api/client'; +import { useServerHealth } from '@/lib/hooks/useServer'; +import { usePlatform } from '@/platform/PlatformContext'; + +const mockedApiClient = vi.mocked(apiClient); +const mockedUseServerHealth = vi.mocked(useServerHealth); +const mockedUsePlatform = vi.mocked(usePlatform); + +describe('GpuAcceleration', () => { + let queryClient: QueryClient; + + beforeEach(() => { + queryClient = new QueryClient({ + defaultOptions: { + queries: { + retry: false, + }, + }, + }); + + // Reset all mocks + vi.clearAllMocks(); + + // Default platform mock (Tauri app) + mockedUsePlatform.mockReturnValue({ + metadata: { isTauri: true }, + lifecycle: { + restartServer: vi.fn().mockResolvedValue(undefined), + setBackendOverride: vi.fn().mockResolvedValue(undefined), + }, + } as any); + }); + + afterEach(() => { + vi.unstubAllGlobals(); + vi.restoreAllMocks(); + }); + + function renderComponent() { + return render( + + + , + ); + } + + it('renders CPU status when no GPU is available', async () => { + mockedUseServerHealth.mockReturnValue({ + data: { + status: 'healthy', + gpu_available: false, + backend_variant: 'cpu', + }, + isLoading: false, + } as any); + + mockedApiClient.getCudaStatus.mockResolvedValue({ + available: false, + active: false, + downloading: false, + }); + + mockedApiClient.getRocmStatus.mockResolvedValue({ + available: false, + active: false, + downloading: false, + }); + + renderComponent(); + + await waitFor(() => { + expect(screen.getByText('CPU')).toBeInTheDocument(); + }); + }); + + it('shows "Download AMD ROCm Backend" button when running CPU on AMD hardware', async () => { + mockedUseServerHealth.mockReturnValue({ + data: { + status: 'healthy', + gpu_available: false, + backend_variant: 'cpu', + }, + isLoading: false, + } as any); + + mockedApiClient.getCudaStatus.mockResolvedValue({ + available: false, + active: false, + downloading: false, + }); + + mockedApiClient.getRocmStatus.mockResolvedValue({ + available: false, + active: false, + downloading: false, + }); + + renderComponent(); + + await waitFor(() => { + expect(screen.getByText('Download AMD ROCm Backend')).toBeInTheDocument(); + }); + }); + + it('shows ROCm download progress via SSE events', async () => { + mockedUseServerHealth.mockReturnValue({ + data: { + status: 'healthy', + gpu_available: false, + backend_variant: 'cpu', + }, + isLoading: false, + } as any); + + mockedApiClient.getCudaStatus.mockResolvedValue({ + available: false, + active: false, + downloading: false, + }); + + mockedApiClient.getRocmStatus.mockResolvedValue({ + available: false, + active: false, + downloading: true, + download_progress: { + model_name: 'rocm-backend', + current: 0, + total: 1000, + progress: 0, + filename: 'Downloading ROCm libraries...', + status: 'downloading', + timestamp: new Date().toISOString(), + }, + }); + + // Mock EventSource — use vi.stubGlobal so vi.restoreAllMocks() in afterEach + // tears it down automatically and doesn't bleed into other tests. + const mockEventSource = { + onmessage: null as ((event: MessageEvent) => void) | null, + onerror: null as (() => void) | null, + close: vi.fn(), + }; + + vi.stubGlobal('EventSource', vi.fn(() => mockEventSource)); + + renderComponent(); + + await waitFor(() => { + expect(screen.getByText('Downloading ROCm libraries...')).toBeInTheDocument(); + }); + + // Simulate SSE progress update + if (mockEventSource.onmessage) { + mockEventSource.onmessage( + new MessageEvent('message', { + data: JSON.stringify({ + model_name: 'rocm-backend', + current: 500, + total: 1000, + progress: 50, + filename: 'Downloading ROCm libraries...', + status: 'downloading', + timestamp: new Date().toISOString(), + }), + }), + ); + } + + await waitFor(() => { + expect(screen.getByText('50.0%')).toBeInTheDocument(); + }); + + // Simulate completion + if (mockEventSource.onmessage) { + mockEventSource.onmessage( + new MessageEvent('message', { + data: JSON.stringify({ + model_name: 'rocm-backend', + current: 1000, + total: 1000, + progress: 100, + filename: 'Extracting ROCm libraries...', + status: 'complete', + timestamp: new Date().toISOString(), + }), + }), + ); + } + }); + + it('shows "Switch to CPU Backend" when running ROCm', async () => { + mockedUseServerHealth.mockReturnValue({ + data: { + status: 'healthy', + gpu_available: true, + gpu_type: 'ROCm (AMD Radeon RX 7900 XTX)', + backend_variant: 'rocm', + vram_used_mb: 2048, + }, + isLoading: false, + } as any); + + renderComponent(); + + await waitFor(() => { + expect(screen.getByText('AMD Radeon RX 7900 XTX')).toBeInTheDocument(); + expect(screen.getByText('Switch to CPU Backend')).toBeInTheDocument(); + }); + }); + + it('shows "Switch to ROCm Backend" when ROCm is downloaded but not active', async () => { + mockedUseServerHealth.mockReturnValue({ + data: { + status: 'healthy', + gpu_available: false, + backend_variant: 'cpu', + }, + isLoading: false, + } as any); + + mockedApiClient.getCudaStatus.mockResolvedValue({ + available: false, + active: false, + downloading: false, + }); + + mockedApiClient.getRocmStatus.mockResolvedValue({ + available: true, + active: false, + downloading: false, + }); + + renderComponent(); + + await waitFor(() => { + expect(screen.getByText('Switch to ROCm Backend')).toBeInTheDocument(); + }); + }); + + it('calls downloadRocmBackend when AMD download button is clicked', async () => { + mockedUseServerHealth.mockReturnValue({ + data: { + status: 'healthy', + gpu_available: false, + backend_variant: 'cpu', + }, + isLoading: false, + } as any); + + mockedApiClient.getCudaStatus.mockResolvedValue({ + available: false, + active: false, + downloading: false, + }); + + mockedApiClient.getRocmStatus.mockResolvedValue({ + available: false, + active: false, + downloading: false, + }); + + mockedApiClient.downloadRocmBackend.mockResolvedValue({ + message: 'ROCm backend download started', + progress_key: 'rocm-backend', + }); + + renderComponent(); + + const downloadButton = await screen.findByText('Download AMD ROCm Backend'); + fireEvent.click(downloadButton); + + await waitFor(() => { + expect(mockedApiClient.downloadRocmBackend).toHaveBeenCalledTimes(1); + }); + }); + + it('calls setBackendOverride("cpu") when switching from ROCm to CPU', async () => { + const setBackendOverrideMock = vi.fn().mockResolvedValue(undefined); + mockedUsePlatform.mockReturnValue({ + metadata: { isTauri: true }, + lifecycle: { + restartServer: vi.fn().mockResolvedValue(undefined), + setBackendOverride: setBackendOverrideMock, + }, + } as any); + + mockedUseServerHealth.mockReturnValue({ + data: { + status: 'healthy', + gpu_available: true, + gpu_type: 'ROCm (AMD Radeon RX 7900 XTX)', + backend_variant: 'rocm', + vram_used_mb: 2048, + }, + isLoading: false, + } as any); + + renderComponent(); + + const switchButton = await screen.findByText('Switch to CPU Backend'); + fireEvent.click(switchButton); + + await waitFor(() => { + expect(setBackendOverrideMock).toHaveBeenCalledWith('cpu'); + }); + }); +}); diff --git a/app/src/components/ServerSettings/GpuAcceleration.tsx b/app/src/components/ServerSettings/GpuAcceleration.tsx index 7b2c9749..46e0d4bd 100644 --- a/app/src/components/ServerSettings/GpuAcceleration.tsx +++ b/app/src/components/ServerSettings/GpuAcceleration.tsx @@ -5,7 +5,7 @@ import { Button } from '@/components/ui/button'; import { Card, CardContent, CardHeader, CardTitle } from '@/components/ui/card'; import { Progress } from '@/components/ui/progress'; import { apiClient } from '@/lib/api/client'; -import type { CudaDownloadProgress } from '@/lib/api/types'; +import type { CudaDownloadProgress, RocmDownloadProgress } from '@/lib/api/types'; import { useServerHealth } from '@/lib/hooks/useServer'; import { usePlatform } from '@/platform/PlatformContext'; import { useServerStore } from '@/stores/serverStore'; @@ -21,6 +21,9 @@ export function GpuAcceleration() { const [restartPhase, setRestartPhase] = useState('idle'); const [error, setError] = useState(null); const [downloadProgress, setDownloadProgress] = useState(null); + const [rocmDownloadProgress, setRocmDownloadProgress] = useState( + null, + ); const healthPollRef = useRef | null>(null); // Query CUDA backend status @@ -36,10 +39,26 @@ export function GpuAcceleration() { enabled: !!health, // Only fetch when backend is reachable }); + // Query ROCm backend status + const { + data: rocmStatus, + isLoading: _rocmStatusLoading, + refetch: refetchRocmStatus, + } = useQuery({ + queryKey: ['rocm-status', serverUrl], + queryFn: () => apiClient.getRocmStatus(), + refetchInterval: (query) => (query.state.status === 'pending' ? false : 10000), + retry: 1, + enabled: !!health, // Only fetch when backend is reachable + }); + // Derived state const isCurrentlyCuda = health?.backend_variant === 'cuda'; + const isCurrentlyRocm = health?.backend_variant === 'rocm'; const cudaAvailable = cudaStatus?.available ?? false; const cudaDownloading = cudaStatus?.downloading ?? false; + const rocmAvailable = rocmStatus?.available ?? false; + const rocmDownloading = rocmStatus?.downloading ?? false; // Clean up health poll on unmount useEffect(() => { @@ -51,7 +70,7 @@ export function GpuAcceleration() { }; }, []); - // SSE progress tracking during download + // SSE progress tracking during CUDA download useEffect(() => { if (!cudaDownloading || !serverUrl) { return; @@ -88,6 +107,43 @@ export function GpuAcceleration() { }; }, [cudaDownloading, serverUrl, refetchCudaStatus]); + // SSE progress tracking during ROCm download + useEffect(() => { + if (!rocmDownloading || !serverUrl) { + return; + } + + const eventSource = new EventSource(`${serverUrl}/backend/rocm-progress`); + + eventSource.onmessage = (event) => { + try { + const data = JSON.parse(event.data) as RocmDownloadProgress; + setRocmDownloadProgress(data); + + if (data.status === 'complete') { + eventSource.close(); + setRocmDownloadProgress(null); + refetchRocmStatus(); + } else if (data.status === 'error') { + eventSource.close(); + setError(data.error || 'Download failed'); + setRocmDownloadProgress(null); + refetchRocmStatus(); + } + } catch (e) { + console.error('Error parsing ROCm progress event:', e); + } + }; + + eventSource.onerror = () => { + eventSource.close(); + }; + + return () => { + eventSource.close(); + }; + }, [rocmDownloading, serverUrl, refetchRocmStatus]); + // Start aggressive health polling during restart const startHealthPolling = useCallback(() => { if (healthPollRef.current) return; @@ -113,7 +169,7 @@ export function GpuAcceleration() { }, 1000); }, [queryClient]); - const handleDownload = async () => { + const handleDownloadCuda = async () => { setError(null); try { await apiClient.downloadCudaBackend(); @@ -128,6 +184,21 @@ export function GpuAcceleration() { } }; + const handleDownloadRocm = async () => { + setError(null); + try { + await apiClient.downloadRocmBackend(); + refetchRocmStatus(); + } catch (e: unknown) { + const msg = e instanceof Error ? e.message : 'Failed to start download'; + if (msg.includes('already downloaded')) { + refetchRocmStatus(); + } else { + setError(msg); + } + } + }; + const handleRestart = async () => { setError(null); setRestartPhase('stopping'); @@ -154,18 +225,17 @@ export function GpuAcceleration() { } }; - const handleSwitchToCpu = async () => { - // To switch to CPU: delete the CUDA binary, then restart. - // start_server always prefers CUDA if present, so we must remove it first. + const handleSwitchToCpuFromCuda = async () => { setError(null); setRestartPhase('stopping'); try { - await apiClient.deleteCudaBackend(); + // Tell Rust launcher to skip GPU binary detection on next start. + // We cannot delete an active .exe on Windows, so we override instead. + await platform.lifecycle.setBackendOverride('cpu'); setRestartPhase('waiting'); startHealthPolling(); await platform.lifecycle.restartServer(); - // Invoke resolved — server is likely ready if (healthPollRef.current) { clearInterval(healthPollRef.current); healthPollRef.current = null; @@ -184,7 +254,36 @@ export function GpuAcceleration() { } }; - const handleDelete = async () => { + const handleSwitchToCpuFromRocm = async () => { + setError(null); + setRestartPhase('stopping'); + + try { + // Tell Rust launcher to skip GPU binary detection on next start. + // We cannot delete an active .exe on Windows, so we override instead. + await platform.lifecycle.setBackendOverride('cpu'); + setRestartPhase('waiting'); + startHealthPolling(); + await platform.lifecycle.restartServer(); + if (healthPollRef.current) { + clearInterval(healthPollRef.current); + healthPollRef.current = null; + } + setRestartPhase('ready'); + queryClient.invalidateQueries(); + setTimeout(() => setRestartPhase('idle'), 2000); + } catch (e: unknown) { + setRestartPhase('idle'); + if (healthPollRef.current) { + clearInterval(healthPollRef.current); + healthPollRef.current = null; + } + setError(e instanceof Error ? e.message : 'Failed to switch to CPU'); + refetchRocmStatus(); + } + }; + + const handleDeleteCuda = async () => { setError(null); try { await apiClient.deleteCudaBackend(); @@ -194,6 +293,16 @@ export function GpuAcceleration() { } }; + const handleDeleteRocm = async () => { + setError(null); + try { + await apiClient.deleteRocmBackend(); + refetchRocmStatus(); + } catch (e: unknown) { + setError(e instanceof Error ? e.message : 'Failed to delete ROCm backend'); + } + }; + const formatBytes = (bytes: number): string => { if (bytes === 0) return '0 B'; const k = 1024; @@ -205,7 +314,7 @@ export function GpuAcceleration() { // Don't render until health data is available if (!health) return null; - // If the system already has native GPU (MPS, etc.), only show info - no CUDA needed + // If the system already has native GPU (MPS, ROCm active, etc.), only show info - no download needed const hasNativeGpu = health.gpu_available && !isCurrentlyCuda && @@ -241,8 +350,6 @@ export function GpuAcceleration() { )} - {/* Native GPU detected - no CUDA download needed */} - {/* Currently running CUDA - show switch back to CPU */} {isCurrentlyCuda && platform.metadata.isTauri && ( <> @@ -261,7 +368,50 @@ export function GpuAcceleration() { Running with CUDA GPU acceleration. Switch back to CPU if needed (you can re-download later).

- + + )} + {error && ( +
+ + {error} +
+ )} + + )} + + {/* Currently running ROCm - show switch back to CPU */} + {isCurrentlyRocm && platform.metadata.isTauri && ( + <> + {restartPhase !== 'idle' ? ( +
+ + + {restartPhase === 'stopping' && 'Stopping server...'} + {restartPhase === 'waiting' && 'Restarting server...'} + {restartPhase === 'ready' && 'Server restarted successfully!'} + +
+ ) : ( +
+

+ Running with ROCm GPU acceleration for AMD. Switch back to CPU if needed (you can + re-download later). +

+ @@ -276,39 +426,169 @@ export function GpuAcceleration() { )} - {/* CUDA download/manage section - show when no native GPU and not currently running CUDA */} - {!hasNativeGpu && !isCurrentlyCuda && ( + {/* Backend download/manage sections - show when no native GPU and not currently running GPU */} + {!hasNativeGpu && !isCurrentlyCuda && !isCurrentlyRocm && ( <> - {/* Download progress (manual download or auto-update) */} - {cudaDownloading && downloadProgress && ( -
-
-
- - - {downloadProgress.filename || - (cudaAvailable - ? 'Updating CUDA backend...' - : 'Downloading CUDA backend...')} - + {/* CUDA Section */} +
+
NVIDIA (CUDA)
+ + {/* CUDA Download progress */} + {cudaDownloading && downloadProgress && ( +
+
+
+ + + {downloadProgress.filename || + (cudaAvailable + ? 'Updating CUDA backend...' + : 'Downloading CUDA backend...')} + +
+ {downloadProgress.total > 0 && ( + + {downloadProgress.progress.toFixed(1)}% + + )}
{downloadProgress.total > 0 && ( - - {downloadProgress.progress.toFixed(1)}% - + <> + +
+ {formatBytes(downloadProgress.current)} /{' '} + {formatBytes(downloadProgress.total)} +
+ )}
- {downloadProgress.total > 0 && ( - <> - -
- {formatBytes(downloadProgress.current)} /{' '} - {formatBytes(downloadProgress.total)} + )} + + {/* CUDA Actions */} + {restartPhase === 'idle' && !cudaDownloading && ( +
+ {!cudaAvailable && ( +
+

+ Download the CUDA backend (~2.4 GB) for NVIDIA GPU acceleration. Requires an + NVIDIA GPU with CUDA support. +

+
- - )} -
- )} + )} + + {cudaAvailable && platform.metadata.isTauri && ( +
+

+ CUDA backend is downloaded and ready. Restart the server to enable GPU + acceleration. +

+ +
+ )} + + {cudaAvailable && ( + + )} +
+ )} +
+ + {/* Divider */} +
+ + {/* ROCm Section */} +
+
AMD (ROCm)
+ + {/* ROCm Download progress */} + {rocmDownloading && rocmDownloadProgress && ( +
+
+
+ + + {rocmDownloadProgress.filename || + (rocmAvailable + ? 'Updating ROCm backend...' + : 'Downloading ROCm backend...')} + +
+ {rocmDownloadProgress.total > 0 && ( + + {rocmDownloadProgress.progress.toFixed(1)}% + + )} +
+ {rocmDownloadProgress.total > 0 && ( + <> + +
+ {formatBytes(rocmDownloadProgress.current)} /{' '} + {formatBytes(rocmDownloadProgress.total)} +
+ + )} +
+ )} + + {/* ROCm Actions */} + {restartPhase === 'idle' && !rocmDownloading && ( +
+ {!rocmAvailable && ( +
+

+ Download the ROCm backend (~2-3 GB) for AMD GPU acceleration. Requires an + AMD Radeon GPU with ROCm support. +

+ +
+ )} + + {rocmAvailable && platform.metadata.isTauri && ( +
+

+ ROCm backend is downloaded and ready. Restart the server to enable AMD GPU + acceleration. +

+ +
+ )} + + {rocmAvailable && ( + + )} +
+ )} +
{/* Restart in progress */} {restartPhase !== 'idle' && ( @@ -329,52 +609,6 @@ export function GpuAcceleration() { {error}
)} - - {/* Actions */} - {restartPhase === 'idle' && !cudaDownloading && ( -
- {/* Not downloaded yet - show download button */} - {!cudaAvailable && ( -
-

- Download the CUDA backend (~2.4 GB) for NVIDIA GPU acceleration. Requires an - NVIDIA GPU with CUDA support. -

- -
- )} - - {/* Downloaded but not active - show switch button */} - {cudaAvailable && platform.metadata.isTauri && ( -
-

- CUDA backend is downloaded and ready. Restart the server to enable GPU - acceleration. -

- -
- )} - - {/* Delete option when downloaded (and not active) */} - {cudaAvailable && ( - - )} -
- )} )} diff --git a/app/src/components/ServerTab/GpuPage.tsx b/app/src/components/ServerTab/GpuPage.tsx index 0caae3aa..8613e65e 100644 --- a/app/src/components/ServerTab/GpuPage.tsx +++ b/app/src/components/ServerTab/GpuPage.tsx @@ -5,7 +5,7 @@ import { useTranslation } from 'react-i18next'; import { Button } from '@/components/ui/button'; import { Progress } from '@/components/ui/progress'; import { apiClient } from '@/lib/api/client'; -import type { CudaDownloadProgress, HealthResponse } from '@/lib/api/types'; +import type { CudaDownloadProgress, RocmDownloadProgress, HealthResponse } from '@/lib/api/types'; import { useServerHealth } from '@/lib/hooks/useServer'; import { usePlatform } from '@/platform/PlatformContext'; import { useServerStore } from '@/stores/serverStore'; @@ -50,7 +50,10 @@ function GpuInfoCard({ health }: { health: HealthResponse }) { : null; const gpuBackend = hasGpu ? health.gpu_type!.replace(/\s*\(.+\)$/, '') : null; const isApple = gpuBackend === 'MPS' || gpuBackend === 'Metal'; - const showBackendVariant = health.backend_variant && health.backend_variant !== 'cpu'; + const showBackendVariant = + health.backend_variant && + health.backend_variant !== 'cpu' && + health.backend_variant.toLowerCase() !== gpuBackend?.toLowerCase(); return (
@@ -115,10 +118,14 @@ export function GpuPage() { const [restartPhase, setRestartPhase] = useState('idle'); const [error, setError] = useState(null); + const [cudaStreaming, setCudaStreaming] = useState(false); + const [rocmStreaming, setRocmStreaming] = useState(false); const [downloadProgress, setDownloadProgress] = useState(null); + const [rocmDownloadProgress, setRocmDownloadProgress] = useState( + null, + ); const healthPollRef = useRef | null>(null); - // Hold the latest `t` in a ref so the CUDA progress SSE effect below doesn't - // tear down and reconnect the EventSource every time the language changes. + const tRef = useRef(t); useEffect(() => { tRef.current = t; @@ -136,9 +143,24 @@ export function GpuPage() { enabled: !!health, }); + const { + data: rocmStatus, + isLoading: _rocmStatusLoading, + refetch: refetchRocmStatus, + } = useQuery({ + queryKey: ['rocm-status', serverUrl], + queryFn: () => apiClient.getRocmStatus(), + refetchInterval: (query) => (query.state.status === 'pending' ? false : 10000), + retry: 1, + enabled: !!health, + }); + const isCurrentlyCuda = health?.backend_variant === 'cuda'; + const isCurrentlyRocm = health?.backend_variant === 'rocm'; const cudaAvailable = cudaStatus?.available ?? false; const cudaDownloading = cudaStatus?.downloading ?? false; + const rocmAvailable = rocmStatus?.available ?? false; + const rocmDownloading = rocmStatus?.downloading ?? false; useEffect(() => { return () => { @@ -150,7 +172,7 @@ export function GpuPage() { }, []); useEffect(() => { - if (!cudaDownloading || !serverUrl) return; + if ((!cudaDownloading && !cudaStreaming) || !serverUrl) return; const eventSource = new EventSource(`${serverUrl}/backend/cuda-progress`); @@ -162,11 +184,13 @@ export function GpuPage() { if (data.status === 'complete') { eventSource.close(); setDownloadProgress(null); + setCudaStreaming(false); refetchCudaStatus(); } else if (data.status === 'error') { eventSource.close(); setError(data.error || tRef.current('settings.gpu.errors.downloadFailed')); setDownloadProgress(null); + setCudaStreaming(false); refetchCudaStatus(); } } catch (e) { @@ -176,12 +200,50 @@ export function GpuPage() { eventSource.onerror = () => { eventSource.close(); + setCudaStreaming(false); + }; + + return () => { + eventSource.close(); + }; + }, [cudaDownloading, cudaStreaming, serverUrl, refetchCudaStatus]); + + useEffect(() => { + if ((!rocmDownloading && !rocmStreaming) || !serverUrl) return; + + const eventSource = new EventSource(`${serverUrl}/backend/rocm-progress`); + + eventSource.onmessage = (event) => { + try { + const data = JSON.parse(event.data) as RocmDownloadProgress; + setRocmDownloadProgress(data); + + if (data.status === 'complete') { + eventSource.close(); + setRocmDownloadProgress(null); + setRocmStreaming(false); + refetchRocmStatus(); + } else if (data.status === 'error') { + eventSource.close(); + setError(data.error || tRef.current('settings.gpu.errors.downloadFailed')); + setRocmDownloadProgress(null); + setRocmStreaming(false); + refetchRocmStatus(); + } + } catch (e) { + console.error('Error parsing ROCm progress event:', e); + } + }; + + eventSource.onerror = () => { + eventSource.close(); + setRocmStreaming(false); }; return () => { eventSource.close(); }; - }, [cudaDownloading, serverUrl, refetchCudaStatus]); + }, [rocmDownloading, rocmStreaming, serverUrl, refetchRocmStatus]); const clearHealthPolling = useCallback(() => { if (healthPollRef.current) { @@ -224,10 +286,11 @@ export function GpuPage() { [platform, startHealthPolling, clearHealthPolling], ); - const handleDownload = async () => { + const handleDownloadCuda = async () => { setError(null); try { await apiClient.downloadCudaBackend(); + setCudaStreaming(true); refetchCudaStatus(); } catch (e: unknown) { const msg = e instanceof Error ? e.message : t('settings.gpu.errors.downloadStart'); @@ -239,6 +302,22 @@ export function GpuPage() { } }; + const handleDownloadRocm = async () => { + setError(null); + try { + await apiClient.downloadRocmBackend(); + setRocmStreaming(true); + refetchRocmStatus(); + } catch (e: unknown) { + const msg = e instanceof Error ? e.message : t('settings.gpu.errors.downloadStart'); + if (msg.includes('already downloaded')) { + refetchRocmStatus(); + } else { + setError(msg); + } + } + }; + const handleRestart = async () => { setError(null); try { @@ -252,15 +331,43 @@ export function GpuPage() { setError(null); setRestartPhase('stopping'); try { - await apiClient.deleteCudaBackend(); + await platform.lifecycle.setBackendOverride('cpu'); await restartServerWithPolling(t('settings.gpu.errors.switchCpu')); } catch (e: unknown) { + setRestartPhase('idle'); setError(e instanceof Error ? e.message : t('settings.gpu.errors.switchCpu')); refetchCudaStatus(); + refetchRocmStatus(); } }; - const handleDelete = async () => { + const handleSwitchToCuda = async () => { + setError(null); + setRestartPhase('stopping'); + try { + await platform.lifecycle.setBackendOverride('cuda'); + await restartServerWithPolling(t('settings.gpu.errors.restartFailed')); + } catch (e: unknown) { + setRestartPhase('idle'); + setError(e instanceof Error ? e.message : t('settings.gpu.errors.restartFailed')); + refetchCudaStatus(); + } + }; + + const handleSwitchToRocm = async () => { + setError(null); + setRestartPhase('stopping'); + try { + await platform.lifecycle.setBackendOverride('rocm'); + await restartServerWithPolling(t('settings.gpu.errors.restartFailed')); + } catch (e: unknown) { + setRestartPhase('idle'); + setError(e instanceof Error ? e.message : t('settings.gpu.errors.restartFailed')); + refetchRocmStatus(); + } + }; + + const handleDeleteCuda = async () => { setError(null); try { await apiClient.deleteCudaBackend(); @@ -270,6 +377,16 @@ export function GpuPage() { } }; + const handleDeleteRocm = async () => { + setError(null); + try { + await apiClient.deleteRocmBackend(); + refetchRocmStatus(); + } catch (e: unknown) { + setError(e instanceof Error ? e.message : t('settings.gpu.errors.deleteRocm')); + } + }; + const formatBytes = (bytes: number): string => { if (bytes === 0) return '0 B'; const k = 1024; @@ -283,6 +400,7 @@ export function GpuPage() { const hasNativeGpu = health.gpu_available && !isCurrentlyCuda && + !isCurrentlyRocm && health.gpu_type && !health.gpu_type.includes('CUDA'); @@ -290,33 +408,186 @@ export function GpuPage() {
- {!hasNativeGpu && !isCurrentlyCuda && ( - - {cudaDownloading && downloadProgress && ( - -
- -
- - {downloadProgress.filename || - (cudaAvailable - ? t('settings.gpu.cuda.updating') - : t('settings.gpu.cuda.downloadingShort'))} - - - {downloadProgress.total > 0 - ? `${formatBytes(downloadProgress.current)} / ${formatBytes(downloadProgress.total)}` - : `${downloadProgress.progress.toFixed(1)}%`} - + {!hasNativeGpu && !isCurrentlyCuda && !isCurrentlyRocm && ( + <> + + {cudaDownloading && downloadProgress && ( + +
+ +
+ + {downloadProgress.filename || + (cudaAvailable + ? t('settings.gpu.cuda.updating') + : t('settings.gpu.cuda.downloadingShort'))} + + + {downloadProgress.total > 0 + ? `${formatBytes(downloadProgress.current)} / ${formatBytes(downloadProgress.total)}` + : `${downloadProgress.progress.toFixed(1)}%`} + +
-
- - )} + + )} + + {restartPhase !== 'idle' && ( + } + /> + )} + + {error && ( + +
+ + {error} +
+
+ )} + + {restartPhase === 'idle' && !cudaDownloading && ( + <> + {!cudaAvailable && !isCurrentlyCuda && ( + + + {t('settings.gpu.download.button')} + + } + /> + )} + + {cudaAvailable && !isCurrentlyCuda && platform.metadata.isTauri && ( + + + {t('settings.gpu.switchToCuda.button')} + + } + /> + )} - {restartPhase !== 'idle' && ( + {cudaAvailable && !isCurrentlyCuda && ( + + + {t('settings.gpu.remove.button')} + + } + /> + )} + + )} + + + + {rocmDownloading && rocmDownloadProgress && ( + +
+ +
+ + {rocmDownloadProgress.filename || + (rocmAvailable + ? t('settings.gpu.rocm.updating') + : t('settings.gpu.rocm.downloadingShort'))} + + + {rocmDownloadProgress.total > 0 + ? `${formatBytes(rocmDownloadProgress.current)} / ${formatBytes(rocmDownloadProgress.total)}` + : `${rocmDownloadProgress.progress.toFixed(1)}%`} + +
+
+
+ )} + + {restartPhase === 'idle' && !rocmDownloading && ( + <> + {!rocmAvailable && !isCurrentlyRocm && ( + + + {t('settings.gpu.downloadRocm.button')} + + } + /> + )} + + {rocmAvailable && !isCurrentlyRocm && platform.metadata.isTauri && ( + + + {t('settings.gpu.switchToRocm.button')} + + } + /> + )} + + {rocmAvailable && !isCurrentlyRocm && ( + + + {t('settings.gpu.removeRocm.button')} + + } + /> + )} + + )} +
+ + )} + + {(isCurrentlyCuda || isCurrentlyRocm) && platform.metadata.isTauri && ( + + {restartPhase !== 'idle' ? ( } /> + ) : ( + + + {t('settings.gpu.switchToCpu.button')} + + } + /> )} - {error && (
@@ -337,67 +618,6 @@ export function GpuPage() {
)} - - {restartPhase === 'idle' && !cudaDownloading && ( - <> - {!cudaAvailable && !isCurrentlyCuda && ( - - - {t('settings.gpu.download.button')} - - } - /> - )} - - {cudaAvailable && !isCurrentlyCuda && platform.metadata.isTauri && ( - - - {t('settings.gpu.switchToCuda.button')} - - } - /> - )} - - {isCurrentlyCuda && platform.metadata.isTauri && ( - - - {t('settings.gpu.switchToCpu.button')} - - } - /> - )} - - {cudaAvailable && !isCurrentlyCuda && ( - - - {t('settings.gpu.remove.button')} - - } - /> - )} - - )}
)} diff --git a/app/src/i18n/locales/en/translation.json b/app/src/i18n/locales/en/translation.json index a449b25f..e132b7cf 100644 --- a/app/src/i18n/locales/en/translation.json +++ b/app/src/i18n/locales/en/translation.json @@ -581,8 +581,13 @@ "description": "Choose the display language for Voicebox." }, "general": { - "docs": { "title": "Read the Docs" }, - "discord": { "title": "Join the Discord", "subtitle": "Get help & share voices" }, + "docs": { + "title": "Read the Docs" + }, + "discord": { + "title": "Join the Discord", + "subtitle": "Get help & share voices" + }, "serverUrl": { "title": "Server URL", "description": "The address of your voicebox backend server.", @@ -685,11 +690,15 @@ "active": "Active", "cuda": { "title": "CUDA Backend", + "activeTitle": "CUDA Backend Active", "description": "NVIDIA GPU acceleration via a downloadable CUDA backend.", "downloading": "Downloading CUDA backend…", "downloadingShort": "Downloading…", "updating": "Updating…" }, + "activeBackend": { + "description": "GPU acceleration is currently enabled." + }, "restart": { "ready": "Server restarted successfully", "waiting": "Restarting server…", @@ -707,10 +716,9 @@ }, "switchToCpu": { "title": "Switch to CPU backend", - "description": "Disable GPU acceleration. You can re-download CUDA later.", + "description": "Disable GPU acceleration. You can re-download the GPU backend later.", "button": "Switch" - }, - "remove": { + }, "remove": { "title": "Remove CUDA backend", "description": "Delete the downloaded CUDA binary to free disk space.", "button": "Remove" @@ -720,9 +728,33 @@ "downloadStart": "Failed to start download", "restartFailed": "Restart failed", "switchCpu": "Failed to switch to CPU", - "deleteCuda": "Failed to delete CUDA backend" + "deleteCuda": "Failed to delete CUDA backend", + "deleteRocm": "Failed to delete ROCm backend" + }, + "footer": "Voicebox automatically detects and uses the best available GPU on your system. On Apple Silicon Macs, the MLX backend runs natively on the Neural Engine and GPU via Metal Performance Shaders (MPS), with no additional setup required. On Windows, you can download optional CUDA (NVIDIA) or ROCm (AMD) backends for hardware-accelerated inference. Intel XPU and DirectML are also supported where available through PyTorch. When no GPU is detected, Voicebox falls back to CPU — all engines still work, just slower.", + "rocm": { + "title": "AMD ROCm Backend", + "activeTitle": "ROCm Backend Active", + "description": "AMD GPU acceleration via a downloadable ROCm backend.", + "downloading": "Downloading ROCm backend…", + "downloadingShort": "Downloading…", + "updating": "Updating…" }, - "footer": "Voicebox automatically detects and uses the best available GPU on your system. On Apple Silicon Macs, the MLX backend runs natively on the Neural Engine and GPU via Metal Performance Shaders (MPS), with no additional setup required. On Windows and Linux with NVIDIA GPUs, you can download an optional CUDA backend for hardware-accelerated inference. AMD ROCm, Intel XPU, and DirectML are also supported where available through PyTorch. When no GPU is detected, Voicebox falls back to CPU — all engines still work, just slower." + "downloadRocm": { + "title": "Download AMD ROCm backend", + "description": "~2-3 GB download. Requires an AMD Radeon GPU with ROCm support.", + "button": "Download" + }, + "switchToRocm": { + "title": "Switch to ROCm backend", + "description": "ROCm backend is downloaded and ready. Restart to enable.", + "button": "Restart" + }, + "removeRocm": { + "title": "Remove ROCm backend", + "description": "Delete the downloaded ROCm binary to free disk space.", + "button": "Remove" + } }, "logs": { "title": "Server Logs", diff --git a/app/src/lib/api/client.ts b/app/src/lib/api/client.ts index 045ea454..08040ebc 100644 --- a/app/src/lib/api/client.ts +++ b/app/src/lib/api/client.ts @@ -19,6 +19,7 @@ import type { ModelStatusListResponse, PresetVoice, ProfileSampleResponse, + RocmStatus, StoryCreate, StoryDetailResponse, StoryItemBatchUpdate, @@ -536,6 +537,23 @@ class ApiClient { }); } + // ROCm Backend Management + async getRocmStatus(): Promise { + return this.request('/backend/rocm-status'); + } + + async downloadRocmBackend(): Promise<{ message: string; progress_key: string }> { + return this.request<{ message: string; progress_key: string }>('/backend/download-rocm', { + method: 'POST', + }); + } + + async deleteRocmBackend(): Promise<{ message: string }> { + return this.request<{ message: string }>('/backend/rocm', { + method: 'DELETE', + }); + } + // Stories async listStories(): Promise { return this.request('/stories'); diff --git a/app/src/lib/api/models/ModelStatus.ts b/app/src/lib/api/models/ModelStatus.ts index 0d744893..fdba4285 100644 --- a/app/src/lib/api/models/ModelStatus.ts +++ b/app/src/lib/api/models/ModelStatus.ts @@ -9,7 +9,7 @@ export type ModelStatus = { model_name: string; display_name: string; downloaded: boolean; - downloading?: boolean; // True if download is in progress + downloading?: boolean; // True if download is in progress size_mb?: number | null; loaded?: boolean; }; diff --git a/app/src/lib/api/types.ts b/app/src/lib/api/types.ts index 86e3012f..1516dde8 100644 --- a/app/src/lib/api/types.ts +++ b/app/src/lib/api/types.ts @@ -168,6 +168,26 @@ export interface CudaStatus { download_progress?: CudaDownloadProgress; } +export interface RocmDownloadProgress { + model_name: string; + current: number; + total: number; + progress: number; + filename?: string; + status: 'downloading' | 'extracting' | 'complete' | 'error'; + timestamp: string; + error?: string; +} + +export interface RocmStatus { + available: boolean; // ROCm binary exists on disk + active: boolean; // Currently running the ROCm binary + binary_path?: string; + rocm_libs_version?: string; + downloading: boolean; // Download in progress + download_progress?: RocmDownloadProgress; +} + export interface ModelProgress { model_name: string; current: number; diff --git a/app/src/platform/types.ts b/app/src/platform/types.ts index c6b06f44..2e11af9a 100644 --- a/app/src/platform/types.ts +++ b/app/src/platform/types.ts @@ -60,6 +60,7 @@ export interface PlatformLifecycle { stopServer(): Promise; restartServer(modelsDir?: string | null): Promise; setKeepServerRunning(keep: boolean): Promise; + setBackendOverride(backend?: string | null): Promise; setupWindowCloseHandler(): Promise; subscribeToServerLogs(callback: (entry: ServerLogEntry) => void): () => void; onServerReady?: () => void; diff --git a/app/src/test/setup.ts b/app/src/test/setup.ts new file mode 100644 index 00000000..7b0828bf --- /dev/null +++ b/app/src/test/setup.ts @@ -0,0 +1 @@ +import '@testing-library/jest-dom'; diff --git a/app/vite.config.ts b/app/vite.config.ts index 36bc168b..86b21952 100644 --- a/app/vite.config.ts +++ b/app/vite.config.ts @@ -1,7 +1,7 @@ import path from 'node:path'; import tailwindcss from '@tailwindcss/vite'; import react from '@vitejs/plugin-react'; -import { defineConfig } from 'vite'; +import { defineConfig } from 'vitest/config'; import { changelogPlugin } from './plugins/changelog'; export default defineConfig({ @@ -11,4 +11,9 @@ export default defineConfig({ '@': path.resolve(__dirname, './src'), }, }, + test: { + globals: true, + environment: 'jsdom', + setupFiles: ['./src/test/setup.ts'], + }, }); diff --git a/backend/app.py b/backend/app.py index 1cbac8a1..78ddcf10 100644 --- a/backend/app.py +++ b/backend/app.py @@ -35,11 +35,24 @@ def format(self, record): logger = logging.getLogger(__name__) -# AMD GPU environment variables must be set before torch import -if not os.environ.get("HSA_OVERRIDE_GFX_VERSION"): - os.environ["HSA_OVERRIDE_GFX_VERSION"] = "10.3.0" -if not os.environ.get("MIOPEN_LOG_LEVEL"): - os.environ["MIOPEN_LOG_LEVEL"] = "4" +# HSA_OVERRIDE_GFX_VERSION=10.3.0 is an RDNA2 compatibility shim for older +# PyTorch ROCm builds. Only set it on the ROCm binary variant — applying it +# globally breaks CUDA and CPU builds, and is unnecessary on RDNA3/4 hardware. +if os.environ.get("VOICEBOX_BACKEND_VARIANT") == "rocm": + import platform + + if ( + os.environ.get("VOICEBOX_ROCM_FORCE_GFX1030") == "1" + and platform.system() == "Linux" + and not os.environ.get("HSA_OVERRIDE_GFX_VERSION") + ): + os.environ["HSA_OVERRIDE_GFX_VERSION"] = "10.3.0" + + if ( + os.environ.get("VOICEBOX_ROCM_ENABLE_MIOPEN_LOG") == "1" + and not os.environ.get("MIOPEN_LOG_LEVEL") + ): + os.environ["MIOPEN_LOG_LEVEL"] = "4" import torch from fastapi import FastAPI @@ -245,8 +258,10 @@ async def startup_event(): logger.warning("GPU COMPATIBILITY: %s", _cuda_warning) from .services.cuda import check_and_update_cuda_binary + from .services.rocm import check_and_update_rocm_binary create_background_task(check_and_update_cuda_binary()) + create_background_task(check_and_update_rocm_binary()) try: progress_manager = get_progress_manager() diff --git a/backend/backends/base.py b/backend/backends/base.py index c566af10..70ec11ef 100644 --- a/backend/backends/base.py +++ b/backend/backends/base.py @@ -138,6 +138,11 @@ def check_cuda_compatibility() -> tuple[bool, str | None]: if not torch.cuda.is_available(): return True, None + # ROCm/HIP uses the cuda frontend but has different architecture names (gfx*). + # Skip NVIDIA-specific compute capability checks on AMD hardware. + if hasattr(torch.version, "hip") and torch.version.hip: + return True, None + major, minor = torch.cuda.get_device_capability(0) capability = f"{major}.{minor}" device_name = torch.cuda.get_device_name(0) diff --git a/backend/backends/hume_backend.py b/backend/backends/hume_backend.py index ecaa29b7..ac51b775 100644 --- a/backend/backends/hume_backend.py +++ b/backend/backends/hume_backend.py @@ -146,7 +146,15 @@ def _load_model_sync(self, model_size: str = "1B"): ) # Determine dtype — use bf16 on CUDA/XPU for ~50% memory savings - if device == "cuda" and torch.cuda.is_bf16_supported(): + # On ROCm/AMD, torch.cuda.is_bf16_supported() works via the HIP abstraction, + # but we wrap it defensively in case an older build lacks the symbol. + _bf16_ok = False + if device == "cuda": + try: + _bf16_ok = torch.cuda.is_bf16_supported() + except Exception: + _bf16_ok = False + if _bf16_ok: model_dtype = torch.bfloat16 elif device == "xpu": # Intel Arc (Alchemist+) supports bf16 natively diff --git a/backend/backends/qwen_custom_voice_backend.py b/backend/backends/qwen_custom_voice_backend.py index 518f8926..5a75d3b5 100644 --- a/backend/backends/qwen_custom_voice_backend.py +++ b/backend/backends/qwen_custom_voice_backend.py @@ -78,7 +78,14 @@ def _get_model_path(self, model_size: str) -> str: def _is_model_cached(self, model_size: Optional[str] = None) -> bool: size = model_size or self.model_size - return is_model_cached(self._get_model_path(size)) + cv_repo = self._get_model_path(size) + # CustomVoice models depend on the Base model's tokenizer/processor config. + # If the Base model is not cached, the AutoProcessor will try to download + # it and crash if we force offline mode. + base_repo = f"Qwen/Qwen3-TTS-12Hz-{size}-Base" + return is_model_cached(cv_repo) and is_model_cached( + base_repo, required_files=["preprocessor_config.json"] + ) async def load_model_async(self, model_size: Optional[str] = None) -> None: if model_size is None: @@ -101,6 +108,8 @@ def _load_model_sync(self, model_size: str) -> None: with model_load_progress(model_name, is_cached): from qwen_tts import Qwen3TTSModel + from huggingface_hub import constants as hf_constants + tts_cache_dir = hf_constants.HF_HUB_CACHE model_path = self._get_model_path(model_size) logger.info("Loading Qwen CustomVoice %s on %s...", model_size, self.device) @@ -109,12 +118,14 @@ def _load_model_sync(self, model_size: str) -> None: if self.device == "cpu": self.model = Qwen3TTSModel.from_pretrained( model_path, + cache_dir=tts_cache_dir, torch_dtype=torch.float32, low_cpu_mem_usage=False, ) else: self.model = Qwen3TTSModel.from_pretrained( model_path, + cache_dir=tts_cache_dir, device_map=self.device, torch_dtype=torch.bfloat16, ) diff --git a/backend/build_binary.py b/backend/build_binary.py index 52bacbfe..f576780c 100644 --- a/backend/build_binary.py +++ b/backend/build_binary.py @@ -22,24 +22,34 @@ def is_apple_silicon(): return platform.system() == "Darwin" and platform.machine() == "arm64" -def build_server(cuda=False): +def build_server(cuda=False, rocm=False): """Build Python server as standalone binary. Args: cuda: If True, build with CUDA support and name the binary voicebox-server-cuda instead of voicebox-server. + rocm: If True, build with ROCm support and name the binary + voicebox-server-rocm instead of voicebox-server. """ + if cuda and rocm: + raise ValueError("Cannot build with both CUDA and ROCm support") + backend_dir = Path(__file__).parent - binary_name = "voicebox-server-cuda" if cuda else "voicebox-server" + if rocm: + binary_name = "voicebox-server-rocm" + elif cuda: + binary_name = "voicebox-server-cuda" + else: + binary_name = "voicebox-server" # PyInstaller arguments - # CUDA builds use --onedir so we can split the output into two archives: + # CUDA and ROCm builds use --onedir so we can split the output into two archives: # 1. Server core (~200-400MB) — versioned with the app - # 2. CUDA libs (~2GB) — versioned independently (only redownloaded on - # CUDA toolkit / torch major version changes) + # 2. GPU libs (~2GB) — versioned independently (only redownloaded on + # GPU toolkit / torch major version changes) # CPU builds remain --onefile for simplicity. - pack_mode = "--onedir" if cuda else "--onefile" + pack_mode = "--onedir" if (cuda or rocm) else "--onefile" args = [ "server.py", # Use server.py as entry point instead of main.py pack_mode, @@ -298,22 +308,74 @@ def build_server(cuda=False): ] ) - # Add CUDA-specific hidden imports - if cuda: - logger.info("Building with CUDA support") + # Add CUDA/ROCm-specific hidden imports + if cuda or rocm: + variant = "ROCm" if rocm else "CUDA" + logger.info("Building with %s support", variant) + gpu_hidden = [ + "--hidden-import", + "torch.cuda", + ] + # cudnn is NVIDIA-specific; ROCm uses MIOpen under the abstraction layer + if cuda: + gpu_hidden.extend( + [ + "--hidden-import", + "torch.backends.cudnn", + ] + ) + args.extend(gpu_hidden) + + if rocm: + # rocm_sdk imports its backend packages dynamically via + # importlib.import_module(py_package_name), which PyInstaller's + # static analyzer cannot see. We must collect them explicitly — + # otherwise only the pure-python rocm_sdk wrapper ships and + # rocm_sdk.find_libraries crashes with UnboundLocalError at boot. + # + # The backend packages also contain the HIP/MIOpen/hipBLAS DLLs + # under bin/ (plus ~750 MB of tensile kernel files under + # bin/rocblas/library and bin/hipblaslt/library) — collect-all + # walks the tree recursively so both DLLs and kernel data are + # bundled. See rocm_sdk/_dist_info.py for the package mapping. args.extend( [ + "--collect-all", + "rocm_sdk", + "--collect-all", + "_rocm_sdk_core", + "--collect-all", + "_rocm_sdk_libraries_custom", + "--collect-all", + "rocm_sdk_core", + "--collect-all", + "rocm_sdk_libraries_custom", "--hidden-import", - "torch.cuda", + "_rocm_sdk_core", "--hidden-import", - "torch.backends.cudnn", + "_rocm_sdk_libraries_custom", + "--hidden-import", + "rocm_sdk_core", + "--hidden-import", + "rocm_sdk_libraries_custom", + "--copy-metadata", + "rocm", + "--copy-metadata", + "rocm-sdk-core", + "--copy-metadata", + "rocm-sdk-libraries-custom", + # Repair rocm_sdk.find_libraries (masks UnboundLocalError + # with a readable ModuleNotFoundError on missing backends). + "--runtime-hook", + "pyi_rth_rocm_sdk.py", ] ) - else: - # Exclude NVIDIA CUDA packages from CPU-only builds to keep binary small. - # When building from a venv with CUDA torch installed, PyInstaller would - # bundle ~3GB of NVIDIA shared libraries. We exclude both the Python - # modules and the binary DLLs. + + # Exclude NVIDIA CUDA packages from non-CUDA builds to keep binary small. + # When building from a venv with CUDA torch installed, PyInstaller would + # bundle ~3GB of NVIDIA shared libraries. We exclude both the Python + # modules and the binary DLLs. This applies to CPU and ROCm builds. + if not cuda: nvidia_packages = [ "nvidia", "nvidia.cublas", @@ -332,8 +394,8 @@ def build_server(cuda=False): for pkg in nvidia_packages: args.extend(["--exclude-module", pkg]) - # Add MLX-specific imports if building on Apple Silicon (never for CUDA builds) - if is_apple_silicon() and not cuda: + # Add MLX-specific imports if building on Apple Silicon (never for GPU builds) + if is_apple_silicon() and not cuda and not rocm: logger.info("Building for Apple Silicon - including MLX dependencies") args.extend( [ @@ -366,7 +428,7 @@ def build_server(cuda=False): "mlx_audio", ] ) - elif not cuda: + elif not cuda and not rocm: logger.info("Building for non-Apple Silicon platform - PyTorch only") dist_dir = str(backend_dir / "dist") @@ -387,19 +449,131 @@ def build_server(cuda=False): os.chdir(backend_dir) # For CPU builds on Windows, ensure we're using CPU-only torch. - # If CUDA torch is installed (local dev), swap to CPU torch before building, - # then restore CUDA torch after. This prevents PyInstaller from bundling - # ~3GB of CUDA DLLs into the CPU binary. - restore_cuda = False - if not cuda and platform.system() == "Windows": - import subprocess - - result = subprocess.run( - [sys.executable, "-c", "import torch; print(torch.version.cuda or '')"], capture_output=True, text=True - ) - has_cuda_torch = bool(result.stdout.strip()) - if has_cuda_torch: - logger.info("CUDA torch detected — installing CPU torch for CPU build...") + # If CUDA or ROCm torch is installed (local dev), swap to CPU torch before + # building, then restore afterwards. This prevents PyInstaller from bundling + # GPU libraries into the CPU binary. + restore_torch = None + try: + if not cuda and not rocm and platform.system() == "Windows": + import subprocess + + cuda_result = subprocess.run( + [sys.executable, "-c", "import torch; print(torch.version.cuda or '')"], capture_output=True, text=True + ) + rocm_result = subprocess.run( + [sys.executable, "-c", "import torch; print(torch.version.hip or '')"], capture_output=True, text=True + ) + + if cuda_result.stdout.strip(): + restore_torch = "cuda" + logger.info("CUDA torch detected — installing CPU torch for CPU build...") + subprocess.run( + [ + sys.executable, + "-m", + "pip", + "install", + "torch", + "torchvision", + "torchaudio", + "--index-url", + "https://download.pytorch.org/whl/cpu", + "--force-reinstall", + "--no-deps", + "-q", + ], + check=True, + ) + elif rocm_result.stdout.strip(): + restore_torch = "rocm" + logger.info("ROCm torch detected — installing CPU torch for CPU build...") + subprocess.run( + [ + sys.executable, + "-m", + "pip", + "install", + "torch", + "torchvision", + "torchaudio", + "--index-url", + "https://download.pytorch.org/whl/cpu", + "--force-reinstall", + "--no-deps", + "-q", + ], + check=True, + ) + + # For ROCm builds on Windows, ensure ROCm torch is installed. + if rocm and platform.system() == "Windows": + import subprocess + + if sys.implementation.name != "cpython" or sys.version_info[:2] != (3, 12): + raise RuntimeError( + "ROCm wheels are cp312-cp312-specific; " + f"got {sys.implementation.name} {sys.version.split()[0]}. " + "Use CPython 3.12 to build the ROCm binary." + ) + + result = subprocess.run( + [sys.executable, "-c", "import torch; print(torch.version.hip or '')"], capture_output=True, text=True + ) + has_rocm_torch = bool(result.stdout.strip()) + if not has_rocm_torch: + logger.info("ROCm torch not detected — installing ROCm torch for ROCm build...") + + # Determine what to restore BEFORE overwriting the environment + cuda_result = subprocess.run( + [sys.executable, "-c", "import torch; print(torch.version.cuda or '')"], + capture_output=True, + text=True, + ) + if cuda_result.stdout.strip(): + restore_torch = "cuda" + else: + restore_torch = "cpu" + + # Now overwrite the environment safely + subprocess.run( + [ + sys.executable, + "-m", + "pip", + "install", + "https://repo.radeon.com/rocm/windows/rocm-rel-7.2.1/rocm_sdk_core-7.2.1-py3-none-win_amd64.whl", + "https://repo.radeon.com/rocm/windows/rocm-rel-7.2.1/rocm_sdk_devel-7.2.1-py3-none-win_amd64.whl", + "https://repo.radeon.com/rocm/windows/rocm-rel-7.2.1/rocm_sdk_libraries_custom-7.2.1-py3-none-win_amd64.whl", + "https://repo.radeon.com/rocm/windows/rocm-rel-7.2.1/rocm-7.2.1.tar.gz", + "--no-deps", + "-q", + ], + check=True, + ) + subprocess.run( + [ + sys.executable, + "-m", + "pip", + "install", + "https://repo.radeon.com/rocm/windows/rocm-rel-7.2.1/torch-2.9.1%2Brocm7.2.1-cp312-cp312-win_amd64.whl", + "https://repo.radeon.com/rocm/windows/rocm-rel-7.2.1/torchaudio-2.9.1%2Brocm7.2.1-cp312-cp312-win_amd64.whl", + "https://repo.radeon.com/rocm/windows/rocm-rel-7.2.1/torchvision-0.24.1%2Brocm7.2.1-cp312-cp312-win_amd64.whl", + "--force-reinstall", + "--no-deps", + "-q", + ], + check=True, + ) + + # Run PyInstaller + PyInstaller.__main__.run(args) + finally: + # Restore torch if we swapped it out (even on build failure) + if restore_torch == "cuda": + logger.info("Restoring CUDA torch...") + import subprocess + subprocess.run( [ sys.executable, @@ -410,21 +584,34 @@ def build_server(cuda=False): "torchvision", "torchaudio", "--index-url", - "https://download.pytorch.org/whl/cpu", + "https://download.pytorch.org/whl/cu128", "--force-reinstall", + "--no-deps", "-q", ], check=True, ) - restore_cuda = True + elif restore_torch == "rocm": + logger.info("Restoring ROCm torch...") + import subprocess - # Run PyInstaller - try: - PyInstaller.__main__.run(args) - finally: - # Restore CUDA torch if we swapped it out (even on build failure) - if restore_cuda: - logger.info("Restoring CUDA torch...") + subprocess.run( + [ + sys.executable, + "-m", + "pip", + "install", + "https://repo.radeon.com/rocm/windows/rocm-rel-7.2.1/torch-2.9.1%2Brocm7.2.1-cp312-cp312-win_amd64.whl", + "https://repo.radeon.com/rocm/windows/rocm-rel-7.2.1/torchaudio-2.9.1%2Brocm7.2.1-cp312-cp312-win_amd64.whl", + "https://repo.radeon.com/rocm/windows/rocm-rel-7.2.1/torchvision-0.24.1%2Brocm7.2.1-cp312-cp312-win_amd64.whl", + "--force-reinstall", + "--no-deps", + "-q", + ], + check=True, + ) + elif restore_torch == "cpu": + logger.info("Restoring CPU torch...") import subprocess subprocess.run( @@ -437,13 +624,15 @@ def build_server(cuda=False): "torchvision", "torchaudio", "--index-url", - "https://download.pytorch.org/whl/cu128", + "https://download.pytorch.org/whl/cpu", "--force-reinstall", + "--no-deps", "-q", ], check=True, ) + logger.info("Binary built in %s", backend_dir / "dist" / binary_name) @@ -454,5 +643,10 @@ def build_server(cuda=False): action="store_true", help="Build CUDA-enabled binary (voicebox-server-cuda)", ) + parser.add_argument( + "--rocm", + action="store_true", + help="Build ROCm-enabled binary (voicebox-server-rocm) for AMD GPUs", + ) cli_args = parser.parse_args() - build_server(cuda=cli_args.cuda) + build_server(cuda=cli_args.cuda, rocm=cli_args.rocm) diff --git a/backend/pyi_rth_rocm_sdk.py b/backend/pyi_rth_rocm_sdk.py new file mode 100644 index 00000000..ffb51aad --- /dev/null +++ b/backend/pyi_rth_rocm_sdk.py @@ -0,0 +1,85 @@ +""" +Runtime hook: repair rocm_sdk.find_libraries under PyInstaller. + +rocm_sdk 7.2.x ships a find_libraries() with a latent bug: when the +backend package (_rocm_sdk_core / _rocm_sdk_libraries_{target}) cannot +be imported, the except clause records the miss but falls through to +`py_root = Path(py_module.__file__).parent`, where py_module was never +assigned. This surfaces as UnboundLocalError instead of the intended +ModuleNotFoundError, masking the real cause. + +Frozen apps trip this because rocm_sdk imports the backend packages +dynamically via importlib, which PyInstaller's static analyzer cannot +see. We re-collect those packages in build_binary.py; this hook is +defense-in-depth: it replaces find_libraries with a corrected version +so any future missing-package case surfaces a readable error. +""" + + +def _patch_rocm_sdk(): + try: + import rocm_sdk + from rocm_sdk import _dist_info + except ModuleNotFoundError as e: + if e.name not in {"rocm_sdk", "rocm_sdk._dist_info"}: + raise + return + + import importlib + import platform + from pathlib import Path + + def find_libraries(*shortnames): + paths = [] + missing_extras = set() + is_windows = platform.system() == "Windows" + for shortname in shortnames: + try: + lib_entry = _dist_info.ALL_LIBRARIES[shortname] + except KeyError: + raise ModuleNotFoundError(f"Unknown rocm library '{shortname}'") from None + + if is_windows and not lib_entry.dll_pattern: + continue + + package = lib_entry.package + target_family = None + if package.is_target_specific: + target_family = _dist_info.determine_target_family() + py_package_name = package.get_py_package_name(target_family) + try: + py_module = importlib.import_module(py_package_name) + except ModuleNotFoundError as e: + if e.name != py_package_name: + raise + missing_extras.add(package.logical_name) + continue + + py_root = Path(py_module.__file__).parent + if is_windows: + relpath = py_root / lib_entry.windows_relpath + entry_pattern = lib_entry.dll_pattern + else: + relpath = py_root / lib_entry.posix_relpath + entry_pattern = lib_entry.so_pattern + matching_paths = sorted(relpath.glob(entry_pattern)) + if len(matching_paths) == 0: + raise FileNotFoundError( + f"Could not find rocm library '{shortname}' at path " + f"'{relpath},' no match for pattern '{entry_pattern}'" + ) + paths.append(matching_paths[0]) + + if missing_extras: + raise ModuleNotFoundError( + f"Missing required rocm backend packages: " + f"{', '.join(sorted(missing_extras))}. The frozen build did " + f"not bundle _rocm_sdk_core / _rocm_sdk_libraries_. " + f"Check build_binary.py --collect-all flags." + ) + return paths + + rocm_sdk.find_libraries = find_libraries + + +_patch_rocm_sdk() diff --git a/backend/requirements-rocm.txt b/backend/requirements-rocm.txt new file mode 100644 index 00000000..8fd6d1b3 --- /dev/null +++ b/backend/requirements-rocm.txt @@ -0,0 +1,4 @@ +--extra-index-url https://repo.radeon.com/rocm/windows/rocm-rel-7.2.1/ +torch==2.9.1+rocm7.2.1 +torchaudio==2.9.1+rocm7.2.1 +torchvision==0.24.1+rocm7.2.1 \ No newline at end of file diff --git a/backend/routes/__init__.py b/backend/routes/__init__.py index 2ee2c956..0e80da9b 100644 --- a/backend/routes/__init__.py +++ b/backend/routes/__init__.py @@ -17,6 +17,7 @@ def register_routers(app: FastAPI) -> None: from .models import router as models_router from .tasks import router as tasks_router from .cuda import router as cuda_router + from .rocm import router as rocm_router app.include_router(health_router) app.include_router(profiles_router) @@ -30,3 +31,4 @@ def register_routers(app: FastAPI) -> None: app.include_router(models_router) app.include_router(tasks_router) app.include_router(cuda_router) + app.include_router(rocm_router) diff --git a/backend/routes/health.py b/backend/routes/health.py index 79c513f5..43272ec4 100644 --- a/backend/routes/health.py +++ b/backend/routes/health.py @@ -103,7 +103,10 @@ async def health(): gpu_type = None if has_cuda: - gpu_type = f"CUDA ({torch.cuda.get_device_name(0)})" + if hasattr(torch.version, "hip") and torch.version.hip: + gpu_type = f"ROCm ({torch.cuda.get_device_name(0)})" + else: + gpu_type = f"CUDA ({torch.cuda.get_device_name(0)})" elif has_mps: gpu_type = "MPS (Apple Silicon)" elif backend_type == "mlx": @@ -164,6 +167,15 @@ async def health(): except Exception: pass + default_variant = "cpu" + if has_cuda: + if hasattr(torch.version, "hip") and torch.version.hip: + default_variant = "rocm" + else: + default_variant = "cuda" + elif has_xpu: + default_variant = "xpu" + return models.HealthResponse( status="healthy", model_loaded=model_loaded, @@ -173,10 +185,7 @@ async def health(): gpu_type=gpu_type, vram_used_mb=vram_used, backend_type=backend_type, - backend_variant=os.environ.get( - "VOICEBOX_BACKEND_VARIANT", - "cuda" if torch.cuda.is_available() else ("xpu" if has_xpu else "cpu"), - ), + backend_variant=os.environ.get("VOICEBOX_BACKEND_VARIANT", default_variant), gpu_compatibility_warning=gpu_compat_warning, ) diff --git a/backend/routes/rocm.py b/backend/routes/rocm.py new file mode 100644 index 00000000..56e08ce8 --- /dev/null +++ b/backend/routes/rocm.py @@ -0,0 +1,79 @@ +"""ROCm backend management endpoints.""" + +import logging + +from fastapi import APIRouter, HTTPException +from fastapi.responses import StreamingResponse + +from ..services.task_queue import create_background_task +from ..utils.progress import get_progress_manager + +router = APIRouter() + +logger = logging.getLogger(__name__) + + +@router.get("/backend/rocm-status") +async def get_rocm_status(): + """Get ROCm backend download/availability status.""" + from ..services import rocm + + return rocm.get_rocm_status() + + +@router.post("/backend/download-rocm") +async def download_rocm_backend(): + """Download the ROCm backend binary.""" + from ..services import rocm + + progress_manager = get_progress_manager() + existing = progress_manager.get_progress(rocm.PROGRESS_KEY) + if existing and existing.get("status") in {"downloading", "extracting"}: + raise HTTPException(status_code=409, detail="ROCm backend download already in progress") + + async def _download(): + try: + await rocm.download_rocm_binary() + except Exception as e: + logger.error("ROCm download failed: %s", e) + + create_background_task(_download()) + return {"message": "ROCm backend download started", "progress_key": rocm.PROGRESS_KEY} + + +@router.delete("/backend/rocm") +async def delete_rocm_backend(): + """Delete the downloaded ROCm backend binary.""" + from ..services import rocm + + if rocm.is_rocm_active(): + raise HTTPException( + status_code=409, + detail="Cannot delete ROCm backend while it is active. Switch to CPU first.", + ) + + deleted = await rocm.delete_rocm_binary() + if not deleted: + raise HTTPException(status_code=404, detail="No ROCm backend found to delete") + + return {"message": "ROCm backend deleted"} + + +@router.get("/backend/rocm-progress") +async def get_rocm_download_progress(): + """Get ROCm backend download progress via Server-Sent Events.""" + progress_manager = get_progress_manager() + + async def event_generator(): + async for event in progress_manager.subscribe("rocm-backend"): + yield event + + return StreamingResponse( + event_generator(), + media_type="text/event-stream", + headers={ + "Cache-Control": "no-cache", + "Connection": "keep-alive", + "X-Accel-Buffering": "no", + }, + ) diff --git a/backend/server.py b/backend/server.py index 5f8cc0f6..047f4fbd 100644 --- a/backend/server.py +++ b/backend/server.py @@ -7,6 +7,7 @@ import sys import os +import re # On Windows with --noconsole (PyInstaller), sys.stdout/stderr are None. # They can also be broken file objects in some edge cases. @@ -47,6 +48,17 @@ def _is_writable(stream): print(f"voicebox-server {__version__}") sys.exit(0) +# Detect backend variant from binary name BEFORE importing backend modules +# so that env-var guards in app.py (e.g. HSA_OVERRIDE_GFX_VERSION) fire at import time. +_binary_name = os.path.basename(sys.executable).lower() +if re.search(r"voicebox-server-rocm(\.exe)?$", _binary_name): + os.environ["VOICEBOX_BACKEND_VARIANT"] = "rocm" +elif re.search(r"voicebox-server-cuda(\.exe)?$", _binary_name): + os.environ["VOICEBOX_BACKEND_VARIANT"] = "cuda" +else: + os.environ.setdefault("VOICEBOX_BACKEND_VARIANT", "cpu") + + import logging # Set up logging FIRST, before any imports that might fail @@ -260,16 +272,7 @@ def _watch(): if args.parent_pid is not None and args.parent_pid <= 0: parser.error("--parent-pid must be a positive integer") - # Detect backend variant from binary name - # voicebox-server-cuda → sets VOICEBOX_BACKEND_VARIANT=cuda - import os - binary_name = os.path.basename(sys.executable).lower() - if "cuda" in binary_name: - os.environ["VOICEBOX_BACKEND_VARIANT"] = "cuda" - logger.info("Backend variant: CUDA") - else: - os.environ["VOICEBOX_BACKEND_VARIANT"] = "cpu" - logger.info("Backend variant: CPU") + logger.info(f"Backend variant: {os.environ.get('VOICEBOX_BACKEND_VARIANT', 'cpu').upper()}") # Register parent watchdog to start after server is fully ready if args.parent_pid is not None: diff --git a/backend/services/rocm.py b/backend/services/rocm.py new file mode 100644 index 00000000..24a0f2f4 --- /dev/null +++ b/backend/services/rocm.py @@ -0,0 +1,465 @@ +""" +ROCm backend download, assembly, and verification. + +Downloads two archives from GitHub Releases: + 1. Server core (voicebox-server-rocm.tar.gz) — the exe + non-AMD deps, + versioned with the app. + 2. ROCm libs (rocm-libs-{version}.tar.gz) — AMD runtime libraries, + versioned independently (only redownloaded on ROCm toolkit bump). + +Both archives are extracted into {data_dir}/backends/rocm/ which forms the +complete PyInstaller --onedir directory structure that torch expects. +""" + +import asyncio +import hashlib +import json +import logging +import os +import shutil +import sys +import tarfile +from pathlib import Path +from typing import Optional + +from ..config import get_data_dir +from ..utils.progress import get_progress_manager +from .. import __version__ + +logger = logging.getLogger(__name__) + +GITHUB_RELEASES_URL = "https://github.com/jamiepine/voicebox/releases/download" + +PROGRESS_KEY = "rocm-backend" + +# The current expected ROCm libs version. Bump this when we change the +# ROCm toolkit version or torch's ROCm dependency changes (e.g. rocm7.2 -> rocm7.4). +ROCM_LIBS_VERSION = "rocm7.2-v1" + +# Prevents concurrent download_rocm_binary() calls from racing on the same +# temp file. The auto-update background task and the manual HTTP endpoint +# can both invoke download_rocm_binary(); without this lock the progress- +# manager status check is a TOCTOU race. +_download_lock = asyncio.Lock() + + +def get_backends_dir() -> Path: + """Directory where downloaded backend binaries are stored.""" + d = get_data_dir() / "backends" + d.mkdir(parents=True, exist_ok=True) + return d + + +def get_rocm_dir() -> Path: + """Directory where the ROCm backend (onedir) is extracted.""" + d = get_backends_dir() / "rocm" + d.mkdir(parents=True, exist_ok=True) + return d + + +def get_rocm_exe_name() -> str: + """Platform-specific ROCm executable filename.""" + if sys.platform == "win32": + return "voicebox-server-rocm.exe" + return "voicebox-server-rocm" + + +def get_rocm_binary_path() -> Optional[Path]: + """Return path to the ROCm executable if it exists inside the onedir.""" + p = get_rocm_dir() / get_rocm_exe_name() + if p.exists(): + return p + return None + + +def get_rocm_libs_manifest_path() -> Path: + """Path to the rocm-libs.json manifest inside the ROCm dir.""" + return get_rocm_dir() / "rocm-libs.json" + + +def get_installed_rocm_libs_version() -> Optional[str]: + """Read the installed ROCm libs version from rocm-libs.json, or None.""" + manifest_path = get_rocm_libs_manifest_path() + if not manifest_path.exists(): + return None + try: + data = json.loads(manifest_path.read_text()) + return data.get("version") + except Exception as e: + logger.warning(f"Could not read rocm-libs.json: {e}") + return None + + +def is_rocm_active() -> bool: + """Check if the current process is the ROCm binary. + + The ROCm binary sets this env var on startup (see server.py). + """ + return os.environ.get("VOICEBOX_BACKEND_VARIANT") == "rocm" + + +def get_rocm_status() -> dict: + """Get current ROCm backend status for the API.""" + progress_manager = get_progress_manager() + rocm_path = get_rocm_binary_path() + progress = progress_manager.get_progress(PROGRESS_KEY) + rocm_libs_version = get_installed_rocm_libs_version() + + return { + "available": rocm_path is not None, + "active": is_rocm_active(), + "binary_path": str(rocm_path) if rocm_path else None, + "rocm_libs_version": rocm_libs_version, + "downloading": progress is not None and progress.get("status") == "downloading", + "download_progress": progress, + } + + +def _needs_server_download(version: Optional[str] = None) -> bool: + """Check if the server core archive needs to be (re)downloaded.""" + rocm_path = get_rocm_binary_path() + if not rocm_path: + return True + # Check if the binary version matches the expected app version + installed = get_rocm_binary_version() + expected = version or __version__ + if expected.startswith("v"): + expected = expected[1:] + return installed != expected + + +def _needs_rocm_libs_download() -> bool: + """Check if the ROCm libs archive needs to be (re)downloaded.""" + installed = get_installed_rocm_libs_version() + if installed is None: + return True + return installed != ROCM_LIBS_VERSION + + +async def _download_and_extract_archive( + client, + url: str, + sha256_url: Optional[str], + dest_dir: Path, + label: str, + progress_offset: int, + total_size: int, +): + """Download a .tar.gz archive and extract it into dest_dir. + + Args: + client: httpx.AsyncClient + url: URL of the .tar.gz archive + sha256_url: URL of the .sha256 checksum file (optional) + dest_dir: Directory to extract into + label: Human-readable label for progress updates + progress_offset: Byte offset for progress reporting (when downloading + multiple archives sequentially) + total_size: Total bytes across all downloads (for progress bar) + """ + progress = get_progress_manager() + temp_path = dest_dir / f".download-{label.replace(' ', '-')}.tmp" + + # Clean up leftover partial download + if temp_path.exists(): + temp_path.unlink() + + # Fetch expected checksum (fail-fast: never extract an unverified archive) + expected_sha = None + if sha256_url: + try: + sha_resp = await client.get(sha256_url) + sha_resp.raise_for_status() + expected_sha = sha_resp.text.strip().split()[0] + logger.info(f"{label}: expected SHA-256: {expected_sha[:16]}...") + except Exception as e: + raise RuntimeError(f"{label}: failed to fetch checksum from {sha256_url}") from e + + # Stream download, verify, and extract — always clean up temp file + downloaded = 0 + try: + async with client.stream("GET", url) as response: + response.raise_for_status() + with open(temp_path, "wb") as f: + async for chunk in response.aiter_bytes(chunk_size=1024 * 1024): + f.write(chunk) + downloaded += len(chunk) + progress.update_progress( + PROGRESS_KEY, + current=progress_offset + downloaded, + total=total_size, + filename=f"Downloading {label}", + status="downloading", + ) + + # Verify integrity + if expected_sha: + progress.update_progress( + PROGRESS_KEY, + current=progress_offset + downloaded, + total=total_size, + filename=f"Verifying {label}...", + status="downloading", + ) + sha256 = hashlib.sha256() + with open(temp_path, "rb") as f: + while True: + data = f.read(1024 * 1024) + if not data: + break + sha256.update(data) + actual = sha256.hexdigest() + if actual != expected_sha: + raise ValueError( + f"{label} integrity check failed: expected {expected_sha[:16]}..., got {actual[:16]}..." + ) + logger.info(f"{label}: integrity verified") + + # Extract (use data filter for path traversal protection on Python 3.12+) + progress.update_progress( + PROGRESS_KEY, + current=progress_offset + downloaded, + total=total_size, + filename=f"Extracting {label}...", + status="downloading", + ) + with tarfile.open(temp_path, "r:gz") as tar: + tar.extractall(path=dest_dir, filter="data") + + logger.info(f"{label}: extracted to {dest_dir}") + finally: + if temp_path.exists(): + temp_path.unlink() + return downloaded + + +async def download_rocm_binary(version: Optional[str] = None): + """Download the ROCm backend (server core + ROCm libs if needed). + + Downloads both archives from GitHub Releases, extracts them into + {data_dir}/backends/rocm/, and writes the rocm-libs.json manifest. + + Only downloads what's needed: + - Server core: always redownloaded (versioned with app) + - ROCm libs: only if missing or version mismatch + + Args: + version: Version tag (e.g. "v0.3.0"). Defaults to current app version. + """ + if _download_lock.locked(): + logger.info("ROCm download already in progress, skipping duplicate request") + return + async with _download_lock: + await _download_rocm_binary_locked(version) + + +async def _download_rocm_binary_locked(version: Optional[str] = None): + """Inner implementation of download_rocm_binary, called under _download_lock.""" + import httpx + + if version is None: + version = f"v{__version__}" + + progress = get_progress_manager() + rocm_dir = get_rocm_dir() + + need_server = _needs_server_download(version) + need_libs = _needs_rocm_libs_download() + + if not need_server and not need_libs: + logger.info("ROCm backend is up to date, nothing to download") + return + + logger.info( + f"Starting ROCm backend download for {version} " + f"(server={'yes' if need_server else 'cached'}, " + f"libs={'yes' if need_libs else 'cached'})" + ) + progress.update_progress( + PROGRESS_KEY, + current=0, + total=0, + filename="Preparing download...", + status="downloading", + ) + + server_base_url = f"{GITHUB_RELEASES_URL}/{version}" + libs_base_url = f"{GITHUB_RELEASES_URL}/{ROCM_LIBS_VERSION}" + server_archive = "voicebox-server-rocm.tar.gz" + libs_archive = f"rocm-libs-{ROCM_LIBS_VERSION}.tar.gz" + + # Always stage when any download is needed, then atomically rename over + # rocm_dir on success. This prevents a failed mid-extraction from leaving + # rocm_dir in a partially-installed state that still passes the + # get_rocm_binary_path() existence check. Existing files are pre-copied + # into staging so partial updates (e.g. libs-only or server-only) preserve + # whatever isn't being re-downloaded. + use_staging = need_server or need_libs + staging_dir = get_backends_dir() / "rocm-staging" + + if use_staging: + if staging_dir.exists(): + shutil.rmtree(staging_dir) + staging_dir.mkdir(parents=True, exist_ok=True) + # Preserve existing files (server or libs) that don't need re-downloading. + # Extracted archives will overwrite only what we actually download. + if rocm_dir.exists(): + shutil.copytree(rocm_dir, staging_dir, dirs_exist_ok=True) + extract_dir = staging_dir + else: + extract_dir = rocm_dir + + try: + async with httpx.AsyncClient(follow_redirects=True, timeout=30.0) as client: + # Estimate total download size + total_size = 0 + if need_server: + try: + head = await client.head(f"{server_base_url}/{server_archive}") + total_size += int(head.headers.get("content-length", 0)) + except Exception: + pass + if need_libs: + try: + head = await client.head(f"{libs_base_url}/{libs_archive}") + total_size += int(head.headers.get("content-length", 0)) + except Exception: + pass + + logger.info(f"Total download size: {total_size / 1024 / 1024:.1f} MB") + + offset = 0 + + # Download server core + if need_server: + server_downloaded = await _download_and_extract_archive( + client, + url=f"{server_base_url}/{server_archive}", + sha256_url=f"{server_base_url}/{server_archive}.sha256", + dest_dir=extract_dir, + label="ROCm server", + progress_offset=offset, + total_size=total_size, + ) + offset += server_downloaded + + # Make executable on Unix + exe_path = extract_dir / get_rocm_exe_name() + if sys.platform != "win32" and exe_path.exists(): + exe_path.chmod(0o755) + + # Download ROCm libs + if need_libs: + await _download_and_extract_archive( + client, + url=f"{libs_base_url}/{libs_archive}", + sha256_url=f"{libs_base_url}/{libs_archive}.sha256", + dest_dir=extract_dir, + label="ROCm libraries", + progress_offset=offset, + total_size=total_size, + ) + + # Write local rocm-libs.json manifest + manifest = {"version": ROCM_LIBS_VERSION} + (extract_dir / "rocm-libs.json").write_text(json.dumps(manifest, indent=2) + "\n") + + # Atomic swap: replace rocm_dir with the fully-extracted staging dir + if use_staging: + backup_dir = get_backends_dir() / "rocm-backup" + if backup_dir.exists(): + shutil.rmtree(backup_dir) + if rocm_dir.exists(): + rocm_dir.rename(backup_dir) + try: + staging_dir.rename(rocm_dir) + except Exception: + if backup_dir.exists() and not rocm_dir.exists(): + backup_dir.rename(rocm_dir) + raise + else: + if backup_dir.exists(): + shutil.rmtree(backup_dir) + + logger.info(f"ROCm backend ready at {rocm_dir}") + progress.mark_complete(PROGRESS_KEY) + + except Exception as e: + if use_staging and staging_dir.exists(): + shutil.rmtree(staging_dir) + logger.error(f"ROCm backend download failed: {e}") + progress.mark_error(PROGRESS_KEY, str(e)) + raise + + +def get_rocm_binary_version() -> Optional[str]: + """Get the version of the installed ROCm binary, or None if not installed.""" + import subprocess + + rocm_path = get_rocm_binary_path() + if not rocm_path: + return None + try: + result = subprocess.run( + [str(rocm_path), "--version"], + capture_output=True, + text=True, + timeout=30, + cwd=str(rocm_path.parent), # Run from the onedir directory + ) + # Output format: "voicebox-server 0.3.0" + for line in result.stdout.strip().splitlines(): + if "voicebox-server" in line: + return line.split()[-1] + except Exception as e: + logger.warning(f"Could not get ROCm binary version: {e}") + return None + + +async def check_and_update_rocm_binary(): + """Check if the ROCm binary is outdated and auto-download if so. + + Called on server startup. Checks both server version and ROCm libs + version. Downloads only what's needed. + """ + rocm_path = get_rocm_binary_path() + if not rocm_path: + return # No ROCm binary installed, nothing to update + + if is_rocm_active(): + logger.info("ROCm backend is active; skipping auto-update to avoid replacing the running backend") + return + + need_server = _needs_server_download() + need_libs = _needs_rocm_libs_download() + + if not need_server and not need_libs: + logger.info(f"ROCm binary is up to date (server=v{__version__}, libs={get_installed_rocm_libs_version()})") + return + + reasons = [] + if need_server: + rocm_version = get_rocm_binary_version() + reasons.append(f"server v{rocm_version} != v{__version__}") + if need_libs: + installed_libs = get_installed_rocm_libs_version() + reasons.append(f"libs {installed_libs} != {ROCM_LIBS_VERSION}") + + logger.info(f"ROCm backend needs update ({', '.join(reasons)}). Auto-downloading...") + + try: + await download_rocm_binary() + except Exception as e: + logger.error(f"Auto-update of ROCm binary failed: {e}") + + +async def delete_rocm_binary() -> bool: + """Delete the downloaded ROCm backend directory. Returns True if deleted.""" + import shutil + + rocm_dir = get_rocm_dir() + if rocm_dir.exists() and any(rocm_dir.iterdir()): + shutil.rmtree(rocm_dir) + logger.info(f"Deleted ROCm backend directory: {rocm_dir}") + return True + return False diff --git a/backend/tests/test_amd_gpu_detect.py b/backend/tests/test_amd_gpu_detect.py new file mode 100644 index 00000000..95d82157 --- /dev/null +++ b/backend/tests/test_amd_gpu_detect.py @@ -0,0 +1,87 @@ +""" +Phase 2.1 Test: AMD GPU detection on Windows. + +Validates is_amd_gpu_windows() via mocked WMI and torch queries. + +Usage: + python -m pytest backend/tests/test_amd_gpu_detect.py -v +""" + +from unittest.mock import MagicMock, patch + +from backend.utils.platform_detect import is_amd_gpu_windows + + +class TestAmdGpuWindows: + """Unit tests for is_amd_gpu_windows with mocks.""" + + @patch("backend.utils.platform_detect.platform.system", return_value="Linux") + def test_returns_false_on_linux(self, _mock_system): + """Non-Windows platforms should always return False.""" + assert is_amd_gpu_windows() is False + + @patch("backend.utils.platform_detect.platform.system", return_value="Windows") + @patch( + "backend.utils.platform_detect.subprocess.run", + return_value=MagicMock(stdout="1\n", returncode=0), + ) + def test_detects_amd_via_wmi(self, _mock_run, _mock_system): + """WMI reporting an AMD adapter should return True.""" + assert is_amd_gpu_windows() is True + + @patch("backend.utils.platform_detect.platform.system", return_value="Windows") + @patch( + "backend.utils.platform_detect.subprocess.run", + return_value=MagicMock(stdout="0\n", returncode=0), + ) + def test_no_amd_via_wmi(self, _mock_run, _mock_system): + """WMI reporting zero AMD adapters should return False.""" + assert is_amd_gpu_windows() is False + + @patch("backend.utils.platform_detect.platform.system", return_value="Windows") + @patch( + "backend.utils.platform_detect.subprocess.run", + side_effect=Exception("WMI not available"), + ) + @patch("torch.cuda.is_available", return_value=True) + @patch( + "torch.cuda.get_device_name", + return_value="AMD Radeon RX 7800 XT", + ) + def test_fallback_to_torch_radeon(self, _mock_name, _mock_avail, _mock_run, _mock_system): + """When WMI fails, torch.cuda.get_device_name('Radeon') should return True.""" + assert is_amd_gpu_windows() is True + + @patch("backend.utils.platform_detect.platform.system", return_value="Windows") + @patch( + "backend.utils.platform_detect.subprocess.run", + side_effect=Exception("WMI not available"), + ) + @patch("torch.cuda.is_available", return_value=True) + @patch( + "torch.cuda.get_device_name", + return_value="NVIDIA GeForce RTX 4090", + ) + def test_fallback_to_torch_nvidia(self, _mock_name, _mock_avail, _mock_run, _mock_system): + """When WMI fails, torch.cuda.get_device_name('NVIDIA') should return False.""" + assert is_amd_gpu_windows() is False + + @patch("backend.utils.platform_detect.platform.system", return_value="Windows") + @patch( + "backend.utils.platform_detect.subprocess.run", + side_effect=Exception("WMI not available"), + ) + @patch("torch.cuda.is_available", return_value=False) + def test_no_torch_cuda(self, _mock_avail, _mock_run, _mock_system): + """When WMI fails and torch.cuda is unavailable, should return False.""" + assert is_amd_gpu_windows() is False + + @patch("backend.utils.platform_detect.platform.system", return_value="Windows") + @patch( + "backend.utils.platform_detect.subprocess.run", + side_effect=Exception("WMI not available"), + ) + def test_torch_not_installed(self, _mock_run, _mock_system): + """When torch is not installed, should return False without crashing.""" + with patch.dict("sys.modules", {"torch": None}): + assert is_amd_gpu_windows() is False diff --git a/backend/tests/test_rocm_backends.py b/backend/tests/test_rocm_backends.py new file mode 100644 index 00000000..ee69de66 --- /dev/null +++ b/backend/tests/test_rocm_backends.py @@ -0,0 +1,68 @@ +""" +Phase 2.2 Test: Backend ROCm compatibility. + +Validates that check_cuda_compatibility() and other backend utilities +behave correctly on ROCm/AMD hardware. + +Usage: + python -m pytest backend/tests/test_rocm_backends.py -v +""" + +from unittest.mock import patch + +import pytest + + +class TestCheckCudaCompatibility: + """Unit tests for check_cuda_compatibility with ROCm awareness.""" + + def test_no_gpu_returns_compatible(self): + from backend.backends.base import check_cuda_compatibility + + with patch("torch.cuda.is_available", return_value=False): + compatible, warning = check_cuda_compatibility() + assert compatible is True + assert warning is None + + def test_rocm_skips_compute_check(self): + """On ROCm, the NVIDIA compute-capability check should be skipped.""" + from backend.backends.base import check_cuda_compatibility + + with patch("torch.cuda.is_available", return_value=True): + with patch("torch.version.hip", "6.2.41133"): + compatible, warning = check_cuda_compatibility() + assert compatible is True + assert warning is None + + def test_cuda_compatible_arch(self): + from backend.backends.base import check_cuda_compatibility + + with patch("torch.cuda.is_available", return_value=True): + with patch("torch.version.hip", None): + with patch("torch.cuda.get_device_capability", return_value=(8, 6)): + with patch("torch.cuda.get_device_name", return_value="NVIDIA GeForce RTX 3060"): + with patch.object( + __import__("torch").cuda, "_get_arch_list", + return_value=["sm_80", "sm_86", "sm_89"], + create=True, + ): + compatible, warning = check_cuda_compatibility() + assert compatible is True + assert warning is None + + def test_cuda_incompatible_arch(self): + from backend.backends.base import check_cuda_compatibility + + with patch("torch.cuda.is_available", return_value=True): + with patch("torch.version.hip", None): + with patch("torch.cuda.get_device_capability", return_value=(9, 0)): + with patch("torch.cuda.get_device_name", return_value="NVIDIA GeForce RTX 4090"): + with patch.object( + __import__("torch").cuda, "_get_arch_list", + return_value=["sm_80", "sm_86"], + create=True, + ): + compatible, warning = check_cuda_compatibility() + assert compatible is False + assert warning is not None + assert "not supported" in warning diff --git a/backend/tests/test_rocm_build.py b/backend/tests/test_rocm_build.py new file mode 100644 index 00000000..2e8d4e0c --- /dev/null +++ b/backend/tests/test_rocm_build.py @@ -0,0 +1,129 @@ +""" +Phase 1.2 Test: ROCm build script configuration. + +Validates that build_binary.py --rocm generates the correct PyInstaller +arguments and optionally performs a true E2E build. + +Usage: + python -m pytest backend/tests/test_rocm_build.py -v + python -m pytest backend/tests/test_rocm_build.py -v -m "slow" # include E2E +""" + +import subprocess +import sys +from pathlib import Path +from unittest.mock import patch + +import pytest + +from build_binary import build_server + + +class TestRocmBuildArgs: + """Validate PyInstaller arguments for ROCm builds.""" + + @pytest.fixture + def captured_args(self): + """Run build_server(rocm=True) with mocked PyInstaller and return args.""" + with ( + patch("build_binary.PyInstaller.__main__.run") as mock_run, + patch("build_binary.platform.system", return_value="Linux"), + patch("build_binary.os.chdir"), + ): + build_server(rocm=True) + return mock_run.call_args[0][0] + + def test_binary_name(self, captured_args): + idx = captured_args.index("--name") + assert captured_args[idx + 1] == "voicebox-server-rocm" + + def test_pack_mode_is_onedir(self, captured_args): + assert "--onedir" in captured_args + assert "--onefile" not in captured_args + + def test_hidden_imports_cuda(self, captured_args): + """ROCm builds must include torch.cuda hidden imports.""" + assert "torch.cuda" in captured_args + + def test_no_cudnn_hidden_import_for_rocm(self, captured_args): + """ROCm builds must NOT include NVIDIA-specific cudnn hidden imports.""" + assert "torch.backends.cudnn" not in captured_args + + def test_nvidia_excludes_present(self, captured_args): + """ROCm builds must exclude nvidia packages to avoid bundling ~3GB of bloat.""" + excludes = [] + for i, arg in enumerate(captured_args): + if arg == "--exclude-module": + excludes.append(captured_args[i + 1]) + assert "nvidia" in excludes + assert "nvidia.cudnn" in excludes + + +class TestRocmBuildCli: + """Validate CLI argument parsing for --rocm.""" + + def test_rocm_flag_parses(self): + build_script = Path(__file__).parent.parent / "build_binary.py" + result = subprocess.run( + [sys.executable, str(build_script), "--rocm", "--help"], + capture_output=True, + text=True, + ) + assert result.returncode == 0 + assert "--rocm" in result.stdout + + def test_cannot_combine_cuda_and_rocm(self): + """Building with both CUDA and ROCm should raise ValueError.""" + with pytest.raises(ValueError, match="Cannot build with both CUDA and ROCm"): + build_server(cuda=True, rocm=True) + + +@pytest.mark.slow() +@pytest.mark.skipif(sys.platform != "win32", reason="ROCm build E2E only runs on Windows") +class TestRocmBuildE2E: + """ + True end-to-end build test. + Executes build_binary.py --rocm, verifies the binary exists, and runs it + with --help to confirm it boots without import errors. + """ + + def test_rocm_binary_compiles_and_runs(self, tmp_path): + backend_dir = Path(__file__).parent.parent + build_script = backend_dir / "build_binary.py" + dist_dir = backend_dir / "dist" + binary_dir = dist_dir / "voicebox-server-rocm" + binary_exe = binary_dir / "voicebox-server-rocm.exe" + + # Clean previous dist if it exists to ensure a fresh build + if binary_dir.exists(): + import shutil + shutil.rmtree(binary_dir) + + # Run the full build (this can take several minutes) + result = subprocess.run( + [sys.executable, str(build_script), "--rocm"], + capture_output=True, + text=True, + cwd=str(backend_dir), + timeout=900, + ) + + assert result.returncode == 0, ( + f"Build failed with stdout:\n{result.stdout}\nstderr:\n{result.stderr}" + ) + assert binary_exe.exists(), ( + f"Expected binary not found at {binary_exe}" + ) + + # Run the binary with --help to ensure it boots without import errors + run_result = subprocess.run( + [str(binary_exe), "--help"], + capture_output=True, + text=True, + timeout=60, + ) + + # A frozen binary may not have argparse help, but it should not crash + # with a ModuleNotFoundError or similar import error. + assert "ModuleNotFoundError" not in run_result.stderr + assert "ImportError" not in run_result.stderr diff --git a/backend/tests/test_rocm_download.py b/backend/tests/test_rocm_download.py new file mode 100644 index 00000000..174ed1a9 --- /dev/null +++ b/backend/tests/test_rocm_download.py @@ -0,0 +1,203 @@ +""" +Tests for the ROCm backend download service. + +Mocks httpx to verify download, extraction, and progress reporting +without hitting the network. +""" + +import json +import tarfile +import tempfile +from io import BytesIO +from pathlib import Path +from unittest.mock import AsyncMock, MagicMock, patch + +import pytest + +from backend.services import rocm +from backend.utils.progress import get_progress_manager + + +@pytest.fixture(autouse=True) +def reset_progress_manager(): + """Reset the global progress manager before each test.""" + import backend.utils.progress + backend.utils.progress._progress_manager = None + yield + backend.utils.progress._progress_manager = None + + +@pytest.fixture +def mock_backends_dir(tmp_path: Path, monkeypatch): + """Patch get_data_dir so downloads land in a temp directory.""" + monkeypatch.setattr(rocm, "get_backends_dir", lambda: tmp_path / "backends") + return tmp_path / "backends" + + +@pytest.fixture +def fake_tar_gz(): + """Create an in-memory .tar.gz archive containing a dummy file.""" + buf = BytesIO() + with tarfile.open(fileobj=buf, mode="w:gz") as tar: + data = b"fake binary content" + info = tarfile.TarInfo(name="voicebox-server-rocm.exe") + info.size = len(data) + tar.addfile(info, BytesIO(data)) + buf.seek(0) + return buf.read() + + +@pytest.fixture +def fake_sha256(): + """Return a dummy SHA-256 hex string.""" + return "a" * 64 + + +class FakeResponse: + """Minimal fake for httpx.Response.""" + + def __init__(self, content: bytes = b"", status_code: int = 200, headers: dict | None = None): + self.content = content + self.status_code = status_code + self.headers = headers or {} + + def raise_for_status(self): + if self.status_code >= 400: + raise Exception(f"HTTP {self.status_code}") + + def iter_bytes(self, chunk_size: int = 1024): + for i in range(0, len(self.content), chunk_size): + yield self.content[i : i + chunk_size] + + async def aiter_bytes(self, chunk_size: int = 1024): + for i in range(0, len(self.content), chunk_size): + yield self.content[i : i + chunk_size] + + @property + def text(self): + return self.content.decode() + + +class FakeHttpxClient: + """Minimal fake for httpx.AsyncClient.""" + + def __init__(self, responses: dict[str, FakeResponse]): + self._responses = responses + + async def __aenter__(self): + return self + + async def __aexit__(self, *args): + return False + + async def head(self, url: str): + return self._responses.get(url, FakeResponse(status_code=404)) + + async def get(self, url: str): + return self._responses.get(url, FakeResponse(status_code=404)) + + def stream(self, method: str, url: str): + resp = self._responses.get(url, FakeResponse(status_code=404)) + resp.raise_for_status() + + class _Streamer: + async def __aenter__(self): + return resp + + async def __aexit__(self, *args): + return False + + async def aiter_bytes(self, chunk_size: int = 1024): + for i in range(0, len(resp.content), chunk_size): + yield resp.content[i : i + chunk_size] + + return _Streamer() + + +@pytest.mark.asyncio +async def test_get_rocm_status_not_installed(mock_backends_dir): + status = rocm.get_rocm_status() + assert status["available"] is False + assert status["active"] is False + assert status["binary_path"] is None + assert status["downloading"] is False + + +@pytest.mark.asyncio +async def test_download_rocm_binary_progress_reporting(mock_backends_dir, fake_tar_gz, fake_sha256): + """ + Verify that download_rocm_binary(): + 1. Downloads the server archive and ROCm libs archive. + 2. Extracts them into the backends/rocm directory. + 3. Reports progress via the progress_manager. + """ + import hashlib + + server_sha = hashlib.sha256(fake_tar_gz).hexdigest() + libs_sha = hashlib.sha256(fake_tar_gz).hexdigest() + + responses = { + "https://github.com/jamiepine/voicebox/releases/download/v0.2.3/voicebox-server-rocm.tar.gz": FakeResponse( + content=fake_tar_gz, + headers={"content-length": str(len(fake_tar_gz))}, + ), + "https://github.com/jamiepine/voicebox/releases/download/v0.2.3/voicebox-server-rocm.tar.gz.sha256": FakeResponse( + content=f"{server_sha} voicebox-server-rocm.tar.gz\n".encode(), + ), + f"https://github.com/jamiepine/voicebox/releases/download/v0.2.3/rocm-libs-{rocm.ROCM_LIBS_VERSION}.tar.gz": FakeResponse( + content=fake_tar_gz, + headers={"content-length": str(len(fake_tar_gz))}, + ), + f"https://github.com/jamiepine/voicebox/releases/download/v0.2.3/rocm-libs-{rocm.ROCM_LIBS_VERSION}.tar.gz.sha256": FakeResponse( + content=f"{libs_sha} rocm-libs.tar.gz\n".encode(), + ), + } + + fake_client = FakeHttpxClient(responses) + + with patch("httpx.AsyncClient", return_value=fake_client): + await rocm.download_rocm_binary(version="v0.2.3") + + # Verify extraction + rocm_dir = rocm.get_rocm_dir() + assert (rocm_dir / "voicebox-server-rocm.exe").exists() + + # Verify manifest written + manifest_path = rocm.get_rocm_libs_manifest_path() + assert manifest_path.exists() + data = json.loads(manifest_path.read_text()) + assert data["version"] == rocm.ROCM_LIBS_VERSION + + # Verify progress was reported + progress = get_progress_manager().get_progress("rocm-backend") + assert progress is not None + assert progress["status"] == "complete" + assert progress["progress"] == 100.0 + + +@pytest.mark.asyncio +async def test_is_rocm_active(mock_backends_dir, monkeypatch): + monkeypatch.setenv("VOICEBOX_BACKEND_VARIANT", "rocm") + assert rocm.is_rocm_active() is True + + monkeypatch.setenv("VOICEBOX_BACKEND_VARIANT", "cpu") + assert rocm.is_rocm_active() is False + + monkeypatch.delenv("VOICEBOX_BACKEND_VARIANT", raising=False) + assert rocm.is_rocm_active() is False + + +@pytest.mark.asyncio +async def test_delete_rocm_binary(mock_backends_dir, fake_tar_gz): + """Test deleting the ROCm backend directory.""" + rocm_dir = rocm.get_rocm_dir() + rocm_dir.mkdir(parents=True, exist_ok=True) + (rocm_dir / "dummy.txt").write_text("hello") + + result = await rocm.delete_rocm_binary() + assert result is True + assert not rocm_dir.exists() + + # Deleting again should return False + result = await rocm.delete_rocm_binary() + assert result is False diff --git a/backend/tests/test_rocm_requirements.py b/backend/tests/test_rocm_requirements.py new file mode 100644 index 00000000..a03572ba --- /dev/null +++ b/backend/tests/test_rocm_requirements.py @@ -0,0 +1,130 @@ +""" +Phase 1.1 Test: ROCm requirements installation. + +Validates that requirements-rocm.txt correctly installs ROCm-enabled PyTorch +and that torch.cuda.is_available() returns True on AMD hardware. + +Usage: + python -m pytest backend/tests/test_rocm_requirements.py -v +""" + +import os +import platform +import subprocess +import sys +import tempfile +from pathlib import Path + +import pytest + + +def _has_amd_hardware(): + """Check if AMD GPU hardware is present on Windows.""" + if platform.system() != "Windows": + return False + try: + result = subprocess.run( + [ + "powershell", + "-Command", + "Get-WmiObject Win32_VideoController | " + "Where-Object {$_.AdapterCompatibility -like '*AMD*'} | " + "Measure-Object | Select-Object -ExpandProperty Count", + ], + capture_output=True, + text=True, + check=True, + ) + return int(result.stdout.strip()) > 0 + except Exception: + return False + + +@pytest.fixture() +def backend_dir(): + return Path(__file__).parent.parent + + +class TestRocmRequirements: + """Validate requirements-rocm.txt content and installation.""" + + def test_requirements_file_exists(self, backend_dir): + req_file = backend_dir / "requirements-rocm.txt" + assert req_file.exists(), "requirements-rocm.txt must exist" + + def test_requirements_file_content(self, backend_dir): + import re + req_file = backend_dir / "requirements-rocm.txt" + content = req_file.read_text() + assert "rocm7.2" in content, "Must point to ROCm 7.2 extra index" + # Parse exact package names to avoid false positives from URL substrings + package_names = re.findall(r"^([A-Za-z][A-Za-z0-9_-]*)", content, re.MULTILINE) + assert "torch" in package_names, "Must include torch package" + assert "torchaudio" in package_names, "Must include torchaudio package" + assert "torchvision" in package_names, "Must include torchvision package" + + @pytest.mark.timeout(900) + @pytest.mark.skipif( + not os.environ.get("VOICEBOX_TEST_ROCM_INSTALL"), + reason="Set VOICEBOX_TEST_ROCM_INSTALL=1 to run the heavy install test", + ) + def test_rocm_torch_installs_and_detects_amd(self, backend_dir): + """ + Create a temporary venv, install requirements-rocm.txt, and verify + torch.cuda.is_available() returns True on AMD hardware. + """ + req_file = backend_dir / "requirements-rocm.txt" + has_amd = _has_amd_hardware() + + with tempfile.TemporaryDirectory() as tmpdir: + venv_dir = Path(tmpdir) / "venv" + subprocess.run( + [sys.executable, "-m", "venv", str(venv_dir)], + check=True, + ) + + if sys.platform == "win32": + venv_python = venv_dir / "Scripts" / "python.exe" + else: + venv_python = venv_dir / "bin" / "python" + + # Upgrade pip to avoid resolver issues + subprocess.run( + [str(venv_python), "-m", "pip", "install", "--upgrade", "pip"], + check=True, + ) + + # Install ROCm requirements + subprocess.run( + [str(venv_python), "-m", "pip", "install", "-r", str(req_file)], + check=True, + ) + + # Verify torch imports and cuda availability + result = subprocess.run( + [ + str(venv_python), + "-c", + "import torch; print(torch.__version__); print(torch.cuda.is_available())", + ], + capture_output=True, + text=True, + check=True, + ) + + lines = result.stdout.strip().splitlines() + assert len(lines) >= 2, f"Unexpected output: {result.stdout}" + torch_version = lines[0] + cuda_available = lines[1] == "True" + + # The honest test: on AMD hardware ROCm torch should report cuda available + if has_amd: + assert cuda_available, ( + f"AMD hardware detected but torch.cuda.is_available() returned False. " + f"torch version: {torch_version}, stderr: {result.stderr}" + ) + else: + assert not cuda_available, ( + f"No AMD hardware detected but torch.cuda.is_available() returned True. " + f"torch version: {torch_version}" + ) diff --git a/backend/utils/platform_detect.py b/backend/utils/platform_detect.py index 1ec2980a..2479137e 100644 --- a/backend/utils/platform_detect.py +++ b/backend/utils/platform_detect.py @@ -3,19 +3,67 @@ """ import platform +import subprocess from typing import Literal def is_apple_silicon() -> bool: """ Check if running on Apple Silicon (arm64 macOS). - + Returns: True if on Apple Silicon, False otherwise """ return platform.system() == "Darwin" and platform.machine() == "arm64" +def is_amd_gpu_windows() -> bool: + """ + Check if the primary GPU on Windows is an AMD Radeon card. + + Uses WMI to query Win32_VideoController, with a fallback to + torch.cuda.get_device_name(0) if WMI is unavailable. This is + useful for deciding whether the ROCm backend is appropriate. + + Returns: + True if an AMD GPU is detected on Windows, False otherwise. + """ + if platform.system() != "Windows": + return False + + # Primary method: WMI query for AMD adapters + try: + result = subprocess.run( + [ + "powershell", + "-Command", + "Get-CimInstance Win32_VideoController | " + "Where-Object {$_.AdapterCompatibility -like '*AMD*'} | " + "Measure-Object | Select-Object -ExpandProperty Count", + ], + capture_output=True, + text=True, + check=True, + ) + if int(result.stdout.strip()) > 0: + return True + except Exception: + pass + + # Fallback: torch.cuda.get_device_name(0) (works for ROCm/HIP too) + try: + import torch + + if torch.cuda.is_available(): + name = torch.cuda.get_device_name(0) + if "Radeon" in name or "AMD" in name: + return True + except Exception: + pass + + return False + + def get_backend_type() -> Literal["mlx", "pytorch"]: """ Detect the best backend for the current platform. diff --git a/tauri/src-tauri/src/main.rs b/tauri/src-tauri/src/main.rs index ca0cdf07..e0cbb058 100644 --- a/tauri/src-tauri/src/main.rs +++ b/tauri/src-tauri/src/main.rs @@ -93,6 +93,34 @@ struct ServerState { server_pid: Mutex>, keep_running_on_close: Mutex, models_dir: Mutex>, + /// Override the backend selection. When set to Some("cpu"), forces the CPU + /// sidecar even if GPU binaries exist on disk. This solves the Windows + /// catch-22 where an active .exe cannot be deleted. + backend_override: Mutex>, +} + +/// Run ` --version` with a 10-second timeout to avoid hanging Tauri startup. +/// Returns the last whitespace-delimited token from stdout (e.g. "0.4.4"), or None on any failure. +async fn probe_binary_version(exe: &std::path::Path, cwd: &std::path::Path) -> Option { + let mut cmd = tokio::process::Command::new(exe); + cmd.arg("--version") + .current_dir(cwd) + .kill_on_drop(true); + + match tokio::time::timeout(std::time::Duration::from_secs(10), cmd.output()).await { + Ok(Ok(output)) => { + let s = String::from_utf8_lossy(&output.stdout); + s.trim().split_whitespace().last().map(String::from) + } + Ok(Err(e)) => { + println!("Version probe failed: {}", e); + None + } + Err(_) => { + println!("Version probe timed out after 10s"); + None + } + } } #[command] @@ -253,6 +281,45 @@ async fn start_server( println!("Data directory: {:?}", data_dir); println!("Remote mode: {}", remote.unwrap_or(false)); + // Check for ROCm backend in data directory (onedir layout: backends/rocm/) + let rocm_binary = { + let rocm_dir = data_dir.join("backends").join("rocm"); + let rocm_name = if cfg!(windows) { + "voicebox-server-rocm.exe" + } else { + "voicebox-server-rocm" + }; + let exe_path = rocm_dir.join(rocm_name); + if exe_path.exists() { + println!("Found ROCm backend at {:?}", rocm_dir); + + let app_version = app.config().version.clone().unwrap_or_default(); + let binary_version = probe_binary_version(&exe_path, &rocm_dir).await; + let version_ok = if !app_version.is_empty() + && binary_version.as_deref() == Some(app_version.as_str()) + { + println!("ROCm binary version {} matches app version", app_version); + true + } else { + println!( + "ROCm binary version mismatch: binary={}, app={}. Falling back to CPU.", + binary_version.as_deref().unwrap_or(""), + app_version + ); + false + }; + + if version_ok { + Some(exe_path) + } else { + None + } + } else { + println!("No ROCm backend found"); + None + } + }; + // Check for CUDA backend in data directory (onedir layout: backends/cuda/) let cuda_binary = { let cuda_dir = data_dir.join("backends").join("cuda"); @@ -268,30 +335,19 @@ async fn start_server( // Version check: run --version from the onedir directory so // PyInstaller can find its support files for the fast --version path let app_version = app.config().version.clone().unwrap_or_default(); - let version_ok = match std::process::Command::new(&exe_path) - .arg("--version") - .current_dir(&cuda_dir) - .output() + let binary_version = probe_binary_version(&exe_path, &cuda_dir).await; + let version_ok = if !app_version.is_empty() + && binary_version.as_deref() == Some(app_version.as_str()) { - Ok(output) => { - // Output format: "voicebox-server X.Y.Z\n" - let version_str = String::from_utf8_lossy(&output.stdout); - let binary_version = version_str.trim().split_whitespace().last().unwrap_or(""); - if binary_version == app_version { - println!("CUDA binary version {} matches app version", binary_version); - true - } else { - println!( - "CUDA binary version mismatch: binary={}, app={}. Falling back to CPU.", - binary_version, app_version - ); - false - } - } - Err(e) => { - println!("Failed to check CUDA binary version: {}. Falling back to CPU.", e); - false - } + println!("CUDA binary version {} matches app version", app_version); + true + } else { + println!( + "CUDA binary version mismatch: binary={}, app={}. Falling back to CPU.", + binary_version.as_deref().unwrap_or(""), + app_version + ); + false }; if version_ok { @@ -358,24 +414,59 @@ async fn start_server( println!("Custom models directory: {}", dir); } + // Respect backend override (e.g., user wants CPU even though GPU binary exists) + let backend_override = state.backend_override.lock().unwrap().clone(); + + // If ROCm binary exists, launch it from the onedir directory. // If CUDA binary exists, launch it from the onedir directory. // .current_dir() is critical: PyInstaller onedir expects all DLLs and - // support files (nvidia/, _internal/, etc.) relative to the exe. - let spawn_result = if let Some(ref cuda_path) = cuda_binary { - let cuda_dir = cuda_path.parent().unwrap(); - println!("Launching CUDA backend: {:?} (cwd: {:?})", cuda_path, cuda_dir); - let mut cmd = app.shell().command(cuda_path.to_str().unwrap()); - cmd = cmd.current_dir(cuda_dir); - cmd = cmd.args(["--data-dir", &data_dir_str, "--port", &port_str, "--parent-pid", &parent_pid_str]); - if is_remote { - cmd = cmd.args(["--host", "0.0.0.0"]); + // support files relative to the exe. + let spawn_result = if backend_override.as_deref() != Some("cpu") { + let mut gpu_spawn = None; + + if let Some(ref rocm_path) = rocm_binary { + let rocm_dir = rocm_path.parent().unwrap(); + println!("Launching ROCm backend: {:?} (cwd: {:?})", rocm_path, rocm_dir); + let mut cmd = app.shell().command(rocm_path.to_str().unwrap()); + cmd = cmd.current_dir(rocm_dir); + cmd = cmd.args(["--data-dir", &data_dir_str, "--port", &port_str, "--parent-pid", &parent_pid_str]); + if is_remote { cmd = cmd.args(["--host", "0.0.0.0"]); } + if let Some(ref dir) = effective_models_dir { cmd = cmd.env("VOICEBOX_MODELS_DIR", dir); } + match cmd.spawn() { + Ok(r) => { gpu_spawn = Some(Ok(r)); } + Err(e) => { println!("ROCm spawn failed ({}), trying CUDA/CPU fallback", e); } + } } - if let Some(ref dir) = effective_models_dir { - cmd = cmd.env("VOICEBOX_MODELS_DIR", dir); + + if gpu_spawn.is_none() { + if let Some(ref cuda_path) = cuda_binary { + let cuda_dir = cuda_path.parent().unwrap(); + println!("Launching CUDA backend: {:?} (cwd: {:?})", cuda_path, cuda_dir); + let mut cmd = app.shell().command(cuda_path.to_str().unwrap()); + cmd = cmd.current_dir(cuda_dir); + cmd = cmd.args(["--data-dir", &data_dir_str, "--port", &port_str, "--parent-pid", &parent_pid_str]); + if is_remote { cmd = cmd.args(["--host", "0.0.0.0"]); } + if let Some(ref dir) = effective_models_dir { cmd = cmd.env("VOICEBOX_MODELS_DIR", dir); } + match cmd.spawn() { + Ok(r) => { gpu_spawn = Some(Ok(r)); } + Err(e) => { println!("CUDA spawn failed ({}), falling back to CPU", e); } + } + } + } + + if let Some(result) = gpu_spawn { + result + } else { + // Fall back to bundled CPU sidecar + sidecar = sidecar.args(["--data-dir", &data_dir_str, "--port", &port_str, "--parent-pid", &parent_pid_str]); + if is_remote { sidecar = sidecar.args(["--host", "0.0.0.0"]); } + if let Some(ref dir) = effective_models_dir { sidecar = sidecar.env("VOICEBOX_MODELS_DIR", dir); } + println!("Spawning bundled CPU server process..."); + sidecar.spawn() } - cmd.spawn() } else { - // Use the bundled CPU sidecar + // Override forces CPU — use bundled sidecar, GPU binary stays on disk + println!("Backend override=cpu: using bundled CPU sidecar"); sidecar = sidecar.args(["--data-dir", &data_dir_str, "--port", &port_str, "--parent-pid", &parent_pid_str]); if is_remote { sidecar = sidecar.args(["--host", "0.0.0.0"]); @@ -383,7 +474,6 @@ async fn start_server( if let Some(ref dir) = effective_models_dir { sidecar = sidecar.env("VOICEBOX_MODELS_DIR", dir); } - println!("Spawning server process..."); sidecar.spawn() }; @@ -655,9 +745,9 @@ async fn restart_server( println!("restart_server: waiting for port release..."); tokio::time::sleep(tokio::time::Duration::from_millis(1000)).await; - // Start server again (will auto-detect CUDA binary and use stored models_dir) + // Start server again (will auto-detect GPU binary and use stored models_dir) println!("restart_server: starting server..."); - start_server(app, state, None, None).await + start_server(app, state.clone(), None, None).await } #[command] @@ -666,6 +756,12 @@ fn set_keep_server_running(state: State<'_, ServerState>, keep_running: bool) { *state.keep_running_on_close.lock().unwrap() = keep_running; } +#[command] +fn set_backend_override(state: State<'_, ServerState>, backend: Option) { + println!("set_backend_override called with: {:?}", backend); + *state.backend_override.lock().unwrap() = backend; +} + #[command] async fn start_system_audio_capture( state: State<'_, audio_capture::AudioCaptureState>, @@ -720,6 +816,7 @@ pub fn run() { server_pid: Mutex::new(None), keep_running_on_close: Mutex::new(false), models_dir: Mutex::new(None), + backend_override: Mutex::new(None), }) .manage(audio_capture::AudioCaptureState::new()) .manage(audio_output::AudioOutputState::new()) @@ -792,6 +889,7 @@ pub fn run() { stop_server, restart_server, set_keep_server_running, + set_backend_override, start_system_audio_capture, stop_system_audio_capture, is_system_audio_supported, diff --git a/tauri/src/platform/lifecycle.ts b/tauri/src/platform/lifecycle.ts index b20da778..249f7e57 100644 --- a/tauri/src/platform/lifecycle.ts +++ b/tauri/src/platform/lifecycle.ts @@ -52,6 +52,15 @@ class TauriLifecycle implements PlatformLifecycle { } } + async setBackendOverride(backend?: string | null): Promise { + try { + await invoke('set_backend_override', { backend: backend ?? undefined }); + } catch (error) { + console.error('Failed to set backend override:', error); + throw error; + } + } + async setupWindowCloseHandler(): Promise { try { // Listen for window close request from Rust