Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
64 changes: 62 additions & 2 deletions backend_service/helpers/gpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,14 @@ def _snapshot_macos(self) -> dict[str, Any]:
# ------------------------------------------------------------------

def _snapshot_nvidia(self) -> dict[str, Any]:
# Try torch.cuda first — when the GPU bundle is installed it reads
# the right total VRAM via the CUDA driver without shelling out,
# and works even if ``nvidia-smi`` isn't on PATH (common on Windows
# when the user installs the driver but not the CUDA toolkit).
torch_snapshot = self._snapshot_torch_cuda()
if torch_snapshot is not None:
return torch_snapshot

try:
out = subprocess.check_output(
[
Expand All @@ -130,8 +138,60 @@ def _snapshot_nvidia(self) -> dict[str, Any]:
except (FileNotFoundError, subprocess.SubprocessError, ValueError):
pass

# Fallback: system RAM via psutil
return self._fallback_psutil()
# No GPU detected — return a None-VRAM dict rather than reporting
# system RAM as if it were VRAM. The image / video safety
# estimators downstream treat ``vram_total_gb is None`` as
# "unknown" and skip the crash warning, which is the correct
# behaviour when we genuinely don't know the card's capacity.
return self._no_gpu_detected()

def _snapshot_torch_cuda(self) -> dict[str, Any] | None:
"""Read total + used VRAM from torch.cuda when available.

Returns ``None`` if torch isn't importable, has no CUDA build, or
no CUDA device is currently visible (driver missing, GPU
passthrough disabled, etc.). The caller then falls through to
``nvidia-smi``.

Importing torch is heavy (~200ms first time) but the result is
cached one level up by ``get_device_vram_total_gb``, so the cost
is paid at most once per backend session.
"""
try:
import torch # type: ignore
except Exception:
return None
try:
if not torch.cuda.is_available():
return None
device = torch.cuda.current_device()
props = torch.cuda.get_device_properties(device)
total_bytes = int(props.total_memory)
try:
free_bytes, _ = torch.cuda.mem_get_info(device)
used_bytes = max(0, total_bytes - int(free_bytes))
except Exception:
used_bytes = 0
return {
"gpu_name": props.name,
"vram_total_gb": round(total_bytes / (1024 ** 3), 2),
"vram_used_gb": round(used_bytes / (1024 ** 3), 2),
"utilization_pct": None,
"temperature_c": None,
"power_w": None,
}
except Exception:
return None

def _no_gpu_detected(self) -> dict[str, Any]:
return {
"gpu_name": "No GPU detected",
"vram_total_gb": None,
"vram_used_gb": None,
"utilization_pct": None,
"temperature_c": None,
"power_w": None,
}

# ------------------------------------------------------------------
# Fallback
Expand Down
11 changes: 11 additions & 0 deletions backend_service/image_runtime.py
Original file line number Diff line number Diff line change
Expand Up @@ -537,6 +537,17 @@ def probe(self) -> ImageRuntimeStatus:
# find_spec answers "is it installable?" without triggering the
# import side effects. Device detection (cuda vs cpu) is deferred
# to preload/generate where we're about to import torch anyway.
#
# ``invalidate_caches`` matters when the GPU bundle install has
# finished mid-process: pip writes the new packages into the
# extras dir (already on ``sys.path`` from process start), but
# ``importlib`` keeps a per-finder cache of negative lookups, so
# the find_spec calls below would still report None even though
# the .dist-info folders are sitting on disk. Calling
# ``invalidate_caches`` first re-walks the path entries so the
# newly installed packages are picked up without a process
# restart.
importlib.invalidate_caches()
missing = [
package
for package, module_name in (
Expand Down
17 changes: 17 additions & 0 deletions backend_service/routes/setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -1067,6 +1067,23 @@ def _gpu_bundle_job_worker(python: str, extras_dir: Path) -> None:
state.cuda_verified = cuda_ok
state.attempts.append({"phase": "verify", "ok": cuda_ok, "output": detail[-2000:]})

# Tell the import system to re-scan ``sys.path`` so packages
# written into the extras dir during this run are visible to the
# next ``importlib.util.find_spec`` call (the image-runtime probe
# uses one). Without this, the runtime continues reporting
# "placeholder" until a backend restart even though the bundle
# is on disk. Also reset the cached VRAM total so the post-install
# capabilities snapshot reflects the freshly importable torch.
try:
importlib.invalidate_caches()
except Exception:
pass
try:
from backend_service.helpers.gpu import reset_vram_total_cache
reset_vram_total_cache()
except Exception:
pass

state.phase = "done"
state.percent = 100.0
state.done = True
Expand Down
170 changes: 170 additions & 0 deletions tests/test_gpu_detection.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,170 @@
"""Tests for the Windows / Linux GPU detection helper.

The pre-fix path returned system RAM via ``psutil.virtual_memory().total``
when ``nvidia-smi`` wasn't on PATH — so an RTX 4090 box on Windows showed
12 GB total in the safety estimator instead of 24 GB. The new path tries
``torch.cuda`` first, falls back to ``nvidia-smi``, and only returns a
``vram_total_gb=None`` when neither answers. The frontend treats ``None``
as "unknown" and skips the spurious crash warning.
"""

from __future__ import annotations

import sys
import types
import unittest
from unittest import mock

from backend_service.helpers import gpu as gpu_module


def _fake_torch_with_cuda(total_bytes: int, free_bytes: int, name: str = "NVIDIA GeForce RTX 4090") -> types.ModuleType:
cuda = types.SimpleNamespace()
cuda.is_available = lambda: True
cuda.current_device = lambda: 0

class _Props:
def __init__(self, mem: int, gpu_name: str) -> None:
self.total_memory = mem
self.name = gpu_name

cuda.get_device_properties = lambda device: _Props(total_bytes, name)
cuda.mem_get_info = lambda device: (free_bytes, total_bytes)

fake = types.ModuleType("torch")
fake.cuda = cuda # type: ignore[attr-defined]
return fake


def _fake_torch_no_cuda() -> types.ModuleType:
cuda = types.SimpleNamespace()
cuda.is_available = lambda: False
fake = types.ModuleType("torch")
fake.cuda = cuda # type: ignore[attr-defined]
return fake


class SnapshotTorchCudaTests(unittest.TestCase):
def setUp(self) -> None:
gpu_module.reset_vram_total_cache()
self.monitor = gpu_module.GPUMonitor()
# Force the monitor onto the nvidia path even when running these
# tests on a Mac developer machine.
self.monitor._system = "Linux"

def tearDown(self) -> None:
gpu_module.reset_vram_total_cache()

def test_torch_cuda_returns_full_vram_for_rtx_4090(self) -> None:
twenty_four_gb = 24 * 1024 ** 3
free = 22 * 1024 ** 3
with mock.patch.dict(sys.modules, {"torch": _fake_torch_with_cuda(twenty_four_gb, free)}):
snapshot = self.monitor._snapshot_torch_cuda()
self.assertIsNotNone(snapshot)
assert snapshot is not None # type narrow
self.assertEqual(snapshot["gpu_name"], "NVIDIA GeForce RTX 4090")
self.assertEqual(snapshot["vram_total_gb"], 24.0)
# 24 - 22 = 2 GB used.
self.assertEqual(snapshot["vram_used_gb"], 2.0)

def test_torch_cuda_unavailable_returns_none(self) -> None:
with mock.patch.dict(sys.modules, {"torch": _fake_torch_no_cuda()}):
snapshot = self.monitor._snapshot_torch_cuda()
self.assertIsNone(snapshot)

def test_torch_not_installed_returns_none(self) -> None:
# Monkeypatch the import to raise ImportError.
original_import = __builtins__["__import__"] if isinstance(__builtins__, dict) else __builtins__.__import__

def fake_import(name, *args, **kwargs):
if name == "torch":
raise ImportError("No module named 'torch'")
return original_import(name, *args, **kwargs)

with mock.patch("builtins.__import__", side_effect=fake_import):
# Also remove any previously cached torch entry so the
# function's ``import torch`` actually invokes the patched
# ``__import__`` instead of resolving via sys.modules.
with mock.patch.dict(sys.modules, {}, clear=False):
sys.modules.pop("torch", None)
snapshot = self.monitor._snapshot_torch_cuda()
self.assertIsNone(snapshot)


class SnapshotNvidiaTests(unittest.TestCase):
def setUp(self) -> None:
gpu_module.reset_vram_total_cache()
self.monitor = gpu_module.GPUMonitor()
self.monitor._system = "Linux"

def tearDown(self) -> None:
gpu_module.reset_vram_total_cache()

def test_falls_back_to_no_gpu_when_torch_and_nvidia_smi_both_fail(self) -> None:
with mock.patch.object(self.monitor, "_snapshot_torch_cuda", return_value=None), \
mock.patch("subprocess.check_output", side_effect=FileNotFoundError):
snapshot = self.monitor._snapshot_nvidia()
self.assertEqual(snapshot["gpu_name"], "No GPU detected")
self.assertIsNone(snapshot["vram_total_gb"])
self.assertIsNone(snapshot["vram_used_gb"])

def test_does_not_fall_back_to_system_ram(self) -> None:
"""The whole point of this fix: don't lie that system RAM is VRAM."""
with mock.patch.object(self.monitor, "_snapshot_torch_cuda", return_value=None), \
mock.patch("subprocess.check_output", side_effect=FileNotFoundError):
snapshot = self.monitor._snapshot_nvidia()
self.assertNotEqual(snapshot["gpu_name"], "System RAM (no GPU detected)")

def test_torch_cuda_takes_precedence_over_nvidia_smi(self) -> None:
torch_snapshot = {
"gpu_name": "RTX 4090",
"vram_total_gb": 24.0,
"vram_used_gb": 1.0,
"utilization_pct": None,
"temperature_c": None,
"power_w": None,
}
with mock.patch.object(self.monitor, "_snapshot_torch_cuda", return_value=torch_snapshot), \
mock.patch("subprocess.check_output") as mock_subprocess:
snapshot = self.monitor._snapshot_nvidia()
self.assertEqual(snapshot["vram_total_gb"], 24.0)
mock_subprocess.assert_not_called()


class GetDeviceVramTotalGbTests(unittest.TestCase):
def setUp(self) -> None:
gpu_module.reset_vram_total_cache()

def tearDown(self) -> None:
gpu_module.reset_vram_total_cache()

def test_returns_none_when_snapshot_has_no_vram(self) -> None:
with mock.patch.object(
gpu_module._monitor,
"snapshot",
return_value={"vram_total_gb": None},
):
self.assertIsNone(gpu_module.get_device_vram_total_gb())

def test_returns_float_when_snapshot_has_vram(self) -> None:
with mock.patch.object(
gpu_module._monitor,
"snapshot",
return_value={"vram_total_gb": 24.0},
):
self.assertEqual(gpu_module.get_device_vram_total_gb(), 24.0)

def test_caches_result_for_process_lifetime(self) -> None:
with mock.patch.object(
gpu_module._monitor,
"snapshot",
return_value={"vram_total_gb": 24.0},
) as mock_snapshot:
gpu_module.get_device_vram_total_gb()
gpu_module.get_device_vram_total_gb()
gpu_module.get_device_vram_total_gb()
self.assertEqual(mock_snapshot.call_count, 1)


if __name__ == "__main__":
unittest.main()
Loading