From e0eee42e5f9ff1d97265fc359652b761e3338589 Mon Sep 17 00:00:00 2001 From: Cryptopoly <31970407+cryptopoly@users.noreply.github.com> Date: Thu, 30 Apr 2026 00:06:21 +0100 Subject: [PATCH 1/4] Memory-budgeted warm pool + library-only chat picker MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit The warm-pool eviction policy was a fixed count (MAX_WARM_MODELS = 2), so loading three 17-18 GB models in succession on a 64 GB Mac kept all three resident — 17 + 17 + 18 + 21 GB llama-server = ~73 GB, swap, OS crash. Replace the count cap with a memory-aware policy: - _model_resident_bytes(info): on-disk weight size as a proxy for RAM, works for both mlx-lm (mmap) and llama.cpp (full-load). - _memory_budget_bytes(): live psutil snapshot of available RAM minus WARM_POOL_MEMORY_HEADROOM_BYTES (6 GB, mirrors spareHeadroomGb). - _evict_warm_pool(incoming_bytes=N): apply the count cap first (defensive — handles psutil-unavailable hosts), then evict oldest until pool + incoming fits under the budget. - _park_active_engine_or_unload now passes the parked model's resident bytes so eviction targets the right footprint. Chat picker also defaulted to a catalog-only model (nvidia/NVIDIA-Nemotron-3-Nano-4B-GGUF) on a fresh thread, then 500'd on Load with 'isn't downloaded on this machine'. Two fixes: - src/App.tsx threadModelOptions drops the catalog branch entirely. Discover tab is the place to pull a new model; the picker shouldn't surface entries that don't exist locally. - backend_service/state.py _default_session_model now prefers the first available text library entry over _default_chat_variant() when no model is loaded. Falls back to the catalog default only if the library is empty (true first-launch case). --- backend_service/inference.py | 75 +++++++++++++++++++++++++++++++----- backend_service/state.py | 34 ++++++++++++++++ src/App.tsx | 25 ++---------- 3 files changed, 104 insertions(+), 30 deletions(-) diff --git a/backend_service/inference.py b/backend_service/inference.py index 423c0fe..e3705b6 100644 --- a/backend_service/inference.py +++ b/backend_service/inference.py @@ -2165,7 +2165,13 @@ def stream_generate( class RuntimeController: + # Hard upper bound on the warm pool independently of memory accounting — + # if psutil isn't available we still want a sane cap. MAX_WARM_MODELS = 2 + # Reserve this much physical memory for the OS / UI / unrelated + # processes when deciding whether a new (or incoming) model fits. Mirrors + # the headroom used by ``helpers/system.py::spareHeadroomGb``. + WARM_POOL_MEMORY_HEADROOM_BYTES = 6 * 1024 * 1024 * 1024 def __init__(self) -> None: self.capabilities = get_backend_capabilities() @@ -2284,7 +2290,9 @@ def _park_active_engine_or_unload( except Exception: pass return - self._evict_warm_pool() + self._evict_warm_pool( + incoming_bytes=self._model_resident_bytes(self.loaded_model), + ) self._warm_pool[current_key] = (self.engine, self.loaded_model) def _tracked_process_pids(self) -> set[int]: @@ -2540,15 +2548,64 @@ def warm_models(self) -> list[dict[str, Any]]: result.append({**info.to_dict(), "warm": True, "active": False}) return result - def _evict_warm_pool(self) -> None: - """Remove the oldest entry from the warm pool if at capacity.""" + @staticmethod + def _model_resident_bytes(info: LoadedModelInfo) -> int: + """Best-effort estimate of RAM held by a loaded model. + + For local weights we use on-disk size as a proxy — mlx-lm mmaps the + weights so RSS tracks file size closely; for llama.cpp / GGUF the + whole file ends up resident once warm. For catalog/no-path entries + we fall back to 0 (no useful estimate, treat as memory-free). + """ + return _path_size_bytes(info.path) if info.path else 0 + + def _warm_pool_resident_bytes(self) -> int: + return sum(self._model_resident_bytes(info) for _, info in self._warm_pool.values()) + + def _memory_budget_bytes(self) -> int: + """Bytes available for warm-pool weights, after OS headroom. + + Returns 0 when psutil isn't usable; callers must fall back to the + count-based MAX_WARM_MODELS cap in that case. + """ + try: + import psutil + + available = int(psutil.virtual_memory().available) + except Exception: + return 0 + return max(0, available - self.WARM_POOL_MEMORY_HEADROOM_BYTES) + + def _pop_oldest_warm_entry(self) -> None: + if not self._warm_pool: + return + oldest_key = next(iter(self._warm_pool)) + old_engine, _ = self._warm_pool.pop(oldest_key) + try: + old_engine.unload_model() + except Exception: + pass + + def _evict_warm_pool(self, *, incoming_bytes: int = 0) -> None: + """Make room for an incoming entry in the warm pool. + + First applies the count cap (MAX_WARM_MODELS) so a flapping budget + can never grow the pool unboundedly. Then, if ``psutil`` reports a + live memory budget, evicts oldest entries until the pool plus the + incoming model fits within ``available - headroom``. + + ``incoming_bytes`` is the resident-byte estimate for the model + about to enter the pool (typically the model being parked from + active to warm). Passing 0 still triggers the count cap. + """ while len(self._warm_pool) >= self.MAX_WARM_MODELS: - oldest_key = next(iter(self._warm_pool)) - old_engine, _ = self._warm_pool.pop(oldest_key) - try: - old_engine.unload_model() - except Exception: - pass + self._pop_oldest_warm_entry() + + budget = self._memory_budget_bytes() + if budget <= 0: + return + while self._warm_pool and self._warm_pool_resident_bytes() + incoming_bytes > budget: + self._pop_oldest_warm_entry() def load_model( self, diff --git a/backend_service/state.py b/backend_service/state.py index 63582f9..345eb86 100644 --- a/backend_service/state.py +++ b/backend_service/state.py @@ -853,6 +853,40 @@ def _default_session_model(self) -> dict[str, Any]: "treeBudget": model_info.treeBudget, } + # No model is currently loaded. Prefer a model the user actually has + # downloaded over a catalog default — surfacing a catalog-only entry + # (e.g. nvidia/NVIDIA-Nemotron-3-Nano-4B-GGUF) just produces a + # confusing "Failed to load … isn't downloaded on this machine" + # error when the user clicks Load. + for entry in self._library(): + entry_type = entry.get("modelType") + if entry_type and entry_type != "text": + continue + if entry.get("broken"): + continue + return { + "model": entry["name"], + "modelRef": entry["name"], + "canonicalRepo": entry.get("canonicalRepo") or entry.get("repo"), + "modelSource": "library", + "modelPath": entry["path"], + "modelBackend": entry.get("backend", "auto"), + "cacheLabel": self._cache_label( + cache_strategy=str(launch_preferences["cacheStrategy"]), + bits=int(launch_preferences["cacheBits"]), + fp16_layers=int(launch_preferences["fp16Layers"]), + ), + "cacheStrategy": launch_preferences["cacheStrategy"], + "cacheBits": launch_preferences["cacheBits"], + "fp16Layers": launch_preferences["fp16Layers"], + "fusedAttention": launch_preferences["fusedAttention"], + "fitModelInMemory": launch_preferences["fitModelInMemory"], + "contextTokens": launch_preferences["contextTokens"], + "speculativeDecoding": launch_preferences.get("speculativeDecoding", False), + "dflashDraftModel": None, + "treeBudget": launch_preferences.get("treeBudget", 0), + } + default_variant = _default_chat_variant() return { "model": default_variant["name"], diff --git a/src/App.tsx b/src/App.tsx index 877691a..9aa3645 100644 --- a/src/App.tsx +++ b/src/App.tsx @@ -337,26 +337,9 @@ export default function App() { const selectedLibraryVariant = selectedLibraryRow?.matchedVariant ?? null; // ── Chat model options ───────────────────────────────────── - const catalogChatOptions: ChatModelOption[] = allFeaturedVariants - .filter((variant) => variant.launchMode === "direct") - .map((variant) => ({ - key: `catalog:${variant.id}`, - label: variant.name, - detail: `${variant.format} / ${variant.quantization}`, - group: "Catalog", - model: variant.name, - modelRef: variant.id, - canonicalRepo: variant.repo, - source: "catalog", - backend: variant.backend, - paramsB: variant.paramsB, - sizeGb: variant.sizeGb, - contextWindow: variant.contextWindow, - format: variant.format, - quantization: variant.quantization, - maxContext: variant.maxContext ?? null, - })); - + // Only list models present in the local library — catalog-only entries + // would let the user pick a model that isn't downloaded yet, which then + // 500s on Load. Discover tab is the place to pull a new model. const libraryChatOptions: ChatModelOption[] = workspace.library .filter((item) => (item.modelType === "text" || (!item.modelType)) && !item.broken) .map((item) => { @@ -383,7 +366,7 @@ export default function App() { }; }); - const threadModelOptions = [...catalogChatOptions, ...libraryChatOptions]; + const threadModelOptions = libraryChatOptions; // ── Cache labels (needed early by useChat) ────────────────── const currentCacheLabel = launchSettings.cacheStrategy === "native" From 91565e5c6897bdc008df6acd87bbcb4db6970e87 Mon Sep 17 00:00:00 2001 From: Cryptopoly <31970407+cryptopoly@users.noreply.github.com> Date: Thu, 30 Apr 2026 19:45:03 +0100 Subject: [PATCH 2/4] Update media model management --- backend_service/app.py | 30 +- backend_service/catalog/image_models.py | 112 +++++ backend_service/catalog/video_models.py | 2 + backend_service/helpers/discovery.py | 4 +- backend_service/helpers/images.py | 45 ++- backend_service/helpers/system.py | 20 +- backend_service/helpers/video.py | 224 +++++++++- backend_service/image_runtime.py | 3 +- backend_service/inference.py | 82 +++- backend_service/mlx_video_runtime.py | 90 ++++- backend_service/models/__init__.py | 1 + backend_service/plugins/__init__.py | 24 +- backend_service/routes/cache.py | 5 +- backend_service/routes/health.py | 14 +- backend_service/routes/images.py | 46 ++- backend_service/routes/video.py | 83 +++- backend_service/state.py | 26 +- backend_service/video_runtime.py | 15 +- cache_compression/__init__.py | 74 ++-- cache_compression/chaosengine.py | 18 +- cache_compression/triattention.py | 35 +- src/App.tsx | 6 +- src/api.ts | 4 +- src/components/LatestImageDiscoverCard.tsx | 172 -------- src/components/StartupProgressPanel.tsx | 10 +- src/features/images/ImageDiscoverTab.tsx | 268 ++++++++++-- src/features/images/ImageModelsTab.tsx | 372 +++++++++++++---- src/features/images/ImageStudioTab.tsx | 7 +- src/features/video/VideoDiscoverTab.tsx | 393 +++++++++++------- src/features/video/VideoModelsTab.tsx | 450 +++++++++++++++------ src/features/video/VideoStudioTab.tsx | 92 +++-- src/hooks/useVideoState.ts | 20 +- src/styles.css | 142 +++++++ src/types.ts | 9 + src/types/image.ts | 9 +- src/utils/__tests__/images.test.ts | 171 ++++++++ src/utils/__tests__/videos.test.ts | 27 ++ src/utils/discoverSort.ts | 52 +++ src/utils/images.ts | 145 ++++++- src/utils/videos.ts | 111 +++++ tests/test_backend_service.py | 57 +++ tests/test_discovery.py | 15 + tests/test_image_discover.py | 53 +++ tests/test_mlx_video.py | 45 +++ tests/test_video_routes.py | 274 ++++++++++++- 45 files changed, 3074 insertions(+), 783 deletions(-) delete mode 100644 src/components/LatestImageDiscoverCard.tsx create mode 100644 src/utils/__tests__/images.test.ts create mode 100644 tests/test_image_discover.py diff --git a/backend_service/app.py b/backend_service/app.py index e226b08..86977d7 100644 --- a/backend_service/app.py +++ b/backend_service/app.py @@ -8,25 +8,20 @@ import uuid from datetime import datetime from pathlib import Path -from typing import Any +from typing import TYPE_CHECKING, Any from fastapi import FastAPI, Request from fastapi.middleware.cors import CORSMiddleware from starlette.responses import JSONResponse -from backend_service.image_runtime import ( - ImageGenerationConfig, - ImageRuntimeManager, -) -from backend_service.video_runtime import ( - VideoGenerationConfig, - VideoRuntimeManager, - start_torch_warmup, -) from backend_service.models import ImageGenerationRequest, VideoGenerationRequest from backend_service.routes import register_routes from backend_service.state import ChaosEngineState +if TYPE_CHECKING: + from backend_service.image_runtime import ImageRuntimeManager + from backend_service.video_runtime import VideoRuntimeManager + # --------------------------------------------------------------------------- # Helper modules -- extracted from this file for maintainability. # --------------------------------------------------------------------------- @@ -121,8 +116,8 @@ # extracted signatures require them explicitly. # --------------------------------------------------------------------------- -def _build_system_snapshot() -> dict[str, Any]: - return _build_system_snapshot_impl(app_version, APP_STARTED_AT) +def _build_system_snapshot(*, capabilities: Any | None = None) -> dict[str, Any]: + return _build_system_snapshot_impl(app_version, APP_STARTED_AT, capabilities=capabilities) def _default_settings() -> dict[str, Any]: @@ -231,6 +226,7 @@ def compute_cache_preview( fp16_layers: int = 4, num_layers: int = 32, num_heads: int = 32, + num_kv_heads: int | None = None, hidden_size: int = 4096, context_tokens: int = 8192, params_b: float = 7.0, @@ -242,6 +238,7 @@ def compute_cache_preview( fp16_layers=fp16_layers, num_layers=num_layers, num_heads=num_heads, + num_kv_heads=num_kv_heads, hidden_size=hidden_size, context_tokens=context_tokens, params_b=params_b, @@ -343,6 +340,8 @@ def _generate_image_artifacts( runtime_manager: ImageRuntimeManager | None = None, ) -> tuple[list[dict[str, Any]], dict[str, Any]]: import logging + from backend_service.image_runtime import ImageGenerationConfig, ImageRuntimeManager + logger = logging.getLogger("chaosengine.images") effective_width, effective_height = ( _apply_draft_resolution(request.width, request.height) @@ -413,6 +412,8 @@ def _generate_video_artifact( HTTP error rather than a fake clip. """ import logging + from backend_service.video_runtime import VideoGenerationConfig + logger = logging.getLogger("chaosengine.video") logger.info( "Generating video: model=%s repo=%s size=%dx%d frames=%d steps=%d", @@ -489,7 +490,10 @@ def create_app( allow_methods=["GET", "POST", "PATCH", "DELETE", "OPTIONS"], allow_headers=["Accept", "Authorization", "Content-Type", "X-ChaosEngine-Token"], ) - app.state.chaosengine = state or ChaosEngineState(server_port=DEFAULT_PORT) + app.state.chaosengine = state or ChaosEngineState( + server_port=DEFAULT_PORT, + background_capability_probe=True, + ) app.state.chaosengine_api_token = _resolve_api_token(api_token) app.state.chaosengine_allowed_origins = frozenset(allowed_origins) # Bearer-token enforcement toggle. Reads from (in order) env override, diff --git a/backend_service/catalog/image_models.py b/backend_service/catalog/image_models.py index d890a46..de616a0 100644 --- a/backend_service/catalog/image_models.py +++ b/backend_service/catalog/image_models.py @@ -271,6 +271,118 @@ ] LATEST_IMAGE_TRACKED_SEEDS: list[dict[str, Any]] = [ + { + "repo": "baidu/ERNIE-Image", + "name": "ERNIE-Image", + "provider": "Baidu", + "styleTags": ["general", "detailed"], + "taskSupport": ["txt2img"], + "sizeGb": 0, + "recommendedResolution": "1024x1024", + "note": "Tracked current text-to-image DiT release from Baidu.", + "gated": False, + "pipelineTag": "text-to-image", + "updatedLabel": "Tracked latest", + "releaseDate": "2026-04", + }, + { + "repo": "baidu/ERNIE-Image-Turbo", + "name": "ERNIE-Image-Turbo", + "provider": "Baidu", + "styleTags": ["general", "fast"], + "taskSupport": ["txt2img"], + "sizeGb": 0, + "recommendedResolution": "1024x1024", + "note": "Tracked faster ERNIE-Image lane for current local image generation discovery.", + "gated": False, + "pipelineTag": "text-to-image", + "updatedLabel": "Tracked latest", + "releaseDate": "2026-04", + }, + { + "repo": "NucleusAI/Nucleus-Image", + "name": "Nucleus-Image", + "provider": "NucleusAI", + "styleTags": ["general", "detailed"], + "taskSupport": ["txt2img"], + "sizeGb": 0, + "recommendedResolution": "1024x1024", + "note": "Tracked current diffusers-compatible text-to-image release.", + "gated": False, + "pipelineTag": "text-to-image", + "updatedLabel": "Tracked latest", + "releaseDate": "2026-04", + }, + { + "repo": "black-forest-labs/FLUX.2-dev", + "name": "FLUX.2 Dev", + "provider": "Black Forest Labs", + "styleTags": ["general", "detailed", "flux"], + "taskSupport": ["txt2img", "img2img"], + "sizeGb": 64.7, + "recommendedResolution": "1024x1024", + "note": "Tracked FLUX.2 generation-and-editing release.", + "gated": True, + "pipelineTag": "image-to-image", + "updatedLabel": "Tracked latest", + "releaseDate": "2026-02", + }, + { + "repo": "black-forest-labs/FLUX.2-klein-9B", + "name": "FLUX.2 Klein 9B", + "provider": "Black Forest Labs", + "styleTags": ["general", "flux", "fast"], + "taskSupport": ["txt2img", "img2img"], + "sizeGb": 0, + "recommendedResolution": "1024x1024", + "note": "Tracked smaller FLUX.2 lane.", + "gated": False, + "pipelineTag": "image-to-image", + "updatedLabel": "Tracked latest", + "releaseDate": "2026-02", + }, + { + "repo": "Tongyi-MAI/Z-Image-Turbo", + "name": "Z-Image-Turbo", + "provider": "Tongyi-MAI", + "styleTags": ["general", "fast"], + "taskSupport": ["txt2img"], + "sizeGb": 0, + "recommendedResolution": "1024x1024", + "note": "Tracked current Z-Image turbo text-to-image release.", + "gated": False, + "pipelineTag": "text-to-image", + "updatedLabel": "Tracked latest", + "releaseDate": "2026-01", + }, + { + "repo": "Tongyi-MAI/Z-Image", + "name": "Z-Image", + "provider": "Tongyi-MAI", + "styleTags": ["general", "detailed"], + "taskSupport": ["txt2img"], + "sizeGb": 0, + "recommendedResolution": "1024x1024", + "note": "Tracked current Z-Image text-to-image release.", + "gated": False, + "pipelineTag": "text-to-image", + "updatedLabel": "Tracked latest", + "releaseDate": "2026-01", + }, + { + "repo": "Qwen/Qwen-Image-Edit-2511", + "name": "Qwen-Image-Edit-2511", + "provider": "Qwen", + "styleTags": ["edit", "qwenimage", "general"], + "taskSupport": ["img2img"], + "sizeGb": 57.7, + "recommendedResolution": "1024x1024", + "note": "Tracked newer Qwen image editing release with improved consistency.", + "gated": False, + "pipelineTag": "image-to-image", + "updatedLabel": "Tracked latest", + "releaseDate": "2025-12", + }, { "repo": "Qwen/Qwen-Image", "name": "Qwen-Image", diff --git a/backend_service/catalog/video_models.py b/backend_service/catalog/video_models.py index e3509bc..65b2e1e 100644 --- a/backend_service/catalog/video_models.py +++ b/backend_service/catalog/video_models.py @@ -160,6 +160,7 @@ "name": "LTX-2.3 · distilled (MLX)", "provider": "Lightricks · prince-canuma", "repo": "prince-canuma/LTX-2.3-distilled", + "textEncoderRepo": "prince-canuma/LTX-2-distilled", "link": "https://huggingface.co/prince-canuma/LTX-2.3-distilled", "runtime": "mlx-video (MLX native)", "styleTags": ["general", "fast", "motion", "mlx"], @@ -178,6 +179,7 @@ "name": "LTX-2.3 · dev (MLX)", "provider": "Lightricks · prince-canuma", "repo": "prince-canuma/LTX-2.3-dev", + "textEncoderRepo": "prince-canuma/LTX-2-distilled", "link": "https://huggingface.co/prince-canuma/LTX-2.3-dev", "runtime": "mlx-video (MLX native)", "styleTags": ["general", "quality", "motion", "mlx"], diff --git a/backend_service/helpers/discovery.py b/backend_service/helpers/discovery.py index 49519cf..f96c3b9 100644 --- a/backend_service/helpers/discovery.py +++ b/backend_service/helpers/discovery.py @@ -46,7 +46,7 @@ def _path_size_bytes(path: Path, *, seen: set[tuple[int, int]] | None = None) -> with iterator as entries: for entry in entries: try: - entry_stat = entry.stat(follow_symlinks=False) + entry_stat = entry.stat(follow_symlinks=True) except OSError: continue entry_id = (entry_stat.st_dev, entry_stat.st_ino) @@ -54,7 +54,7 @@ def _path_size_bytes(path: Path, *, seen: set[tuple[int, int]] | None = None) -> continue visited.add(entry_id) try: - is_dir = entry.is_dir(follow_symlinks=False) + is_dir = entry.is_dir(follow_symlinks=True) except OSError: is_dir = False if is_dir: diff --git a/backend_service/helpers/images.py b/backend_service/helpers/images.py index b58f064..60d4bb8 100644 --- a/backend_service/helpers/images.py +++ b/backend_service/helpers/images.py @@ -419,11 +419,15 @@ def _is_latest_image_candidate(model: dict[str, Any], curated_repos: set[str]) - lowered = model_id.lower() excluded_fragments = ( "-lora", + "_lora", + "lora-", "controlnet", "ip-adapter", + "adapter", "tensorrt", "_amdgpu", "onnx", + "embedding", "instruct-pix2pix", ) if any(fragment in lowered for fragment in excluded_fragments): @@ -431,19 +435,44 @@ def _is_latest_image_candidate(model: dict[str, Any], curated_repos: set[str]) - tags = {str(tag).lower() for tag in (model.get("tags") or [])} pipeline_tag = str(model.get("pipeline_tag") or "").lower() - allowed_orgs = { + excluded_tags = { + "lora", + "controlnet", + "adapter", + "adapters", + "textual-inversion", + "embedding", + "embeddings", + "onnx", + } + if tags & excluded_tags: + return False + + trusted_providers = { "black-forest-labs", + "baidu", "stabilityai", "qwen", "hidream-ai", "zai-org", + "tongyi-mai", + "nucleusai", "efficient-large-model", "hunyuanvideo-community", "tencent-hunyuan", "thudm", + "diffusers", } provider = model_id.split("/", 1)[0].lower() if "/" in model_id else "" - if provider and provider not in allowed_orgs: + try: + downloads = int(model.get("downloads") or 0) + except (TypeError, ValueError): + downloads = 0 + try: + likes = int(model.get("likes") or 0) + except (TypeError, ValueError): + likes = 0 + if provider and provider not in trusted_providers and downloads < 1000 and likes < 25: return False if "diffusers" not in tags: @@ -481,9 +510,9 @@ def _latest_image_model_payloads(library: list[dict[str, Any]], limit: int = 10) try: params = urllib.parse.urlencode({ "filter": "diffusers", - "sort": "modified", + "sort": "createdAt", "direction": "-1", - "limit": "48", + "limit": "96", "full": "true", }) url = f"https://huggingface.co/api/models?{params}" @@ -502,10 +531,16 @@ def _latest_image_model_payloads(library: list[dict[str, Any]], limit: int = 10) ] return _tracked_latest_seed_payloads(library)[:limit] - candidates: list[dict[str, Any]] = [] + accepted_models: list[dict[str, Any]] = [] for model in data: if not isinstance(model, dict) or not _is_latest_image_candidate(model, curated_repos): continue + accepted_models.append(model) + if len(accepted_models) >= max(limit * 2, limit): + break + + candidates: list[dict[str, Any]] = [] + for model in accepted_models: model_id = str(model.get("id") or "") provider = model_id.split("/", 1)[0] if "/" in model_id else "Community" tags = [str(tag) for tag in (model.get("tags") or [])] diff --git a/backend_service/helpers/system.py b/backend_service/helpers/system.py index b3a5584..fad84ce 100644 --- a/backend_service/helpers/system.py +++ b/backend_service/helpers/system.py @@ -378,9 +378,23 @@ def _list_llm_processes(limit: int = 12) -> list[dict[str, Any]]: return matches[:limit] -def _build_system_snapshot(app_version: str, app_started_at: float) -> dict[str, Any]: - from backend_service.inference import get_backend_capabilities - native = get_backend_capabilities().to_dict() +def _capabilities_payload(capabilities: Any | None = None) -> dict[str, Any]: + if capabilities is None: + from backend_service.inference import get_backend_capabilities + return get_backend_capabilities().to_dict() + to_dict = getattr(capabilities, "to_dict", None) + if callable(to_dict): + return dict(to_dict()) + return dict(capabilities) + + +def _build_system_snapshot( + app_version: str, + app_started_at: float, + *, + capabilities: Any | None = None, +) -> dict[str, Any]: + native = _capabilities_payload(capabilities) memory = psutil.virtual_memory() try: swap = psutil.swap_memory() diff --git a/backend_service/helpers/video.py b/backend_service/helpers/video.py index d0409ab..5a7b3f9 100644 --- a/backend_service/helpers/video.py +++ b/backend_service/helpers/video.py @@ -67,8 +67,17 @@ def _video_model_payloads(library: list[dict[str, Any]]) -> list[dict[str, Any]] # Merge live metadata first so curated fields (releaseDate, # familyName) still win when both exist. enriched = {**enriched, **live_metadata} - enriched["availableLocally"] = _video_repo_runtime_ready(repo) if repo else False - enriched["hasLocalData"] = enriched["availableLocally"] or _video_repo_has_any_local_data(repo) + validation_error = _video_variant_validation_error(enriched) + enriched["availableLocally"] = validation_error is None + enriched["hasLocalData"] = ( + enriched["availableLocally"] + or _video_variant_has_any_local_data(enriched) + ) + enriched["localStatusReason"] = ( + _video_variant_local_status_reason(enriched, validation_error) + if enriched["hasLocalData"] and validation_error + else None + ) enriched["familyName"] = family["name"] release_date = str(variant.get("releaseDate") or "").strip() or None enriched["releaseDate"] = release_date @@ -120,16 +129,16 @@ def _is_video_repo(repo_id: str) -> bool: def _video_repo_runtime_ready(repo_id: str) -> bool: """True if the local snapshot is complete enough to load. - Routes the validator by engine: mlx-video repos ship text_encoder / - tokenizer / transformer / vae folders without ``model_index.json``, - so the diffusers-shape check would always falsely fail for them. - Diffusers repos still go through ``validate_local_diffusers_snapshot``. + Routes the validator by engine: mlx-video repos ship component folders + without ``model_index.json``, so the diffusers-shape check would always + falsely fail for them. Diffusers repos still go through + ``validate_local_diffusers_snapshot``. """ snapshot_dir = _hf_repo_snapshot_dir(repo_id) if snapshot_dir is None: return False if _is_mlx_video_routed_repo(repo_id): - return _validate_mlx_video_snapshot(snapshot_dir) is None + return _validate_mlx_video_snapshot(snapshot_dir, repo_id) is None return validate_local_diffusers_snapshot(snapshot_dir, repo_id) is None @@ -156,19 +165,151 @@ def _video_repo_has_any_local_data(repo_id: str) -> bool: def _video_variant_available_locally(variant: dict[str, Any]) -> bool: + return _video_variant_validation_error(variant) is None + + +def _video_variant_has_any_local_data(variant: dict[str, Any]) -> bool: + repo = str(variant.get("repo") or "") + if repo and _video_repo_has_any_local_data(repo): + return True + gguf_repo = str(variant.get("ggufRepo") or "") + if gguf_repo and _video_repo_has_any_local_data(gguf_repo): + return True + return False + + +def _video_variant_validation_error(variant: dict[str, Any]) -> str | None: repo = str(variant.get("repo") or "") if not repo: - return False - return _video_repo_runtime_ready(repo) + return "Video model variant is missing its base repo id." + repo_error = _video_download_validation_error(repo) + if repo_error: + return repo_error + text_error = _video_variant_mlx_text_components_validation_error(variant) + if text_error: + return text_error + return _video_variant_gguf_validation_error(variant) + + +def _video_variant_mlx_text_components_validation_error(variant: dict[str, Any]) -> str | None: + repo = str(variant.get("repo") or "") + if not _is_mlx_video_routed_repo(repo): + return None + snapshot_dir = _hf_repo_snapshot_dir(repo) + if snapshot_dir is None: + return None + missing = _missing_mlx_text_components(Path(snapshot_dir)) + if not missing: + return None + + text_encoder_repo = str(variant.get("textEncoderRepo") or "").strip() + if text_encoder_repo and text_encoder_repo != repo: + text_snapshot = _hf_repo_snapshot_dir(text_encoder_repo) + if text_snapshot is not None and not _missing_mlx_text_components(Path(text_snapshot)): + return None + return ( + "The local snapshot is missing shared mlx-video text components: " + f"{', '.join(missing)}. Download the shared text encoder " + f"({text_encoder_repo}) and retry." + ) + + return ( + "The local snapshot is incomplete. Missing mlx-video components: " + f"{', '.join(missing)}. Re-download the model and keep ChaosEngineAI " + "open until the download completes." + ) + + +def _video_variant_missing_text_encoder_repo(variant: dict[str, Any]) -> str | None: + error = _video_variant_mlx_text_components_validation_error(variant) + text_encoder_repo = str(variant.get("textEncoderRepo") or "").strip() + if error and text_encoder_repo: + return text_encoder_repo + return None + + +def _video_variant_local_status_reason( + variant: dict[str, Any], + validation_error: str | None, +) -> str | None: + if not validation_error: + return None + gguf_file = str(variant.get("ggufFile") or "").strip() + gguf_repo = str(variant.get("ggufRepo") or "").strip() + if gguf_file and gguf_repo and "GGUF transformer file is missing" in validation_error: + return f"Base model installed; missing GGUF transformer: {gguf_repo}/{gguf_file}." + + prefix = "The local snapshot is incomplete. Missing mlx-video components: " + if validation_error.startswith(prefix): + missing = validation_error[len(prefix):].split(". Re-download", 1)[0] + return f"Missing MLX components: {missing}." + + shared_prefix = "The local snapshot is missing shared mlx-video text components: " + if validation_error.startswith(shared_prefix): + missing = validation_error[len(shared_prefix):].split(". Download", 1)[0] + text_encoder_repo = str(variant.get("textEncoderRepo") or "").strip() + source = f" from {text_encoder_repo}" if text_encoder_repo else "" + return f"Missing shared MLX text components{source}: {missing}." + + if validation_error.startswith("The selected GGUF transformer resolved to a cache path"): + return f"GGUF transformer cache path is invalid: {gguf_repo}/{gguf_file}." + + return validation_error + + +def _video_variant_gguf_validation_error(variant: dict[str, Any]) -> str | None: + gguf_file = str(variant.get("ggufFile") or "").strip() + if not gguf_file: + return None + gguf_repo = str(variant.get("ggufRepo") or "").strip() + if not gguf_repo: + return ( + f"{variant.get('name') or 'This GGUF video variant'} is missing " + "its GGUF repository metadata." + ) + try: + from huggingface_hub import hf_hub_download # type: ignore + + local_path = hf_hub_download( + repo_id=gguf_repo, + filename=gguf_file, + local_files_only=True, + ) + except Exception: + return ( + "The base diffusers snapshot is installed, but the selected GGUF " + f"transformer file is missing: {gguf_repo}/{gguf_file}. Download " + "the GGUF variant before generating so the app does not fall back " + "to the full BF16 transformer." + ) + if not Path(local_path).exists(): + return ( + "The selected GGUF transformer resolved to a cache path that does " + f"not exist: {gguf_repo}/{gguf_file}. Retry the GGUF download." + ) + return None def _video_download_repo_ids() -> set[str]: - return { + repos = { str(variant.get("repo") or "") for family in VIDEO_MODEL_FAMILIES for variant in family["variants"] if str(variant.get("repo") or "") } + repos.update( + str(variant.get("ggufRepo") or "") + for family in VIDEO_MODEL_FAMILIES + for variant in family["variants"] + if str(variant.get("ggufRepo") or "") + ) + repos.update( + str(variant.get("textEncoderRepo") or "") + for family in VIDEO_MODEL_FAMILIES + for variant in family["variants"] + if str(variant.get("textEncoderRepo") or "") + ) + return repos # Diffusers pipelines only need the standard per-component folders @@ -202,13 +343,23 @@ def _video_download_repo_ids() -> set[str]: _VIDEO_MLX_ALLOW_PATTERNS: list[str] = [ "text_encoder/**", "tokenizer/**", + "text_projections/**", + "audio_vae/**", "transformer/**", "vae/**", + "vocoder/**", "*spatial-upscaler*.safetensors", "*.md", "LICENSE*", ] +_VIDEO_MLX_TEXT_ENCODER_ALLOW_PATTERNS: list[str] = [ + "text_encoder/**", + "tokenizer/**", + "*.md", + "LICENSE*", +] + def _video_repo_allow_patterns(repo_id: str) -> list[str] | None: """Patterns to pass to ``snapshot_download`` for a video repo. @@ -236,11 +387,10 @@ def _video_download_validation_error(repo_id: str) -> str | None: "Retry the download and make sure the backend can access Hugging Face." ) # mlx-video routed repos (e.g. ``prince-canuma/LTX-2-*``) ship MLX - # layout — text_encoder / tokenizer / transformer / vae folders - # without ``model_index.json``. Don't apply the diffusers-shape - # validator to them; check for the MLX component folders instead. + # layouts without ``model_index.json``. Don't apply the diffusers-shape + # validator to them; check for the expected MLX component folders instead. if _is_mlx_video_routed_repo(repo_id): - return _validate_mlx_video_snapshot(snapshot_dir) + return _validate_mlx_video_snapshot(snapshot_dir, repo_id) return validate_local_diffusers_snapshot(snapshot_dir, repo_id) @@ -261,16 +411,54 @@ def _is_mlx_video_routed_repo(repo_id: str) -> bool: # diffusers layout — no model_index.json. Lifted from the ``prince-canuma/ # LTX-2-distilled`` repo tree as the canonical shape; bump as new mlx-video # families with different layouts come online. -_MLX_VIDEO_REQUIRED_COMPONENTS: tuple[str, ...] = ( +_MLX_VIDEO_LTX2_REQUIRED_COMPONENTS: tuple[str, ...] = ( "text_encoder", "tokenizer", + "text_projections", + "transformer", + "vae", +) + +_MLX_VIDEO_LTX23_REQUIRED_COMPONENTS: tuple[str, ...] = ( + "audio_vae", + "text_projections", "transformer", "vae", + "vocoder", ) -def _validate_mlx_video_snapshot(snapshot_dir: str) -> str | None: - """Return ``None`` if the snapshot has the four MLX component folders. +def _mlx_video_required_components(repo_id: str | None = None) -> tuple[str, ...]: + repo_key = str(repo_id or "").lower() + if "ltx-2.3" in repo_key: + return _MLX_VIDEO_LTX23_REQUIRED_COMPONENTS + return _MLX_VIDEO_LTX2_REQUIRED_COMPONENTS + + +def _missing_mlx_text_components(root: Path) -> list[str]: + missing: list[str] = [] + checks = { + "text_encoder": ( + root / "text_encoder" / "config.json", + root / "text_encoder" / "model.safetensors.index.json", + ), + "tokenizer": ( + root / "tokenizer" / "tokenizer.json", + root / "tokenizer" / "tokenizer.model", + ), + } + for component, required_paths in checks.items(): + component_dir = root / component + if not component_dir.is_dir(): + missing.append(component) + continue + if not all(path.exists() for path in required_paths): + missing.append(component) + return missing + + +def _validate_mlx_video_snapshot(snapshot_dir: str, repo_id: str | None = None) -> str | None: + """Return ``None`` if the snapshot has the required MLX component folders. Mirrors the contract of ``validate_local_diffusers_snapshot`` so the callers can swap one for the other without restructuring the result @@ -284,7 +472,7 @@ def _validate_mlx_video_snapshot(snapshot_dir: str) -> str | None: "Re-download the model." ) missing: list[str] = [] - for component in _MLX_VIDEO_REQUIRED_COMPONENTS: + for component in _mlx_video_required_components(repo_id): component_dir = root / component if not component_dir.is_dir(): missing.append(component) diff --git a/backend_service/image_runtime.py b/backend_service/image_runtime.py index c4f987b..5fd46ea 100644 --- a/backend_service/image_runtime.py +++ b/backend_service/image_runtime.py @@ -26,7 +26,6 @@ PHASE_LOADING, PHASE_SAVING, ) -from cache_compression import apply_diffusion_cache_strategy WORKSPACE_ROOT = Path(__file__).resolve().parents[1] @@ -660,6 +659,8 @@ def generate(self, config: ImageGenerationConfig) -> list[GeneratedImage]: # for this pipeline yet we swallow NotImplementedError and run # the stock pipeline — the UI surfaces the "Scaffold" badge so # users know why speedup didn't appear. + from cache_compression import apply_diffusion_cache_strategy + cache_note = apply_diffusion_cache_strategy( pipeline, strategy_id=config.cacheStrategy, diff --git a/backend_service/inference.py b/backend_service/inference.py index e3705b6..dd105cf 100644 --- a/backend_service/inference.py +++ b/backend_service/inference.py @@ -593,6 +593,7 @@ class BackendCapabilities: converterAvailable: bool = False vllmAvailable: bool = False vllmVersion: str | None = None + probing: bool = False def to_dict(self) -> dict[str, Any]: return { @@ -610,6 +611,7 @@ def to_dict(self) -> dict[str, Any]: "converterAvailable": self.converterAvailable, "vllmAvailable": self.vllmAvailable, "vllmVersion": self.vllmVersion, + "probing": self.probing, } @@ -617,6 +619,35 @@ def to_dict(self) -> dict[str, Any]: _capability_lock = RLock() +def _initial_backend_capabilities() -> BackendCapabilities: + """Cheap capability placeholder used while the real probe runs. + + The full probe imports/spawns MLX and checks vLLM, which can add seconds + to cold start. These path checks are safe enough for initial UI rendering; + load_model() still refreshes capabilities synchronously before selecting + an engine. + """ + python_executable = _resolve_mlx_python() + llama_server_path = _resolve_llama_server() + llama_server_turbo_path = _resolve_llama_server_turbo() + llama_cli_path = _resolve_llama_cli() + return BackendCapabilities( + pythonExecutable=python_executable, + mlxAvailable=False, + mlxLmAvailable=False, + mlxUsable=False, + mlxMessage="Native backend detection is still running.", + ggufAvailable=bool(llama_server_path) or bool(llama_server_turbo_path), + llamaCliPath=llama_cli_path, + llamaServerPath=llama_server_path, + llamaServerTurboPath=llama_server_turbo_path, + converterAvailable=False, + vllmAvailable=False, + vllmVersion=None, + probing=True, + ) + + def _probe_native_backends() -> BackendCapabilities: python_executable = _resolve_mlx_python() llama_server_path = _resolve_llama_server() @@ -2173,8 +2204,8 @@ class RuntimeController: # the headroom used by ``helpers/system.py::spareHeadroomGb``. WARM_POOL_MEMORY_HEADROOM_BYTES = 6 * 1024 * 1024 * 1024 - def __init__(self) -> None: - self.capabilities = get_backend_capabilities() + def __init__(self, *, background_probe: bool = False) -> None: + self.capabilities = _initial_backend_capabilities() self.engine: BaseInferenceEngine = MockInferenceEngine(self.capabilities) self.loaded_model: LoadedModelInfo | None = None self.runtime_note: str | None = None @@ -2184,6 +2215,51 @@ def __init__(self) -> None: self._loading_progress: dict[str, Any] | None = None self._loading_log_tail: list[str] = [] self._recent_orphaned_workers: list[dict[str, Any]] = [] + self._capability_probe_thread: Thread | None = None + self._capability_probe_lock = Lock() + if background_probe: + self.start_capability_probe() + + def start_capability_probe(self, *, force: bool = False) -> None: + with self._capability_probe_lock: + if ( + self._capability_probe_thread is not None + and self._capability_probe_thread.is_alive() + and not force + ): + return + thread = Thread( + target=self._capability_probe_worker, + kwargs={"force": force}, + name="chaosengine-capability-probe", + daemon=True, + ) + self._capability_probe_thread = thread + thread.start() + + def _capability_probe_worker(self, *, force: bool = False) -> None: + try: + capabilities = get_backend_capabilities(force=force) + except Exception as exc: + current = self.capabilities + capabilities = BackendCapabilities( + pythonExecutable=current.pythonExecutable, + mlxAvailable=False, + mlxLmAvailable=False, + mlxUsable=False, + mlxMessage=f"Native backend detection failed: {type(exc).__name__}: {exc}", + ggufAvailable=current.ggufAvailable, + llamaCliPath=current.llamaCliPath, + llamaServerPath=current.llamaServerPath, + llamaServerTurboPath=current.llamaServerTurboPath, + converterAvailable=False, + vllmAvailable=False, + vllmVersion=None, + probing=False, + ) + self.capabilities = capabilities + if isinstance(self.engine, MockInferenceEngine): + self.engine.capabilities = capabilities @staticmethod def _warm_pool_key( @@ -2468,6 +2544,8 @@ def refresh_capabilities(self, *, force: bool = False) -> BackendCapabilities: _LLAMA_HELP_CACHE.clear() _CACHE_TYPE_CACHE.clear() self.capabilities = get_backend_capabilities(force=force) + if isinstance(self.engine, MockInferenceEngine): + self.engine.capabilities = self.capabilities return self.capabilities def _select_engine( diff --git a/backend_service/mlx_video_runtime.py b/backend_service/mlx_video_runtime.py index b52abfb..346d170 100644 --- a/backend_service/mlx_video_runtime.py +++ b/backend_service/mlx_video_runtime.py @@ -88,6 +88,11 @@ ("Lightricks/LTX-2.3", "ltx-2.3-spatial-upscaler-x2-1.0.safetensors"), ), } +_LTX2_SHARED_TEXT_ENCODER_CANDIDATES: tuple[str, ...] = ( + "prince-canuma/LTX-2-distilled", + "Lightricks/LTX-2", +) +_LTX2_TEXT_COMPONENTS: tuple[str, ...] = ("text_encoder", "tokenizer") _LTX2_DISTILLED_STAGE_1_STEPS = 8 _LTX2_DISTILLED_STAGE_2_STEPS = 3 @@ -255,6 +260,86 @@ def _resolve_ltx2_spatial_upscaler( ) +def _resolve_local_snapshot(repo_or_path: str) -> Path | None: + candidate = Path(repo_or_path) + if candidate.exists(): + return candidate + try: + from huggingface_hub import snapshot_download # type: ignore + + return Path(snapshot_download(repo_id=repo_or_path, local_files_only=True)) + except Exception: + return None + + +def _missing_ltx2_text_components(root: Path) -> list[str]: + missing: list[str] = [] + checks = { + "text_encoder": ( + root / "text_encoder" / "config.json", + root / "text_encoder" / "model.safetensors.index.json", + ), + "tokenizer": ( + root / "tokenizer" / "tokenizer.json", + root / "tokenizer" / "tokenizer.model", + ), + } + for component, required_paths in checks.items(): + if not (root / component).is_dir(): + missing.append(component) + continue + if not all(path.exists() for path in required_paths): + missing.append(component) + return missing + + +def _resolve_ltx2_text_component_source(repo: str) -> Path: + for candidate_repo in tuple(dict.fromkeys((repo, *_LTX2_SHARED_TEXT_ENCODER_CANDIDATES))): + snapshot = _resolve_local_snapshot(candidate_repo) + if snapshot is not None and not _missing_ltx2_text_components(snapshot): + return snapshot + checked = ", ".join(_LTX2_SHARED_TEXT_ENCODER_CANDIDATES) + raise RuntimeError( + "LTX-2.3 MLX generation needs shared text_encoder and tokenizer " + f"components, but none were found locally. Download {checked} or " + "resume this model download, then try again." + ) + + +def _prepare_ltx2_model_path(repo: str, workspace: Path) -> Path: + model_path = _resolve_local_snapshot(repo) + if model_path is None: + raise RuntimeError( + f"LTX-2 MLX model snapshot is not available locally for {repo}. " + "Download the model before generating." + ) + + missing = _missing_ltx2_text_components(model_path) + if not missing: + return model_path + + text_source = _resolve_ltx2_text_component_source(repo) + overlay = workspace / "model-overlay" + shutil.rmtree(overlay, ignore_errors=True) + overlay.mkdir(parents=True, exist_ok=True) + + missing_set = set(missing) + for entry in model_path.iterdir(): + if entry.name in missing_set: + continue + (overlay / entry.name).symlink_to(entry, target_is_directory=entry.is_dir()) + for component in _LTX2_TEXT_COMPONENTS: + target = overlay / component + if target.exists() or target.is_symlink(): + if target.is_dir() and not target.is_symlink(): + shutil.rmtree(target) + else: + target.unlink() + source = text_source / component + target.symlink_to(source, target_is_directory=True) + return overlay + + class _ProgressSink(Protocol): def __call__(self, phase: str, message: str, fraction: float) -> None: ... @@ -409,12 +494,15 @@ def _build_cmd( entry = _resolve_entry_point(config.repo) python = _resolve_video_python() pipeline_flag = _resolve_pipeline_flag(config.repo) + model_repo_arg = config.repo + if resolve_aux_files and "ltx-2.3" in config.repo.lower(): + model_repo_arg = str(_prepare_ltx2_model_path(config.repo, output_path.parent)) cmd = [ python, "-m", entry, "--model-repo", - config.repo, + model_repo_arg, "--pipeline", pipeline_flag, "--prompt", diff --git a/backend_service/models/__init__.py b/backend_service/models/__init__.py index fe70463..a47fe80 100644 --- a/backend_service/models/__init__.py +++ b/backend_service/models/__init__.py @@ -195,6 +195,7 @@ class DeleteModelRequest(BaseModel): class DownloadModelRequest(BaseModel): repo: str = Field(min_length=3, max_length=256) + modelId: str | None = Field(default=None, min_length=1, max_length=256) class ImageGenerationRequest(BaseModel): diff --git a/backend_service/plugins/__init__.py b/backend_service/plugins/__init__.py index f218910..7d8f56d 100644 --- a/backend_service/plugins/__init__.py +++ b/backend_service/plugins/__init__.py @@ -9,6 +9,7 @@ from dataclasses import dataclass, field from enum import Enum from pathlib import Path +from threading import RLock from typing import Any import importlib import json @@ -38,28 +39,43 @@ class BasePlugin(ABC): def manifest(self) -> PluginManifest: ... class PluginRegistry: - def __init__(self): + def __init__(self, *, auto_register_builtins: bool = False): self._plugins: dict[str, tuple[PluginManifest, Any]] = {} + self._auto_register_builtins = auto_register_builtins + self._builtins_registered = False + self._lock = RLock() def register(self, manifest: PluginManifest, instance: Any = None): self._plugins[manifest.id] = (manifest, instance) + def ensure_builtins(self) -> None: + if not self._auto_register_builtins or self._builtins_registered: + return + with self._lock: + if not self._builtins_registered: + self.register_builtins() + def get(self, plugin_id: str) -> tuple[PluginManifest, Any] | None: + self.ensure_builtins() return self._plugins.get(plugin_id) def list_all(self) -> list[PluginManifest]: + self.ensure_builtins() return [m for m, _ in self._plugins.values()] def list_by_type(self, plugin_type: PluginType) -> list[tuple[PluginManifest, Any]]: + self.ensure_builtins() return [(m, i) for m, i in self._plugins.values() if m.plugin_type == plugin_type] def enable(self, plugin_id: str) -> bool: + self.ensure_builtins() if plugin_id in self._plugins: self._plugins[plugin_id][0].enabled = True return True return False def disable(self, plugin_id: str) -> bool: + self.ensure_builtins() if plugin_id in self._plugins: self._plugins[plugin_id][0].enabled = False return True @@ -89,7 +105,7 @@ def register_builtins(self): """Register all built-in components as plugins.""" # Cache strategies from cache_compression import registry as cache_registry - for strategy in cache_registry._strategies.values(): + for strategy in cache_registry.strategies(): manifest = PluginManifest( id=f"cache.{strategy.strategy_id}", name=strategy.name, @@ -110,7 +126,7 @@ def register_builtins(self): description=tool.description, ) self.register(manifest, tool) + self._builtins_registered = True # Module singleton -plugin_registry = PluginRegistry() -plugin_registry.register_builtins() +plugin_registry = PluginRegistry(auto_register_builtins=True) diff --git a/backend_service/routes/cache.py b/backend_service/routes/cache.py index 418b20e..1edd5ed 100644 --- a/backend_service/routes/cache.py +++ b/backend_service/routes/cache.py @@ -2,7 +2,7 @@ from typing import Any -from fastapi import APIRouter, Query +from fastapi import APIRouter, Query, Request from backend_service.app import _build_system_snapshot, compute_cache_preview @@ -11,6 +11,7 @@ @router.get("/api/cache/preview") def cache_preview( + request: Request, bits: int = Query(3, ge=0, le=8), fp16_layers: int = Query(4, ge=0, le=16), num_layers: int = Query(32, ge=1, le=160), @@ -21,7 +22,7 @@ def cache_preview( params_b: float = Query(7.0, ge=0.5, le=1000.0), strategy: str = Query("native"), ) -> dict[str, Any]: - system_stats = _build_system_snapshot() + system_stats = _build_system_snapshot(capabilities=request.app.state.chaosengine.runtime.capabilities) return compute_cache_preview( bits=bits, fp16_layers=fp16_layers, diff --git a/backend_service/routes/health.py b/backend_service/routes/health.py index 8ddf97e..8d40109 100644 --- a/backend_service/routes/health.py +++ b/backend_service/routes/health.py @@ -15,18 +15,16 @@ def health(request: Request) -> dict[str, Any]: state = request.app.state.chaosengine from backend_service.app import WORKSPACE_ROOT, app_version - runtime_status = state.runtime.status( - active_requests=state.active_requests, - requests_served=state.requests_served, - ) + capabilities = state.runtime.capabilities + loaded_model = state.runtime.loaded_model return { "status": "ok", "workspaceRoot": str(WORKSPACE_ROOT), - "runtime": _runtime_label(), + "runtime": _runtime_label(capabilities.to_dict()), "appVersion": app_version, - "engine": runtime_status["engine"], - "loadedModel": runtime_status["loadedModel"], - "nativeBackends": runtime_status["nativeBackends"], + "engine": state.runtime.engine.engine_name, + "loadedModel": loaded_model.to_dict() if loaded_model is not None else None, + "nativeBackends": capabilities.to_dict(), } diff --git a/backend_service/routes/images.py b/backend_service/routes/images.py index 9a741c5..7f81689 100644 --- a/backend_service/routes/images.py +++ b/backend_service/routes/images.py @@ -25,11 +25,53 @@ _find_image_output, _delete_image_output, ) -from backend_service.progress import GenerationCancelled, IMAGE_PROGRESS +from backend_service.progress import GenerationCancelled, IMAGE_PROGRESS, VIDEO_PROGRESS router = APIRouter() +def _unload_idle_video_runtime_for_image(request: Request, action: str) -> None: + """Free resident video diffusion weights before image work starts. + + Image and video pipelines live in separate managers, so loading an image + model no longer implicitly releases a previously-loaded video model. That + can leave tens of GB resident across Studio switches. If video generation + is actively running, fail fast instead of blocking the image request behind + a long render. + """ + state = request.app.state.chaosengine + if VIDEO_PROGRESS.snapshot().get("active"): + raise HTTPException( + status_code=409, + detail=( + "A video generation is still running. Wait for it to finish or cancel it " + "before loading an image model." + ), + ) + try: + runtime = state.video_runtime.capabilities() + except Exception: + return + loaded_repo = str(runtime.get("loadedModelRepo") or "") + if not loaded_repo: + return + try: + state.video_runtime.unload() + except Exception as exc: + state.add_log( + "images", + "warning", + f"Could not unload video model before {action}: {type(exc).__name__}: {exc}", + ) + return + state.add_log( + "images", + "info", + f"Unloaded video model {loaded_repo} before {action} to free memory.", + ) + state.add_activity("Video model unloaded", f"Freed memory for {action}") + + @router.get("/api/images/catalog") def image_catalog(request: Request) -> dict[str, Any]: state = request.app.state.chaosengine @@ -89,6 +131,7 @@ def preload_image_model(request: Request, body: ImageRuntimePreloadRequest) -> d validation_error = _image_download_validation_error(variant["repo"]) detail = validation_error or f"{variant['name']} is not installed locally yet." raise HTTPException(status_code=409, detail=detail) + _unload_idle_video_runtime_for_image(request, "image preload") try: runtime = state.image_runtime.preload(variant["repo"]) except RuntimeError as exc: @@ -185,6 +228,7 @@ def generate_image(request: Request, body: ImageGenerationRequest) -> dict[str, state.add_log("images", "error", f"Image model not found in catalog or tracked seeds: '{body.modelId}'") raise HTTPException(status_code=404, detail=f"Unknown image model '{body.modelId}'. The model isn't in the curated catalog or tracked seeds.") state.add_log("images", "info", f"Resolved variant: {variant.get('name')} (repo={variant.get('repo')})") + _unload_idle_video_runtime_for_image(request, "image generation") try: artifacts, runtime = _generate_image_artifacts(body, variant, state.image_runtime) except GenerationCancelled: diff --git a/backend_service/routes/video.py b/backend_service/routes/video.py index 54485ec..58c0398 100644 --- a/backend_service/routes/video.py +++ b/backend_service/routes/video.py @@ -13,13 +13,16 @@ from fastapi.responses import FileResponse from backend_service.helpers.video import ( + _VIDEO_MLX_TEXT_ENCODER_ALLOW_PATTERNS, _find_video_variant, _find_video_variant_by_repo, _is_video_repo, _video_download_repo_ids, _video_download_validation_error, _video_model_payloads, + _video_variant_missing_text_encoder_repo, _video_variant_available_locally, + _video_variant_validation_error, ) from backend_service.models import ( DownloadModelRequest, @@ -27,12 +30,47 @@ VideoRuntimePreloadRequest, VideoRuntimeUnloadRequest, ) -from backend_service.progress import GenerationCancelled, VIDEO_PROGRESS +from backend_service.progress import GenerationCancelled, IMAGE_PROGRESS, VIDEO_PROGRESS router = APIRouter() +def _unload_idle_image_runtime_for_video(request: Request, action: str) -> None: + """Free resident image diffusion weights before video work starts.""" + state = request.app.state.chaosengine + if IMAGE_PROGRESS.snapshot().get("active"): + raise HTTPException( + status_code=409, + detail=( + "An image generation is still running. Wait for it to finish or cancel it " + "before loading a video model." + ), + ) + try: + runtime = state.image_runtime.capabilities() + except Exception: + return + loaded_repo = str(runtime.get("loadedModelRepo") or "") + if not loaded_repo: + return + try: + state.image_runtime.unload() + except Exception as exc: + state.add_log( + "video", + "warning", + f"Could not unload image model before {action}: {type(exc).__name__}: {exc}", + ) + return + state.add_log( + "video", + "info", + f"Unloaded image model {loaded_repo} before {action} to free memory.", + ) + state.add_activity("Image model unloaded", f"Freed memory for {action}") + + @router.get("/api/video/catalog") def video_catalog(request: Request) -> dict[str, Any]: """Return the curated catalog of video generation models.""" @@ -114,10 +152,11 @@ def preload_video_model(request: Request, body: VideoRuntimePreloadRequest) -> d raise HTTPException(status_code=404, detail=f"Unknown video model '{body.modelId}'.") if not _video_variant_available_locally(variant): - validation_error = _video_download_validation_error(variant["repo"]) + validation_error = _video_variant_validation_error(variant) detail = validation_error or f"{variant['name']} is not installed locally yet." raise HTTPException(status_code=409, detail=detail) + _unload_idle_image_runtime_for_video(request, "video preload") try: runtime = state.video_runtime.preload(variant["repo"]) except RuntimeError as exc: @@ -257,10 +296,11 @@ def generate_video(request: Request, body: VideoGenerationRequest) -> dict[str, ) if not _video_variant_available_locally(variant): - validation_error = _video_download_validation_error(variant["repo"]) + validation_error = _video_variant_validation_error(variant) detail = validation_error or f"{variant['name']} is not installed locally yet." raise HTTPException(status_code=409, detail=detail) + _unload_idle_image_runtime_for_video(request, "video generation") try: artifact, runtime = _generate_video_artifact(body, variant, state.video_runtime) except GenerationCancelled: @@ -305,12 +345,47 @@ def download_video_model(request: Request, body: DownloadModelRequest) -> dict[s at an arbitrary model via the API. """ state = request.app.state.chaosengine + variant = _find_video_variant(body.modelId) if body.modelId else None + if body.modelId and variant is None: + raise HTTPException(status_code=404, detail=f"Unknown video model '{body.modelId}'.") + if variant is not None and variant.get("ggufFile"): + base_error = _video_download_validation_error(str(variant["repo"])) + if base_error: + label = variant["name"] + state.add_log("video", "info", f"Video download requested: {label} base ({variant['repo']})") + return {"download": state.start_download(str(variant["repo"]))} + gguf_repo = str(variant.get("ggufRepo") or "") + gguf_file = str(variant.get("ggufFile") or "") + if not gguf_repo or not gguf_file: + raise HTTPException(status_code=400, detail=f"GGUF metadata is incomplete for {variant['name']}.") + state.add_log("video", "info", f"Video download requested: {variant['name']} GGUF ({gguf_repo}/{gguf_file})") + return { + "download": state.start_download( + gguf_repo, + allow_patterns=[gguf_file, "*.md", "LICENSE*"], + ) + } + + if variant is not None: + text_encoder_repo = _video_variant_missing_text_encoder_repo(variant) + if text_encoder_repo: + state.add_log( + "video", + "info", + f"Video download requested: {variant['name']} shared text encoder ({text_encoder_repo})", + ) + return { + "download": state.start_download( + text_encoder_repo, + allow_patterns=list(_VIDEO_MLX_TEXT_ENCODER_ALLOW_PATTERNS), + ) + } + if not _is_video_repo(body.repo): raise HTTPException( status_code=404, detail=f"Repo '{body.repo}' is not in the curated video model catalog.", ) - variant = _find_video_variant_by_repo(body.repo) label = variant["name"] if variant else body.repo state.add_log("video", "info", f"Video download requested: {label} ({body.repo})") return {"download": state.start_download(body.repo)} diff --git a/backend_service/state.py b/backend_service/state.py index 345eb86..ee874a6 100644 --- a/backend_service/state.py +++ b/backend_service/state.py @@ -18,7 +18,6 @@ from fastapi import HTTPException from starlette.responses import StreamingResponse -from cache_compression import registry as cache_registry from backend_service.catalog import CATALOG from backend_service.inference import RuntimeController @@ -170,6 +169,7 @@ def __init__( benchmarks_path: Path | None = None, chat_sessions_path: Path | None = None, library_cache_path: Path | None = None, + background_capability_probe: bool = False, ) -> None: # Defer imports of module-level constants to avoid circular imports from backend_service.app import ( @@ -210,7 +210,7 @@ def __init__( self._library_scan_done.set() else: self._library_scan_done.set() - self.runtime = RuntimeController() + self.runtime = RuntimeController(background_probe=background_capability_probe) self._image_runtime: "ImageRuntimeManager | None" = None self._video_runtime: "VideoRuntimeManager | None" = None self._chat_sessions_path = chat_sessions_path if chat_sessions_path is not None else CHAT_SESSIONS_PATH @@ -419,10 +419,16 @@ def _settings_payload(self, library: list[dict[str, Any]]) -> dict[str, Any]: "hfCachePath": str(self.settings.get("hfCachePath") or ""), } + def _system_snapshot(self) -> dict[str, Any]: + try: + return self._system_snapshot_provider(capabilities=self.runtime.capabilities) + except TypeError: + return self._system_snapshot_provider() + def _bootstrap(self) -> None: from backend_service.app import app_version - system = self._system_snapshot_provider() + system = self._system_snapshot() recommendation = _best_fit_recommendation(system) self.add_log("chaosengine", "info", f"Workspace booted in {system['backendLabel']} mode.") self.add_log("chaosengine", "info", f"ChaosEngine v{app_version} detected.") @@ -494,6 +500,8 @@ def _native_cache_label() -> str: return "Native f16 cache" def _cache_label(self, *, cache_strategy: str, bits: int, fp16_layers: int) -> str: + from cache_compression import registry as cache_registry + strategy = cache_registry.get(cache_strategy) if strategy is not None: return strategy.label(bits, fp16_layers) @@ -1233,7 +1241,7 @@ def _conversion_details( fp16_layers=launch_preferences["fp16Layers"], context_tokens=launch_preferences["contextTokens"], params_b=params_b, - system_stats=self._system_snapshot_provider(), + system_stats=self._system_snapshot(), ) if params_b is not None else None @@ -1341,7 +1349,7 @@ def run_benchmark(self, request: BenchmarkRunRequest) -> dict[str, Any]: fp16_layers=request.fp16Layers, context_tokens=request.contextTokens, params_b=params_b, - system_stats=self._system_snapshot_provider(), + system_stats=self._system_snapshot(), ) use_compressed = request.cacheBits > 0 cache_gb = preview["optimizedCacheGb"] if use_compressed else preview["baselineCacheGb"] @@ -2462,7 +2470,7 @@ def _sse_stream(): headers={"Cache-Control": "no-cache", "Connection": "keep-alive", "X-Accel-Buffering": "no"}, ) - def start_download(self, repo: str) -> dict[str, Any]: + def start_download(self, repo: str, allow_patterns: list[str] | None = None) -> dict[str, Any]: from backend_service.helpers.huggingface import ( _friendly_hf_download_error, _hf_repo_downloaded_bytes, @@ -2569,7 +2577,7 @@ def _progress_worker() -> None: # allowlist so we skip legacy single-file checkpoints the # pipelines never load. Both helpers return None for repos # outside their catalog, so only one ever applies. - allow_patterns = ( + effective_allow_patterns = allow_patterns or ( _video_repo_allow_patterns(repo) or _image_repo_allow_patterns(repo) ) @@ -2577,7 +2585,7 @@ def _progress_worker() -> None: repo, env, process_log, - allow_patterns=allow_patterns, + allow_patterns=effective_allow_patterns, ) with self._lock: if self._download_tokens.get(repo) == download_token: @@ -2903,7 +2911,7 @@ def server_status(self) -> dict[str, Any]: def workspace(self) -> dict[str, Any]: from backend_service.app import compute_cache_preview - system_stats = self._system_snapshot_provider() + system_stats = self._system_snapshot() try: loaded_name = self.runtime.loaded_model.name if self.runtime.loaded_model else None loaded_engine = self.runtime.engine.engine_name if self.runtime.engine else None diff --git a/backend_service/video_runtime.py b/backend_service/video_runtime.py index e7426ce..f931a25 100644 --- a/backend_service/video_runtime.py +++ b/backend_service/video_runtime.py @@ -32,7 +32,6 @@ from backend_service.helpers.gpu import nvidia_gpu_present from backend_service.image_runtime import validate_local_diffusers_snapshot -from cache_compression import apply_diffusion_cache_strategy from backend_service.progress import ( GenerationCancelled, PHASE_DECODING, @@ -985,6 +984,8 @@ def generate(self, config: VideoGenerationConfig) -> GeneratedVideo: # ~1.3–2× on Wan). NotImplementedError is swallowed by the # helper when the pipeline class has no vendored patch yet; # see FU-007 in CLAUDE.md. + from cache_compression import apply_diffusion_cache_strategy + apply_diffusion_cache_strategy( pipeline, strategy_id=config.cacheStrategy, @@ -1527,6 +1528,11 @@ def _ensure_pipeline( pipeline_kwargs["transformer"] = quantized_transformer if gguf_note: VIDEO_PROGRESS.set_phase(PHASE_LOADING, message=gguf_note) + if quantized_transformer is None: + raise RuntimeError( + gguf_note + or f"Could not load requested GGUF transformer {gguf_file}." + ) elif use_nf4: VIDEO_PROGRESS.set_phase( PHASE_LOADING, @@ -1635,9 +1641,10 @@ def _try_load_gguf_transformer( Mirrors the image-side loader: GGUF weights cover the DiT only; VAE and text encoders are loaded from the base ``repo`` snapshot. - All failure modes are non-fatal — a missing ``gguf`` package, an - old diffusers without ``GGUFQuantizationConfig``, or an HF cache - miss falls back to the standard fp16 / bf16 transformer path. + The helper itself only reports ``(None, note)`` on failure so tests + can exercise each missing-dependency path. ``_ensure_pipeline`` treats + a requested GGUF variant as strict and raises with that note rather + than silently loading the full fp16 / bf16 transformer. """ if importlib.util.find_spec("gguf") is None: return None, ( diff --git a/cache_compression/__init__.py b/cache_compression/__init__.py index ad3cb81..1bcfa2c 100644 --- a/cache_compression/__init__.py +++ b/cache_compression/__init__.py @@ -10,6 +10,7 @@ from abc import ABC, abstractmethod import importlib import platform +from threading import RLock from typing import Any @@ -147,18 +148,27 @@ class CacheStrategyRegistry: def __init__(self) -> None: self._strategies: dict[str, CacheStrategy] = {} + self._discovered = False + self._lock = RLock() def register(self, strategy: CacheStrategy) -> None: self._strategies[strategy.strategy_id] = strategy def get(self, strategy_id: str) -> CacheStrategy | None: + self._ensure_discovered() return self._strategies.get(strategy_id) def default(self) -> CacheStrategy: + self._ensure_discovered() return self._strategies["native"] + def strategies(self) -> list[CacheStrategy]: + self._ensure_discovered() + return list(self._strategies.values()) + def available(self) -> list[dict[str, Any]]: """Return a JSON-friendly list for the frontend.""" + self._ensure_discovered() out: list[dict[str, Any]] = [] for s in self._strategies.values(): out.append({ @@ -176,11 +186,19 @@ def available(self) -> list[dict[str, Any]]: }) return out + def _ensure_discovered(self) -> None: + if self._discovered: + return + with self._lock: + if not self._discovered: + self.discover() + def discover(self) -> list[CacheStrategy]: """Import all known adapter modules and return available strategies.""" - self._strategies = {} + with self._lock: + self._strategies = {} - strategy_specs = [ + strategy_specs = [ { "id": "native", "name": "Native f16", @@ -248,31 +266,32 @@ def discover(self) -> list[CacheStrategy]: "supports_fp16_layers": False, "required_llama_binary": "standard", }, - ] - - for spec in strategy_specs: - try: - module = importlib.import_module(spec["module"]) - cls = getattr(module, spec["class_name"]) - instance = cls() - except Exception as exc: - if spec["id"] == "native": - raise - instance = _BrokenStrategy( - strategy_id=str(spec["id"]), - name=str(spec["name"]), - bit_range=spec["bit_range"], - default_bits=spec["default_bits"], - supports_fp16_layers=bool(spec["supports_fp16_layers"]), - required_llama_binary=str(spec.get("required_llama_binary", "standard")), - reason=( - f"{spec['name']} could not be loaded in this runtime. " - f"ChaosEngineAI kept the card visible so the UI does not silently collapse to Native f16 only. " - f"Import error: {exc}" - ), - ) - self.register(instance) - return list(self._strategies.values()) + ] + + for spec in strategy_specs: + try: + module = importlib.import_module(spec["module"]) + cls = getattr(module, spec["class_name"]) + instance = cls() + except Exception as exc: + if spec["id"] == "native": + raise + instance = _BrokenStrategy( + strategy_id=str(spec["id"]), + name=str(spec["name"]), + bit_range=spec["bit_range"], + default_bits=spec["default_bits"], + supports_fp16_layers=bool(spec["supports_fp16_layers"]), + required_llama_binary=str(spec.get("required_llama_binary", "standard")), + reason=( + f"{spec['name']} could not be loaded in this runtime. " + f"ChaosEngineAI kept the card visible so the UI does not silently collapse to Native f16 only. " + f"Import error: {exc}" + ), + ) + self.register(instance) + self._discovered = True + return list(self._strategies.values()) class _BrokenStrategy(CacheStrategy): @@ -330,7 +349,6 @@ def required_llama_binary(self) -> str: # Module-level singleton — import and use ``registry`` directly. registry = CacheStrategyRegistry() -registry.discover() def apply_diffusion_cache_strategy( diff --git a/cache_compression/chaosengine.py b/cache_compression/chaosengine.py index e088dfe..c8d9000 100644 --- a/cache_compression/chaosengine.py +++ b/cache_compression/chaosengine.py @@ -16,28 +16,16 @@ from __future__ import annotations -import importlib +import importlib.util from typing import Any from cache_compression import CacheStrategy -def _load_chaosengine() -> Any | None: - try: - return importlib.import_module("chaos_engine") - except ImportError: - return None - - def _chaosengine_available() -> bool: - mod = _load_chaosengine() - if mod is None: - return False - # Check for the core cache module try: - cache_mod = importlib.import_module("chaos_engine.cache") - return hasattr(cache_mod, "config") or True - except ImportError: + return importlib.util.find_spec("chaos_engine") is not None + except (ImportError, AttributeError, ValueError): return False diff --git a/cache_compression/triattention.py b/cache_compression/triattention.py index fe2d339..ef524ec 100644 --- a/cache_compression/triattention.py +++ b/cache_compression/triattention.py @@ -17,36 +17,23 @@ from __future__ import annotations +import importlib.util from typing import Any from cache_compression import CacheStrategy -_triattention = None -_vllm = None -_mlx_lm = None -try: - import triattention as _triattention # type: ignore[import-untyped] -except ImportError: - pass -try: - import vllm as _vllm # type: ignore[import-untyped] -except ImportError: - pass -try: - import mlx_lm as _mlx_lm # type: ignore[import-untyped] -except ImportError: - pass +def _module_available(module_name: str) -> bool: + try: + return importlib.util.find_spec(module_name) is not None + except (ImportError, AttributeError, ValueError): + return False def _has_mlx_entrypoint() -> bool: - if _triattention is None or _mlx_lm is None: - return False - try: - from triattention.mlx import apply_triattention_mlx # noqa: F401 - return True - except ImportError: - return False + # Keep availability checks side-effect free. Importing mlx_lm can touch + # MLX/Metal at module load and can abort in headless or sandboxed contexts. + return _module_available("triattention") and _module_available("mlx_lm") class TriAttentionStrategy(CacheStrategy): @@ -63,7 +50,7 @@ def has_mlx_backend(self) -> bool: return _has_mlx_entrypoint() def has_vllm_backend(self) -> bool: - return _triattention is not None and _vllm is not None + return _module_available("triattention") and _module_available("vllm") def is_available(self) -> bool: return self.has_mlx_backend() or self.has_vllm_backend() @@ -119,7 +106,7 @@ def apply_vllm_patches(self) -> None: Must be called BEFORE creating a ``vllm.LLM`` instance. """ - if _triattention is None: + if not _module_available("triattention"): raise RuntimeError("triattention is not installed.") try: from triattention.vllm.runtime.integration_monkeypatch import ( diff --git a/src/App.tsx b/src/App.tsx index 9aa3645..62636b3 100644 --- a/src/App.tsx +++ b/src/App.tsx @@ -1474,7 +1474,7 @@ export default function App() { longLiveJob={videoState.longLiveJob} onActiveTabChange={setActiveTab} onOpenVideoStudio={videoState.openVideoStudio} - onVideoDownload={(repo) => void videoState.handleVideoDownload(repo)} + onVideoDownload={(repo, modelId) => void videoState.handleVideoDownload(repo, modelId)} onCancelVideoDownload={(repo) => void videoState.handleCancelVideoDownload(repo)} onDeleteVideoDownload={(repo) => void videoState.handleDeleteVideoDownload(repo)} onOpenExternalUrl={(url) => void handleOpenExternalUrl(url)} @@ -1496,7 +1496,7 @@ export default function App() { fileRevealLabel={fileRevealLabel} onActiveTabChange={setActiveTab} onOpenVideoStudio={videoState.openVideoStudio} - onVideoDownload={(repo) => void videoState.handleVideoDownload(repo)} + onVideoDownload={(repo, modelId) => void videoState.handleVideoDownload(repo, modelId)} onCancelVideoDownload={(repo) => void videoState.handleCancelVideoDownload(repo)} onDeleteVideoDownload={(repo) => void videoState.handleDeleteVideoDownload(repo)} onPreloadVideoModel={(variant) => void videoState.handlePreloadVideoModel(variant)} @@ -1556,7 +1556,7 @@ export default function App() { onActiveTabChange={setActiveTab} onPreloadVideoModel={(variant) => void videoState.handlePreloadVideoModel(variant)} onUnloadVideoModel={(variant) => void videoState.handleUnloadVideoModel(variant)} - onVideoDownload={(repo) => void videoState.handleVideoDownload(repo)} + onVideoDownload={(repo, modelId) => void videoState.handleVideoDownload(repo, modelId)} onGenerateVideo={() => void videoState.handleVideoGenerate()} onOpenExternalUrl={(url) => void handleOpenExternalUrl(url)} onRestartServer={() => void handleRestartServer()} diff --git a/src/api.ts b/src/api.ts index 1741f99..1881b06 100644 --- a/src/api.ts +++ b/src/api.ts @@ -678,8 +678,8 @@ export async function unloadImageModel(modelId?: string): Promise { - const result = await postJson<{ download: DownloadStatus }>("/api/video/download", { repo }); +export async function downloadVideoModel(repo: string, modelId?: string): Promise { + const result = await postJson<{ download: DownloadStatus }>("/api/video/download", { repo, modelId }); return result.download; } diff --git a/src/components/LatestImageDiscoverCard.tsx b/src/components/LatestImageDiscoverCard.tsx deleted file mode 100644 index 982e2c3..0000000 --- a/src/components/LatestImageDiscoverCard.tsx +++ /dev/null @@ -1,172 +0,0 @@ -import type { ImageModelVariant } from "../types"; -import type { DownloadStatus } from "../api"; -import { - imagePrimarySizeLabel, - imageSecondarySizeLabel, - formatImageLicenseLabel, - formatImageAccessError, - formatReleaseLabel, - isGatedImageAccessError, -} from "../utils/format"; -import { downloadProgressLabel, downloadSizeTooltip } from "../utils/downloads"; - -export interface LatestImageDiscoverCardProps { - variant: ImageModelVariant; - downloadState?: DownloadStatus; - fileRevealLabel?: string; - onDownload: (repo: string) => void; - onCancelDownload: (repo: string) => void; - onDeleteDownload: (repo: string) => void; - onOpenExternalUrl: (url: string) => void; - onNavigateSettings: () => void; - onRevealPath?: (path: string) => void; -} - -export function LatestImageDiscoverCard({ - variant, - downloadState, - fileRevealLabel, - onDownload, - onCancelDownload, - onDeleteDownload, - onOpenExternalUrl, - onNavigateSettings, - onRevealPath, -}: LatestImageDiscoverCardProps) { - const isDownloadPaused = downloadState?.state === "cancelled"; - const isDownloadComplete = downloadState?.state === "completed"; - const isDownloadFailed = downloadState?.state === "failed"; - const hasLocalData = Boolean(variant.hasLocalData || isDownloadComplete || isDownloadPaused || isDownloadFailed); - const friendlyDownloadError = formatImageAccessError(downloadState?.error, variant); - const needsGatedAccess = isGatedImageAccessError(downloadState?.error); - return ( -
-
-
-
-

{variant.name}

- {variant.provider} - {!variant.availableLocally && isDownloadComplete ? Downloaded : null} - {isDownloadPaused ? Paused : null} - {isDownloadFailed ? Download Failed : null} -
-

{variant.note}

-
- {variant.updatedLabel ?? "Recently updated"} -
- -
- {imagePrimarySizeLabel(variant)} - {imageSecondarySizeLabel(variant) ? {imageSecondarySizeLabel(variant)} : null} - {variant.recommendedResolution} - {variant.pipelineTag ? {variant.pipelineTag} : null} -
- -
- {formatReleaseLabel(variant.releaseLabel, variant.releaseDate ?? variant.createdAt) ? ( - {formatReleaseLabel(variant.releaseLabel, variant.releaseDate ?? variant.createdAt)} - ) : null} - {variant.downloadsLabel ? {variant.downloadsLabel} : null} - {variant.likesLabel ? {variant.likesLabel} : null} - {variant.license ? {formatImageLicenseLabel(variant.license)} : null} - {typeof variant.gated === "boolean" ? {variant.gated ? "Gated access" : "Open access"} : null} -
- -
- {variant.taskSupport.map((task) => ( - {task} - ))} - {variant.styleTags.map((tag) => ( - {tag} - ))} -
- - {isDownloadFailed && downloadState?.error ? ( -
-

{friendlyDownloadError}

- {needsGatedAccess ? ( -
- - -
- ) : null} - {friendlyDownloadError !== downloadState.error ? ( -
- Technical details -

{downloadState.error}

-
- ) : null} -
- ) : null} - -
- {variant.availableLocally ? ( - Installed - ) : downloadState?.state === "downloading" ? ( - <> - {downloadProgressLabel(downloadState)} - - - - ) : isDownloadPaused ? ( - <> - {downloadProgressLabel(downloadState)} - - - - ) : isDownloadFailed ? ( - <> - - - - ) : isDownloadComplete ? ( - Download complete - ) : ( - <> - - {hasLocalData ? ( - - ) : null} - - )} - {variant.localPath && onRevealPath ? ( - - ) : null} - -
-
- ); -} diff --git a/src/components/StartupProgressPanel.tsx b/src/components/StartupProgressPanel.tsx index f3619a8..141c164 100644 --- a/src/components/StartupProgressPanel.tsx +++ b/src/components/StartupProgressPanel.tsx @@ -12,8 +12,8 @@ interface Props { // 1. Tauri extracts the bundled ~280 MB runtime tarball into a // manifest-hash-suffixed cache dir. Cold SSD + gunzip = 5-15 s. // 2. The Rust shell spawns the Python sidecar. Python 3.11 imports -// FastAPI + uvicorn + huggingface_hub + dflash registry + the -// image/video catalogs. First-time page cache warmup = 10-25 s. +// the core FastAPI app. Heavier image/video/cache runtimes stay lazy +// until their routes are used. // 3. The FastAPI server finishes binding its port and answers // /api/workspace, which releases the splash. // @@ -85,14 +85,14 @@ function pickPhase( if (elapsedSeconds < 25) { return { title: "Starting Python runtime", - detail: "Loading FastAPI, HuggingFace hub, and the cache strategies.", + detail: "Loading the core API and restoring workspace state.", }; } if (elapsedSeconds < 45) { return { - title: "Importing modules", + title: "Waiting for backend", detail: - "Warming up diffusers / MLX / dflash — most of the wait on a cold start.", + "The sidecar is still binding its API port and checking local runtime state.", }; } return { diff --git a/src/features/images/ImageDiscoverTab.tsx b/src/features/images/ImageDiscoverTab.tsx index f7f60f9..fd5a576 100644 --- a/src/features/images/ImageDiscoverTab.tsx +++ b/src/features/images/ImageDiscoverTab.tsx @@ -1,5 +1,5 @@ +import { useMemo, useState } from "react"; import { Panel } from "../../components/Panel"; -import { LatestImageDiscoverCard } from "../../components/LatestImageDiscoverCard"; import type { DownloadStatus } from "../../api"; import type { ImageModelVariant, @@ -10,6 +10,19 @@ import type { ImageDiscoverTaskFilter, ImageDiscoverAccessFilter, } from "../../types/image"; +import { + downloadProgressLabel, + downloadSizeTooltip, + formatImageAccessError, + formatImageLicenseLabel, + formatReleaseLabel, + imageDiscoverMemoryEstimate, + imagePrimarySizeLabel, + imageSecondarySizeLabel, + isGatedImageAccessError, +} from "../../utils"; + +type MediaStatusFilter = "all" | "installed" | "not-installed" | "downloading" | "paused" | "failed" | "incomplete"; export interface ImageDiscoverTabProps { combinedImageDiscoverResults: ImageModelVariant[]; @@ -35,6 +48,43 @@ export interface ImageDiscoverTabProps { onRevealPath: (path: string) => void; } +function imageDiscoverSortLabel(sort: DiscoverSort): string { + if (sort === "size") return "largest size first"; + if (sort === "ram") return "highest RAM/VRAM first"; + if (sort === "likes") return "most liked first"; + if (sort === "downloads") return "most downloads first"; + return "newest released first"; +} + +function sortIndicator(activeSort: DiscoverSort, key: DiscoverSort): string { + return activeSort === key ? " \u25BC" : ""; +} + +function imageVariantStatus( + variant: ImageModelVariant, + downloadState?: DownloadStatus, +): MediaStatusFilter { + if (variant.availableLocally || downloadState?.state === "completed") return "installed"; + if (downloadState?.state === "downloading") return "downloading"; + if (downloadState?.state === "cancelled") return "paused"; + if (downloadState?.state === "failed") return "failed"; + if (variant.hasLocalData) return "incomplete"; + return "not-installed"; +} + +function statusBadge(status: MediaStatusFilter, downloadState?: DownloadStatus) { + if (status === "installed") return Installed; + if (status === "downloading" && downloadState) { + return {downloadProgressLabel(downloadState)}; + } + if (status === "paused" && downloadState) { + return {downloadProgressLabel(downloadState)}; + } + if (status === "failed") return Download Failed; + if (status === "incomplete") return Incomplete; + return Not installed; +} + export function ImageDiscoverTab({ combinedImageDiscoverResults, imageDiscoverSearchInput, @@ -58,11 +108,22 @@ export function ImageDiscoverTab({ onOpenExternalUrl, onRevealPath, }: ImageDiscoverTabProps) { + const [statusFilter, setStatusFilter] = useState("all"); + const filteredResults = useMemo( + () => + combinedImageDiscoverResults.filter((variant) => { + if (statusFilter === "all") return true; + return imageVariantStatus(variant, activeImageDownloads[variant.repo]) === statusFilter; + }), + [activeImageDownloads, combinedImageDiscoverResults, statusFilter], + ); + const hasActiveFilters = imageDiscoverHasActiveFilters || statusFilter !== "all"; + return (
@@ -81,7 +142,7 @@ export function ImageDiscoverTab({
-
+
+
diff --git a/src/features/images/ImageStudioTab.tsx b/src/features/images/ImageStudioTab.tsx index ab7573e..70f64c9 100644 --- a/src/features/images/ImageStudioTab.tsx +++ b/src/features/images/ImageStudioTab.tsx @@ -231,15 +231,20 @@ export function ImageStudioTab({ assessImageGenerationSafety({ width: imageWidth, height: imageHeight, - device: imageRuntimeStatus.device, + device: imageRuntimeStatus.device ?? imageRuntimeStatus.expectedDevice, deviceMemoryGb: imageRuntimeStatus.deviceMemoryGb, baseModelFootprintGb: selectedImageVariant?.sizeGb, + repo: selectedImageVariant?.repo, + ggufFile: selectedImageVariant?.ggufFile, }), [ imageWidth, imageHeight, imageRuntimeStatus.device, + imageRuntimeStatus.expectedDevice, imageRuntimeStatus.deviceMemoryGb, + selectedImageVariant?.repo, + selectedImageVariant?.ggufFile, selectedImageVariant?.sizeGb, ], ); diff --git a/src/features/video/VideoDiscoverTab.tsx b/src/features/video/VideoDiscoverTab.tsx index 475bdb3..e9f2ade 100644 --- a/src/features/video/VideoDiscoverTab.tsx +++ b/src/features/video/VideoDiscoverTab.tsx @@ -1,4 +1,4 @@ -import { useEffect } from "react"; +import { useEffect, useMemo, useState } from "react"; import { InstallLogPanel } from "../../components/InstallLogPanel"; import { Panel } from "../../components/Panel"; import type { DownloadStatus, InstallResult, LongLiveJobState } from "../../api"; @@ -14,17 +14,19 @@ import { downloadSizeTooltip, formatReleaseLabel, number, + videoDiscoverMemoryEstimate, + videoDownloadStatusForVariant, videoPrimarySizeLabel, videoSecondarySizeLabel, } from "../../utils"; +type MediaStatusFilter = "all" | "installed" | "not-installed" | "downloading" | "paused" | "failed" | "incomplete"; + // LongLive ships via a dedicated Python installer (isolated venv + GitHub // clone + HF weights at Efficient-Large-Model/LongLive-1.3B), not via // snapshot_download. The catalog repo id ``NVlabs/LongLive-1.3B`` is the // GitHub org and intentionally does not resolve on Hugging Face — we use -// it purely as a routing key. Detect LongLive here so the Discover card -// can swap the Download button for an Install LongLive CTA that matches -// the Studio tab's existing install affordance. +// it purely as a routing key. function isLongLiveRepo(repo: string | undefined): boolean { return repo?.startsWith("NVlabs/LongLive") ?? false; } @@ -44,13 +46,10 @@ export interface VideoDiscoverTabProps { fileRevealLabel: string; longLiveStatus: VideoRuntimeStatus | null; installingLongLive: boolean; - // Live LongLive install job — same async-poll job as VideoStudioTab so - // either tab's "Install LongLive" button drives the same backend - // worker and renders the same per-phase terminal output. longLiveJob: LongLiveJobState | null; onActiveTabChange: (tab: TabId) => void; onOpenVideoStudio: (modelId?: string) => void; - onVideoDownload: (repo: string) => void; + onVideoDownload: (repo: string, modelId?: string) => void; onCancelVideoDownload: (repo: string) => void; onDeleteVideoDownload: (repo: string) => void; onOpenExternalUrl: (url: string) => void; @@ -59,6 +58,51 @@ export interface VideoDiscoverTabProps { onInstallLongLive: () => Promise; } +function videoDiscoverSortLabel(sort: DiscoverSort): string { + if (sort === "size") return "largest size first"; + if (sort === "ram") return "highest RAM/VRAM first"; + if (sort === "likes") return "most liked first"; + if (sort === "downloads") return "most downloads first"; + return "newest released first"; +} + +function sortIndicator(activeSort: DiscoverSort, key: DiscoverSort): string { + return activeSort === key ? " \u25BC" : ""; +} + +function videoVariantStatus( + variant: VideoModelVariant, + downloadState: DownloadStatus | undefined, + longLiveReady: boolean, + installingLongLive: boolean, +): MediaStatusFilter { + if (isLongLiveRepo(variant.repo)) { + if (longLiveReady) return "installed"; + if (installingLongLive) return "downloading"; + return "not-installed"; + } + if (variant.availableLocally || downloadState?.state === "completed") return "installed"; + if (downloadState?.state === "downloading") return "downloading"; + if (downloadState?.state === "cancelled") return "paused"; + if (downloadState?.state === "failed") return "failed"; + if (variant.hasLocalData) return "incomplete"; + return "not-installed"; +} + +function statusBadge(status: MediaStatusFilter, downloadState?: DownloadStatus, longLiveInstalling = false) { + if (status === "installed") return Installed; + if (longLiveInstalling) return Installing…; + if (status === "downloading" && downloadState) { + return {downloadProgressLabel(downloadState)}; + } + if (status === "paused" && downloadState) { + return {downloadProgressLabel(downloadState)}; + } + if (status === "failed") return Download Failed; + if (status === "incomplete") return Incomplete; + return Not installed; +} + export function VideoDiscoverTab({ combinedVideoDiscoverResults, videoDiscoverSearchInput, @@ -85,22 +129,31 @@ export function VideoDiscoverTab({ onRefreshLongLiveStatus, onInstallLongLive, }: VideoDiscoverTabProps) { - // Probe LongLive install state whenever the results include a LongLive - // variant so the card can render "Installed" vs "Install LongLive" - // without the user having to open the Studio tab first. Mirrors the - // same effect in VideoStudioTab. const hasLongLiveVariant = combinedVideoDiscoverResults.some((variant) => isLongLiveRepo(variant.repo), ); useEffect(() => { if (hasLongLiveVariant) onRefreshLongLiveStatus(); }, [hasLongLiveVariant, onRefreshLongLiveStatus]); + + const [statusFilter, setStatusFilter] = useState("all"); const longLiveReady = longLiveStatus?.realGenerationAvailable ?? false; + const filteredResults = useMemo( + () => + combinedVideoDiscoverResults.filter((variant) => { + if (statusFilter === "all") return true; + const downloadState = videoDownloadStatusForVariant(activeVideoDownloads, variant); + return videoVariantStatus(variant, downloadState, longLiveReady, installingLongLive) === statusFilter; + }), + [activeVideoDownloads, combinedVideoDiscoverResults, installingLongLive, longLiveReady, statusFilter], + ); + const hasActiveFilters = videoDiscoverHasActiveFilters || statusFilter !== "all"; + return (
@@ -143,6 +196,22 @@ export function VideoDiscoverTab({ +
)}
diff --git a/src/features/video/VideoModelsTab.tsx b/src/features/video/VideoModelsTab.tsx index 83760b3..cb2e197 100644 --- a/src/features/video/VideoModelsTab.tsx +++ b/src/features/video/VideoModelsTab.tsx @@ -1,3 +1,4 @@ +import { useMemo, useState } from "react"; import { Panel } from "../../components/Panel"; import type { DownloadStatus } from "../../api"; import type { @@ -6,7 +7,18 @@ import type { VideoModelVariant, VideoRuntimeStatus, } from "../../types"; -import { downloadProgressLabel, formatReleaseLabel, number, videoPrimarySizeLabel } from "../../utils"; +import { + downloadProgressLabel, + formatReleaseLabel, + number, + videoDiscoverMemoryEstimate, + videoDownloadStatusForVariant, + videoPrimarySizeLabel, + videoSecondarySizeLabel, +} from "../../utils"; + +type InstalledVideoSort = "date" | "size" | "ram" | "name"; +type InstalledVideoStatusFilter = "all" | "loaded" | "installed" | "incomplete" | "downloading" | "paused" | "failed"; export interface VideoModelsTabProps { installedVideoVariants: VideoModelVariant[]; @@ -19,7 +31,7 @@ export interface VideoModelsTabProps { fileRevealLabel: string; onActiveTabChange: (tab: TabId) => void; onOpenVideoStudio: (modelId?: string) => void; - onVideoDownload: (repo: string) => void; + onVideoDownload: (repo: string, modelId?: string) => void; onCancelVideoDownload: (repo: string) => void; onDeleteVideoDownload: (repo: string) => void; onPreloadVideoModel: (variant: VideoModelVariant) => void; @@ -28,6 +40,60 @@ export interface VideoModelsTabProps { onRevealPath: (path: string) => void; } +function releaseSortKey(variant: VideoModelVariant): string { + return variant.releaseDate ?? variant.createdAt ?? variant.lastModified ?? ""; +} + +function sizeSortKey(variant: VideoModelVariant): number | null { + const candidates = [variant.onDiskGb, variant.coreWeightsGb, variant.repoSizeGb, variant.sizeGb]; + for (const value of candidates) { + if (typeof value === "number" && Number.isFinite(value) && value > 0) return value; + } + return null; +} + +function compareNullableNumberDesc(left: number | null, right: number | null): number { + const leftKnown = typeof left === "number" && Number.isFinite(left); + const rightKnown = typeof right === "number" && Number.isFinite(right); + if (leftKnown && rightKnown) return (right as number) - (left as number); + if (leftKnown) return -1; + if (rightKnown) return 1; + return 0; +} + +function videoStatus( + variant: VideoModelVariant, + downloadState: DownloadStatus | undefined, + loadedVideoVariant: VideoModelVariant | null, +): InstalledVideoStatusFilter { + if (loadedVideoVariant?.id === variant.id) return "loaded"; + if (downloadState?.state === "downloading") return "downloading"; + if (downloadState?.state === "cancelled") return "paused"; + if (downloadState?.state === "failed") return "failed"; + if (variant.availableLocally || downloadState?.state === "completed") return "installed"; + return "incomplete"; +} + +function statusBadge(status: InstalledVideoStatusFilter, downloadState?: DownloadStatus) { + if (status === "loaded") return In Memory; + if (status === "installed") return Installed; + if (status === "downloading" && downloadState) return {downloadProgressLabel(downloadState)}; + if (status === "paused" && downloadState) return {downloadProgressLabel(downloadState)}; + if (status === "failed") return Download Failed; + return Incomplete; +} + +function sortIndicator(activeSort: InstalledVideoSort, key: InstalledVideoSort): string { + return activeSort === key ? " \u25BC" : ""; +} + +function sortLabel(sort: InstalledVideoSort): string { + if (sort === "size") return "largest size first"; + if (sort === "ram") return "highest RAM/VRAM first"; + if (sort === "name") return "name A-Z"; + return "newest released first"; +} + export function VideoModelsTab({ installedVideoVariants, videoCatalog, @@ -47,12 +113,62 @@ export function VideoModelsTab({ onOpenExternalUrl, onRevealPath, }: VideoModelsTabProps) { + const [searchInput, setSearchInput] = useState(""); + const [taskFilter, setTaskFilter] = useState<"all" | VideoModelVariant["taskSupport"][number]>("all"); + const [statusFilter, setStatusFilter] = useState("all"); + const [sort, setSort] = useState("date"); + const normalizedSearch = searchInput.trim().toLowerCase(); + const hasActiveFilters = + normalizedSearch.length > 0 || taskFilter !== "all" || statusFilter !== "all" || sort !== "date"; + + const rows = useMemo(() => { + return installedVideoVariants + .map((variant) => { + const family = videoCatalog.find((item) => + item.variants.some((candidate) => candidate.id === variant.id), + ); + const downloadState = videoDownloadStatusForVariant(activeVideoDownloads, variant); + const status = videoStatus(variant, downloadState, loadedVideoVariant); + const memoryEstimate = videoDiscoverMemoryEstimate(variant); + return { variant, family, downloadState, status, memoryEstimate }; + }) + .filter(({ variant, family, status }) => { + if (taskFilter !== "all" && !variant.taskSupport.includes(taskFilter)) return false; + if (statusFilter !== "all" && status !== statusFilter) return false; + if (!normalizedSearch) return true; + const haystack = [ + variant.name, + variant.provider, + variant.repo, + variant.runtime, + family?.name ?? "", + variant.recommendedResolution, + variant.styleTags.join(" "), + variant.taskSupport.join(" "), + ].join(" ").toLowerCase(); + return haystack.includes(normalizedSearch); + }) + .sort((left, right) => { + if (sort === "name") return left.variant.name.localeCompare(right.variant.name); + if (sort === "size") { + const diff = compareNullableNumberDesc(sizeSortKey(left.variant), sizeSortKey(right.variant)); + if (diff !== 0) return diff; + } else if (sort === "ram") { + const diff = compareNullableNumberDesc(left.memoryEstimate?.estimatedPeakGb ?? null, right.memoryEstimate?.estimatedPeakGb ?? null); + if (diff !== 0) return diff; + } + const dateDiff = releaseSortKey(right.variant).localeCompare(releaseSortKey(left.variant)); + if (dateDiff !== 0) return dateDiff; + return left.variant.name.localeCompare(right.variant.name); + }); + }, [activeVideoDownloads, installedVideoVariants, loadedVideoVariant, normalizedSearch, sort, statusFilter, taskFilter, videoCatalog]); + return (
0 - ? `${installedVideoVariants.length} model${installedVideoVariants.length !== 1 ? "s" : ""} with local data` + ? `${rows.length} of ${installedVideoVariants.length} model${installedVideoVariants.length !== 1 ? "s" : ""} with local data` : "No video models detected locally yet"} className="span-2" actions={ @@ -66,123 +182,217 @@ export function VideoModelsTab({

Download a video model from Video Discover to get started.

) : ( -
- {installedVideoVariants.map((variant) => { - const family = videoCatalog.find((item) => - item.variants.some((candidate) => candidate.id === variant.id), - ); - const isComplete = variant.availableLocally; - const isPartial = !isComplete && variant.hasLocalData; - const downloadState = activeVideoDownloads[variant.repo]; - const isDownloading = downloadState?.state === "downloading"; - const isPaused = downloadState?.state === "cancelled"; - const isDownloadComplete = downloadState?.state === "completed"; - const isDownloadFailed = downloadState?.state === "failed"; - const canDeleteLocalData = Boolean( - isComplete || isDownloadComplete || isPaused || isDownloadFailed || isPartial, - ); - const isLoadedInMemory = loadedVideoVariant?.id === variant.id; - const canPreload = isComplete && videoRuntimeStatus.realGenerationAvailable && !isLoadedInMemory; - return ( -
-
-
-

{variant.name}

-

{family?.name ?? variant.provider}

-
- {isLoadedInMemory ? ( - In Memory - ) : isComplete || isDownloadComplete ? ( - Installed - ) : isDownloading ? ( - {downloadProgressLabel(downloadState)} - ) : isPaused ? ( - {downloadProgressLabel(downloadState)} - ) : isDownloadFailed ? ( - Download Failed - ) : isPartial ? ( - Incomplete - ) : null} -
-
- {videoPrimarySizeLabel(variant)} - {variant.recommendedResolution} - {number(variant.defaultDurationSeconds)}s clip - {formatReleaseLabel(variant.releaseLabel, variant.releaseDate) ? ( - {formatReleaseLabel(variant.releaseLabel, variant.releaseDate)} - ) : null} - {variant.styleTags.slice(0, 3).map((tag) => ( - {tag} - ))} -
- {isDownloadFailed && downloadState?.error ? ( -

{downloadState.error}

- ) : null} -
- {isComplete || isDownloadComplete ? ( - - ) : isDownloading ? ( - - ) : isPaused ? ( - - ) : ( - - )} - {canPreload ? ( - - ) : null} - {isLoadedInMemory ? ( - - ) : null} - {isDownloading || canDeleteLocalData ? ( - - ) : null} - {variant.localPath ? ( - - ) : null} - -
-
- ); - })} -
+ <> +
+ + + + +
+ +
+
+
+ {rows.length} model{rows.length !== 1 ? "s" : ""} · {sortLabel(sort)} + {normalizedSearch ? Search: {searchInput.trim()} : null} + {taskFilter !== "all" ? Task: {taskFilter} : null} + {statusFilter !== "all" ? Status: {statusFilter} : null} +
+ {rows.length === 0 ? ( +
+

No installed video models match the current filters.

+
+ ) : ( +
+
+ + Provider + Tasks + + + Spec + + Status + +
+
+ {rows.map(({ variant, family, downloadState, status, memoryEstimate }) => { + const isLoadedInMemory = status === "loaded"; + const isComplete = status === "loaded" || status === "installed"; + const isDownloading = status === "downloading"; + const isPaused = status === "paused"; + const isDownloadFailed = status === "failed"; + const isPartial = status === "incomplete"; + const canDeleteLocalData = Boolean(isComplete || isPaused || isDownloadFailed || isPartial); + const localStatusReason = !isComplete && !isDownloading ? variant.localStatusReason : null; + const canPreload = isComplete && videoRuntimeStatus.realGenerationAvailable && !isLoadedInMemory; + const secondarySize = videoSecondarySizeLabel(variant); + const releaseLabel = formatReleaseLabel(variant.releaseLabel, variant.releaseDate ?? variant.createdAt); + return ( +
+
+
+ {variant.name} + {family?.name ?? variant.provider} +
+ {variant.styleTags.slice(0, 4).map((tag) => ( + {tag} + ))} +
+
+ {variant.provider} +
+ {variant.taskSupport.map((task) => ( + {task} + ))} +
+ + {videoPrimarySizeLabel(variant)} + {secondarySize ? {secondarySize} : null} + + + {memoryEstimate?.label ?? "pending"} + + + {variant.recommendedResolution} + {number(variant.defaultDurationSeconds)}s clip + + {releaseLabel ?? "Unknown"} + {statusBadge(status, downloadState)} +
+ {isComplete ? ( + + ) : isDownloading ? ( + + ) : ( + + )} + {canPreload ? ( + + ) : null} + {isLoadedInMemory ? ( + + ) : null} + {isDownloading || canDeleteLocalData ? ( + + ) : null} + {variant.localPath ? ( + + ) : null} + +
+
+ {isDownloadFailed && downloadState?.error ? ( +
+

{downloadState.error}

+
+ ) : localStatusReason ? ( +
+

{localStatusReason}

+
+ ) : null} +
+ ); + })} +
+
+ )} + )}
diff --git a/src/features/video/VideoStudioTab.tsx b/src/features/video/VideoStudioTab.tsx index 26448ed..825bb82 100644 --- a/src/features/video/VideoStudioTab.tsx +++ b/src/features/video/VideoStudioTab.tsx @@ -15,6 +15,7 @@ import { defaultVideoVariantForFamily, downloadProgressLabel, number, + videoDownloadStatusForVariant, videoPrimarySizeLabel, videoSecondarySizeLabel, } from "../../utils"; @@ -68,7 +69,7 @@ export interface VideoStudioTabProps { onActiveTabChange: (tab: TabId) => void; onPreloadVideoModel: (variant: VideoModelVariant) => void; onUnloadVideoModel: (variant?: VideoModelVariant) => void; - onVideoDownload: (repo: string) => void; + onVideoDownload: (repo: string, modelId?: string) => void; onGenerateVideo: () => void; onOpenExternalUrl: (url: string) => void; onRestartServer: () => void; @@ -285,6 +286,7 @@ export function VideoStudioTab({ const mp4EncoderMissing = missingDependencies.some( (dep) => dep === "imageio" || dep === "imageio-ffmpeg", ); + const gpuBundleRestartRequired = gpuBundleJob?.phase === "done" && gpuBundleJob.requiresRestart; // Tokenizer / text-encoder packages individual pipelines need lazily — // tiktoken for LTX-Video, sentencepiece for Wan / HunyuanVideo / CogVideoX // / Mochi, plus the protobuf + ftfy support libs. We list them out as a @@ -343,7 +345,7 @@ export function VideoStudioTab({ variants: family.variants.filter((variant) => { if (variant.availableLocally) return true; if (variant.hasLocalData) return true; - const downloadState = activeVideoDownloads[variant.repo]; + const downloadState = videoDownloadStatusForVariant(activeVideoDownloads, variant); return downloadState?.state === "downloading" || downloadState?.state === "completed"; }), })) @@ -426,7 +428,7 @@ export function VideoStudioTab({ && !(mlxVideoStatus.missingDependencies ?? []).includes("mlx-video"); const downloadState = useMemo( - () => (selectedVideoVariant ? activeVideoDownloads[selectedVideoVariant.repo] : undefined), + () => (selectedVideoVariant ? videoDownloadStatusForVariant(activeVideoDownloads, selectedVideoVariant) : undefined), [activeVideoDownloads, selectedVideoVariant], ); const isDownloading = downloadState?.state === "downloading"; @@ -457,6 +459,8 @@ export function VideoStudioTab({ ? "Choose a video model first." : !isDownloaded ? `${selectedVideoVariant.name} is not installed locally yet.` + : gpuBundleRestartRequired + ? "Restart the backend to activate the newly installed GPU runtime before generating." : !selectedVideoRuntimeStatus.realGenerationAvailable ? (selectedVideoRuntimeStatus.message || "Video runtime is not ready.") : !hasPrompt @@ -599,6 +603,9 @@ export function VideoStudioTab({ {videoRuntimeStatus.realGenerationAvailable ? "Real engine ready" : "Fallback active"} + {gpuBundleRestartRequired ? ( + Restart required + ) : null} Engine: {videoRuntimeStatus.activeEngine} {/* Prefer the actual-loaded device; fall back to the predicted * expectedDevice computed via nvidia-smi + find_spec (no torch @@ -733,36 +740,31 @@ export function VideoStudioTab({
) : null} - {!videoRuntimeStatus.realGenerationAvailable ? ( + {gpuBundleRestartRequired ? ( + <> +
+

+ GPU runtime installed to{" "} + {gpuBundleJob.targetDir ?? "extras"}. The running backend + still has its old import cache — click Restart Backend to activate the + new runtime, then video generation will use it. +

+
+ +
+
+ + + ) : !videoRuntimeStatus.realGenerationAvailable ? ( <>
- {/* Same post-install-awaiting-restart branch Image Studio - * uses. After a successful GPU bundle install, the - * running backend still can't see the new torch in - * extras (PYTHONPATH is snapshotted at spawn). Nudge - * the user toward Restart Backend instead of asking - * them to install again. */} - {gpuBundleJob?.phase === "done" && gpuBundleJob.requiresRestart ? ( - <> -

- GPU runtime installed to{" "} - {gpuBundleJob.targetDir ?? "extras"}. The running backend - still has its old import cache — click Restart Backend to activate the - new runtime, then video generation will use your GPU. -

-
- -
- - ) : ( - <>

Video generation needs the GPU runtime bundle (torch + diffusers + tokenizers, ~2.5 GB). Install it once — it writes to a persistent user-local directory so @@ -781,8 +783,6 @@ export function VideoStudioTab({ {busyAction === "Restarting server..." ? "Restarting..." : "Restart Backend"}

- - )}
@@ -800,7 +800,7 @@ export function VideoStudioTab({ > {studioFamilies.flatMap((family) => family.variants.map((variant) => { - const downloadState = activeVideoDownloads[variant.repo]; + const downloadState = videoDownloadStatusForVariant(activeVideoDownloads, variant); const isDownloadingVariant = downloadState?.state === "downloading"; const suffix = variant.availableLocally ? " (installed)" @@ -846,7 +846,9 @@ export function VideoStudioTab({ ) : isDownloading ? ( {downloadProgressLabel(downloadState)} ) : ( - Not downloaded + + {selectedVideoVariant.hasLocalData ? "Incomplete" : "Not downloaded"} + )} {selectedVideoLoaded ? In Memory : null} {videoRuntimeLoadedDifferentModel && loadedVideoVariant ? ( @@ -855,6 +857,12 @@ export function VideoStudioTab({ ) : null} + {selectedVideoVariant?.localStatusReason && !isDownloaded && !isDownloading ? ( +

+ {selectedVideoVariant.localStatusReason} +

+ ) : null} +