diff --git a/backend_service/app.py b/backend_service/app.py index e226b08..86977d7 100644 --- a/backend_service/app.py +++ b/backend_service/app.py @@ -8,25 +8,20 @@ import uuid from datetime import datetime from pathlib import Path -from typing import Any +from typing import TYPE_CHECKING, Any from fastapi import FastAPI, Request from fastapi.middleware.cors import CORSMiddleware from starlette.responses import JSONResponse -from backend_service.image_runtime import ( - ImageGenerationConfig, - ImageRuntimeManager, -) -from backend_service.video_runtime import ( - VideoGenerationConfig, - VideoRuntimeManager, - start_torch_warmup, -) from backend_service.models import ImageGenerationRequest, VideoGenerationRequest from backend_service.routes import register_routes from backend_service.state import ChaosEngineState +if TYPE_CHECKING: + from backend_service.image_runtime import ImageRuntimeManager + from backend_service.video_runtime import VideoRuntimeManager + # --------------------------------------------------------------------------- # Helper modules -- extracted from this file for maintainability. # --------------------------------------------------------------------------- @@ -121,8 +116,8 @@ # extracted signatures require them explicitly. # --------------------------------------------------------------------------- -def _build_system_snapshot() -> dict[str, Any]: - return _build_system_snapshot_impl(app_version, APP_STARTED_AT) +def _build_system_snapshot(*, capabilities: Any | None = None) -> dict[str, Any]: + return _build_system_snapshot_impl(app_version, APP_STARTED_AT, capabilities=capabilities) def _default_settings() -> dict[str, Any]: @@ -231,6 +226,7 @@ def compute_cache_preview( fp16_layers: int = 4, num_layers: int = 32, num_heads: int = 32, + num_kv_heads: int | None = None, hidden_size: int = 4096, context_tokens: int = 8192, params_b: float = 7.0, @@ -242,6 +238,7 @@ def compute_cache_preview( fp16_layers=fp16_layers, num_layers=num_layers, num_heads=num_heads, + num_kv_heads=num_kv_heads, hidden_size=hidden_size, context_tokens=context_tokens, params_b=params_b, @@ -343,6 +340,8 @@ def _generate_image_artifacts( runtime_manager: ImageRuntimeManager | None = None, ) -> tuple[list[dict[str, Any]], dict[str, Any]]: import logging + from backend_service.image_runtime import ImageGenerationConfig, ImageRuntimeManager + logger = logging.getLogger("chaosengine.images") effective_width, effective_height = ( _apply_draft_resolution(request.width, request.height) @@ -413,6 +412,8 @@ def _generate_video_artifact( HTTP error rather than a fake clip. """ import logging + from backend_service.video_runtime import VideoGenerationConfig + logger = logging.getLogger("chaosengine.video") logger.info( "Generating video: model=%s repo=%s size=%dx%d frames=%d steps=%d", @@ -489,7 +490,10 @@ def create_app( allow_methods=["GET", "POST", "PATCH", "DELETE", "OPTIONS"], allow_headers=["Accept", "Authorization", "Content-Type", "X-ChaosEngine-Token"], ) - app.state.chaosengine = state or ChaosEngineState(server_port=DEFAULT_PORT) + app.state.chaosengine = state or ChaosEngineState( + server_port=DEFAULT_PORT, + background_capability_probe=True, + ) app.state.chaosengine_api_token = _resolve_api_token(api_token) app.state.chaosengine_allowed_origins = frozenset(allowed_origins) # Bearer-token enforcement toggle. Reads from (in order) env override, diff --git a/backend_service/catalog/image_models.py b/backend_service/catalog/image_models.py index d890a46..fce458b 100644 --- a/backend_service/catalog/image_models.py +++ b/backend_service/catalog/image_models.py @@ -44,7 +44,7 @@ "taskSupport": ["txt2img"], "sizeGb": 6.8, "recommendedResolution": "1024x1024", - "note": "GGUF Q4_K_M — runs on ~8 GB VRAM / Apple Silicon with near-unchanged quality.", + "note": "GGUF Q4_K_M — quantizes the FLUX transformer; the full diffusers pipeline still carries the base text encoders/VAE in memory.", "estimatedGenerationSeconds": 5.2, "releaseDate": "2024-09", }, @@ -62,7 +62,7 @@ "taskSupport": ["txt2img"], "sizeGb": 12.7, "recommendedResolution": "1024x1024", - "note": "GGUF Q8_0 — near-bf16 quality at ~half the memory footprint.", + "note": "GGUF Q8_0 — near-bf16 transformer quality; text encoders/VAE still make the full FLUX runtime memory-heavy.", "estimatedGenerationSeconds": 4.8, "releaseDate": "2024-09", }, @@ -125,7 +125,7 @@ "taskSupport": ["txt2img"], "sizeGb": 6.8, "recommendedResolution": "1024x1024", - "note": "GGUF Q4_K_M — fits FLUX Dev on 8 GB VRAM / Apple Silicon with minimal quality loss.", + "note": "GGUF Q4_K_M — quantizes the FLUX Dev transformer; expect the full diffusers pipeline to remain memory-heavy from text encoders/VAE.", "estimatedGenerationSeconds": 9.0, "releaseDate": "2024-09", }, @@ -143,7 +143,7 @@ "taskSupport": ["txt2img"], "sizeGb": 9.9, "recommendedResolution": "1024x1024", - "note": "GGUF Q6_K — mid-point between Q4 size and Q8 quality.", + "note": "GGUF Q6_K — mid-point between Q4 size and Q8 transformer quality; FLUX text encoders/VAE still dominate runtime memory.", "estimatedGenerationSeconds": 8.4, "releaseDate": "2024-09", }, @@ -161,7 +161,7 @@ "taskSupport": ["txt2img"], "sizeGb": 12.7, "recommendedResolution": "1024x1024", - "note": "GGUF Q8_0 — near-bf16 quality at roughly half the memory.", + "note": "GGUF Q8_0 — near-bf16 transformer quality; text encoders/VAE still make the full FLUX runtime memory-heavy.", "estimatedGenerationSeconds": 7.8, "releaseDate": "2024-09", }, @@ -271,6 +271,154 @@ ] LATEST_IMAGE_TRACKED_SEEDS: list[dict[str, Any]] = [ + { + "repo": "baidu/ERNIE-Image", + "name": "ERNIE-Image", + "provider": "Baidu", + "styleTags": ["general", "detailed"], + "taskSupport": ["txt2img"], + "sizeGb": 29.43, + "runtimeFootprintGb": 24.0, + "runtimeFootprintMpsGb": 32.0, + "runtimeFootprintCpuGb": 36.0, + "coreWeightsGb": 29.43, + "repoSizeGb": 29.47, + "recommendedResolution": "1024x1024", + "note": "Tracked current text-to-image DiT release from Baidu.", + "gated": False, + "pipelineTag": "text-to-image", + "updatedLabel": "Tracked latest", + "releaseDate": "2026-04", + }, + { + "repo": "baidu/ERNIE-Image-Turbo", + "name": "ERNIE-Image-Turbo", + "provider": "Baidu", + "styleTags": ["general", "fast"], + "taskSupport": ["txt2img"], + "sizeGb": 29.43, + "runtimeFootprintGb": 24.0, + "runtimeFootprintMpsGb": 32.0, + "runtimeFootprintCpuGb": 36.0, + "coreWeightsGb": 29.43, + "repoSizeGb": 29.47, + "recommendedResolution": "1024x1024", + "note": "Tracked faster ERNIE-Image lane for current local image generation discovery.", + "gated": False, + "pipelineTag": "text-to-image", + "updatedLabel": "Tracked latest", + "releaseDate": "2026-04", + }, + { + "repo": "NucleusAI/Nucleus-Image", + "name": "Nucleus-Image", + "provider": "NucleusAI", + "styleTags": ["general", "detailed"], + "taskSupport": ["txt2img"], + "sizeGb": 48.09, + "runtimeFootprintGb": 48.0, + "runtimeFootprintMpsGb": 55.0, + "runtimeFootprintCpuGb": 60.0, + "coreWeightsGb": 48.09, + "repoSizeGb": 48.11, + "recommendedResolution": "1024x1024", + "note": "Tracked current diffusers-compatible text-to-image release.", + "gated": False, + "pipelineTag": "text-to-image", + "updatedLabel": "Tracked latest", + "releaseDate": "2026-04", + }, + { + "repo": "black-forest-labs/FLUX.2-dev", + "name": "FLUX.2 Dev", + "provider": "Black Forest Labs", + "styleTags": ["general", "detailed", "flux"], + "taskSupport": ["txt2img", "img2img"], + "sizeGb": 64.7, + "runtimeFootprintGb": 65.0, + "runtimeFootprintMpsGb": 78.0, + "runtimeFootprintCpuGb": 90.0, + "recommendedResolution": "1024x1024", + "note": "Tracked FLUX.2 generation-and-editing release.", + "gated": True, + "pipelineTag": "image-to-image", + "updatedLabel": "Tracked latest", + "releaseDate": "2026-02", + }, + { + "repo": "black-forest-labs/FLUX.2-klein-9B", + "name": "FLUX.2 Klein 9B", + "provider": "Black Forest Labs", + "styleTags": ["general", "flux", "fast"], + "taskSupport": ["txt2img", "img2img"], + "sizeGb": 49.23, + "runtimeFootprintGb": 49.0, + "runtimeFootprintMpsGb": 55.0, + "runtimeFootprintCpuGb": 64.0, + "coreWeightsGb": 49.23, + "repoSizeGb": 49.26, + "recommendedResolution": "1024x1024", + "note": "Tracked smaller FLUX.2 lane.", + "gated": False, + "pipelineTag": "image-to-image", + "updatedLabel": "Tracked latest", + "releaseDate": "2026-02", + }, + { + "repo": "Tongyi-MAI/Z-Image-Turbo", + "name": "Z-Image-Turbo", + "provider": "Tongyi-MAI", + "styleTags": ["general", "fast"], + "taskSupport": ["txt2img"], + "sizeGb": 30.58, + "runtimeFootprintGb": 16.0, + "runtimeFootprintMpsGb": 20.0, + "runtimeFootprintCpuGb": 24.0, + "coreWeightsGb": 30.58, + "repoSizeGb": 30.64, + "recommendedResolution": "1024x1024", + "note": "Tracked current Z-Image turbo text-to-image release.", + "gated": False, + "pipelineTag": "text-to-image", + "updatedLabel": "Tracked latest", + "releaseDate": "2026-01", + }, + { + "repo": "Tongyi-MAI/Z-Image", + "name": "Z-Image", + "provider": "Tongyi-MAI", + "styleTags": ["general", "detailed"], + "taskSupport": ["txt2img"], + "sizeGb": 19.11, + "runtimeFootprintGb": 22.0, + "runtimeFootprintMpsGb": 24.0, + "runtimeFootprintCpuGb": 30.0, + "coreWeightsGb": 19.11, + "repoSizeGb": 19.14, + "recommendedResolution": "1024x1024", + "note": "Tracked current Z-Image text-to-image release.", + "gated": False, + "pipelineTag": "text-to-image", + "updatedLabel": "Tracked latest", + "releaseDate": "2026-01", + }, + { + "repo": "Qwen/Qwen-Image-Edit-2511", + "name": "Qwen-Image-Edit-2511", + "provider": "Qwen", + "styleTags": ["edit", "qwenimage", "general"], + "taskSupport": ["img2img"], + "sizeGb": 57.7, + "runtimeFootprintGb": 58.0, + "runtimeFootprintMpsGb": 72.0, + "runtimeFootprintCpuGb": 72.0, + "recommendedResolution": "1024x1024", + "note": "Tracked newer Qwen image editing release with improved consistency.", + "gated": False, + "pipelineTag": "image-to-image", + "updatedLabel": "Tracked latest", + "releaseDate": "2025-12", + }, { "repo": "Qwen/Qwen-Image", "name": "Qwen-Image", @@ -278,6 +426,9 @@ "styleTags": ["general", "detailed", "qwenimage"], "taskSupport": ["txt2img"], "sizeGb": 57.7, + "runtimeFootprintGb": 58.0, + "runtimeFootprintMpsGb": 72.0, + "runtimeFootprintCpuGb": 72.0, "recommendedResolution": "1024x1024", "note": "Tracked diffusers-native Qwen image generation family.", "gated": False, @@ -292,6 +443,9 @@ "styleTags": ["edit", "qwenimage", "general"], "taskSupport": ["img2img"], "sizeGb": 57.7, + "runtimeFootprintGb": 58.0, + "runtimeFootprintMpsGb": 72.0, + "runtimeFootprintCpuGb": 72.0, "recommendedResolution": "1024x1024", "note": "Tracked Qwen edit lane so Image Discover can surface newer editing-capable models too.", "gated": False, @@ -306,6 +460,9 @@ "styleTags": ["hidream", "detailed", "quality"], "taskSupport": ["txt2img"], "sizeGb": 47.2, + "runtimeFootprintGb": 58.0, + "runtimeFootprintMpsGb": 62.0, + "runtimeFootprintCpuGb": 70.0, "recommendedResolution": "1024x1024", "note": "Tracked larger open-image generation lane from the HiDream family.", "gated": False, @@ -320,6 +477,9 @@ "styleTags": ["general", "edit", "detailed"], "taskSupport": ["txt2img", "img2img"], "sizeGb": 35.8, + "runtimeFootprintGb": 40.0, + "runtimeFootprintMpsGb": 45.0, + "runtimeFootprintCpuGb": 52.0, "recommendedResolution": "1024x1024", "note": "Tracked unified generation-and-editing lane from the GLM image family.", "gated": False, @@ -333,6 +493,9 @@ "styleTags": ["sana", "fast", "small"], "taskSupport": ["txt2img"], "sizeGb": 7.7, + "runtimeFootprintGb": 8.0, + "runtimeFootprintMpsGb": 10.0, + "runtimeFootprintCpuGb": 12.0, "recommendedResolution": "1024x1024", "note": "Tracked smaller Sana Sprint lane for faster local image generation.", "gated": False, @@ -347,6 +510,9 @@ "styleTags": ["sana", "fast", "detailed"], "taskSupport": ["txt2img"], "sizeGb": 9.74, + "runtimeFootprintGb": 10.0, + "runtimeFootprintMpsGb": 12.0, + "runtimeFootprintCpuGb": 15.0, "recommendedResolution": "1024x1024", "note": "Tracked larger Sana Sprint lane with a better quality-to-speed balance.", "gated": False, diff --git a/backend_service/catalog/video_models.py b/backend_service/catalog/video_models.py index e3509bc..9fd6773 100644 --- a/backend_service/catalog/video_models.py +++ b/backend_service/catalog/video_models.py @@ -35,6 +35,7 @@ "styleTags": ["general", "fast", "motion"], "taskSupport": ["txt2video"], "sizeGb": 2.0, + "runtimeFootprintGb": 10.0, "recommendedResolution": "768x512", "defaultDurationSeconds": 4.0, "note": "Small, fast, Apache 2.0 — best starter pick for a local video runtime.", @@ -55,9 +56,10 @@ "styleTags": ["general", "fast", "motion", "gguf"], "taskSupport": ["txt2video"], "sizeGb": 1.4, + "runtimeFootprintGb": 10.0, "recommendedResolution": "768x512", "defaultDurationSeconds": 4.0, - "note": "GGUF Q4_K_M — runs on 6-8 GB VRAM / Apple Silicon at near-native quality.", + "note": "GGUF Q4_K_M — quantizes the transformer, but the LTX pipeline still needs roughly a 10 GB runtime envelope for text encoder/VAE/buffers.", "estimatedGenerationSeconds": 50.0, "availableLocally": False, "releaseDate": "2024-12", @@ -75,6 +77,7 @@ "styleTags": ["general", "motion", "quality", "gguf"], "taskSupport": ["txt2video"], "sizeGb": 1.7, + "runtimeFootprintGb": 10.0, "recommendedResolution": "768x512", "defaultDurationSeconds": 4.0, "note": "GGUF Q6_K — mid-point between Q4 footprint and Q8 fidelity.", @@ -95,6 +98,7 @@ "styleTags": ["general", "motion", "quality", "gguf"], "taskSupport": ["txt2video"], "sizeGb": 2.2, + "runtimeFootprintGb": 10.0, "recommendedResolution": "768x512", "defaultDurationSeconds": 4.0, "note": "GGUF Q8_0 — near-bf16 quality at roughly half the memory.", @@ -129,6 +133,7 @@ "styleTags": ["general", "fast", "motion", "mlx"], "taskSupport": ["txt2video"], "sizeGb": 19.0, + "runtimeFootprintGb": 27.0, "recommendedResolution": "768x512", "defaultDurationSeconds": 4.0, "note": "Distilled LTX-2 — fastest MLX path for previews. Use the dev variant for final fidelity.", @@ -147,6 +152,7 @@ "styleTags": ["general", "quality", "motion", "mlx"], "taskSupport": ["txt2video"], "sizeGb": 19.0, + "runtimeFootprintGb": 27.0, "recommendedResolution": "768x512", "defaultDurationSeconds": 4.0, "note": "Full LTX-2 dev weights — higher fidelity, longer sampling than distilled.", @@ -160,11 +166,13 @@ "name": "LTX-2.3 · distilled (MLX)", "provider": "Lightricks · prince-canuma", "repo": "prince-canuma/LTX-2.3-distilled", + "textEncoderRepo": "prince-canuma/LTX-2-distilled", "link": "https://huggingface.co/prince-canuma/LTX-2.3-distilled", "runtime": "mlx-video (MLX native)", "styleTags": ["general", "fast", "motion", "mlx"], "taskSupport": ["txt2video"], "sizeGb": 19.0, + "runtimeFootprintGb": 27.0, "recommendedResolution": "768x512", "defaultDurationSeconds": 4.0, "note": "LTX-2.3 distilled — refreshed fast preview path with sharper texture detail vs LTX-2. Use the dev variant for final fidelity.", @@ -178,11 +186,13 @@ "name": "LTX-2.3 · dev (MLX)", "provider": "Lightricks · prince-canuma", "repo": "prince-canuma/LTX-2.3-dev", + "textEncoderRepo": "prince-canuma/LTX-2-distilled", "link": "https://huggingface.co/prince-canuma/LTX-2.3-dev", "runtime": "mlx-video (MLX native)", "styleTags": ["general", "quality", "motion", "mlx"], "taskSupport": ["txt2video"], "sizeGb": 19.0, + "runtimeFootprintGb": 27.0, "recommendedResolution": "768x512", "defaultDurationSeconds": 4.0, "note": "LTX-2.3 dev — quality tier; full sampler steps for best output. Apple Silicon native via MLX. Install mlx-video from Setup → GPU runtime bundle to enable.", @@ -222,6 +232,7 @@ # Resident peak ~14 GB during text encoding (UMT5-XXL bf16); # drops to ~4 GB during diffusion when encoder is freed. "runtimeFootprintGb": 14.0, + "runtimeFootprintMpsGb": 23.0, "recommendedResolution": "832x480", "defaultDurationSeconds": 4.0, "note": "1.3B transformer + UMT5 text encoder. ~16GB on disk. Best starter pick for trying local video end-to-end on modest hardware.", @@ -267,9 +278,10 @@ # ~0.9 GB GGUF transformer + ~14 GB shared UMT5-XXL/VAE base. "sizeGb": 14.9, "runtimeFootprintGb": 12.5, # Q4_K_M trans (~0.9 GB) + UMT5 (~11 GB) + "runtimeFootprintMpsGb": 21.5, "recommendedResolution": "832x480", "defaultDurationSeconds": 4.0, - "note": "Q4_K_M — smallest quantized 1.3B; runs in <8 GB unified memory once base is cached.", + "note": "Q4_K_M — smallest quantized 1.3B. The transformer is tiny, but UMT5/VAE keep the full runtime envelope in the 12-23 GB range depending on device/offload.", "estimatedGenerationSeconds": 70.0, "availableLocally": False, "releaseDate": "2025-03", @@ -287,9 +299,11 @@ "styleTags": ["general", "fast", "small", "gguf"], "taskSupport": ["txt2video"], "sizeGb": 15.2, + "runtimeFootprintGb": 13.2, + "runtimeFootprintMpsGb": 22.0, "recommendedResolution": "832x480", "defaultDurationSeconds": 4.0, - "note": "Q6_K — mid-point between Q4 footprint and Q8 fidelity.", + "note": "Q6_K — mid-point between Q4 footprint and Q8 fidelity; still carries the shared UMT5/VAE runtime overhead.", "estimatedGenerationSeconds": 68.0, "availableLocally": False, "releaseDate": "2025-03", @@ -307,9 +321,11 @@ "styleTags": ["general", "quality", "small", "gguf"], "taskSupport": ["txt2video"], "sizeGb": 15.5, + "runtimeFootprintGb": 13.8, + "runtimeFootprintMpsGb": 22.5, "recommendedResolution": "832x480", "defaultDurationSeconds": 4.0, - "note": "Q8_0 — near-bf16 quality at roughly half the transformer footprint.", + "note": "Q8_0 — near-bf16 quality at roughly half the transformer footprint; shared UMT5/VAE still dominate peak RAM/VRAM.", "estimatedGenerationSeconds": 65.0, "availableLocally": False, "releaseDate": "2025-03", @@ -329,6 +345,8 @@ # ~7 GB GGUF transformer + ~14 GB shared UMT5-XXL/VAE — fits # comfortably on a 24 GB RTX 4090 with VAE headroom. "sizeGb": 21.0, + "runtimeFootprintGb": 18.0, + "runtimeFootprintMpsGb": 27.0, "recommendedResolution": "832x480", "defaultDurationSeconds": 5.0, "note": "Q4_K_M — unlocks Wan 2.1 14B on 24 GB VRAM (RTX 4090) without bnb.", @@ -349,6 +367,8 @@ "styleTags": ["general", "quality", "motion", "gguf"], "taskSupport": ["txt2video"], "sizeGb": 24.0, + "runtimeFootprintGb": 21.0, + "runtimeFootprintMpsGb": 30.0, "recommendedResolution": "832x480", "defaultDurationSeconds": 5.0, "note": "Q6_K — mid-point between Q4 footprint and Q8 fidelity.", @@ -369,6 +389,8 @@ "styleTags": ["general", "quality", "motion", "gguf"], "taskSupport": ["txt2video"], "sizeGb": 28.0, + "runtimeFootprintGb": 25.0, + "runtimeFootprintMpsGb": 34.0, "recommendedResolution": "832x480", "defaultDurationSeconds": 5.0, "note": "Q8_0 — near-bf16 quality at roughly half the transformer footprint.", @@ -415,6 +437,7 @@ # over-estimates resident because the repo carries duplicate # sharded safetensors + tokenizer caches. "runtimeFootprintGb": 22.0, + "runtimeFootprintMpsGb": 24.0, "recommendedResolution": "832x480", "defaultDurationSeconds": 5.0, "note": "Best Wan 2.2 pick for consumer hardware. 24 GB on disk, runs on a 24 GB GPU or a 32 GB+ Mac.", @@ -440,6 +463,7 @@ "sizeGb": 17.5, # GGUF Q4_K_M trans (~3.5 GB) + UMT5-XXL during encode (~11 GB). "runtimeFootprintGb": 14.5, + "runtimeFootprintMpsGb": 22.0, "recommendedResolution": "832x480", "defaultDurationSeconds": 5.0, "note": "Q4_K_M — smallest Wan 2.2 that still generates usable quality. Best fit for 16 GB unified memory.", @@ -461,6 +485,7 @@ "taskSupport": ["txt2video"], "sizeGb": 18.2, "runtimeFootprintGb": 16.5, # Q6_K trans ~5 GB + UMT5 ~11 GB + "runtimeFootprintMpsGb": 24.0, "recommendedResolution": "832x480", "defaultDurationSeconds": 5.0, "note": "Q6_K — mid-point between Q4 footprint and Q8 fidelity.", @@ -482,6 +507,7 @@ "taskSupport": ["txt2video"], "sizeGb": 19.0, "runtimeFootprintGb": 18.0, # Q8 trans ~7 GB + UMT5 ~11 GB + "runtimeFootprintMpsGb": 26.0, "recommendedResolution": "832x480", "defaultDurationSeconds": 5.0, "note": "Q8_0 — near-bf16 quality at roughly half the transformer footprint.", @@ -518,6 +544,7 @@ # bogus "needs 176 GB" warning, but the note flags that the # offload mode is required. "runtimeFootprintGb": 30.0, + "runtimeFootprintMpsGb": 36.0, "recommendedResolution": "832x480", "defaultDurationSeconds": 5.0, "note": ( @@ -591,9 +618,10 @@ "styleTags": ["general", "motion", "balanced"], "taskSupport": ["txt2video"], "sizeGb": 10.0, + "runtimeFootprintGb": 22.0, "recommendedResolution": "848x480", "defaultDurationSeconds": 5.4, - "note": "Apache 2.0, balanced footprint, strong motion quality.", + "note": "Apache 2.0, balanced footprint, strong motion quality. Diffusers' bf16/offload path lands around a 22 GB runtime envelope.", "estimatedGenerationSeconds": 150.0, "availableLocally": False, "releaseDate": "2024-10", @@ -629,9 +657,10 @@ # CPU-offload tricks. Smaller than Wan 2.1 1.3B because there's # no UMT5-XXL — just the standard T5. "sizeGb": 9.0, + "runtimeFootprintGb": 19.0, "recommendedResolution": "720x480", "defaultDurationSeconds": 6.0, - "note": "Smallest CogVideoX. Apache 2.0 weights, ~9 GB on disk, runs on consumer GPUs.", + "note": "Smallest CogVideoX. Apache 2.0 weights, ~9 GB on disk; runtime peak is closer to 19 GB without the most aggressive offload/tiling.", "estimatedGenerationSeconds": 90.0, "availableLocally": False, "releaseDate": "2024-08", @@ -650,9 +679,10 @@ # same envelope as Wan 2.2 — needs 24 GB VRAM or 32 GB+ # unified memory. "sizeGb": 18.0, + "runtimeFootprintGb": 33.0, "recommendedResolution": "720x480", "defaultDurationSeconds": 6.0, - "note": "Quality tier. ~18 GB on disk. Same CogVideoXPipeline class as the 2B.", + "note": "Quality tier. ~18 GB on disk; budget for a 32 GB-class runtime envelope unless aggressive offload is enabled.", "estimatedGenerationSeconds": 200.0, "availableLocally": False, "releaseDate": "2024-08", diff --git a/backend_service/helpers/discovery.py b/backend_service/helpers/discovery.py index 49519cf..676e5c1 100644 --- a/backend_service/helpers/discovery.py +++ b/backend_service/helpers/discovery.py @@ -46,7 +46,7 @@ def _path_size_bytes(path: Path, *, seen: set[tuple[int, int]] | None = None) -> with iterator as entries: for entry in entries: try: - entry_stat = entry.stat(follow_symlinks=False) + entry_stat = entry.stat(follow_symlinks=True) except OSError: continue entry_id = (entry_stat.st_dev, entry_stat.st_ino) @@ -54,7 +54,7 @@ def _path_size_bytes(path: Path, *, seen: set[tuple[int, int]] | None = None) -> continue visited.add(entry_id) try: - is_dir = entry.is_dir(follow_symlinks=False) + is_dir = entry.is_dir(follow_symlinks=True) except OSError: is_dir = False if is_dir: @@ -324,6 +324,7 @@ def _detect_model_quantization(path: Path, fmt: str, *, name_hint: str = "") -> "stable-diffusion", "sdxl", "flux.", "flux1", "flux-", "dall-e", "imagen", "kandinsky", "wuerstchen", "diffusion-pipe", "qwen-image", "qwen/qwen-image", + "sana_sprint", "sana-sprint", "sana sprint", "sana_1600m", "sana-1600m", ) @@ -348,6 +349,7 @@ def _detect_model_quantization(path: Path, fmt: str, *, name_hint: str = "") -> "mochi-1", "cogvideo", "ltx-video", + "ltx-2", "zeroscope", "animatediff", ) diff --git a/backend_service/helpers/huggingface.py b/backend_service/helpers/huggingface.py index b4c192b..8c379ce 100644 --- a/backend_service/helpers/huggingface.py +++ b/backend_service/helpers/huggingface.py @@ -7,6 +7,7 @@ import urllib.error import urllib.parse import urllib.request +from hashlib import sha256 from datetime import datetime, timezone from pathlib import Path from typing import Any @@ -17,7 +18,7 @@ _HF_REPO_PATTERN = re.compile(r"^[a-zA-Z0-9_.\-]+/[a-zA-Z0-9_.\-]+$") -_HUB_FILE_CACHE: dict[str, dict[str, Any]] = {} +_HUB_FILE_CACHE: dict[tuple[str, str], dict[str, Any]] = {} _DISCOVER_SEARCH_PUNCT_RE = re.compile(r"[^a-z0-9]+") _DISCOVER_SEARCH_ALPHA_NUM_RE = re.compile(r"([a-z])(\d)|(\d)([a-z])") _TEXT_DISCOVER_PIPELINES = { @@ -30,6 +31,21 @@ _HF_QUERY_URL_HOSTS = {"huggingface.co", "www.huggingface.co", "hf.co", "www.hf.co"} +def _hf_token_value() -> str: + return str(os.environ.get("HF_TOKEN") or os.environ.get("HUGGING_FACE_HUB_TOKEN") or "").strip() + + +def _hf_token_cache_key() -> str: + token = _hf_token_value() + if not token: + return "anonymous" + return f"token:{sha256(token.encode('utf-8')).hexdigest()[:16]}" + + +def _clear_huggingface_caches() -> None: + _HUB_FILE_CACHE.clear() + + def _extract_hf_repo_id_from_query(value: str) -> str | None: text = str(value or "").strip() if not text: @@ -122,6 +138,9 @@ def _search_huggingface_hub(query: str, library: list[dict[str, Any]], limit: in }) url = f"https://huggingface.co/api/models?{params}" req = urllib.request.Request(url, headers={"User-Agent": "ChaosEngineAI/0.2.0"}) + token = _hf_token_value() + if token: + req.add_header("Authorization", f"Bearer {token}") with urllib.request.urlopen(req, timeout=8) as resp: data = json.loads(resp.read().decode()) except Exception: @@ -246,11 +265,12 @@ def _hub_repo_files(repo_id: str) -> dict[str, Any]: HUGGING_FACE_HUB_TOKEN for gated repos and degrades to a non-fatal warning on transient upstream 5xx errors. """ - cached = _HUB_FILE_CACHE.get(repo_id) + cache_key = (repo_id, _hf_token_cache_key()) + cached = _HUB_FILE_CACHE.get(cache_key) if cached is not None: return cached - token = os.environ.get("HF_TOKEN") or os.environ.get("HUGGING_FACE_HUB_TOKEN") + token = _hf_token_value() try: encoded_repo = urllib.parse.quote(repo_id, safe="/") url = f"https://huggingface.co/api/models/{encoded_repo}?blobs=true" @@ -326,7 +346,7 @@ def _hub_repo_files(repo_id: str) -> dict[str, Any]: pipeline_tag=data.get("pipeline_tag"), last_modified=data.get("lastModified"), ) - _HUB_FILE_CACHE[repo_id] = payload + _HUB_FILE_CACHE[cache_key] = payload return payload @@ -509,7 +529,7 @@ def _hf_repo_snapshot_dir(repo_id: str) -> Path | None: def _known_repo_size_gb(repo_id: str) -> float | None: - cached = _HUB_FILE_CACHE.get(repo_id) + cached = _HUB_FILE_CACHE.get((repo_id, _hf_token_cache_key())) if cached is not None: cached_total = cached.get("totalSizeGb") if isinstance(cached_total, (int, float)) and cached_total > 0: diff --git a/backend_service/helpers/images.py b/backend_service/helpers/images.py index b58f064..51fcd7d 100644 --- a/backend_service/helpers/images.py +++ b/backend_service/helpers/images.py @@ -21,15 +21,17 @@ _format_release_label, _hf_number_label, _hf_repo_snapshot_dir, + _hf_token_cache_key, + _hf_token_value, _parse_iso_datetime, ) from backend_service.helpers.discovery import _candidate_model_dirs, _path_size_bytes from backend_service.image_runtime import validate_local_diffusers_snapshot -_IMAGE_DISCOVER_METADATA_CACHE: dict[str, tuple[float, dict[str, Any]]] = {} +_IMAGE_DISCOVER_METADATA_CACHE: dict[tuple[str, str], tuple[float, dict[str, Any]]] = {} _IMAGE_DISCOVER_METADATA_TTL_SECONDS = 6 * 60 * 60 -_LATEST_IMAGE_MODELS_CACHE: tuple[float, list[dict[str, Any]]] | None = None +_LATEST_IMAGE_MODELS_CACHE: tuple[float, str, list[dict[str, Any]]] | None = None _LATEST_IMAGE_MODELS_TTL_SECONDS = 3 * 60 * 60 # Cache keyed by (path, mtime_ns) — we recompute only when the snapshot dir @@ -37,6 +39,47 @@ _SNAPSHOT_SIZE_CACHE: dict[tuple[str, int], int] = {} +def _positive_float(value: Any) -> float | None: + try: + parsed = float(value) + except (TypeError, ValueError): + return None + if parsed > 0: + return parsed + return None + + +def _positive_int(value: Any) -> int | None: + try: + parsed = int(value) + except (TypeError, ValueError): + return None + if parsed > 0: + return parsed + return None + + +def _image_seed_size_metadata(seed: dict[str, Any]) -> tuple[float, float | None, float | None]: + catalog_size_gb = _positive_float(seed.get("sizeGb")) + core_weights_gb = _positive_float(seed.get("coreWeightsGb")) or catalog_size_gb + repo_size_gb = _positive_float(seed.get("repoSizeGb")) + size_gb = core_weights_gb or repo_size_gb or catalog_size_gb or 0.0 + return float(size_gb), core_weights_gb, repo_size_gb + + +def _tracked_seed_for_repo(repo_id: str) -> dict[str, Any] | None: + for seed in LATEST_IMAGE_TRACKED_SEEDS: + if str(seed.get("repo") or "") == repo_id: + return seed + return None + + +def _clear_image_discover_caches() -> None: + global _LATEST_IMAGE_MODELS_CACHE + _IMAGE_DISCOVER_METADATA_CACHE.clear() + _LATEST_IMAGE_MODELS_CACHE = None + + def _snapshot_on_disk_bytes(snapshot_dir: Path | None) -> int | None: """Walk the HF snapshot dir and return its true on-disk byte size. @@ -166,12 +209,19 @@ def _find_image_variant(model_id: str) -> dict[str, Any] | None: for seed in LATEST_IMAGE_TRACKED_SEEDS: repo = str(seed.get("repo") or "") if repo == model_id: + size_gb, core_weights_gb, repo_size_gb = _image_seed_size_metadata(seed) return { "id": repo, "repo": repo, "name": seed.get("name") or repo.split("/", 1)[-1], "provider": seed.get("provider") or "Community", - "sizeGb": seed.get("sizeGb") or 0, + "sizeGb": size_gb, + "runtimeFootprintGb": seed.get("runtimeFootprintGb"), + "runtimeFootprintMpsGb": seed.get("runtimeFootprintMpsGb"), + "runtimeFootprintCudaGb": seed.get("runtimeFootprintCudaGb"), + "runtimeFootprintCpuGb": seed.get("runtimeFootprintCpuGb"), + "coreWeightsGb": core_weights_gb, + "repoSizeGb": repo_size_gb, "styleTags": list(seed.get("styleTags") or []), "taskSupport": list(seed.get("taskSupport") or ["txt2img"]), "recommendedResolution": seed.get("recommendedResolution") or "1024x1024", @@ -188,12 +238,19 @@ def _find_image_variant_by_repo(repo: str) -> dict[str, Any] | None: for seed in LATEST_IMAGE_TRACKED_SEEDS: seed_repo = str(seed.get("repo") or "") if seed_repo == repo: + size_gb, core_weights_gb, repo_size_gb = _image_seed_size_metadata(seed) return { "id": seed_repo, "repo": seed_repo, "name": seed.get("name") or seed_repo.split("/", 1)[-1], "provider": seed.get("provider") or "Community", - "sizeGb": seed.get("sizeGb") or 0, + "sizeGb": size_gb, + "runtimeFootprintGb": seed.get("runtimeFootprintGb"), + "runtimeFootprintMpsGb": seed.get("runtimeFootprintMpsGb"), + "runtimeFootprintCudaGb": seed.get("runtimeFootprintCudaGb"), + "runtimeFootprintCpuGb": seed.get("runtimeFootprintCpuGb"), + "coreWeightsGb": core_weights_gb, + "repoSizeGb": repo_size_gb, "styleTags": list(seed.get("styleTags") or []), "taskSupport": list(seed.get("taskSupport") or ["txt2img"]), "recommendedResolution": seed.get("recommendedResolution") or "1024x1024", @@ -203,13 +260,14 @@ def _find_image_variant_by_repo(repo: str) -> dict[str, Any] | None: def _image_repo_live_metadata(repo_id: str) -> dict[str, Any]: now = time.time() - cached = _IMAGE_DISCOVER_METADATA_CACHE.get(repo_id) + cache_key = (repo_id, _hf_token_cache_key()) + cached = _IMAGE_DISCOVER_METADATA_CACHE.get(cache_key) if cached is not None: cached_at, payload = cached if (now - cached_at) < _IMAGE_DISCOVER_METADATA_TTL_SECONDS: return payload - token = os.environ.get("HF_TOKEN") or os.environ.get("HUGGING_FACE_HUB_TOKEN") + token = _hf_token_value() payload: dict[str, Any] try: encoded_repo = urllib.parse.quote(repo_id, safe="/") @@ -222,6 +280,7 @@ def _image_repo_live_metadata(repo_id: str) -> dict[str, Any]: total_bytes = 0 weight_bytes = 0 + used_storage_bytes = _positive_int(data.get("usedStorage")) for sibling in data.get("siblings") or []: if not isinstance(sibling, dict): continue @@ -229,14 +288,12 @@ def _image_repo_live_metadata(repo_id: str) -> dict[str, Any]: if not path: continue lfs = sibling.get("lfs") if isinstance(sibling.get("lfs"), dict) else {} - size_bytes = sibling.get("size") or lfs.get("size") or 0 - try: - size_int = int(size_bytes) - except (TypeError, ValueError): - size_int = 0 + size_int = _positive_int(sibling.get("size")) or _positive_int(lfs.get("size")) or 0 total_bytes += size_int if _classify_hub_file(path) == "weight": weight_bytes += size_int + if total_bytes <= 0 and used_storage_bytes is not None: + total_bytes = used_storage_bytes card = data.get("cardData") or {} license_value = str(card.get("license") or "").strip() or None if isinstance(card, dict) else None @@ -276,7 +333,7 @@ def _image_repo_live_metadata(repo_id: str) -> dict[str, Any]: "metadataWarning": "Live Hugging Face metadata is temporarily unavailable. Showing curated defaults.", } - _IMAGE_DISCOVER_METADATA_CACHE[repo_id] = (now, payload) + _IMAGE_DISCOVER_METADATA_CACHE[cache_key] = (now, payload) return payload @@ -365,6 +422,7 @@ def _tracked_latest_seed_payloads(library: list[dict[str, Any]]) -> list[dict[st release_date = str(seed.get("releaseDate") or "").strip() or None snapshot_dir = _hf_repo_snapshot_dir(repo_id) on_disk_bytes = _snapshot_on_disk_bytes(snapshot_dir) + size_gb, core_weights_gb, repo_size_gb = _image_seed_size_metadata(seed) payloads.append( { "id": repo_id, @@ -377,7 +435,11 @@ def _tracked_latest_seed_payloads(library: list[dict[str, Any]]) -> list[dict[st "runtime": "Tracked diffusers candidate", "styleTags": list(seed.get("styleTags") or []), "taskSupport": list(seed.get("taskSupport") or ["txt2img"]), - "sizeGb": float(seed.get("sizeGb") or 0.0), + "sizeGb": size_gb, + "runtimeFootprintGb": seed.get("runtimeFootprintGb"), + "runtimeFootprintMpsGb": seed.get("runtimeFootprintMpsGb"), + "runtimeFootprintCudaGb": seed.get("runtimeFootprintCudaGb"), + "runtimeFootprintCpuGb": seed.get("runtimeFootprintCpuGb"), "recommendedResolution": str(seed.get("recommendedResolution") or "Unknown"), "note": str( seed.get("note") @@ -402,9 +464,9 @@ def _tracked_latest_seed_payloads(library: list[dict[str, Any]]) -> list[dict[st "gated": seed.get("gated"), "pipelineTag": seed.get("pipelineTag"), "repoSizeBytes": None, - "repoSizeGb": None, + "repoSizeGb": repo_size_gb, "coreWeightsBytes": None, - "coreWeightsGb": None, + "coreWeightsGb": core_weights_gb, "metadataWarning": "Showing ChaosEngineAI tracked latest defaults until live Hugging Face metadata is available.", "source": "latest", } @@ -419,11 +481,15 @@ def _is_latest_image_candidate(model: dict[str, Any], curated_repos: set[str]) - lowered = model_id.lower() excluded_fragments = ( "-lora", + "_lora", + "lora-", "controlnet", "ip-adapter", + "adapter", "tensorrt", "_amdgpu", "onnx", + "embedding", "instruct-pix2pix", ) if any(fragment in lowered for fragment in excluded_fragments): @@ -431,19 +497,44 @@ def _is_latest_image_candidate(model: dict[str, Any], curated_repos: set[str]) - tags = {str(tag).lower() for tag in (model.get("tags") or [])} pipeline_tag = str(model.get("pipeline_tag") or "").lower() - allowed_orgs = { + excluded_tags = { + "lora", + "controlnet", + "adapter", + "adapters", + "textual-inversion", + "embedding", + "embeddings", + "onnx", + } + if tags & excluded_tags: + return False + + trusted_providers = { "black-forest-labs", + "baidu", "stabilityai", "qwen", "hidream-ai", "zai-org", + "tongyi-mai", + "nucleusai", "efficient-large-model", "hunyuanvideo-community", "tencent-hunyuan", "thudm", + "diffusers", } provider = model_id.split("/", 1)[0].lower() if "/" in model_id else "" - if provider and provider not in allowed_orgs: + try: + downloads = int(model.get("downloads") or 0) + except (TypeError, ValueError): + downloads = 0 + try: + likes = int(model.get("likes") or 0) + except (TypeError, ValueError): + likes = 0 + if provider and provider not in trusted_providers and downloads < 1000 and likes < 25: return False if "diffusers" not in tags: @@ -467,9 +558,14 @@ def _latest_image_model_payloads(library: list[dict[str, Any]], limit: int = 10) } now = time.time() + token_cache_key = _hf_token_cache_key() cached_entries = _LATEST_IMAGE_MODELS_CACHE - if cached_entries is not None and (now - cached_entries[0]) < _LATEST_IMAGE_MODELS_TTL_SECONDS: - latest = cached_entries[1] + if ( + cached_entries is not None + and cached_entries[1] == token_cache_key + and (now - cached_entries[0]) < _LATEST_IMAGE_MODELS_TTL_SECONDS + ): + latest = cached_entries[2] return [ { **entry, @@ -481,18 +577,21 @@ def _latest_image_model_payloads(library: list[dict[str, Any]], limit: int = 10) try: params = urllib.parse.urlencode({ "filter": "diffusers", - "sort": "modified", + "sort": "createdAt", "direction": "-1", - "limit": "48", + "limit": "96", "full": "true", }) url = f"https://huggingface.co/api/models?{params}" req = urllib.request.Request(url, headers={"User-Agent": "ChaosEngineAI/0.2.0"}) + token = _hf_token_value() + if token: + req.add_header("Authorization", f"Bearer {token}") with urllib.request.urlopen(req, timeout=8) as resp: data = json.loads(resp.read().decode()) except Exception: - if cached_entries is not None: - latest = cached_entries[1] + if cached_entries is not None and cached_entries[1] == token_cache_key: + latest = cached_entries[2] return [ { **entry, @@ -502,10 +601,16 @@ def _latest_image_model_payloads(library: list[dict[str, Any]], limit: int = 10) ] return _tracked_latest_seed_payloads(library)[:limit] - candidates: list[dict[str, Any]] = [] + accepted_models: list[dict[str, Any]] = [] for model in data: if not isinstance(model, dict) or not _is_latest_image_candidate(model, curated_repos): continue + accepted_models.append(model) + if len(accepted_models) >= max(limit * 2, limit): + break + + candidates: list[dict[str, Any]] = [] + for model in accepted_models: model_id = str(model.get("id") or "") provider = model_id.split("/", 1)[0] if "/" in model_id else "Community" tags = [str(tag) for tag in (model.get("tags") or [])] @@ -513,6 +618,22 @@ def _latest_image_model_payloads(library: list[dict[str, Any]], limit: int = 10) metadata = _image_repo_live_metadata(model_id) snapshot_dir = _hf_repo_snapshot_dir(model_id) on_disk_bytes = _snapshot_on_disk_bytes(snapshot_dir) + on_disk_gb = _bytes_to_gb(on_disk_bytes) if on_disk_bytes else None + tracked_seed = _tracked_seed_for_repo(model_id) + fallback_size_gb, fallback_core_weights_gb, fallback_repo_size_gb = ( + _image_seed_size_metadata(tracked_seed) + if tracked_seed is not None + else (0.0, None, None) + ) + core_weights_gb = _positive_float(metadata.get("coreWeightsGb")) or fallback_core_weights_gb + repo_size_gb = _positive_float(metadata.get("repoSizeGb")) or fallback_repo_size_gb + size_gb = ( + _positive_float(metadata.get("coreWeightsGb")) + or _positive_float(metadata.get("repoSizeGb")) + or _positive_float(on_disk_gb) + or _positive_float(fallback_size_gb) + or 0.0 + ) candidates.append({ "id": model_id, "familyId": "latest", @@ -524,7 +645,7 @@ def _latest_image_model_payloads(library: list[dict[str, Any]], limit: int = 10) "runtime": "Diffusers candidate", "styleTags": _image_discover_style_tags(tags), "taskSupport": _image_task_support_from_metadata(pipeline_tag, tags), - "sizeGb": float(metadata.get("coreWeightsGb") or metadata.get("repoSizeGb") or 0.0), + "sizeGb": size_gb, "recommendedResolution": _image_recommended_resolution(model_id, pipeline_tag, tags), "note": ( "Latest official diffusers-compatible image model tracked by ChaosEngineAI. " @@ -534,7 +655,7 @@ def _latest_image_model_payloads(library: list[dict[str, Any]], limit: int = 10) "hasLocalData": snapshot_dir is not None, "localPath": str(snapshot_dir) if snapshot_dir else None, "onDiskBytes": on_disk_bytes, - "onDiskGb": _bytes_to_gb(on_disk_bytes) if on_disk_bytes else None, + "onDiskGb": on_disk_gb, "estimatedGenerationSeconds": None, "downloads": metadata.get("downloads"), "likes": metadata.get("likes"), @@ -548,9 +669,9 @@ def _latest_image_model_payloads(library: list[dict[str, Any]], limit: int = 10) "gated": bool(metadata.get("gated")) if metadata.get("gated") is not None else None, "pipelineTag": metadata.get("pipelineTag") or pipeline_tag, "repoSizeBytes": metadata.get("repoSizeBytes"), - "repoSizeGb": metadata.get("repoSizeGb"), + "repoSizeGb": repo_size_gb, "coreWeightsBytes": metadata.get("coreWeightsBytes"), - "coreWeightsGb": metadata.get("coreWeightsGb"), + "coreWeightsGb": core_weights_gb, "metadataWarning": metadata.get("metadataWarning"), "source": "latest", }) @@ -572,7 +693,7 @@ def _latest_image_model_payloads(library: list[dict[str, Any]], limit: int = 10) seen_repos.add(repo_id) latest = candidates[:limit] - _LATEST_IMAGE_MODELS_CACHE = (now, latest) + _LATEST_IMAGE_MODELS_CACHE = (now, token_cache_key, latest) return latest @@ -662,7 +783,7 @@ def _image_download_repo_ids() -> set[str]: if cached_entries is not None: repos.update( str(entry.get("repo") or "") - for entry in cached_entries[1] + for entry in cached_entries[2] if str(entry.get("repo") or "") ) return repos diff --git a/backend_service/helpers/persistence.py b/backend_service/helpers/persistence.py index 86322c2..b8f3d75 100644 --- a/backend_service/helpers/persistence.py +++ b/backend_service/helpers/persistence.py @@ -11,7 +11,7 @@ LEGACY_SEEDED_CHAT_IDS = {"ui-direction", "model-shortlist"} LEGACY_SEEDED_BENCHMARK_IDS = {"baseline", "native-34", "native-36", "native-44"} -LIBRARY_CACHE_VERSION = 1 +LIBRARY_CACHE_VERSION = 3 def _default_chat_variant() -> dict[str, Any]: diff --git a/backend_service/helpers/system.py b/backend_service/helpers/system.py index b3a5584..fad84ce 100644 --- a/backend_service/helpers/system.py +++ b/backend_service/helpers/system.py @@ -378,9 +378,23 @@ def _list_llm_processes(limit: int = 12) -> list[dict[str, Any]]: return matches[:limit] -def _build_system_snapshot(app_version: str, app_started_at: float) -> dict[str, Any]: - from backend_service.inference import get_backend_capabilities - native = get_backend_capabilities().to_dict() +def _capabilities_payload(capabilities: Any | None = None) -> dict[str, Any]: + if capabilities is None: + from backend_service.inference import get_backend_capabilities + return get_backend_capabilities().to_dict() + to_dict = getattr(capabilities, "to_dict", None) + if callable(to_dict): + return dict(to_dict()) + return dict(capabilities) + + +def _build_system_snapshot( + app_version: str, + app_started_at: float, + *, + capabilities: Any | None = None, +) -> dict[str, Any]: + native = _capabilities_payload(capabilities) memory = psutil.virtual_memory() try: swap = psutil.swap_memory() diff --git a/backend_service/helpers/video.py b/backend_service/helpers/video.py index d0409ab..d5b6684 100644 --- a/backend_service/helpers/video.py +++ b/backend_service/helpers/video.py @@ -67,8 +67,27 @@ def _video_model_payloads(library: list[dict[str, Any]]) -> list[dict[str, Any]] # Merge live metadata first so curated fields (releaseDate, # familyName) still win when both exist. enriched = {**enriched, **live_metadata} - enriched["availableLocally"] = _video_repo_runtime_ready(repo) if repo else False - enriched["hasLocalData"] = enriched["availableLocally"] or _video_repo_has_any_local_data(repo) + validation_error = _video_variant_validation_error(enriched) + local_data_repos = _video_variant_local_data_repos(enriched) + enriched["availableLocally"] = validation_error is None + enriched["hasLocalData"] = ( + enriched["availableLocally"] + or bool(local_data_repos) + ) + primary_local_repo = ( + repo + if repo and repo in local_data_repos + else local_data_repos[0] + if local_data_repos + else None + ) + enriched["localDataRepos"] = local_data_repos + enriched["primaryLocalRepo"] = primary_local_repo + enriched["localStatusReason"] = ( + _video_variant_local_status_reason(enriched, validation_error) + if enriched["hasLocalData"] and validation_error + else None + ) enriched["familyName"] = family["name"] release_date = str(variant.get("releaseDate") or "").strip() or None enriched["releaseDate"] = release_date @@ -81,7 +100,11 @@ def _video_model_payloads(library: list[dict[str, Any]]) -> list[dict[str, Any]] # Absolute path to the HF snapshot, used by the Reveal File button. # Only populated when there is actually something on disk so the # UI can reliably hide the button otherwise. - snapshot_dir = _hf_repo_snapshot_dir(repo) if (enriched["hasLocalData"] and repo) else None + snapshot_dir = ( + _hf_repo_snapshot_dir(primary_local_repo) + if (enriched["hasLocalData"] and primary_local_repo) + else None + ) enriched["localPath"] = str(snapshot_dir) if snapshot_dir else None on_disk_bytes = _snapshot_on_disk_bytes(snapshot_dir) enriched["onDiskBytes"] = on_disk_bytes @@ -104,11 +127,20 @@ def _find_video_variant(model_id: str) -> dict[str, Any] | None: def _find_video_variant_by_repo(repo: str) -> dict[str, Any] | None: for family in VIDEO_MODEL_FAMILIES: for variant in family["variants"]: - if variant["repo"] == repo: + if repo in _video_variant_download_repos(variant): return variant return None +def _video_variant_download_repos(variant: dict[str, Any]) -> list[str]: + repos: list[str] = [] + for key in ("repo", "ggufRepo", "textEncoderRepo"): + repo = str(variant.get(key) or "").strip() + if repo and repo not in repos: + repos.append(repo) + return repos + + def _is_video_repo(repo_id: str) -> bool: return any( str(variant.get("repo") or "") == repo_id @@ -117,19 +149,23 @@ def _is_video_repo(repo_id: str) -> bool: ) +def _is_video_download_repo(repo_id: str) -> bool: + return repo_id in _video_download_repo_ids() + + def _video_repo_runtime_ready(repo_id: str) -> bool: """True if the local snapshot is complete enough to load. - Routes the validator by engine: mlx-video repos ship text_encoder / - tokenizer / transformer / vae folders without ``model_index.json``, - so the diffusers-shape check would always falsely fail for them. - Diffusers repos still go through ``validate_local_diffusers_snapshot``. + Routes the validator by engine: mlx-video repos ship component folders + without ``model_index.json``, so the diffusers-shape check would always + falsely fail for them. Diffusers repos still go through + ``validate_local_diffusers_snapshot``. """ snapshot_dir = _hf_repo_snapshot_dir(repo_id) if snapshot_dir is None: return False if _is_mlx_video_routed_repo(repo_id): - return _validate_mlx_video_snapshot(snapshot_dir) is None + return _validate_mlx_video_snapshot(snapshot_dir, repo_id) is None return validate_local_diffusers_snapshot(snapshot_dir, repo_id) is None @@ -155,20 +191,154 @@ def _video_repo_has_any_local_data(repo_id: str) -> bool: return False +def _video_variant_local_data_repos(variant: dict[str, Any]) -> list[str]: + return [ + repo + for repo in _video_variant_download_repos(variant) + if _video_repo_has_any_local_data(repo) + ] + + def _video_variant_available_locally(variant: dict[str, Any]) -> bool: + return _video_variant_validation_error(variant) is None + + +def _video_variant_has_any_local_data(variant: dict[str, Any]) -> bool: + return bool(_video_variant_local_data_repos(variant)) + + +def _video_variant_validation_error(variant: dict[str, Any]) -> str | None: repo = str(variant.get("repo") or "") if not repo: - return False - return _video_repo_runtime_ready(repo) + return "Video model variant is missing its base repo id." + repo_error = _video_download_validation_error(repo) + if repo_error: + return repo_error + text_error = _video_variant_mlx_text_components_validation_error(variant) + if text_error: + return text_error + return _video_variant_gguf_validation_error(variant) + + +def _video_variant_mlx_text_components_validation_error(variant: dict[str, Any]) -> str | None: + repo = str(variant.get("repo") or "") + if not _is_mlx_video_routed_repo(repo): + return None + snapshot_dir = _hf_repo_snapshot_dir(repo) + if snapshot_dir is None: + return None + missing = _missing_mlx_text_components(Path(snapshot_dir)) + if not missing: + return None + + text_encoder_repo = str(variant.get("textEncoderRepo") or "").strip() + if text_encoder_repo and text_encoder_repo != repo: + text_snapshot = _hf_repo_snapshot_dir(text_encoder_repo) + if text_snapshot is not None and not _missing_mlx_text_components(Path(text_snapshot)): + return None + return ( + "The local snapshot is missing shared mlx-video text components: " + f"{', '.join(missing)}. Download the shared text encoder " + f"({text_encoder_repo}) and retry." + ) + + return ( + "The local snapshot is incomplete. Missing mlx-video components: " + f"{', '.join(missing)}. Re-download the model and keep ChaosEngineAI " + "open until the download completes." + ) + + +def _video_variant_missing_text_encoder_repo(variant: dict[str, Any]) -> str | None: + error = _video_variant_mlx_text_components_validation_error(variant) + text_encoder_repo = str(variant.get("textEncoderRepo") or "").strip() + if error and text_encoder_repo: + return text_encoder_repo + return None + + +def _video_variant_local_status_reason( + variant: dict[str, Any], + validation_error: str | None, +) -> str | None: + if not validation_error: + return None + gguf_file = str(variant.get("ggufFile") or "").strip() + gguf_repo = str(variant.get("ggufRepo") or "").strip() + if gguf_file and gguf_repo and "GGUF transformer file is missing" in validation_error: + return f"Base model installed; missing GGUF transformer: {gguf_repo}/{gguf_file}." + + prefix = "The local snapshot is incomplete. Missing mlx-video components: " + if validation_error.startswith(prefix): + missing = validation_error[len(prefix):].split(". Re-download", 1)[0] + return f"Missing MLX components: {missing}." + + shared_prefix = "The local snapshot is missing shared mlx-video text components: " + if validation_error.startswith(shared_prefix): + missing = validation_error[len(shared_prefix):].split(". Download", 1)[0] + text_encoder_repo = str(variant.get("textEncoderRepo") or "").strip() + source = f" from {text_encoder_repo}" if text_encoder_repo else "" + return f"Missing shared MLX text components{source}: {missing}." + + if validation_error.startswith("The selected GGUF transformer resolved to a cache path"): + return f"GGUF transformer cache path is invalid: {gguf_repo}/{gguf_file}." + + return validation_error + + +def _video_variant_gguf_validation_error(variant: dict[str, Any]) -> str | None: + gguf_file = str(variant.get("ggufFile") or "").strip() + if not gguf_file: + return None + gguf_repo = str(variant.get("ggufRepo") or "").strip() + if not gguf_repo: + return ( + f"{variant.get('name') or 'This GGUF video variant'} is missing " + "its GGUF repository metadata." + ) + try: + from huggingface_hub import hf_hub_download # type: ignore + + local_path = hf_hub_download( + repo_id=gguf_repo, + filename=gguf_file, + local_files_only=True, + ) + except Exception: + return ( + "The base diffusers snapshot is installed, but the selected GGUF " + f"transformer file is missing: {gguf_repo}/{gguf_file}. Download " + "the GGUF variant before generating so the app does not fall back " + "to the full BF16 transformer." + ) + if not Path(local_path).exists(): + return ( + "The selected GGUF transformer resolved to a cache path that does " + f"not exist: {gguf_repo}/{gguf_file}. Retry the GGUF download." + ) + return None def _video_download_repo_ids() -> set[str]: - return { + repos = { str(variant.get("repo") or "") for family in VIDEO_MODEL_FAMILIES for variant in family["variants"] if str(variant.get("repo") or "") } + repos.update( + str(variant.get("ggufRepo") or "") + for family in VIDEO_MODEL_FAMILIES + for variant in family["variants"] + if str(variant.get("ggufRepo") or "") + ) + repos.update( + str(variant.get("textEncoderRepo") or "") + for family in VIDEO_MODEL_FAMILIES + for variant in family["variants"] + if str(variant.get("textEncoderRepo") or "") + ) + return repos # Diffusers pipelines only need the standard per-component folders @@ -202,13 +372,23 @@ def _video_download_repo_ids() -> set[str]: _VIDEO_MLX_ALLOW_PATTERNS: list[str] = [ "text_encoder/**", "tokenizer/**", + "text_projections/**", + "audio_vae/**", "transformer/**", "vae/**", + "vocoder/**", "*spatial-upscaler*.safetensors", "*.md", "LICENSE*", ] +_VIDEO_MLX_TEXT_ENCODER_ALLOW_PATTERNS: list[str] = [ + "text_encoder/**", + "tokenizer/**", + "*.md", + "LICENSE*", +] + def _video_repo_allow_patterns(repo_id: str) -> list[str] | None: """Patterns to pass to ``snapshot_download`` for a video repo. @@ -236,11 +416,10 @@ def _video_download_validation_error(repo_id: str) -> str | None: "Retry the download and make sure the backend can access Hugging Face." ) # mlx-video routed repos (e.g. ``prince-canuma/LTX-2-*``) ship MLX - # layout — text_encoder / tokenizer / transformer / vae folders - # without ``model_index.json``. Don't apply the diffusers-shape - # validator to them; check for the MLX component folders instead. + # layouts without ``model_index.json``. Don't apply the diffusers-shape + # validator to them; check for the expected MLX component folders instead. if _is_mlx_video_routed_repo(repo_id): - return _validate_mlx_video_snapshot(snapshot_dir) + return _validate_mlx_video_snapshot(snapshot_dir, repo_id) return validate_local_diffusers_snapshot(snapshot_dir, repo_id) @@ -261,16 +440,54 @@ def _is_mlx_video_routed_repo(repo_id: str) -> bool: # diffusers layout — no model_index.json. Lifted from the ``prince-canuma/ # LTX-2-distilled`` repo tree as the canonical shape; bump as new mlx-video # families with different layouts come online. -_MLX_VIDEO_REQUIRED_COMPONENTS: tuple[str, ...] = ( +_MLX_VIDEO_LTX2_REQUIRED_COMPONENTS: tuple[str, ...] = ( "text_encoder", "tokenizer", + "text_projections", + "transformer", + "vae", +) + +_MLX_VIDEO_LTX23_REQUIRED_COMPONENTS: tuple[str, ...] = ( + "audio_vae", + "text_projections", "transformer", "vae", + "vocoder", ) -def _validate_mlx_video_snapshot(snapshot_dir: str) -> str | None: - """Return ``None`` if the snapshot has the four MLX component folders. +def _mlx_video_required_components(repo_id: str | None = None) -> tuple[str, ...]: + repo_key = str(repo_id or "").lower() + if "ltx-2.3" in repo_key: + return _MLX_VIDEO_LTX23_REQUIRED_COMPONENTS + return _MLX_VIDEO_LTX2_REQUIRED_COMPONENTS + + +def _missing_mlx_text_components(root: Path) -> list[str]: + missing: list[str] = [] + checks = { + "text_encoder": ( + root / "text_encoder" / "config.json", + root / "text_encoder" / "model.safetensors.index.json", + ), + "tokenizer": ( + root / "tokenizer" / "tokenizer.json", + root / "tokenizer" / "tokenizer.model", + ), + } + for component, required_paths in checks.items(): + component_dir = root / component + if not component_dir.is_dir(): + missing.append(component) + continue + if not all(path.exists() for path in required_paths): + missing.append(component) + return missing + + +def _validate_mlx_video_snapshot(snapshot_dir: str, repo_id: str | None = None) -> str | None: + """Return ``None`` if the snapshot has the required MLX component folders. Mirrors the contract of ``validate_local_diffusers_snapshot`` so the callers can swap one for the other without restructuring the result @@ -284,7 +501,7 @@ def _validate_mlx_video_snapshot(snapshot_dir: str) -> str | None: "Re-download the model." ) missing: list[str] = [] - for component in _MLX_VIDEO_REQUIRED_COMPONENTS: + for component in _mlx_video_required_components(repo_id): component_dir = root / component if not component_dir.is_dir(): missing.append(component) diff --git a/backend_service/image_runtime.py b/backend_service/image_runtime.py index c4f987b..5fd46ea 100644 --- a/backend_service/image_runtime.py +++ b/backend_service/image_runtime.py @@ -26,7 +26,6 @@ PHASE_LOADING, PHASE_SAVING, ) -from cache_compression import apply_diffusion_cache_strategy WORKSPACE_ROOT = Path(__file__).resolve().parents[1] @@ -660,6 +659,8 @@ def generate(self, config: ImageGenerationConfig) -> list[GeneratedImage]: # for this pipeline yet we swallow NotImplementedError and run # the stock pipeline — the UI surfaces the "Scaffold" badge so # users know why speedup didn't appear. + from cache_compression import apply_diffusion_cache_strategy + cache_note = apply_diffusion_cache_strategy( pipeline, strategy_id=config.cacheStrategy, diff --git a/backend_service/inference.py b/backend_service/inference.py index 423c0fe..ef8c321 100644 --- a/backend_service/inference.py +++ b/backend_service/inference.py @@ -593,6 +593,7 @@ class BackendCapabilities: converterAvailable: bool = False vllmAvailable: bool = False vllmVersion: str | None = None + probing: bool = False def to_dict(self) -> dict[str, Any]: return { @@ -610,6 +611,7 @@ def to_dict(self) -> dict[str, Any]: "converterAvailable": self.converterAvailable, "vllmAvailable": self.vllmAvailable, "vllmVersion": self.vllmVersion, + "probing": self.probing, } @@ -617,6 +619,35 @@ def to_dict(self) -> dict[str, Any]: _capability_lock = RLock() +def _initial_backend_capabilities() -> BackendCapabilities: + """Cheap capability placeholder used while the real probe runs. + + The full probe imports/spawns MLX and checks vLLM, which can add seconds + to cold start. These path checks are safe enough for initial UI rendering; + load_model() still refreshes capabilities synchronously before selecting + an engine. + """ + python_executable = _resolve_mlx_python() + llama_server_path = _resolve_llama_server() + llama_server_turbo_path = _resolve_llama_server_turbo() + llama_cli_path = _resolve_llama_cli() + return BackendCapabilities( + pythonExecutable=python_executable, + mlxAvailable=False, + mlxLmAvailable=False, + mlxUsable=False, + mlxMessage="Native backend detection is still running.", + ggufAvailable=bool(llama_server_path) or bool(llama_server_turbo_path), + llamaCliPath=llama_cli_path, + llamaServerPath=llama_server_path, + llamaServerTurboPath=llama_server_turbo_path, + converterAvailable=False, + vllmAvailable=False, + vllmVersion=None, + probing=True, + ) + + def _probe_native_backends() -> BackendCapabilities: python_executable = _resolve_mlx_python() llama_server_path = _resolve_llama_server() @@ -2165,10 +2196,16 @@ def stream_generate( class RuntimeController: + # Hard upper bound on the warm pool independently of memory accounting — + # if psutil isn't available we still want a sane cap. MAX_WARM_MODELS = 2 + # Reserve this much physical memory for the OS / UI / unrelated + # processes when deciding whether a new (or incoming) model fits. Mirrors + # the headroom used by ``helpers/system.py::spareHeadroomGb``. + WARM_POOL_MEMORY_HEADROOM_BYTES = 6 * 1024 * 1024 * 1024 - def __init__(self) -> None: - self.capabilities = get_backend_capabilities() + def __init__(self, *, background_probe: bool = False) -> None: + self.capabilities = _initial_backend_capabilities() self.engine: BaseInferenceEngine = MockInferenceEngine(self.capabilities) self.loaded_model: LoadedModelInfo | None = None self.runtime_note: str | None = None @@ -2178,6 +2215,51 @@ def __init__(self) -> None: self._loading_progress: dict[str, Any] | None = None self._loading_log_tail: list[str] = [] self._recent_orphaned_workers: list[dict[str, Any]] = [] + self._capability_probe_thread: Thread | None = None + self._capability_probe_lock = Lock() + if background_probe: + self.start_capability_probe() + + def start_capability_probe(self, *, force: bool = False) -> None: + with self._capability_probe_lock: + if ( + self._capability_probe_thread is not None + and self._capability_probe_thread.is_alive() + and not force + ): + return + thread = Thread( + target=self._capability_probe_worker, + kwargs={"force": force}, + name="chaosengine-capability-probe", + daemon=True, + ) + self._capability_probe_thread = thread + thread.start() + + def _capability_probe_worker(self, *, force: bool = False) -> None: + try: + capabilities = get_backend_capabilities(force=force) + except Exception as exc: + current = self.capabilities + capabilities = BackendCapabilities( + pythonExecutable=current.pythonExecutable, + mlxAvailable=False, + mlxLmAvailable=False, + mlxUsable=False, + mlxMessage=f"Native backend detection failed: {type(exc).__name__}: {exc}", + ggufAvailable=current.ggufAvailable, + llamaCliPath=current.llamaCliPath, + llamaServerPath=current.llamaServerPath, + llamaServerTurboPath=current.llamaServerTurboPath, + converterAvailable=False, + vllmAvailable=False, + vllmVersion=None, + probing=False, + ) + self.capabilities = capabilities + if isinstance(self.engine, MockInferenceEngine): + self.engine.capabilities = capabilities @staticmethod def _warm_pool_key( @@ -2266,6 +2348,7 @@ def _park_active_engine_or_unload( *, requested_identity: str, keep_warm_previous: bool = True, + required_free_bytes: int = 0, ) -> None: if not self.loaded_model or not self.engine: return @@ -2284,7 +2367,19 @@ def _park_active_engine_or_unload( except Exception: pass return - self._evict_warm_pool() + active_bytes = max( + self._model_resident_bytes(self.loaded_model), + self._engine_resident_bytes(self.engine), + ) + self._evict_warm_pool( + incoming_bytes=active_bytes, + ) + if not self._can_keep_warm_model(active_bytes, required_free_bytes=required_free_bytes): + try: + self.engine.unload_model() + except Exception: + pass + return self._warm_pool[current_key] = (self.engine, self.loaded_model) def _tracked_process_pids(self) -> set[int]: @@ -2460,6 +2555,8 @@ def refresh_capabilities(self, *, force: bool = False) -> BackendCapabilities: _LLAMA_HELP_CACHE.clear() _CACHE_TYPE_CACHE.clear() self.capabilities = get_backend_capabilities(force=force) + if isinstance(self.engine, MockInferenceEngine): + self.engine.capabilities = self.capabilities return self.capabilities def _select_engine( @@ -2540,15 +2637,100 @@ def warm_models(self) -> list[dict[str, Any]]: result.append({**info.to_dict(), "warm": True, "active": False}) return result - def _evict_warm_pool(self) -> None: - """Remove the oldest entry from the warm pool if at capacity.""" + @staticmethod + def _model_resident_bytes(info: LoadedModelInfo) -> int: + """Best-effort estimate of RAM held by a loaded model. + + For local weights we use on-disk size as a proxy — mlx-lm mmaps the + weights so RSS tracks file size closely; for llama.cpp / GGUF the + whole file ends up resident once warm. For catalog/no-path entries + we fall back to 0 (no useful estimate, treat as memory-free). + """ + return _path_size_bytes(info.path) if info.path else 0 + + @staticmethod + def _target_resident_bytes(*, path: str | None, runtime_target: str | None) -> int: + for candidate in (path, runtime_target): + if not candidate: + continue + size = _path_size_bytes(candidate) + if size > 0: + return size + return 0 + + @staticmethod + def _engine_resident_bytes(engine: BaseInferenceEngine | None) -> int: + if engine is None: + return 0 + pid_getter = getattr(engine, "process_pid", None) + pid = pid_getter() if callable(pid_getter) else None + if not isinstance(pid, int): + return 0 + try: + import psutil + + return int(psutil.Process(pid).memory_info().rss) + except Exception: + return 0 + + def _warm_pool_resident_bytes(self) -> int: + return sum( + max(self._model_resident_bytes(info), self._engine_resident_bytes(engine)) + for engine, info in self._warm_pool.values() + ) + + def _memory_budget_bytes(self) -> int: + """Bytes available for warm-pool weights, after OS headroom. + + Returns 0 when psutil isn't usable; callers must fall back to the + count-based MAX_WARM_MODELS cap in that case. + """ + try: + import psutil + + available = int(psutil.virtual_memory().available) + except Exception: + return 0 + return max(0, available - self.WARM_POOL_MEMORY_HEADROOM_BYTES) + + def _pop_oldest_warm_entry(self) -> None: + if not self._warm_pool: + return + oldest_key = next(iter(self._warm_pool)) + old_engine, _ = self._warm_pool.pop(oldest_key) + try: + old_engine.unload_model() + except Exception: + pass + + def _can_keep_warm_model(self, incoming_bytes: int, *, required_free_bytes: int = 0) -> bool: + budget = self._memory_budget_bytes() + if budget <= 0: + return True + if required_free_bytes > budget: + return False + return self._warm_pool_resident_bytes() + incoming_bytes <= budget + + def _evict_warm_pool(self, *, incoming_bytes: int = 0) -> None: + """Make room for an incoming entry in the warm pool. + + First applies the count cap (MAX_WARM_MODELS) so a flapping budget + can never grow the pool unboundedly. Then, if ``psutil`` reports a + live memory budget, evicts oldest entries until the pool plus the + incoming model fits within ``available - headroom``. + + ``incoming_bytes`` is the resident-byte estimate for the model + about to enter the pool (typically the model being parked from + active to warm). Passing 0 still triggers the count cap. + """ while len(self._warm_pool) >= self.MAX_WARM_MODELS: - oldest_key = next(iter(self._warm_pool)) - old_engine, _ = self._warm_pool.pop(oldest_key) - try: - old_engine.unload_model() - except Exception: - pass + self._pop_oldest_warm_entry() + + budget = self._memory_budget_bytes() + if budget <= 0: + return + while self._warm_pool and self._warm_pool_resident_bytes() + incoming_bytes > budget: + self._pop_oldest_warm_entry() def load_model( self, @@ -2598,6 +2780,10 @@ def _internal_progress(progress: dict[str, Any]) -> None: runtime_target=runtime_target, path=path, ) + incoming_load_bytes = self._target_resident_bytes( + path=path, + runtime_target=runtime_target, + ) # Check warm pool first — instant switch if the exact runtime profile is cached pool_key = self._warm_pool_key( @@ -2639,6 +2825,7 @@ def _internal_progress(progress: dict[str, Any]) -> None: self._park_active_engine_or_unload( requested_identity=requested_identity, keep_warm_previous=keep_warm_previous, + required_free_bytes=incoming_load_bytes, ) self.engine = selected_engine diff --git a/backend_service/mlx_video_runtime.py b/backend_service/mlx_video_runtime.py index b52abfb..346d170 100644 --- a/backend_service/mlx_video_runtime.py +++ b/backend_service/mlx_video_runtime.py @@ -88,6 +88,11 @@ ("Lightricks/LTX-2.3", "ltx-2.3-spatial-upscaler-x2-1.0.safetensors"), ), } +_LTX2_SHARED_TEXT_ENCODER_CANDIDATES: tuple[str, ...] = ( + "prince-canuma/LTX-2-distilled", + "Lightricks/LTX-2", +) +_LTX2_TEXT_COMPONENTS: tuple[str, ...] = ("text_encoder", "tokenizer") _LTX2_DISTILLED_STAGE_1_STEPS = 8 _LTX2_DISTILLED_STAGE_2_STEPS = 3 @@ -255,6 +260,86 @@ def _resolve_ltx2_spatial_upscaler( ) +def _resolve_local_snapshot(repo_or_path: str) -> Path | None: + candidate = Path(repo_or_path) + if candidate.exists(): + return candidate + try: + from huggingface_hub import snapshot_download # type: ignore + + return Path(snapshot_download(repo_id=repo_or_path, local_files_only=True)) + except Exception: + return None + + +def _missing_ltx2_text_components(root: Path) -> list[str]: + missing: list[str] = [] + checks = { + "text_encoder": ( + root / "text_encoder" / "config.json", + root / "text_encoder" / "model.safetensors.index.json", + ), + "tokenizer": ( + root / "tokenizer" / "tokenizer.json", + root / "tokenizer" / "tokenizer.model", + ), + } + for component, required_paths in checks.items(): + if not (root / component).is_dir(): + missing.append(component) + continue + if not all(path.exists() for path in required_paths): + missing.append(component) + return missing + + +def _resolve_ltx2_text_component_source(repo: str) -> Path: + for candidate_repo in tuple(dict.fromkeys((repo, *_LTX2_SHARED_TEXT_ENCODER_CANDIDATES))): + snapshot = _resolve_local_snapshot(candidate_repo) + if snapshot is not None and not _missing_ltx2_text_components(snapshot): + return snapshot + checked = ", ".join(_LTX2_SHARED_TEXT_ENCODER_CANDIDATES) + raise RuntimeError( + "LTX-2.3 MLX generation needs shared text_encoder and tokenizer " + f"components, but none were found locally. Download {checked} or " + "resume this model download, then try again." + ) + + +def _prepare_ltx2_model_path(repo: str, workspace: Path) -> Path: + model_path = _resolve_local_snapshot(repo) + if model_path is None: + raise RuntimeError( + f"LTX-2 MLX model snapshot is not available locally for {repo}. " + "Download the model before generating." + ) + + missing = _missing_ltx2_text_components(model_path) + if not missing: + return model_path + + text_source = _resolve_ltx2_text_component_source(repo) + overlay = workspace / "model-overlay" + shutil.rmtree(overlay, ignore_errors=True) + overlay.mkdir(parents=True, exist_ok=True) + + missing_set = set(missing) + for entry in model_path.iterdir(): + if entry.name in missing_set: + continue + (overlay / entry.name).symlink_to(entry, target_is_directory=entry.is_dir()) + for component in _LTX2_TEXT_COMPONENTS: + target = overlay / component + if target.exists() or target.is_symlink(): + if target.is_dir() and not target.is_symlink(): + shutil.rmtree(target) + else: + target.unlink() + source = text_source / component + target.symlink_to(source, target_is_directory=True) + return overlay + + class _ProgressSink(Protocol): def __call__(self, phase: str, message: str, fraction: float) -> None: ... @@ -409,12 +494,15 @@ def _build_cmd( entry = _resolve_entry_point(config.repo) python = _resolve_video_python() pipeline_flag = _resolve_pipeline_flag(config.repo) + model_repo_arg = config.repo + if resolve_aux_files and "ltx-2.3" in config.repo.lower(): + model_repo_arg = str(_prepare_ltx2_model_path(config.repo, output_path.parent)) cmd = [ python, "-m", entry, "--model-repo", - config.repo, + model_repo_arg, "--pipeline", pipeline_flag, "--prompt", diff --git a/backend_service/models/__init__.py b/backend_service/models/__init__.py index fe70463..a47fe80 100644 --- a/backend_service/models/__init__.py +++ b/backend_service/models/__init__.py @@ -195,6 +195,7 @@ class DeleteModelRequest(BaseModel): class DownloadModelRequest(BaseModel): repo: str = Field(min_length=3, max_length=256) + modelId: str | None = Field(default=None, min_length=1, max_length=256) class ImageGenerationRequest(BaseModel): diff --git a/backend_service/plugins/__init__.py b/backend_service/plugins/__init__.py index f218910..7d8f56d 100644 --- a/backend_service/plugins/__init__.py +++ b/backend_service/plugins/__init__.py @@ -9,6 +9,7 @@ from dataclasses import dataclass, field from enum import Enum from pathlib import Path +from threading import RLock from typing import Any import importlib import json @@ -38,28 +39,43 @@ class BasePlugin(ABC): def manifest(self) -> PluginManifest: ... class PluginRegistry: - def __init__(self): + def __init__(self, *, auto_register_builtins: bool = False): self._plugins: dict[str, tuple[PluginManifest, Any]] = {} + self._auto_register_builtins = auto_register_builtins + self._builtins_registered = False + self._lock = RLock() def register(self, manifest: PluginManifest, instance: Any = None): self._plugins[manifest.id] = (manifest, instance) + def ensure_builtins(self) -> None: + if not self._auto_register_builtins or self._builtins_registered: + return + with self._lock: + if not self._builtins_registered: + self.register_builtins() + def get(self, plugin_id: str) -> tuple[PluginManifest, Any] | None: + self.ensure_builtins() return self._plugins.get(plugin_id) def list_all(self) -> list[PluginManifest]: + self.ensure_builtins() return [m for m, _ in self._plugins.values()] def list_by_type(self, plugin_type: PluginType) -> list[tuple[PluginManifest, Any]]: + self.ensure_builtins() return [(m, i) for m, i in self._plugins.values() if m.plugin_type == plugin_type] def enable(self, plugin_id: str) -> bool: + self.ensure_builtins() if plugin_id in self._plugins: self._plugins[plugin_id][0].enabled = True return True return False def disable(self, plugin_id: str) -> bool: + self.ensure_builtins() if plugin_id in self._plugins: self._plugins[plugin_id][0].enabled = False return True @@ -89,7 +105,7 @@ def register_builtins(self): """Register all built-in components as plugins.""" # Cache strategies from cache_compression import registry as cache_registry - for strategy in cache_registry._strategies.values(): + for strategy in cache_registry.strategies(): manifest = PluginManifest( id=f"cache.{strategy.strategy_id}", name=strategy.name, @@ -110,7 +126,7 @@ def register_builtins(self): description=tool.description, ) self.register(manifest, tool) + self._builtins_registered = True # Module singleton -plugin_registry = PluginRegistry() -plugin_registry.register_builtins() +plugin_registry = PluginRegistry(auto_register_builtins=True) diff --git a/backend_service/routes/cache.py b/backend_service/routes/cache.py index 418b20e..1edd5ed 100644 --- a/backend_service/routes/cache.py +++ b/backend_service/routes/cache.py @@ -2,7 +2,7 @@ from typing import Any -from fastapi import APIRouter, Query +from fastapi import APIRouter, Query, Request from backend_service.app import _build_system_snapshot, compute_cache_preview @@ -11,6 +11,7 @@ @router.get("/api/cache/preview") def cache_preview( + request: Request, bits: int = Query(3, ge=0, le=8), fp16_layers: int = Query(4, ge=0, le=16), num_layers: int = Query(32, ge=1, le=160), @@ -21,7 +22,7 @@ def cache_preview( params_b: float = Query(7.0, ge=0.5, le=1000.0), strategy: str = Query("native"), ) -> dict[str, Any]: - system_stats = _build_system_snapshot() + system_stats = _build_system_snapshot(capabilities=request.app.state.chaosengine.runtime.capabilities) return compute_cache_preview( bits=bits, fp16_layers=fp16_layers, diff --git a/backend_service/routes/health.py b/backend_service/routes/health.py index 8ddf97e..8d40109 100644 --- a/backend_service/routes/health.py +++ b/backend_service/routes/health.py @@ -15,18 +15,16 @@ def health(request: Request) -> dict[str, Any]: state = request.app.state.chaosengine from backend_service.app import WORKSPACE_ROOT, app_version - runtime_status = state.runtime.status( - active_requests=state.active_requests, - requests_served=state.requests_served, - ) + capabilities = state.runtime.capabilities + loaded_model = state.runtime.loaded_model return { "status": "ok", "workspaceRoot": str(WORKSPACE_ROOT), - "runtime": _runtime_label(), + "runtime": _runtime_label(capabilities.to_dict()), "appVersion": app_version, - "engine": runtime_status["engine"], - "loadedModel": runtime_status["loadedModel"], - "nativeBackends": runtime_status["nativeBackends"], + "engine": state.runtime.engine.engine_name, + "loadedModel": loaded_model.to_dict() if loaded_model is not None else None, + "nativeBackends": capabilities.to_dict(), } diff --git a/backend_service/routes/images.py b/backend_service/routes/images.py index 9a741c5..7f81689 100644 --- a/backend_service/routes/images.py +++ b/backend_service/routes/images.py @@ -25,11 +25,53 @@ _find_image_output, _delete_image_output, ) -from backend_service.progress import GenerationCancelled, IMAGE_PROGRESS +from backend_service.progress import GenerationCancelled, IMAGE_PROGRESS, VIDEO_PROGRESS router = APIRouter() +def _unload_idle_video_runtime_for_image(request: Request, action: str) -> None: + """Free resident video diffusion weights before image work starts. + + Image and video pipelines live in separate managers, so loading an image + model no longer implicitly releases a previously-loaded video model. That + can leave tens of GB resident across Studio switches. If video generation + is actively running, fail fast instead of blocking the image request behind + a long render. + """ + state = request.app.state.chaosengine + if VIDEO_PROGRESS.snapshot().get("active"): + raise HTTPException( + status_code=409, + detail=( + "A video generation is still running. Wait for it to finish or cancel it " + "before loading an image model." + ), + ) + try: + runtime = state.video_runtime.capabilities() + except Exception: + return + loaded_repo = str(runtime.get("loadedModelRepo") or "") + if not loaded_repo: + return + try: + state.video_runtime.unload() + except Exception as exc: + state.add_log( + "images", + "warning", + f"Could not unload video model before {action}: {type(exc).__name__}: {exc}", + ) + return + state.add_log( + "images", + "info", + f"Unloaded video model {loaded_repo} before {action} to free memory.", + ) + state.add_activity("Video model unloaded", f"Freed memory for {action}") + + @router.get("/api/images/catalog") def image_catalog(request: Request) -> dict[str, Any]: state = request.app.state.chaosengine @@ -89,6 +131,7 @@ def preload_image_model(request: Request, body: ImageRuntimePreloadRequest) -> d validation_error = _image_download_validation_error(variant["repo"]) detail = validation_error or f"{variant['name']} is not installed locally yet." raise HTTPException(status_code=409, detail=detail) + _unload_idle_video_runtime_for_image(request, "image preload") try: runtime = state.image_runtime.preload(variant["repo"]) except RuntimeError as exc: @@ -185,6 +228,7 @@ def generate_image(request: Request, body: ImageGenerationRequest) -> dict[str, state.add_log("images", "error", f"Image model not found in catalog or tracked seeds: '{body.modelId}'") raise HTTPException(status_code=404, detail=f"Unknown image model '{body.modelId}'. The model isn't in the curated catalog or tracked seeds.") state.add_log("images", "info", f"Resolved variant: {variant.get('name')} (repo={variant.get('repo')})") + _unload_idle_video_runtime_for_image(request, "image generation") try: artifacts, runtime = _generate_image_artifacts(body, variant, state.image_runtime) except GenerationCancelled: diff --git a/backend_service/routes/setup.py b/backend_service/routes/setup.py index 413e333..dcdfd92 100644 --- a/backend_service/routes/setup.py +++ b/backend_service/routes/setup.py @@ -1,5 +1,6 @@ from __future__ import annotations +import importlib import os import platform import shutil @@ -210,6 +211,8 @@ def install_pip_package(request: Request, body: InstallPackageRequest) -> dict[s f"{', '.join(unique)}" ).strip() + importlib.invalidate_caches() + # Re-probe capabilities after install state.runtime.refresh_capabilities(force=True) caps = state.runtime.capabilities.to_dict() diff --git a/backend_service/routes/video.py b/backend_service/routes/video.py index 54485ec..c11a977 100644 --- a/backend_service/routes/video.py +++ b/backend_service/routes/video.py @@ -13,13 +13,17 @@ from fastapi.responses import FileResponse from backend_service.helpers.video import ( + _VIDEO_MLX_TEXT_ENCODER_ALLOW_PATTERNS, _find_video_variant, _find_video_variant_by_repo, + _is_video_download_repo, _is_video_repo, _video_download_repo_ids, _video_download_validation_error, _video_model_payloads, + _video_variant_missing_text_encoder_repo, _video_variant_available_locally, + _video_variant_validation_error, ) from backend_service.models import ( DownloadModelRequest, @@ -27,12 +31,47 @@ VideoRuntimePreloadRequest, VideoRuntimeUnloadRequest, ) -from backend_service.progress import GenerationCancelled, VIDEO_PROGRESS +from backend_service.progress import GenerationCancelled, IMAGE_PROGRESS, VIDEO_PROGRESS router = APIRouter() +def _unload_idle_image_runtime_for_video(request: Request, action: str) -> None: + """Free resident image diffusion weights before video work starts.""" + state = request.app.state.chaosengine + if IMAGE_PROGRESS.snapshot().get("active"): + raise HTTPException( + status_code=409, + detail=( + "An image generation is still running. Wait for it to finish or cancel it " + "before loading a video model." + ), + ) + try: + runtime = state.image_runtime.capabilities() + except Exception: + return + loaded_repo = str(runtime.get("loadedModelRepo") or "") + if not loaded_repo: + return + try: + state.image_runtime.unload() + except Exception as exc: + state.add_log( + "video", + "warning", + f"Could not unload image model before {action}: {type(exc).__name__}: {exc}", + ) + return + state.add_log( + "video", + "info", + f"Unloaded image model {loaded_repo} before {action} to free memory.", + ) + state.add_activity("Image model unloaded", f"Freed memory for {action}") + + @router.get("/api/video/catalog") def video_catalog(request: Request) -> dict[str, Any]: """Return the curated catalog of video generation models.""" @@ -114,10 +153,11 @@ def preload_video_model(request: Request, body: VideoRuntimePreloadRequest) -> d raise HTTPException(status_code=404, detail=f"Unknown video model '{body.modelId}'.") if not _video_variant_available_locally(variant): - validation_error = _video_download_validation_error(variant["repo"]) + validation_error = _video_variant_validation_error(variant) detail = validation_error or f"{variant['name']} is not installed locally yet." raise HTTPException(status_code=409, detail=detail) + _unload_idle_image_runtime_for_video(request, "video preload") try: runtime = state.video_runtime.preload(variant["repo"]) except RuntimeError as exc: @@ -257,10 +297,11 @@ def generate_video(request: Request, body: VideoGenerationRequest) -> dict[str, ) if not _video_variant_available_locally(variant): - validation_error = _video_download_validation_error(variant["repo"]) + validation_error = _video_variant_validation_error(variant) detail = validation_error or f"{variant['name']} is not installed locally yet." raise HTTPException(status_code=409, detail=detail) + _unload_idle_image_runtime_for_video(request, "video generation") try: artifact, runtime = _generate_video_artifact(body, variant, state.video_runtime) except GenerationCancelled: @@ -305,12 +346,47 @@ def download_video_model(request: Request, body: DownloadModelRequest) -> dict[s at an arbitrary model via the API. """ state = request.app.state.chaosengine + variant = _find_video_variant(body.modelId) if body.modelId else None + if body.modelId and variant is None: + raise HTTPException(status_code=404, detail=f"Unknown video model '{body.modelId}'.") + if variant is not None and variant.get("ggufFile"): + base_error = _video_download_validation_error(str(variant["repo"])) + if base_error: + label = variant["name"] + state.add_log("video", "info", f"Video download requested: {label} base ({variant['repo']})") + return {"download": state.start_download(str(variant["repo"]))} + gguf_repo = str(variant.get("ggufRepo") or "") + gguf_file = str(variant.get("ggufFile") or "") + if not gguf_repo or not gguf_file: + raise HTTPException(status_code=400, detail=f"GGUF metadata is incomplete for {variant['name']}.") + state.add_log("video", "info", f"Video download requested: {variant['name']} GGUF ({gguf_repo}/{gguf_file})") + return { + "download": state.start_download( + gguf_repo, + allow_patterns=[gguf_file, "*.md", "LICENSE*"], + ) + } + + if variant is not None: + text_encoder_repo = _video_variant_missing_text_encoder_repo(variant) + if text_encoder_repo: + state.add_log( + "video", + "info", + f"Video download requested: {variant['name']} shared text encoder ({text_encoder_repo})", + ) + return { + "download": state.start_download( + text_encoder_repo, + allow_patterns=list(_VIDEO_MLX_TEXT_ENCODER_ALLOW_PATTERNS), + ) + } + if not _is_video_repo(body.repo): raise HTTPException( status_code=404, detail=f"Repo '{body.repo}' is not in the curated video model catalog.", ) - variant = _find_video_variant_by_repo(body.repo) label = variant["name"] if variant else body.repo state.add_log("video", "info", f"Video download requested: {label} ({body.repo})") return {"download": state.start_download(body.repo)} @@ -332,7 +408,7 @@ def video_download_status(request: Request) -> dict[str, Any]: @router.post("/api/video/download/cancel") def cancel_video_download(request: Request, body: DownloadModelRequest) -> dict[str, Any]: state = request.app.state.chaosengine - if not _is_video_repo(body.repo): + if not _is_video_download_repo(body.repo): raise HTTPException( status_code=404, detail=f"Repo '{body.repo}' is not in the curated video model catalog.", @@ -343,7 +419,7 @@ def cancel_video_download(request: Request, body: DownloadModelRequest) -> dict[ @router.post("/api/video/download/delete") def delete_video_download(request: Request, body: DownloadModelRequest) -> dict[str, Any]: state = request.app.state.chaosengine - if not _is_video_repo(body.repo): + if not _is_video_download_repo(body.repo): raise HTTPException( status_code=404, detail=f"Repo '{body.repo}' is not in the curated video model catalog.", diff --git a/backend_service/state.py b/backend_service/state.py index 63582f9..67fcfa9 100644 --- a/backend_service/state.py +++ b/backend_service/state.py @@ -18,7 +18,6 @@ from fastapi import HTTPException from starlette.responses import StreamingResponse -from cache_compression import registry as cache_registry from backend_service.catalog import CATALOG from backend_service.inference import RuntimeController @@ -170,6 +169,7 @@ def __init__( benchmarks_path: Path | None = None, chat_sessions_path: Path | None = None, library_cache_path: Path | None = None, + background_capability_probe: bool = False, ) -> None: # Defer imports of module-level constants to avoid circular imports from backend_service.app import ( @@ -210,7 +210,7 @@ def __init__( self._library_scan_done.set() else: self._library_scan_done.set() - self.runtime = RuntimeController() + self.runtime = RuntimeController(background_probe=background_capability_probe) self._image_runtime: "ImageRuntimeManager | None" = None self._video_runtime: "VideoRuntimeManager | None" = None self._chat_sessions_path = chat_sessions_path if chat_sessions_path is not None else CHAT_SESSIONS_PATH @@ -419,10 +419,16 @@ def _settings_payload(self, library: list[dict[str, Any]]) -> dict[str, Any]: "hfCachePath": str(self.settings.get("hfCachePath") or ""), } + def _system_snapshot(self) -> dict[str, Any]: + try: + return self._system_snapshot_provider(capabilities=self.runtime.capabilities) + except TypeError: + return self._system_snapshot_provider() + def _bootstrap(self) -> None: from backend_service.app import app_version - system = self._system_snapshot_provider() + system = self._system_snapshot() recommendation = _best_fit_recommendation(system) self.add_log("chaosengine", "info", f"Workspace booted in {system['backendLabel']} mode.") self.add_log("chaosengine", "info", f"ChaosEngine v{app_version} detected.") @@ -494,6 +500,8 @@ def _native_cache_label() -> str: return "Native f16 cache" def _cache_label(self, *, cache_strategy: str, bits: int, fp16_layers: int) -> str: + from cache_compression import registry as cache_registry + strategy = cache_registry.get(cache_strategy) if strategy is not None: return strategy.label(bits, fp16_layers) @@ -853,6 +861,40 @@ def _default_session_model(self) -> dict[str, Any]: "treeBudget": model_info.treeBudget, } + # No model is currently loaded. Prefer a model the user actually has + # downloaded over a catalog default — surfacing a catalog-only entry + # (e.g. nvidia/NVIDIA-Nemotron-3-Nano-4B-GGUF) just produces a + # confusing "Failed to load … isn't downloaded on this machine" + # error when the user clicks Load. + for entry in self._library(): + entry_type = entry.get("modelType") + if entry_type and entry_type != "text": + continue + if entry.get("broken"): + continue + return { + "model": entry["name"], + "modelRef": entry["name"], + "canonicalRepo": entry.get("canonicalRepo") or entry.get("repo"), + "modelSource": "library", + "modelPath": entry["path"], + "modelBackend": entry.get("backend", "auto"), + "cacheLabel": self._cache_label( + cache_strategy=str(launch_preferences["cacheStrategy"]), + bits=int(launch_preferences["cacheBits"]), + fp16_layers=int(launch_preferences["fp16Layers"]), + ), + "cacheStrategy": launch_preferences["cacheStrategy"], + "cacheBits": launch_preferences["cacheBits"], + "fp16Layers": launch_preferences["fp16Layers"], + "fusedAttention": launch_preferences["fusedAttention"], + "fitModelInMemory": launch_preferences["fitModelInMemory"], + "contextTokens": launch_preferences["contextTokens"], + "speculativeDecoding": launch_preferences.get("speculativeDecoding", False), + "dflashDraftModel": None, + "treeBudget": launch_preferences.get("treeBudget", 0), + } + default_variant = _default_chat_variant() return { "model": default_variant["name"], @@ -1087,6 +1129,7 @@ def update_settings(self, request: UpdateSettingsRequest) -> dict[str, Any]: next_settings["remoteProviders"] = normalized if request.huggingFaceToken is not None: + previous_token_value = str(next_settings.get("huggingFaceToken") or "") token_value = request.huggingFaceToken.strip() next_settings["huggingFaceToken"] = token_value if token_value: @@ -1095,6 +1138,12 @@ def update_settings(self, request: UpdateSettingsRequest) -> dict[str, Any]: else: os.environ.pop("HF_TOKEN", None) os.environ.pop("HUGGING_FACE_HUB_TOKEN", None) + if token_value != previous_token_value: + from backend_service.helpers.huggingface import _clear_huggingface_caches + from backend_service.helpers.images import _clear_image_discover_caches + + _clear_huggingface_caches() + _clear_image_discover_caches() # Output directory overrides. Empty string clears the override. # Anything non-empty must be absolute or ~-relative — same rule as @@ -1199,7 +1248,7 @@ def _conversion_details( fp16_layers=launch_preferences["fp16Layers"], context_tokens=launch_preferences["contextTokens"], params_b=params_b, - system_stats=self._system_snapshot_provider(), + system_stats=self._system_snapshot(), ) if params_b is not None else None @@ -1307,7 +1356,7 @@ def run_benchmark(self, request: BenchmarkRunRequest) -> dict[str, Any]: fp16_layers=request.fp16Layers, context_tokens=request.contextTokens, params_b=params_b, - system_stats=self._system_snapshot_provider(), + system_stats=self._system_snapshot(), ) use_compressed = request.cacheBits > 0 cache_gb = preview["optimizedCacheGb"] if use_compressed else preview["baselineCacheGb"] @@ -2428,7 +2477,7 @@ def _sse_stream(): headers={"Cache-Control": "no-cache", "Connection": "keep-alive", "X-Accel-Buffering": "no"}, ) - def start_download(self, repo: str) -> dict[str, Any]: + def start_download(self, repo: str, allow_patterns: list[str] | None = None) -> dict[str, Any]: from backend_service.helpers.huggingface import ( _friendly_hf_download_error, _hf_repo_downloaded_bytes, @@ -2535,7 +2584,7 @@ def _progress_worker() -> None: # allowlist so we skip legacy single-file checkpoints the # pipelines never load. Both helpers return None for repos # outside their catalog, so only one ever applies. - allow_patterns = ( + effective_allow_patterns = allow_patterns or ( _video_repo_allow_patterns(repo) or _image_repo_allow_patterns(repo) ) @@ -2543,7 +2592,7 @@ def _progress_worker() -> None: repo, env, process_log, - allow_patterns=allow_patterns, + allow_patterns=effective_allow_patterns, ) with self._lock: if self._download_tokens.get(repo) == download_token: @@ -2869,7 +2918,7 @@ def server_status(self) -> dict[str, Any]: def workspace(self) -> dict[str, Any]: from backend_service.app import compute_cache_preview - system_stats = self._system_snapshot_provider() + system_stats = self._system_snapshot() try: loaded_name = self.runtime.loaded_model.name if self.runtime.loaded_model else None loaded_engine = self.runtime.engine.engine_name if self.runtime.engine else None diff --git a/backend_service/video_runtime.py b/backend_service/video_runtime.py index e7426ce..f301294 100644 --- a/backend_service/video_runtime.py +++ b/backend_service/video_runtime.py @@ -32,7 +32,6 @@ from backend_service.helpers.gpu import nvidia_gpu_present from backend_service.image_runtime import validate_local_diffusers_snapshot -from cache_compression import apply_diffusion_cache_strategy from backend_service.progress import ( GenerationCancelled, PHASE_DECODING, @@ -540,15 +539,15 @@ def _enhance_prompt(repo: str, prompt: str) -> tuple[str, str | None]: # pressure. Numbers come from the catalog ``sizeGb`` estimates for the # stock variants; GGUF Q4/Q6/Q8 variants override at the call site. _VIDEO_MODEL_FOOTPRINT_BF16_GB: dict[str, float] = { - "Lightricks/LTX-Video": 14.0, + "Lightricks/LTX-Video": 10.0, "Wan-AI/Wan2.1-T2V-1.3B-Diffusers": 9.0, "Wan-AI/Wan2.1-T2V-14B-Diffusers": 28.0, "Wan-AI/Wan2.2-TI2V-5B-Diffusers": 11.0, "Wan-AI/Wan2.2-T2V-A14B-Diffusers": 28.0, "hunyuanvideo-community/HunyuanVideo": 26.0, - "genmo/mochi-1-preview": 20.0, - "THUDM/CogVideoX-2b": 10.0, - "THUDM/CogVideoX-5b": 18.0, + "genmo/mochi-1-preview": 22.0, + "THUDM/CogVideoX-2b": 19.0, + "THUDM/CogVideoX-5b": 33.0, } # GGUF quant level → multiplier vs the bf16 footprint. Keys are matched as @@ -985,6 +984,8 @@ def generate(self, config: VideoGenerationConfig) -> GeneratedVideo: # ~1.3–2× on Wan). NotImplementedError is swallowed by the # helper when the pipeline class has no vendored patch yet; # see FU-007 in CLAUDE.md. + from cache_compression import apply_diffusion_cache_strategy + apply_diffusion_cache_strategy( pipeline, strategy_id=config.cacheStrategy, @@ -1527,6 +1528,11 @@ def _ensure_pipeline( pipeline_kwargs["transformer"] = quantized_transformer if gguf_note: VIDEO_PROGRESS.set_phase(PHASE_LOADING, message=gguf_note) + if quantized_transformer is None: + raise RuntimeError( + gguf_note + or f"Could not load requested GGUF transformer {gguf_file}." + ) elif use_nf4: VIDEO_PROGRESS.set_phase( PHASE_LOADING, @@ -1635,9 +1641,10 @@ def _try_load_gguf_transformer( Mirrors the image-side loader: GGUF weights cover the DiT only; VAE and text encoders are loaded from the base ``repo`` snapshot. - All failure modes are non-fatal — a missing ``gguf`` package, an - old diffusers without ``GGUFQuantizationConfig``, or an HF cache - miss falls back to the standard fp16 / bf16 transformer path. + The helper itself only reports ``(None, note)`` on failure so tests + can exercise each missing-dependency path. ``_ensure_pipeline`` treats + a requested GGUF variant as strict and raises with that note rather + than silently loading the full fp16 / bf16 transformer. """ if importlib.util.find_spec("gguf") is None: return None, ( diff --git a/cache_compression/__init__.py b/cache_compression/__init__.py index ad3cb81..1bcfa2c 100644 --- a/cache_compression/__init__.py +++ b/cache_compression/__init__.py @@ -10,6 +10,7 @@ from abc import ABC, abstractmethod import importlib import platform +from threading import RLock from typing import Any @@ -147,18 +148,27 @@ class CacheStrategyRegistry: def __init__(self) -> None: self._strategies: dict[str, CacheStrategy] = {} + self._discovered = False + self._lock = RLock() def register(self, strategy: CacheStrategy) -> None: self._strategies[strategy.strategy_id] = strategy def get(self, strategy_id: str) -> CacheStrategy | None: + self._ensure_discovered() return self._strategies.get(strategy_id) def default(self) -> CacheStrategy: + self._ensure_discovered() return self._strategies["native"] + def strategies(self) -> list[CacheStrategy]: + self._ensure_discovered() + return list(self._strategies.values()) + def available(self) -> list[dict[str, Any]]: """Return a JSON-friendly list for the frontend.""" + self._ensure_discovered() out: list[dict[str, Any]] = [] for s in self._strategies.values(): out.append({ @@ -176,11 +186,19 @@ def available(self) -> list[dict[str, Any]]: }) return out + def _ensure_discovered(self) -> None: + if self._discovered: + return + with self._lock: + if not self._discovered: + self.discover() + def discover(self) -> list[CacheStrategy]: """Import all known adapter modules and return available strategies.""" - self._strategies = {} + with self._lock: + self._strategies = {} - strategy_specs = [ + strategy_specs = [ { "id": "native", "name": "Native f16", @@ -248,31 +266,32 @@ def discover(self) -> list[CacheStrategy]: "supports_fp16_layers": False, "required_llama_binary": "standard", }, - ] - - for spec in strategy_specs: - try: - module = importlib.import_module(spec["module"]) - cls = getattr(module, spec["class_name"]) - instance = cls() - except Exception as exc: - if spec["id"] == "native": - raise - instance = _BrokenStrategy( - strategy_id=str(spec["id"]), - name=str(spec["name"]), - bit_range=spec["bit_range"], - default_bits=spec["default_bits"], - supports_fp16_layers=bool(spec["supports_fp16_layers"]), - required_llama_binary=str(spec.get("required_llama_binary", "standard")), - reason=( - f"{spec['name']} could not be loaded in this runtime. " - f"ChaosEngineAI kept the card visible so the UI does not silently collapse to Native f16 only. " - f"Import error: {exc}" - ), - ) - self.register(instance) - return list(self._strategies.values()) + ] + + for spec in strategy_specs: + try: + module = importlib.import_module(spec["module"]) + cls = getattr(module, spec["class_name"]) + instance = cls() + except Exception as exc: + if spec["id"] == "native": + raise + instance = _BrokenStrategy( + strategy_id=str(spec["id"]), + name=str(spec["name"]), + bit_range=spec["bit_range"], + default_bits=spec["default_bits"], + supports_fp16_layers=bool(spec["supports_fp16_layers"]), + required_llama_binary=str(spec.get("required_llama_binary", "standard")), + reason=( + f"{spec['name']} could not be loaded in this runtime. " + f"ChaosEngineAI kept the card visible so the UI does not silently collapse to Native f16 only. " + f"Import error: {exc}" + ), + ) + self.register(instance) + self._discovered = True + return list(self._strategies.values()) class _BrokenStrategy(CacheStrategy): @@ -330,7 +349,6 @@ def required_llama_binary(self) -> str: # Module-level singleton — import and use ``registry`` directly. registry = CacheStrategyRegistry() -registry.discover() def apply_diffusion_cache_strategy( diff --git a/cache_compression/chaosengine.py b/cache_compression/chaosengine.py index e088dfe..c8d9000 100644 --- a/cache_compression/chaosengine.py +++ b/cache_compression/chaosengine.py @@ -16,28 +16,16 @@ from __future__ import annotations -import importlib +import importlib.util from typing import Any from cache_compression import CacheStrategy -def _load_chaosengine() -> Any | None: - try: - return importlib.import_module("chaos_engine") - except ImportError: - return None - - def _chaosengine_available() -> bool: - mod = _load_chaosengine() - if mod is None: - return False - # Check for the core cache module try: - cache_mod = importlib.import_module("chaos_engine.cache") - return hasattr(cache_mod, "config") or True - except ImportError: + return importlib.util.find_spec("chaos_engine") is not None + except (ImportError, AttributeError, ValueError): return False diff --git a/cache_compression/triattention.py b/cache_compression/triattention.py index fe2d339..ef524ec 100644 --- a/cache_compression/triattention.py +++ b/cache_compression/triattention.py @@ -17,36 +17,23 @@ from __future__ import annotations +import importlib.util from typing import Any from cache_compression import CacheStrategy -_triattention = None -_vllm = None -_mlx_lm = None -try: - import triattention as _triattention # type: ignore[import-untyped] -except ImportError: - pass -try: - import vllm as _vllm # type: ignore[import-untyped] -except ImportError: - pass -try: - import mlx_lm as _mlx_lm # type: ignore[import-untyped] -except ImportError: - pass +def _module_available(module_name: str) -> bool: + try: + return importlib.util.find_spec(module_name) is not None + except (ImportError, AttributeError, ValueError): + return False def _has_mlx_entrypoint() -> bool: - if _triattention is None or _mlx_lm is None: - return False - try: - from triattention.mlx import apply_triattention_mlx # noqa: F401 - return True - except ImportError: - return False + # Keep availability checks side-effect free. Importing mlx_lm can touch + # MLX/Metal at module load and can abort in headless or sandboxed contexts. + return _module_available("triattention") and _module_available("mlx_lm") class TriAttentionStrategy(CacheStrategy): @@ -63,7 +50,7 @@ def has_mlx_backend(self) -> bool: return _has_mlx_entrypoint() def has_vllm_backend(self) -> bool: - return _triattention is not None and _vllm is not None + return _module_available("triattention") and _module_available("vllm") def is_available(self) -> bool: return self.has_mlx_backend() or self.has_vllm_backend() @@ -119,7 +106,7 @@ def apply_vllm_patches(self) -> None: Must be called BEFORE creating a ``vllm.LLM`` instance. """ - if _triattention is None: + if not _module_available("triattention"): raise RuntimeError("triattention is not installed.") try: from triattention.vllm.runtime.integration_monkeypatch import ( diff --git a/cache_compression/turboquant.py b/cache_compression/turboquant.py index 32797f4..52e8044 100644 --- a/cache_compression/turboquant.py +++ b/cache_compression/turboquant.py @@ -51,6 +51,22 @@ def _has_required_turboquant_mlx_hooks() -> bool: return all(any(hook in source for source in sources) for hook in _REQUIRED_HOOKS) +def _has_full_turboquant_mlx_package() -> bool: + if not _has_required_turboquant_mlx_hooks(): + return False + try: + module = importlib.import_module("turboquant_mlx") + except ImportError: + return False + checker = getattr(module, "_has_full_turboquant", None) + if not callable(checker): + return False + try: + return bool(checker()) + except Exception: + return False + + def _load_turboquant_mlx_hooks() -> tuple[Any | None, Any | None]: if not _has_required_turboquant_mlx_hooks(): return None, None @@ -72,10 +88,10 @@ def name(self) -> str: return "TurboQuant" def is_available(self) -> bool: - # Keep availability probing side-effect free. Some MLX packages touch - # Metal during import, so we only report ready when the expected hooks - # are present in the installed source tree. - return _has_required_turboquant_mlx_hooks() + # The in-tree adapter provides ChaosEngineAI's stable hooks, while + # turboquant-mlx-full provides the actual TurboQuantKVCache. Report + # "Ready" only when both are visible to the backend runtime. + return _has_full_turboquant_mlx_package() def availability_badge(self) -> str: return "Ready" if self.is_available() else "Experimental" @@ -86,6 +102,11 @@ def availability_tone(self) -> str: def availability_reason(self) -> str | None: if self.is_available(): return None + if not _has_required_turboquant_mlx_hooks(): + return ( + "ChaosEngineAI's TurboQuant MLX adapter is not available in " + "this runtime. Rebuild or update the app, then retry." + ) return ( "Install turboquant-mlx (arozanov fork; PyPI name " "``turboquant-mlx-full``) into ChaosEngineAI's backend runtime: " diff --git a/pyproject.toml b/pyproject.toml index 479aec4..6e93ee3 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -69,4 +69,4 @@ testpaths = ["tests"] addopts = "--tb=short -q" [tool.setuptools.packages.find] -include = ["backend_service*", "cache_compression*", "dflash*"] +include = ["backend_service*", "cache_compression*", "dflash*", "turboquant_mlx*"] diff --git a/scripts/stage-runtime.mjs b/scripts/stage-runtime.mjs index 714e63d..bded9c9 100644 --- a/scripts/stage-runtime.mjs +++ b/scripts/stage-runtime.mjs @@ -61,7 +61,7 @@ function main() { pruneBundledProjectArtifacts(); ensureDir(backendDest); - for (const relativePath of ["backend_service", "cache_compression"]) { + for (const relativePath of ["backend_service", "cache_compression", "turboquant_mlx"]) { copyTree(path.join(workspaceRoot, relativePath), path.join(backendDest, relativePath)); } for (const relativeFile of ["README.md", "pyproject.toml"]) { diff --git a/src-tauri/src/lib.rs b/src-tauri/src/lib.rs index a46edc2..4f29137 100644 --- a/src-tauri/src/lib.rs +++ b/src-tauri/src/lib.rs @@ -636,12 +636,15 @@ fn apply_embedded_runtime_env(command: &mut Command, runtime: &EmbeddedRuntime) .env("PYTHONNOUSERSITE", "1") .env("CHAOSENGINE_EMBEDDED_RUNTIME", "1"); - // Prepend the user-local extras dir to PYTHONPATH so packages installed - // at runtime (CUDA torch, diffusers, etc. via /api/setup/install-gpu-bundle) - // shadow anything in the bundled site-packages. The extras dir lives - // outside the ephemeral %TEMP% runtime extraction so it survives app - // updates — the installer re-extracts the bundled runtime from scratch - // on each launch, but never touches the extras tree. + // Insert the user-local extras dir after the app backend but before + // bundled site-packages. Runtime-installed packages (CUDA torch, + // diffusers, etc.) still shadow bundled third-party wheels, while + // app-owned adapter modules in backend/ keep priority over same-named + // upstream packages installed into extras. + // + // The extras dir lives outside the ephemeral %TEMP% runtime extraction + // so it survives app updates — the installer re-extracts the bundled + // runtime from scratch on each launch, but never touches the extras tree. // (CHAOSENGINE_EXTRAS_SITE_PACKAGES is already set by the caller so // the backend can target it for pip --target installs.) let extras_dir = chaosengine_extras_site_packages_for_python( @@ -650,10 +653,13 @@ fn apply_embedded_runtime_env(command: &mut Command, runtime: &EmbeddedRuntime) ) .filter(|path| path.is_dir()); let mut python_path_entries: Vec = Vec::with_capacity(runtime.python_path.len() + 1); + if let Some(first) = runtime.python_path.first() { + python_path_entries.push(first.clone()); + } if let Some(extras) = extras_dir.as_ref() { python_path_entries.push(extras.clone()); } - python_path_entries.extend(runtime.python_path.iter().cloned()); + python_path_entries.extend(runtime.python_path.iter().skip(1).cloned()); if let Some(python_path) = join_paths(&python_path_entries) { command.env("PYTHONPATH", python_path); } diff --git a/src/App.tsx b/src/App.tsx index 877691a..20c2555 100644 --- a/src/App.tsx +++ b/src/App.tsx @@ -1,4 +1,4 @@ -import { useEffect, useRef, useState } from "react"; +import { useEffect, useMemo, useRef, useState } from "react"; import { checkBackend, convertModel, @@ -67,6 +67,7 @@ import { libraryItemBackend, libraryItemSourceKind, inferHfRepoFromLocalPath, + isChatLibraryItem, downloadProgressLabel, syncRuntime, settingsDraftFromWorkspace, @@ -174,6 +175,7 @@ export default function App() { systemPrompt, setSystemPrompt, serverModelKey, setServerModelKey, installingPackage, + installLogs, updateLaunchSetting, updateConversionDraft, handleAddDirectory, @@ -230,17 +232,18 @@ export default function App() { // ── Library state ────────────────────────────────────────── const [librarySearchInput, setLibrarySearchInput] = useState(""); - const [selectedLibraryPath, setSelectedLibraryPath] = useState(workspace.library[0]?.path ?? ""); + const [selectedLibraryPath, setSelectedLibraryPath] = useState(workspace.library.find(isChatLibraryItem)?.path ?? ""); const [expandedLibraryPath, setExpandedLibraryPath] = useState(null); const [librarySortKey, setLibrarySortKey] = useState<"name" | "format" | "backend" | "size" | "ram" | "compressed" | "modified" | "context">("modified"); const [librarySortDir, setLibrarySortDir] = useState<"asc" | "desc">("desc"); const [libraryCapFilter, setLibraryCapFilter] = useState(null); const [libraryFormatFilter, setLibraryFormatFilter] = useState(null); const [libraryBackendFilter, setLibraryBackendFilter] = useState(null); + const chatLibrary = useMemo(() => workspace.library.filter(isChatLibraryItem), [workspace.library]); // Library search sync useEffect(() => { - const nextFilteredLibrary = workspace.library + const nextFilteredLibrary = chatLibrary .filter((item) => { const haystack = `${item.name} ${item.path} ${item.format} ${item.directoryLabel ?? ""}`.toLowerCase(); return haystack.includes(librarySearchInput.trim().toLowerCase()); @@ -253,10 +256,10 @@ export default function App() { setSelectedLibraryPath((current) => nextFilteredLibrary.some((item) => item.path === current) ? current : nextFilteredLibrary[0].path, ); - }, [workspace.library, librarySearchInput]); + }, [chatLibrary, librarySearchInput]); // Library rows - const libraryRows = workspace.library.map((item) => { + const libraryRows = chatLibrary.map((item) => { const matchedVariant = findCatalogVariantForLibraryItem(workspace.featuredModels, item); return { item, @@ -314,7 +317,6 @@ export default function App() { }; }); const filteredLibraryRows = [...libraryRows, ...syntheticDownloadRows] - .filter(({ item }) => item.modelType === "text" || (!item.modelType)) .filter(({ item, displayFormat, displayQuantization, displayBackend, sourceKind }) => { const haystack = `${item.name} ${item.path} ${displayFormat} ${displayQuantization ?? ""} ${displayBackend} ${sourceKind} ${item.directoryLabel ?? ""}`.toLowerCase(); return haystack.includes(librarySearchInput.trim().toLowerCase()); @@ -337,28 +339,11 @@ export default function App() { const selectedLibraryVariant = selectedLibraryRow?.matchedVariant ?? null; // ── Chat model options ───────────────────────────────────── - const catalogChatOptions: ChatModelOption[] = allFeaturedVariants - .filter((variant) => variant.launchMode === "direct") - .map((variant) => ({ - key: `catalog:${variant.id}`, - label: variant.name, - detail: `${variant.format} / ${variant.quantization}`, - group: "Catalog", - model: variant.name, - modelRef: variant.id, - canonicalRepo: variant.repo, - source: "catalog", - backend: variant.backend, - paramsB: variant.paramsB, - sizeGb: variant.sizeGb, - contextWindow: variant.contextWindow, - format: variant.format, - quantization: variant.quantization, - maxContext: variant.maxContext ?? null, - })); - - const libraryChatOptions: ChatModelOption[] = workspace.library - .filter((item) => (item.modelType === "text" || (!item.modelType)) && !item.broken) + // Only list models present in the local library — catalog-only entries + // would let the user pick a model that isn't downloaded yet, which then + // 500s on Load. Discover tab is the place to pull a new model. + const libraryChatOptions: ChatModelOption[] = chatLibrary + .filter((item) => !item.broken) .map((item) => { const matched = findCatalogVariantForLibraryItem(workspace.featuredModels, item); const displayFormat = libraryItemFormat(item, matched); @@ -383,7 +368,7 @@ export default function App() { }; }); - const threadModelOptions = [...catalogChatOptions, ...libraryChatOptions]; + const threadModelOptions = libraryChatOptions; // ── Cache labels (needed early by useChat) ────────────────── const currentCacheLabel = launchSettings.cacheStrategy === "native" @@ -639,7 +624,7 @@ export default function App() { const previewSavings = Math.max(0, preview.baselineCacheGb - preview.optimizedCacheGb); const conversionReady = Boolean(nativeBackends?.converterAvailable ?? workspace.system.mlxLmAvailable); const enabledDirectoryCount = (workspace.settings?.modelDirectories ?? []).filter((directory) => directory.enabled).length; - const libraryTotalSizeGb = workspace.library.reduce((sum, item) => sum + item.sizeGb, 0); + const libraryTotalSizeGb = chatLibrary.reduce((sum, item) => sum + item.sizeGb, 0); const localVariantCount = allFeaturedVariants.filter((variant) => variant.availableLocally).length; const fileRevealLabel = workspace.system.platform === "Darwin" ? "Show in Finder" : @@ -664,11 +649,11 @@ export default function App() { if (!selectedServerOptionBase || selectedServerOptionBase.source !== "catalog") return selectedServerOptionBase; const variant = findVariantForReference(workspace.featuredModels, selectedServerOptionBase.modelRef, selectedServerOptionBase.model); if (!variant) return selectedServerOptionBase; - const localItem = findLibraryItemForVariant(workspace.library, variant); + const localItem = findLibraryItemForVariant(chatLibrary, variant); if (!localItem) return selectedServerOptionBase; return libraryChatOptions.find((option) => option.path === localItem.path) ?? selectedServerOptionBase; })(); - const convertibleLibrary = workspace.library.filter((item) => libraryItemFormat(item) !== "MLX"); + const convertibleLibrary = chatLibrary.filter((item) => libraryItemFormat(item) !== "MLX"); const conversionSource = convertibleLibrary.find((item) => item.path === conversionDraft.path) ?? null; const conversionVariant = (conversionSource ? findCatalogVariantForLibraryItem(workspace.featuredModels, conversionSource) : null) ?? @@ -777,12 +762,12 @@ export default function App() { if (!threadModelOptions.length) { setBenchmarkModelKey(""); return; } setBenchmarkModelKey((current) => { if (threadModelOptions.some((option) => option.key === current)) return current; - const firstHealthy = workspace.library.find((item) => !item.broken); + const firstHealthy = chatLibrary.find((item) => !item.broken); if (firstHealthy) return `library:${firstHealthy.path}`; - if (workspace.library.length > 0) return `library:${workspace.library[0].path}`; + if (chatLibrary.length > 0) return `library:${chatLibrary[0].path}`; return activeThreadOption?.key ?? loadedModelOption?.key ?? threadModelOptions[0].key; }); - }, [activeThreadOption?.key, loadedModelOption?.key, serverOptionKeySignature, workspace.library, setBenchmarkModelKey]); + }, [activeThreadOption?.key, chatLibrary, loadedModelOption?.key, serverOptionKeySignature, setBenchmarkModelKey]); // Sync benchmarkDraft model fields useEffect(() => { @@ -1022,7 +1007,7 @@ export default function App() { } function loadPayloadFromVariant(variant: ModelVariant, nextTab?: TabId) { - const localItem = findLibraryItemForVariant(workspace.library, variant); + const localItem = findLibraryItemForVariant(chatLibrary, variant); if (localItem) { return { modelRef: localItem.name, @@ -1050,7 +1035,7 @@ export default function App() { | "cacheStrategy" | "cacheBits" | "fp16Layers" | "fusedAttention" | "fitModelInMemory" | "contextTokens" | "speculativeDecoding" | "dflashDraftModel" | "treeBudget" > { - const localItem = findLibraryItemForVariant(workspace.library, variant); + const localItem = findLibraryItemForVariant(chatLibrary, variant); const modelRef = localItem?.name ?? variant.id; const modelName = localItem?.name ?? variant.name; const modelBackend = localItem ? libraryItemBackend(localItem, variant) : variant.backend; @@ -1124,7 +1109,7 @@ export default function App() { if (normalizedKey?.startsWith("catalog:")) { const modelRef = normalizedKey.slice("catalog:".length); const variant = findVariantForReference(workspace.featuredModels, modelRef, undefined); - const localItem = variant ? findLibraryItemForVariant(workspace.library, variant) : null; + const localItem = variant ? findLibraryItemForVariant(chatLibrary, variant) : null; if (localItem) normalizedKey = `library:${localItem.path}`; } // If no key given, or the key references a model no longer in the options @@ -1276,7 +1261,7 @@ export default function App() { expandedVariantId={expandedVariantId} onExpandedVariantIdChange={setExpandedVariantId} onDetailFamilyIdChange={setDetailFamilyId} - library={workspace.library} + library={chatLibrary} activeDownloads={activeDownloads} onDownloadModel={(repo) => void handleDownloadModel(repo)} onCancelModelDownload={(repo) => void handleCancelModelDownload(repo)} @@ -1491,7 +1476,7 @@ export default function App() { longLiveJob={videoState.longLiveJob} onActiveTabChange={setActiveTab} onOpenVideoStudio={videoState.openVideoStudio} - onVideoDownload={(repo) => void videoState.handleVideoDownload(repo)} + onVideoDownload={(repo, modelId) => void videoState.handleVideoDownload(repo, modelId)} onCancelVideoDownload={(repo) => void videoState.handleCancelVideoDownload(repo)} onDeleteVideoDownload={(repo) => void videoState.handleDeleteVideoDownload(repo)} onOpenExternalUrl={(url) => void handleOpenExternalUrl(url)} @@ -1513,7 +1498,7 @@ export default function App() { fileRevealLabel={fileRevealLabel} onActiveTabChange={setActiveTab} onOpenVideoStudio={videoState.openVideoStudio} - onVideoDownload={(repo) => void videoState.handleVideoDownload(repo)} + onVideoDownload={(repo, modelId) => void videoState.handleVideoDownload(repo, modelId)} onCancelVideoDownload={(repo) => void videoState.handleCancelVideoDownload(repo)} onDeleteVideoDownload={(repo) => void videoState.handleDeleteVideoDownload(repo)} onPreloadVideoModel={(variant) => void videoState.handlePreloadVideoModel(variant)} @@ -1573,7 +1558,7 @@ export default function App() { onActiveTabChange={setActiveTab} onPreloadVideoModel={(variant) => void videoState.handlePreloadVideoModel(variant)} onUnloadVideoModel={(variant) => void videoState.handleUnloadVideoModel(variant)} - onVideoDownload={(repo) => void videoState.handleVideoDownload(repo)} + onVideoDownload={(repo, modelId) => void videoState.handleVideoDownload(repo, modelId)} onGenerateVideo={() => void videoState.handleVideoGenerate()} onOpenExternalUrl={(url) => void handleOpenExternalUrl(url)} onRestartServer={() => void handleRestartServer()} @@ -1613,7 +1598,7 @@ export default function App() { convertibleLibrary={convertibleLibrary} nativeBackends={nativeBackends} preview={preview} - workspace={{ system: workspace.system, library: workspace.library }} + workspace={{ system: workspace.system, library: chatLibrary }} launchCacheLabel={launchCacheLabel} busy={busy} busyAction={busyAction} @@ -1730,7 +1715,7 @@ export default function App() { ) : null; @@ -1982,6 +1969,7 @@ export default function App() { availableCacheStrategies={workspace.system.availableCacheStrategies} dflashInfo={workspace.system.dflash} installingPackage={installingPackage} + installLogs={installLogs} turboInstalled={Boolean(workspace.system.llamaServerTurboPath)} onPendingLaunchChange={setPendingLaunch} onLaunchModelSearchChange={setLaunchModelSearch} @@ -2056,7 +2044,7 @@ export default function App() {
Variants ({family.variants.length}) {family.variants.map((variant) => { - const matchedLocal = findLibraryItemForVariant(workspace.library, variant); + const matchedLocal = findLibraryItemForVariant(chatLibrary, variant); const downloadState = activeDownloads[variant.repo]; const isDownloading = downloadState?.state === "downloading"; const isDownloadPaused = downloadState?.state === "cancelled"; diff --git a/src/api.ts b/src/api.ts index 1741f99..1881b06 100644 --- a/src/api.ts +++ b/src/api.ts @@ -678,8 +678,8 @@ export async function unloadImageModel(modelId?: string): Promise { - const result = await postJson<{ download: DownloadStatus }>("/api/video/download", { repo }); +export async function downloadVideoModel(repo: string, modelId?: string): Promise { + const result = await postJson<{ download: DownloadStatus }>("/api/video/download", { repo, modelId }); return result.download; } diff --git a/src/components/LatestImageDiscoverCard.tsx b/src/components/LatestImageDiscoverCard.tsx deleted file mode 100644 index 982e2c3..0000000 --- a/src/components/LatestImageDiscoverCard.tsx +++ /dev/null @@ -1,172 +0,0 @@ -import type { ImageModelVariant } from "../types"; -import type { DownloadStatus } from "../api"; -import { - imagePrimarySizeLabel, - imageSecondarySizeLabel, - formatImageLicenseLabel, - formatImageAccessError, - formatReleaseLabel, - isGatedImageAccessError, -} from "../utils/format"; -import { downloadProgressLabel, downloadSizeTooltip } from "../utils/downloads"; - -export interface LatestImageDiscoverCardProps { - variant: ImageModelVariant; - downloadState?: DownloadStatus; - fileRevealLabel?: string; - onDownload: (repo: string) => void; - onCancelDownload: (repo: string) => void; - onDeleteDownload: (repo: string) => void; - onOpenExternalUrl: (url: string) => void; - onNavigateSettings: () => void; - onRevealPath?: (path: string) => void; -} - -export function LatestImageDiscoverCard({ - variant, - downloadState, - fileRevealLabel, - onDownload, - onCancelDownload, - onDeleteDownload, - onOpenExternalUrl, - onNavigateSettings, - onRevealPath, -}: LatestImageDiscoverCardProps) { - const isDownloadPaused = downloadState?.state === "cancelled"; - const isDownloadComplete = downloadState?.state === "completed"; - const isDownloadFailed = downloadState?.state === "failed"; - const hasLocalData = Boolean(variant.hasLocalData || isDownloadComplete || isDownloadPaused || isDownloadFailed); - const friendlyDownloadError = formatImageAccessError(downloadState?.error, variant); - const needsGatedAccess = isGatedImageAccessError(downloadState?.error); - return ( -
-
-
-
-

{variant.name}

- {variant.provider} - {!variant.availableLocally && isDownloadComplete ? Downloaded : null} - {isDownloadPaused ? Paused : null} - {isDownloadFailed ? Download Failed : null} -
-

{variant.note}

-
- {variant.updatedLabel ?? "Recently updated"} -
- -
- {imagePrimarySizeLabel(variant)} - {imageSecondarySizeLabel(variant) ? {imageSecondarySizeLabel(variant)} : null} - {variant.recommendedResolution} - {variant.pipelineTag ? {variant.pipelineTag} : null} -
- -
- {formatReleaseLabel(variant.releaseLabel, variant.releaseDate ?? variant.createdAt) ? ( - {formatReleaseLabel(variant.releaseLabel, variant.releaseDate ?? variant.createdAt)} - ) : null} - {variant.downloadsLabel ? {variant.downloadsLabel} : null} - {variant.likesLabel ? {variant.likesLabel} : null} - {variant.license ? {formatImageLicenseLabel(variant.license)} : null} - {typeof variant.gated === "boolean" ? {variant.gated ? "Gated access" : "Open access"} : null} -
- -
- {variant.taskSupport.map((task) => ( - {task} - ))} - {variant.styleTags.map((tag) => ( - {tag} - ))} -
- - {isDownloadFailed && downloadState?.error ? ( -
-

{friendlyDownloadError}

- {needsGatedAccess ? ( -
- - -
- ) : null} - {friendlyDownloadError !== downloadState.error ? ( -
- Technical details -

{downloadState.error}

-
- ) : null} -
- ) : null} - -
- {variant.availableLocally ? ( - Installed - ) : downloadState?.state === "downloading" ? ( - <> - {downloadProgressLabel(downloadState)} - - - - ) : isDownloadPaused ? ( - <> - {downloadProgressLabel(downloadState)} - - - - ) : isDownloadFailed ? ( - <> - - - - ) : isDownloadComplete ? ( - Download complete - ) : ( - <> - - {hasLocalData ? ( - - ) : null} - - )} - {variant.localPath && onRevealPath ? ( - - ) : null} - -
-
- ); -} diff --git a/src/components/LaunchModal.tsx b/src/components/LaunchModal.tsx index b89d5ab..ba0d7a5 100644 --- a/src/components/LaunchModal.tsx +++ b/src/components/LaunchModal.tsx @@ -1,5 +1,5 @@ import { ModelLaunchModal } from "./ModelLaunchModal"; -import type { LaunchPreferences, PreviewMetrics, SystemStats } from "../types"; +import type { LaunchPreferences, PreviewMetrics, StrategyInstallLog, SystemStats } from "../types"; import type { ChatModelOption } from "../types/chat"; export interface PendingLaunch { @@ -19,6 +19,7 @@ export interface LaunchModalProps { availableCacheStrategies: SystemStats["availableCacheStrategies"] | undefined; dflashInfo?: SystemStats["dflash"]; installingPackage: string | null; + installLogs?: Record; turboInstalled?: boolean; onPendingLaunchChange: (value: PendingLaunch | null | ((prev: PendingLaunch | null) => PendingLaunch | null)) => void; onLaunchModelSearchChange: (value: string) => void; @@ -39,6 +40,7 @@ export function LaunchModal({ availableCacheStrategies, dflashInfo, installingPackage, + installLogs, turboInstalled, onPendingLaunchChange, onLaunchModelSearchChange, @@ -76,6 +78,7 @@ export function LaunchModal({ availableCacheStrategies={availableCacheStrategies} dflashInfo={dflashInfo} installingPackage={installingPackage} + installLogs={installLogs} turboInstalled={turboInstalled} onSelectedKeyChange={setSelectedLaunchKey} onSearchChange={onLaunchModelSearchChange} diff --git a/src/components/ModelActionIcons.tsx b/src/components/ModelActionIcons.tsx new file mode 100644 index 0000000..7a2329e --- /dev/null +++ b/src/components/ModelActionIcons.tsx @@ -0,0 +1,256 @@ +import type { ButtonHTMLAttributes, ReactNode } from "react"; + +export type ActionIconName = + | "cancel" + | "chat" + | "convert" + | "delete" + | "download" + | "generate" + | "huggingFace" + | "install" + | "modelCard" + | "pause" + | "resume" + | "retry" + | "reveal" + | "server"; + +export type ModelStatusKind = + | "downloaded" + | "downloading" + | "failed" + | "incomplete" + | "installed" + | "loaded" + | "paused"; + +type IconProps = { + name: ActionIconName | ModelStatusKind; + className?: string; +}; + +function Svg({ children, className }: { children: ReactNode; className?: string }) { + return ( + + ); +} + +export function ModelActionIcon({ name, className }: IconProps) { + switch (name) { + case "cancel": + case "failed": + return ( + + + + + + ); + case "chat": + return ( + + + + + + ); + case "convert": + return ( + + + + + + + ); + case "delete": + return ( + + + + + + + + ); + case "download": + case "downloading": + return ( + + + + + + ); + case "generate": + return ( + + + + + + + ); + case "huggingFace": + return ( + + + + + + + + + + + + ); + case "install": + return ( + + + + + + ); + case "modelCard": + return ( + + + + + + + + + ); + case "pause": + case "paused": + return ( + + + + + + ); + case "resume": + return ( + + + + ); + case "retry": + return ( + + + + + ); + case "reveal": + return ( + + + + + + ); + case "server": + return ( + + + + + + + + + ); + case "downloaded": + case "installed": + return ( + + + + + ); + case "loaded": + return ( + + + + + ); + case "incomplete": + default: + return ( + + + + + ); + } +} + +type IconActionButtonProps = ButtonHTMLAttributes & { + icon: ActionIconName; + label: string; + buttonStyle?: "primary" | "secondary"; + danger?: boolean; +}; + +export function IconActionButton({ + icon, + label, + buttonStyle = "secondary", + danger = false, + className = "", + title, + type = "button", + ...props +}: IconActionButtonProps) { + return ( + + ); +} + +export function StatusIcon({ + status, + label, + detail, +}: { + status: ModelStatusKind; + label: string; + detail?: string | null; +}) { + const title = detail ? `${label}: ${detail}` : label; + return ( + + + {title} + + ); +} diff --git a/src/components/ModelLaunchModal.tsx b/src/components/ModelLaunchModal.tsx index db8ff73..432ce6c 100644 --- a/src/components/ModelLaunchModal.tsx +++ b/src/components/ModelLaunchModal.tsx @@ -1,7 +1,7 @@ import { useEffect, useState } from "react"; import { RuntimeControls } from "./RuntimeControls"; import { number, sizeLabel } from "../utils"; -import type { LaunchPreferences, PreviewMetrics, SystemStats } from "../types"; +import type { LaunchPreferences, PreviewMetrics, StrategyInstallLog, SystemStats } from "../types"; import type { ChatModelOption } from "../types/chat"; export interface ModelLaunchModalProps { @@ -19,6 +19,7 @@ export interface ModelLaunchModalProps { availableCacheStrategies: SystemStats["availableCacheStrategies"] | undefined; dflashInfo?: SystemStats["dflash"]; installingPackage: string | null; + installLogs?: Record; turboInstalled?: boolean; onSelectedKeyChange: (key: string) => void; onSearchChange: (value: string) => void; @@ -43,6 +44,7 @@ export function ModelLaunchModal({ availableCacheStrategies, dflashInfo, installingPackage, + installLogs, turboInstalled, onSelectedKeyChange, onSearchChange, @@ -161,6 +163,7 @@ export function ModelLaunchModal({ availableCacheStrategies={availableCacheStrategies} onInstallPackage={onInstallPackage} installingPackage={installingPackage} + installLogs={installLogs} dflashInfo={dflashInfo} selectedBackend={selectedOption?.backend} selectedModelRef={selectedOption?.modelRef} diff --git a/src/components/PerformancePreview.tsx b/src/components/PerformancePreview.tsx index e7ee592..80e51a8 100644 --- a/src/components/PerformancePreview.tsx +++ b/src/components/PerformancePreview.tsx @@ -1,5 +1,6 @@ import type { PreviewMetrics } from "../types"; import { ProgressRow } from "./ProgressRow"; +import { getCacheFitStatus } from "../utils/cache"; interface PerformancePreviewProps { preview: PreviewMetrics; @@ -13,48 +14,6 @@ function fmt(value: number, digits = 1): string { return value.toFixed(digits); } -interface FitStatus { - label: string; - className: string; - /** Human-readable explanation of the dominant lever when things don't - * fit. Only populated for the "May not fit" tier — the other tiers are - * self-explanatory. */ - advice: string | null; -} - -function getFitStatus( - optimizedCacheGb: number, - diskSizeGb: number, - totalGb: number, - bits: number, -): FitStatus { - // Use total system memory since loading a new model unloads the previous one. - const totalNeeded = optimizedCacheGb + diskSizeGb; - // Reserve ~20% for OS and other apps - const usable = totalGb * 0.80; - const ratio = usable > 0 ? totalNeeded / usable : 1; - if (ratio < 0.7) return { label: "Fits easily", className: "success", advice: null }; - if (ratio < 0.95) return { label: "Tight fit", className: "warning", advice: null }; - - // "May not fit" — pick the most useful lever to show the user. When the - // cache pool dwarfs the weights (classic "256K context on a 26B model" - // situation), the right fix is context + strategy, not model size. When - // the weights themselves are the problem, no context lever will help. - const cacheDominates = optimizedCacheGb > diskSizeGb * 1.5; - let advice: string; - if (!cacheDominates) { - advice = - "Model weights alone exceed available RAM. Pick a smaller model or a more aggressive quantisation."; - } else if (bits <= 0) { - advice = - "Native f16 cache grows with context — at this setting it's bigger than RAM. Lower the context slider, or pick a compressed strategy (RotorQuant / TriAttention)."; - } else { - advice = - "Compressed cache still exceeds RAM at this context. Lower the context slider or reduce FP16 layers."; - } - return { label: "May not fit", className: "warning", advice }; -} - function getSpeedLabel(tokS: number): { label: string; className: string } | null { if (tokS < 5) return { label: "Slow", className: "perf-preview__speed-label--slow" }; if (tokS < 15) return { label: "Good", className: "perf-preview__speed-label--good" }; @@ -64,7 +23,7 @@ function getSpeedLabel(tokS: number): { label: string; className: string } | nul export function PerformancePreview({ preview, availableMemoryGb, totalMemoryGb, compact, actualDiskSizeGb }: PerformancePreviewProps) { const diskGb = actualDiskSizeGb ?? preview.diskSizeGb; - const fitStatus = getFitStatus(preview.optimizedCacheGb, diskGb, totalMemoryGb, preview.bits); + const fitStatus = getCacheFitStatus(preview.optimizedCacheGb, diskGb, totalMemoryGb, preview.bits); const cacheDelta = preview.baselineCacheGb - preview.optimizedCacheGb; const qualityDelta = preview.qualityPercent - 100; const cacheMax = Math.max(preview.baselineCacheGb, totalMemoryGb * 0.6, 1); diff --git a/src/components/RuntimeControls.tsx b/src/components/RuntimeControls.tsx index 1d0b98c..9480fcb 100644 --- a/src/components/RuntimeControls.tsx +++ b/src/components/RuntimeControls.tsx @@ -1,5 +1,5 @@ import { useEffect, useState } from "react"; -import type { LaunchPreferences, PreviewMetrics } from "../types"; +import type { LaunchPreferences, PreviewMetrics, StrategyInstallLog } from "../types"; import { SliderField } from "./SliderField"; import { PerformancePreview } from "./PerformancePreview"; import { @@ -121,6 +121,7 @@ interface RuntimeControlsProps { availableCacheStrategies?: CacheStrategyOption[]; onInstallPackage?: (strategyId: string) => void; installingPackage?: string | null; + installLogs?: Record; dflashInfo?: DFlashInfo; /** Backend of the selected model (e.g. "mlx", "gguf", "vllm", "auto"). Used for compatibility validation. */ selectedBackend?: string | null; @@ -133,6 +134,77 @@ interface RuntimeControlsProps { turboUpdateAvailable?: boolean; } +function StrategyInstallTerminal({ + label, + log, +}: { + label: string; + log?: StrategyInstallLog; +}) { + const status = log?.status ?? "idle"; + const summaryStatus = + status === "running" ? "running" : + status === "success" ? "complete" : + status === "failed" ? "failed" : + "ready"; + const lines = log?.steps.length + ? log.steps.map((step) => [ + `$ ${step.command}`, + `[${step.status.toUpperCase()}] ${step.label}`, + formatStrategyInstallOutput(step.output, step.status), + ].join("\n")).join("\n\n") + : "No install output yet. Run the installer to capture stdout and stderr here."; + + return ( +
+ + {label} install terminal + + {summaryStatus} + + +
+ {log?.startedAt ? Started {log.startedAt} : Collapsed by default. Open after an install attempt to inspect failures.} + {log?.finishedAt ? Finished {log.finishedAt} : null} +
+
{lines}
+
+ ); +} + +function formatStrategyInstallOutput(output: string, status: string): string { + const trimmed = output.trim(); + if (!trimmed || status !== "success") return trimmed || "(no output)"; + + const lines = trimmed.split(/\r?\n/); + const filtered: string[] = []; + let omittedResolverWarning = false; + let inResolverWarning = false; + for (const line of lines) { + const text = line.trim(); + if (/^ERROR: pip's dependency resolver does not currently take into account/i.test(text)) { + omittedResolverWarning = true; + inResolverWarning = true; + continue; + } + if (inResolverWarning) { + if ( + text === "" || + /^\S+\s+\S+\s+requires\s+.+which is not installed\.$/i.test(text) || + /^\S+\s+\S+\s+requires\s+.+but you have .+ which is incompatible\.$/i.test(text) + ) { + continue; + } + inResolverWarning = false; + } + filtered.push(line); + } + if (omittedResolverWarning) { + filtered.push("[pip resolver warnings omitted; install command exited successfully]"); + } + return filtered.join("\n").trim() || "(no output)"; +} + export function RuntimeControls({ settings, onChange, @@ -147,6 +219,7 @@ export function RuntimeControls({ availableCacheStrategies, onInstallPackage, installingPackage, + installLogs, dflashInfo, selectedBackend, selectedModelRef, @@ -173,6 +246,8 @@ export function RuntimeControls({ const dflashUnavailableReason = dflashSupport.reason; const ddtreeAvailable = dflashSupport.ddtreeAvailable; const canInstallDflashForModel = dflashSupport.modelSupported === true; + const dflashInstallLog = installLogs?.["dflash-mlx"] ?? installLogs?.dflash; + const showDflashInstallTerminal = Boolean(dflashInstallLog || (!dflashInstalled && !isGgufBackend && canInstallDflashForModel && onInstallPackage)); const specActive = settings.speculativeDecoding && dflashAvailable; const strategies = (availableCacheStrategies ?? [{id: "native", name: "Native f16", available: true, bitRange: null, defaultBits: null, supportsFp16Layers: false}]) .filter((s) => !s.appliesTo || s.appliesTo.length === 0 || s.appliesTo.includes("text")); @@ -180,6 +255,13 @@ export function RuntimeControls({ const selectedStrategy = strategies.find(s => s.id === settings.cacheStrategy) ?? strategies[0]; const fp16LayersSupported = Boolean(selectedStrategy?.supportsFp16Layers) && !isGgufBackend; const [expandedInfo, setExpandedInfo] = useState(null); + const isStrategyRuntimeAvailable = (strategy: CacheStrategyOption) => { + if (strategy.requiredLlamaBinary === "turbo" && isGgufBackend) { + return Boolean(turboInstalled); + } + return strategy.available; + }; + const selectedStrategyRuntimeAvailable = selectedStrategy ? isStrategyRuntimeAvailable(selectedStrategy) : false; useEffect(() => { if (isGgufBackend && settings.fp16Layers !== 0) { @@ -189,7 +271,7 @@ export function RuntimeControls({ useEffect(() => { if (settings.cacheStrategy === "native") return; - if (hasSelectedStrategy && selectedStrategy?.available && isStrategyCompatible(settings.cacheStrategy, selectedBackend)) return; + if (hasSelectedStrategy && selectedStrategyRuntimeAvailable && isStrategyCompatible(settings.cacheStrategy, selectedBackend)) return; onChange("cacheStrategy", "native"); if (settings.cacheBits !== 0) onChange("cacheBits", 0); if (settings.fp16Layers !== 0) onChange("fp16Layers", 0); @@ -197,7 +279,7 @@ export function RuntimeControls({ hasSelectedStrategy, onChange, selectedBackend, - selectedStrategy?.available, + selectedStrategyRuntimeAvailable, settings.cacheBits, settings.cacheStrategy, settings.fp16Layers, @@ -253,7 +335,7 @@ export function RuntimeControls({ } function selectStrategy(strategy: CacheStrategyOption) { - if (!strategy.available || !isStrategyCompatible(strategy.id, selectedBackend)) return; + if (!isStrategyRuntimeAvailable(strategy) || !isStrategyCompatible(strategy.id, selectedBackend)) return; onChange("cacheStrategy", strategy.id); if (strategy.defaultBits != null) { onChange("cacheBits", strategy.defaultBits); @@ -285,8 +367,9 @@ export function RuntimeControls({ const incompatReason = strategyIncompatReason(strategy.id, selectedBackend); const isIncompat = incompatReason != null; const needsTurbo = strategy.requiredLlamaBinary === "turbo"; - const turboMissing = needsTurbo && isGgufBackend && turboInstalled === false; - const isDisabled = !strategy.available || (specActive && strategy.id !== "native") || isIncompat || turboMissing; + const turboMissing = needsTurbo && isGgufBackend && !turboInstalled; + const runtimeAvailable = isStrategyRuntimeAvailable(strategy); + const isDisabled = !runtimeAvailable || (specActive && strategy.id !== "native") || isIncompat || turboMissing; return (
@@ -301,10 +384,10 @@ export function RuntimeControls({ {strategy.name} - {isIncompat ? "N/A" : turboMissing ? "No turbo binary" : strategy.available ? "Ready" : strategy.availabilityBadge ?? "Install"} + {isIncompat ? "N/A" : turboMissing ? "No turbo binary" : runtimeAvailable ? "Ready" : strategy.availabilityBadge ?? "Install"} {info ? ( @@ -321,7 +404,7 @@ export function RuntimeControls({ {isExpanded && info ? (

{info.description}

- {!strategy.available && strategy.availabilityReason ? ( + {!runtimeAvailable && strategy.availabilityReason ? (

{strategy.availabilityReason}

) : null}
@@ -332,7 +415,7 @@ export function RuntimeControls({
Install: {info.install} - {info.autoInstallPackage && onInstallPackage && !strategy.available ? ( + {info.autoInstallPackage && onInstallPackage && !runtimeAvailable ? (
) : null}
@@ -524,6 +613,9 @@ export function RuntimeControls({ ) : null}
) : null} + {showDflashInstallTerminal ? ( + + ) : null} {settings.speculativeDecoding && dflashAvailable ? (
) : null} @@ -408,6 +407,17 @@ export function OnlineModelsTab({ const isDownloadPaused = downloadState?.state === "cancelled"; const isDownloadFailed = downloadState?.state === "failed"; const isDownloadComplete = downloadState?.state === "completed"; + const hubStatus: { kind: ModelStatusKind; label: string; detail?: string | null } | null = model.availableLocally + ? { kind: "installed", label: "Installed" } + : isDownloadComplete + ? { kind: "downloaded", label: "Download complete" } + : isDownloading && downloadState + ? { kind: "downloading", label: "Downloading", detail: downloadProgressLabel(downloadState) } + : isDownloadPaused && downloadState + ? { kind: "paused", label: "Paused", detail: downloadProgressLabel(downloadState) } + : isDownloadFailed && downloadState + ? { kind: "failed", label: "Failed", detail: downloadState.error ?? "Download failed" } + : null; return (
{model.name} {model.provider} {model.format} - {model.availableLocally ? Downloaded : null} - {!model.availableLocally && isDownloadComplete ? Download complete : null} - {!model.availableLocally && isDownloading ? ( - {downloadProgressLabel(downloadState)} - ) : null} - {!model.availableLocally && isDownloadPaused ? ( - {downloadProgressLabel(downloadState)} - ) : null} - {!model.availableLocally && isDownloadFailed ? ( - Download failed - ) : null} + {hubStatus ? : null}
{formatReleaseLabel(model.releaseLabel, model.createdAt) ? ( @@ -554,66 +554,30 @@ export function OnlineModelsTab({
{model.availableLocally ? ( <> - - + onOpenModelSelector("thread")} /> + onOpenModelSelector("server")} /> ) : isDownloading ? ( <> - {downloadProgressLabel(downloadState)} - - + onCancelModelDownload(model.repo)} /> + onDeleteModelDownload(model.repo)} /> ) : isDownloadPaused ? ( <> - {downloadProgressLabel(downloadState)} - - + onDownloadModel(model.repo)} /> + onDeleteModelDownload(model.repo)} /> ) : isDownloadFailed ? ( <> - Download failed - - + onDownloadModel(model.repo)} /> + onDeleteModelDownload(model.repo)} /> ) : isDownloadComplete ? ( - Download complete + ) : ( - + onDownloadModel(model.repo)} /> )} - { - event.preventDefault(); - onOpenExternalUrl(model.link); - }} - > - Open on HuggingFace ↗ - + onOpenExternalUrl(model.link)} />
) : null} diff --git a/src/features/video/VideoDiscoverTab.tsx b/src/features/video/VideoDiscoverTab.tsx index 475bdb3..383b6aa 100644 --- a/src/features/video/VideoDiscoverTab.tsx +++ b/src/features/video/VideoDiscoverTab.tsx @@ -1,5 +1,6 @@ -import { useEffect } from "react"; +import { useEffect, useMemo, useState } from "react"; import { InstallLogPanel } from "../../components/InstallLogPanel"; +import { IconActionButton, StatusIcon } from "../../components/ModelActionIcons"; import { Panel } from "../../components/Panel"; import type { DownloadStatus, InstallResult, LongLiveJobState } from "../../api"; import type { @@ -10,21 +11,27 @@ import type { import type { DiscoverSort } from "../../types/image"; import type { VideoDiscoverTaskFilter } from "../../types/video"; import { + compactModelSizeLabel, + compactReleaseLabel, downloadProgressLabel, downloadSizeTooltip, formatReleaseLabel, - number, + videoDiscoverMemoryEstimate, + videoDeleteLabelForRepo, + videoDeleteRepoForVariant, + videoDownloadStatusForVariant, videoPrimarySizeLabel, videoSecondarySizeLabel, } from "../../utils"; +type MediaStatusFilter = "all" | "installed" | "not-installed" | "downloading" | "paused" | "failed" | "incomplete"; +type SortDir = "asc" | "desc"; + // LongLive ships via a dedicated Python installer (isolated venv + GitHub // clone + HF weights at Efficient-Large-Model/LongLive-1.3B), not via // snapshot_download. The catalog repo id ``NVlabs/LongLive-1.3B`` is the // GitHub org and intentionally does not resolve on Hugging Face — we use -// it purely as a routing key. Detect LongLive here so the Discover card -// can swap the Download button for an Install LongLive CTA that matches -// the Studio tab's existing install affordance. +// it purely as a routing key. function isLongLiveRepo(repo: string | undefined): boolean { return repo?.startsWith("NVlabs/LongLive") ?? false; } @@ -44,13 +51,10 @@ export interface VideoDiscoverTabProps { fileRevealLabel: string; longLiveStatus: VideoRuntimeStatus | null; installingLongLive: boolean; - // Live LongLive install job — same async-poll job as VideoStudioTab so - // either tab's "Install LongLive" button drives the same backend - // worker and renders the same per-phase terminal output. longLiveJob: LongLiveJobState | null; onActiveTabChange: (tab: TabId) => void; onOpenVideoStudio: (modelId?: string) => void; - onVideoDownload: (repo: string) => void; + onVideoDownload: (repo: string, modelId?: string) => void; onCancelVideoDownload: (repo: string) => void; onDeleteVideoDownload: (repo: string) => void; onOpenExternalUrl: (url: string) => void; @@ -59,6 +63,106 @@ export interface VideoDiscoverTabProps { onInstallLongLive: () => Promise; } +function videoDiscoverSortLabel(sort: DiscoverSort): string { + if (sort === "name") return "name"; + if (sort === "provider") return "provider"; + if (sort === "tasks") return "tasks"; + if (sort === "size") return "largest size first"; + if (sort === "ram") return "highest RAM/VRAM first"; + if (sort === "likes") return "most liked first"; + if (sort === "downloads") return "most downloads first"; + if (sort === "status") return "status"; + return "newest released first"; +} + +function sortIndicator(activeSort: DiscoverSort, sortDir: SortDir, key: DiscoverSort): string { + if (activeSort !== key) return ""; + return sortDir === "asc" ? " \u25B2" : " \u25BC"; +} + +function defaultSortDir(sort: DiscoverSort): SortDir { + return sort === "name" || sort === "provider" || sort === "tasks" ? "asc" : "desc"; +} + +function releaseSortKey(variant: VideoModelVariant): string { + return variant.releaseDate ?? variant.createdAt ?? variant.lastModified ?? ""; +} + +function sizeSortKey(variant: VideoModelVariant): number | null { + const candidates = [variant.onDiskGb, variant.coreWeightsGb, variant.repoSizeGb, variant.sizeGb]; + for (const value of candidates) { + if (typeof value === "number" && Number.isFinite(value) && value > 0) return value; + } + return null; +} + +function compareNullableNumberDesc(left: number | null, right: number | null): number { + const leftKnown = typeof left === "number" && Number.isFinite(left); + const rightKnown = typeof right === "number" && Number.isFinite(right); + if (leftKnown && rightKnown) return (right as number) - (left as number); + if (leftKnown) return -1; + if (rightKnown) return 1; + return 0; +} + +function compareNullableNumber(left: number | null, right: number | null, dir: SortDir): number { + const desc = compareNullableNumberDesc(left, right); + return dir === "desc" ? desc : -desc; +} + +function statusSortKey(status: MediaStatusFilter): number { + if (status === "installed") return 0; + if (status === "downloading") return 1; + if (status === "paused") return 2; + if (status === "failed") return 3; + if (status === "incomplete") return 4; + if (status === "not-installed") return 5; + return 6; +} + +function memoryParts(label: string | null | undefined): { primary: string; secondary: string | null } { + if (!label) return { primary: "pending", secondary: null }; + const [primary, secondary] = label.split(" @ "); + if (!secondary) return { primary, secondary: null }; + return { primary: `${primary} @`, secondary }; +} + +function videoVariantStatus( + variant: VideoModelVariant, + downloadState: DownloadStatus | undefined, + longLiveReady: boolean, + installingLongLive: boolean, +): MediaStatusFilter { + if (isLongLiveRepo(variant.repo)) { + if (longLiveReady) return "installed"; + if (installingLongLive) return "downloading"; + return "not-installed"; + } + if (variant.availableLocally || downloadState?.state === "completed") return "installed"; + if (downloadState?.state === "downloading") return "downloading"; + if (downloadState?.state === "cancelled") return "paused"; + if (downloadState?.state === "failed") return "failed"; + if (variant.hasLocalData) return "incomplete"; + return "not-installed"; +} + +function statusBadge(status: MediaStatusFilter, downloadState?: DownloadStatus, longLiveInstalling = false) { + const downloadDetail = downloadState + ? [downloadProgressLabel(downloadState), downloadSizeTooltip(downloadState)].filter(Boolean).join(" / ") + : null; + if (status === "installed") return ; + if (longLiveInstalling) return ; + if (status === "downloading" && downloadState) { + return ; + } + if (status === "paused" && downloadState) { + return ; + } + if (status === "failed") return ; + if (status === "incomplete") return ; + return ; +} + export function VideoDiscoverTab({ combinedVideoDiscoverResults, videoDiscoverSearchInput, @@ -85,22 +189,85 @@ export function VideoDiscoverTab({ onRefreshLongLiveStatus, onInstallLongLive, }: VideoDiscoverTabProps) { - // Probe LongLive install state whenever the results include a LongLive - // variant so the card can render "Installed" vs "Install LongLive" - // without the user having to open the Studio tab first. Mirrors the - // same effect in VideoStudioTab. const hasLongLiveVariant = combinedVideoDiscoverResults.some((variant) => isLongLiveRepo(variant.repo), ); useEffect(() => { if (hasLongLiveVariant) onRefreshLongLiveStatus(); }, [hasLongLiveVariant, onRefreshLongLiveStatus]); + + const [statusFilter, setStatusFilter] = useState("all"); + const [sortDir, setSortDir] = useState(defaultSortDir(videoDiscoverSort)); const longLiveReady = longLiveStatus?.realGenerationAvailable ?? false; + const filteredResults = useMemo( + () => + combinedVideoDiscoverResults + .map((variant) => { + const downloadState = videoDownloadStatusForVariant(activeVideoDownloads, variant); + const status = videoVariantStatus(variant, downloadState, longLiveReady, installingLongLive); + const memoryEstimate = videoDiscoverMemoryEstimate(variant); + return { variant, status, memoryEstimate }; + }) + .filter(({ status }) => statusFilter === "all" || status === statusFilter) + .sort((left, right) => { + if (videoDiscoverSort === "name") { + const diff = left.variant.name.localeCompare(right.variant.name); + return sortDir === "asc" ? diff : -diff; + } + if (videoDiscoverSort === "provider") { + const diff = left.variant.provider.localeCompare(right.variant.provider); + if (diff !== 0) return sortDir === "asc" ? diff : -diff; + } + if (videoDiscoverSort === "tasks") { + const diff = left.variant.taskSupport.join(" ").localeCompare(right.variant.taskSupport.join(" ")); + if (diff !== 0) return sortDir === "asc" ? diff : -diff; + } + if (videoDiscoverSort === "size") { + const diff = compareNullableNumber(sizeSortKey(left.variant), sizeSortKey(right.variant), sortDir); + if (diff !== 0) return diff; + } else if (videoDiscoverSort === "ram") { + const diff = compareNullableNumber(left.memoryEstimate?.estimatedPeakGb ?? null, right.memoryEstimate?.estimatedPeakGb ?? null, sortDir); + if (diff !== 0) return diff; + } else if (videoDiscoverSort === "status") { + const diff = statusSortKey(left.status) - statusSortKey(right.status); + if (diff !== 0) return sortDir === "asc" ? diff : -diff; + } else if (videoDiscoverSort === "likes") { + const diff = compareNullableNumber(left.variant.likes ?? null, right.variant.likes ?? null, sortDir); + if (diff !== 0) return diff; + } else if (videoDiscoverSort === "downloads") { + const diff = compareNullableNumber(left.variant.downloads ?? null, right.variant.downloads ?? null, sortDir); + if (diff !== 0) return diff; + } + const dateDiff = releaseSortKey(right.variant).localeCompare(releaseSortKey(left.variant)); + if (dateDiff !== 0) return sortDir === "desc" ? dateDiff : -dateDiff; + return left.variant.name.localeCompare(right.variant.name); + }), + [ + activeVideoDownloads, + combinedVideoDiscoverResults, + installingLongLive, + longLiveReady, + sortDir, + statusFilter, + videoDiscoverSort, + ], + ); + const hasActiveFilters = videoDiscoverHasActiveFilters || statusFilter !== "all"; + + function applySort(nextSort: DiscoverSort) { + if (videoDiscoverSort === nextSort) { + setSortDir(sortDir === "asc" ? "desc" : "asc"); + } else { + onVideoDiscoverSortChange(nextSort); + setSortDir(defaultSortDir(nextSort)); + } + } + return (
@@ -143,16 +310,42 @@ export function VideoDiscoverTab({ +
@@ -162,8 +355,11 @@ export function VideoDiscoverTab({ onClick={() => { onVideoDiscoverSearchInputChange(""); onVideoDiscoverTaskFilterChange("all"); + setStatusFilter("all"); + onVideoDiscoverSortChange("release"); + setSortDir("desc"); }} - disabled={!videoDiscoverHasActiveFilters} + disabled={!hasActiveFilters} > Clear Filters @@ -172,12 +368,7 @@ export function VideoDiscoverTab({
- {combinedVideoDiscoverResults.length} model{combinedVideoDiscoverResults.length !== 1 ? "s" : ""} ·{" "} - {videoDiscoverSort === "likes" - ? "most liked first" - : videoDiscoverSort === "downloads" - ? "most downloads first" - : "newest released first"} + {filteredResults.length} model{filteredResults.length !== 1 ? "s" : ""} · {videoDiscoverSortLabel(videoDiscoverSort)} {videoDiscoverSearchQuery ? ( Search: {videoDiscoverSearchInput.trim()} @@ -185,152 +376,142 @@ export function VideoDiscoverTab({ {videoDiscoverTaskFilter !== "all" ? ( Task: {videoDiscoverTaskFilter} ) : null} + {statusFilter !== "all" ? Status: {statusFilter} : null}
- {combinedVideoDiscoverResults.length === 0 ? ( + {filteredResults.length === 0 ? (

Try broadening the filters or search terms.

) : ( -
- {combinedVideoDiscoverResults.map((variant) => { - const isLongLive = isLongLiveRepo(variant.repo); - const downloadState = activeVideoDownloads[variant.repo]; - const isDownloading = downloadState?.state === "downloading"; - const isPaused = downloadState?.state === "cancelled"; - const isDownloadComplete = downloadState?.state === "completed"; - // LongLive never goes through the HF download pipeline — stale - // failure states from a prior mis-routed Download click would - // otherwise keep rendering "Download Failed" even after we - // offer the correct install CTA. - const isDownloadFailed = - !isLongLive && downloadState?.state === "failed"; - const isComplete = isLongLive - ? longLiveReady - : variant.availableLocally || isDownloadComplete; - const isPartial = !isLongLive && !isComplete && variant.hasLocalData; - const canDeleteLocalData = isLongLive - ? false - : Boolean( - isComplete || isDownloadComplete || isPaused || isDownloadFailed || isPartial, - ); - return ( -
-
-
-

{variant.name}

-

{variant.familyName ?? variant.provider}

-
- {isComplete ? ( - Installed - ) : isLongLive ? ( - installingLongLive ? ( - Installing… - ) : ( - Not installed - ) - ) : isDownloading ? ( - - {downloadProgressLabel(downloadState)} +
+
+ + + + + + + + +
+
+ {filteredResults.map(({ variant, status, memoryEstimate }) => { + const isLongLive = isLongLiveRepo(variant.repo); + const downloadState = videoDownloadStatusForVariant(activeVideoDownloads, variant); + const isComplete = status === "installed"; + const isDownloading = status === "downloading"; + const isPaused = status === "paused"; + const isDownloadFailed = status === "failed"; + const isPartial = status === "incomplete"; + const isDownloadComplete = downloadState?.state === "completed"; + const canDeleteLocalData = isLongLive + ? false + : Boolean(isComplete || isDownloadComplete || isPaused || isDownloadFailed || isPartial); + const localStatusReason = !isComplete && !isDownloading ? variant.localStatusReason : null; + const secondarySize = videoSecondarySizeLabel(variant); + const releaseLabel = compactReleaseLabel(formatReleaseLabel(variant.releaseLabel, variant.releaseDate ?? variant.createdAt)); + const primarySizeLabel = videoPrimarySizeLabel(variant); + const sizeTitle = [primarySizeLabel, secondarySize].filter(Boolean).join(" / "); + const memory = memoryParts(memoryEstimate?.label); + const deleteRepo = videoDeleteRepoForVariant(variant, downloadState); + const deleteLabel = isDownloading + ? "Cancel download" + : videoDeleteLabelForRepo(variant, deleteRepo, "Delete model"); + return ( +
+
+
+ {variant.name} + {variant.note} +
+ {variant.styleTags.slice(0, 4).map((tag) => ( + {tag} + ))} +
+
+ {variant.provider} +
+ {variant.taskSupport.map((task) => ( + {task} + ))} +
+ + {compactModelSizeLabel(primarySizeLabel)} - ) : isPaused ? ( - {downloadProgressLabel(downloadState)} - ) : isDownloadFailed ? ( - Download Failed - ) : isPartial ? ( - Incomplete - ) : null} -
-
- {videoPrimarySizeLabel(variant)} - {videoSecondarySizeLabel(variant) ? {videoSecondarySizeLabel(variant)} : null} - {variant.recommendedResolution} - {number(variant.defaultDurationSeconds)}s clip - {formatReleaseLabel(variant.releaseLabel, variant.releaseDate) ? ( - {formatReleaseLabel(variant.releaseLabel, variant.releaseDate)} - ) : null} - {variant.downloadsLabel ? {variant.downloadsLabel} : null} - {variant.likesLabel ? {variant.likesLabel} : null} - {variant.styleTags.slice(0, 3).map((tag) => ( - {tag} - ))} -
-

{variant.note}

- {isLongLive && !isComplete ? ( -

- LongLive installs into an isolated venv at{" "} - ~/.chaosengine/longlive. CUDA only, 5–15 min - depending on network. -

- ) : null} - {isDownloadFailed && downloadState?.error ? ( -

{downloadState.error}

- ) : null} -
- {isLongLive ? ( - isComplete ? ( - - ) : ( - <> - - - - ) - ) : isComplete ? ( - - ) : isDownloading ? ( - - ) : isPaused ? ( - - ) : ( - - )} - {!isLongLive && (isDownloading || canDeleteLocalData) ? ( - - ) : null} - {variant.localPath ? ( - + + {memory.primary} + {memory.secondary ? {memory.secondary} : null} + + + {releaseLabel ?? "Unknown"} + {variant.downloadsLabel ? {variant.downloadsLabel} : null} + {variant.likesLabel ? {variant.likesLabel} : null} + + {statusBadge(status, downloadState, isLongLive && installingLongLive && !longLiveReady)} +
+ {isLongLive ? ( + isComplete ? ( + onOpenVideoStudio(variant.id)} /> + ) : ( + <> + void onInstallLongLive()} disabled={installingLongLive} /> + + + ) + ) : isComplete ? ( + onOpenVideoStudio(variant.id)} /> + ) : isDownloading ? ( + <> + onCancelVideoDownload(downloadState?.repo ?? variant.repo)} /> + onDeleteVideoDownload(deleteRepo)} /> + + ) : isPaused ? ( + <> + onVideoDownload(variant.repo, variant.id)} /> + onDeleteVideoDownload(deleteRepo)} /> + + ) : ( + onVideoDownload(variant.repo, variant.id)} /> + )} + {!isLongLive && !isDownloading && !isPaused && canDeleteLocalData ? ( + onDeleteVideoDownload(deleteRepo)} /> + ) : null} + {variant.localPath ? ( + onRevealPath(variant.localPath as string)} /> + ) : null} + onOpenExternalUrl(variant.link)} /> +
+
+ {isLongLive && !isComplete ? ( +
+

+ LongLive installs into an isolated venv at ~/.chaosengine/longlive. + CUDA only, 5-15 min depending on network. +

+
+ ) : isDownloadFailed && downloadState?.error ? ( +
+

{downloadState.error}

+
+ ) : localStatusReason ? ( +
+

{localStatusReason}

+
) : null} -
-
- ); - })} + ); + })} +
)}
diff --git a/src/features/video/VideoModelsTab.tsx b/src/features/video/VideoModelsTab.tsx index 83760b3..c6ab00f 100644 --- a/src/features/video/VideoModelsTab.tsx +++ b/src/features/video/VideoModelsTab.tsx @@ -1,4 +1,6 @@ +import { useMemo, useState } from "react"; import { Panel } from "../../components/Panel"; +import { IconActionButton, StatusIcon } from "../../components/ModelActionIcons"; import type { DownloadStatus } from "../../api"; import type { TabId, @@ -6,7 +8,22 @@ import type { VideoModelVariant, VideoRuntimeStatus, } from "../../types"; -import { downloadProgressLabel, formatReleaseLabel, number, videoPrimarySizeLabel } from "../../utils"; +import { + compactModelSizeLabel, + compactReleaseLabel, + downloadProgressLabel, + formatReleaseLabel, + videoDiscoverMemoryEstimate, + videoDeleteLabelForRepo, + videoDeleteRepoForVariant, + videoDownloadStatusForVariant, + videoPrimarySizeLabel, + videoSecondarySizeLabel, +} from "../../utils"; + +type InstalledVideoSort = "name" | "provider" | "tasks" | "size" | "ram" | "date" | "status"; +type SortDir = "asc" | "desc"; +type InstalledVideoStatusFilter = "all" | "loaded" | "installed" | "incomplete" | "downloading" | "paused" | "failed"; export interface VideoModelsTabProps { installedVideoVariants: VideoModelVariant[]; @@ -19,7 +36,7 @@ export interface VideoModelsTabProps { fileRevealLabel: string; onActiveTabChange: (tab: TabId) => void; onOpenVideoStudio: (modelId?: string) => void; - onVideoDownload: (repo: string) => void; + onVideoDownload: (repo: string, modelId?: string) => void; onCancelVideoDownload: (repo: string) => void; onDeleteVideoDownload: (repo: string) => void; onPreloadVideoModel: (variant: VideoModelVariant) => void; @@ -28,6 +45,91 @@ export interface VideoModelsTabProps { onRevealPath: (path: string) => void; } +function releaseSortKey(variant: VideoModelVariant): string { + return variant.releaseDate ?? variant.createdAt ?? variant.lastModified ?? ""; +} + +function sizeSortKey(variant: VideoModelVariant): number | null { + const candidates = [variant.onDiskGb, variant.coreWeightsGb, variant.repoSizeGb, variant.sizeGb]; + for (const value of candidates) { + if (typeof value === "number" && Number.isFinite(value) && value > 0) return value; + } + return null; +} + +function compareNullableNumberDesc(left: number | null, right: number | null): number { + const leftKnown = typeof left === "number" && Number.isFinite(left); + const rightKnown = typeof right === "number" && Number.isFinite(right); + if (leftKnown && rightKnown) return (right as number) - (left as number); + if (leftKnown) return -1; + if (rightKnown) return 1; + return 0; +} + +function compareNullableNumber(left: number | null, right: number | null, dir: SortDir): number { + const desc = compareNullableNumberDesc(left, right); + return dir === "desc" ? desc : -desc; +} + +function statusSortKey(status: InstalledVideoStatusFilter): number { + if (status === "loaded") return 0; + if (status === "installed") return 1; + if (status === "downloading") return 2; + if (status === "paused") return 3; + if (status === "failed") return 4; + if (status === "incomplete") return 5; + return 6; +} + +function defaultSortDir(sort: InstalledVideoSort): SortDir { + return sort === "name" || sort === "provider" || sort === "tasks" ? "asc" : "desc"; +} + +function videoStatus( + variant: VideoModelVariant, + downloadState: DownloadStatus | undefined, + loadedVideoVariant: VideoModelVariant | null, +): InstalledVideoStatusFilter { + if (loadedVideoVariant?.id === variant.id) return "loaded"; + if (downloadState?.state === "downloading") return "downloading"; + if (downloadState?.state === "cancelled") return "paused"; + if (downloadState?.state === "failed") return "failed"; + if (variant.availableLocally || downloadState?.state === "completed") return "installed"; + return "incomplete"; +} + +function statusBadge(status: InstalledVideoStatusFilter, downloadState?: DownloadStatus) { + if (status === "loaded") return ; + if (status === "installed") return ; + if (status === "downloading" && downloadState) return ; + if (status === "paused" && downloadState) return ; + if (status === "failed") return ; + return ; +} + +function sortIndicator(activeSort: InstalledVideoSort, sortDir: SortDir, key: InstalledVideoSort): string { + if (activeSort !== key) return ""; + return sortDir === "asc" ? " \u25B2" : " \u25BC"; +} + +function sortLabel(sort: InstalledVideoSort, sortDir: SortDir): string { + const direction = sortDir === "asc" ? "ascending" : "descending"; + if (sort === "provider") return `provider ${direction}`; + if (sort === "tasks") return `tasks ${direction}`; + if (sort === "size") return sortDir === "desc" ? "largest size first" : "smallest size first"; + if (sort === "ram") return sortDir === "desc" ? "highest RAM/VRAM first" : "lowest RAM/VRAM first"; + if (sort === "status") return `status ${direction}`; + if (sort === "name") return sortDir === "asc" ? "name A-Z" : "name Z-A"; + return sortDir === "desc" ? "newest released first" : "oldest released first"; +} + +function memoryParts(label: string | null | undefined): { primary: string; secondary: string | null } { + if (!label) return { primary: "pending", secondary: null }; + const [primary, secondary] = label.split(" @ "); + if (!secondary) return { primary, secondary: null }; + return { primary: `${primary} @`, secondary }; +} + export function VideoModelsTab({ installedVideoVariants, videoCatalog, @@ -47,12 +149,86 @@ export function VideoModelsTab({ onOpenExternalUrl, onRevealPath, }: VideoModelsTabProps) { + const [searchInput, setSearchInput] = useState(""); + const [taskFilter, setTaskFilter] = useState<"all" | VideoModelVariant["taskSupport"][number]>("all"); + const [statusFilter, setStatusFilter] = useState("all"); + const [sort, setSort] = useState("date"); + const [sortDir, setSortDir] = useState("desc"); + const normalizedSearch = searchInput.trim().toLowerCase(); + const hasActiveFilters = + normalizedSearch.length > 0 || taskFilter !== "all" || statusFilter !== "all" || sort !== "date" || sortDir !== "desc"; + + function applySort(nextSort: InstalledVideoSort) { + if (sort === nextSort) { + setSortDir(sortDir === "asc" ? "desc" : "asc"); + } else { + setSort(nextSort); + setSortDir(defaultSortDir(nextSort)); + } + } + + const rows = useMemo(() => { + return installedVideoVariants + .map((variant) => { + const family = videoCatalog.find((item) => + item.variants.some((candidate) => candidate.id === variant.id), + ); + const downloadState = videoDownloadStatusForVariant(activeVideoDownloads, variant); + const status = videoStatus(variant, downloadState, loadedVideoVariant); + const memoryEstimate = videoDiscoverMemoryEstimate(variant); + return { variant, family, downloadState, status, memoryEstimate }; + }) + .filter(({ variant, family, status }) => { + if (taskFilter !== "all" && !variant.taskSupport.includes(taskFilter)) return false; + if (statusFilter !== "all" && status !== statusFilter) return false; + if (!normalizedSearch) return true; + const haystack = [ + variant.name, + variant.provider, + variant.repo, + variant.runtime, + family?.name ?? "", + variant.recommendedResolution, + variant.styleTags.join(" "), + variant.taskSupport.join(" "), + ].join(" ").toLowerCase(); + return haystack.includes(normalizedSearch); + }) + .sort((left, right) => { + if (sort === "name") { + const diff = left.variant.name.localeCompare(right.variant.name); + return sortDir === "asc" ? diff : -diff; + } + if (sort === "provider") { + const diff = left.variant.provider.localeCompare(right.variant.provider); + if (diff !== 0) return sortDir === "asc" ? diff : -diff; + } + if (sort === "tasks") { + const diff = left.variant.taskSupport.join(" ").localeCompare(right.variant.taskSupport.join(" ")); + if (diff !== 0) return sortDir === "asc" ? diff : -diff; + } + if (sort === "size") { + const diff = compareNullableNumber(sizeSortKey(left.variant), sizeSortKey(right.variant), sortDir); + if (diff !== 0) return diff; + } else if (sort === "ram") { + const diff = compareNullableNumber(left.memoryEstimate?.estimatedPeakGb ?? null, right.memoryEstimate?.estimatedPeakGb ?? null, sortDir); + if (diff !== 0) return diff; + } else if (sort === "status") { + const diff = statusSortKey(left.status) - statusSortKey(right.status); + if (diff !== 0) return sortDir === "asc" ? diff : -diff; + } + const dateDiff = releaseSortKey(right.variant).localeCompare(releaseSortKey(left.variant)); + if (dateDiff !== 0) return sortDir === "desc" ? dateDiff : -dateDiff; + return left.variant.name.localeCompare(right.variant.name); + }); + }, [activeVideoDownloads, installedVideoVariants, loadedVideoVariant, normalizedSearch, sort, sortDir, statusFilter, taskFilter, videoCatalog]); + return (
0 - ? `${installedVideoVariants.length} model${installedVideoVariants.length !== 1 ? "s" : ""} with local data` + ? `${rows.length} of ${installedVideoVariants.length} model${installedVideoVariants.length !== 1 ? "s" : ""} with local data` : "No video models detected locally yet"} className="span-2" actions={ @@ -66,123 +242,185 @@ export function VideoModelsTab({

Download a video model from Video Discover to get started.

) : ( -
- {installedVideoVariants.map((variant) => { - const family = videoCatalog.find((item) => - item.variants.some((candidate) => candidate.id === variant.id), - ); - const isComplete = variant.availableLocally; - const isPartial = !isComplete && variant.hasLocalData; - const downloadState = activeVideoDownloads[variant.repo]; - const isDownloading = downloadState?.state === "downloading"; - const isPaused = downloadState?.state === "cancelled"; - const isDownloadComplete = downloadState?.state === "completed"; - const isDownloadFailed = downloadState?.state === "failed"; - const canDeleteLocalData = Boolean( - isComplete || isDownloadComplete || isPaused || isDownloadFailed || isPartial, - ); - const isLoadedInMemory = loadedVideoVariant?.id === variant.id; - const canPreload = isComplete && videoRuntimeStatus.realGenerationAvailable && !isLoadedInMemory; - return ( -
-
-
-

{variant.name}

-

{family?.name ?? variant.provider}

-
- {isLoadedInMemory ? ( - In Memory - ) : isComplete || isDownloadComplete ? ( - Installed - ) : isDownloading ? ( - {downloadProgressLabel(downloadState)} - ) : isPaused ? ( - {downloadProgressLabel(downloadState)} - ) : isDownloadFailed ? ( - Download Failed - ) : isPartial ? ( - Incomplete - ) : null} -
-
- {videoPrimarySizeLabel(variant)} - {variant.recommendedResolution} - {number(variant.defaultDurationSeconds)}s clip - {formatReleaseLabel(variant.releaseLabel, variant.releaseDate) ? ( - {formatReleaseLabel(variant.releaseLabel, variant.releaseDate)} - ) : null} - {variant.styleTags.slice(0, 3).map((tag) => ( - {tag} - ))} -
- {isDownloadFailed && downloadState?.error ? ( -

{downloadState.error}

- ) : null} -
- {isComplete || isDownloadComplete ? ( - - ) : isDownloading ? ( - - ) : isPaused ? ( - - ) : ( - - )} - {canPreload ? ( - - ) : null} - {isLoadedInMemory ? ( - - ) : null} - {isDownloading || canDeleteLocalData ? ( - - ) : null} - {variant.localPath ? ( - - ) : null} - -
-
- ); - })} -
+ <> +
+ + + + +
+ +
+
+
+ {rows.length} model{rows.length !== 1 ? "s" : ""} · {sortLabel(sort, sortDir)} + {normalizedSearch ? Search: {searchInput.trim()} : null} + {taskFilter !== "all" ? Task: {taskFilter} : null} + {statusFilter !== "all" ? Status: {statusFilter} : null} +
+ {rows.length === 0 ? ( +
+

No installed video models match the current filters.

+
+ ) : ( +
+
+ + + + + + + + +
+
+ {rows.map(({ variant, family, downloadState, status, memoryEstimate }) => { + const isLoadedInMemory = status === "loaded"; + const isComplete = status === "loaded" || status === "installed"; + const isDownloading = status === "downloading"; + const isPaused = status === "paused"; + const isDownloadFailed = status === "failed"; + const isPartial = status === "incomplete"; + const canDeleteLocalData = Boolean(isComplete || isPaused || isDownloadFailed || isPartial); + const localStatusReason = !isComplete && !isDownloading ? variant.localStatusReason : null; + const secondarySize = videoSecondarySizeLabel(variant); + const releaseLabel = compactReleaseLabel(formatReleaseLabel(variant.releaseLabel, variant.releaseDate ?? variant.createdAt)); + const primarySizeLabel = videoPrimarySizeLabel(variant); + const sizeTitle = [primarySizeLabel, secondarySize].filter(Boolean).join(" / "); + const memory = memoryParts(memoryEstimate?.label); + const deleteRepo = videoDeleteRepoForVariant(variant, downloadState); + const deleteLabel = isDownloading + ? "Cancel download" + : videoDeleteLabelForRepo(variant, deleteRepo, "Delete model"); + return ( +
+
+
+ {variant.name} + {family?.name ?? variant.provider} +
+ {variant.styleTags.slice(0, 4).map((tag) => ( + {tag} + ))} +
+
+ {variant.provider} +
+ {variant.taskSupport.map((task) => ( + {task} + ))} +
+ + {compactModelSizeLabel(primarySizeLabel)} + + + {memory.primary} + {memory.secondary ? {memory.secondary} : null} + + {releaseLabel ?? "Unknown"} + {statusBadge(status, downloadState)} +
+ {isComplete ? ( + onOpenVideoStudio(variant.id)} /> + ) : isDownloading ? ( + onCancelVideoDownload(downloadState?.repo ?? variant.repo)} /> + ) : ( + onVideoDownload(variant.repo, variant.id)} /> + )} + {isDownloading || canDeleteLocalData ? ( + onDeleteVideoDownload(deleteRepo)} /> + ) : null} + {variant.localPath ? ( + onRevealPath(variant.localPath as string)} /> + ) : null} + onOpenExternalUrl(variant.link)} /> +
+
+ {isDownloadFailed && downloadState?.error ? ( +
+

{downloadState.error}

+
+ ) : localStatusReason ? ( +
+

{localStatusReason}

+
+ ) : null} +
+ ); + })} +
+
+ )} + )}
diff --git a/src/features/video/VideoStudioTab.tsx b/src/features/video/VideoStudioTab.tsx index 26448ed..9eb6cbf 100644 --- a/src/features/video/VideoStudioTab.tsx +++ b/src/features/video/VideoStudioTab.tsx @@ -15,6 +15,7 @@ import { defaultVideoVariantForFamily, downloadProgressLabel, number, + videoDownloadStatusForVariant, videoPrimarySizeLabel, videoSecondarySizeLabel, } from "../../utils"; @@ -68,7 +69,7 @@ export interface VideoStudioTabProps { onActiveTabChange: (tab: TabId) => void; onPreloadVideoModel: (variant: VideoModelVariant) => void; onUnloadVideoModel: (variant?: VideoModelVariant) => void; - onVideoDownload: (repo: string) => void; + onVideoDownload: (repo: string, modelId?: string) => void; onGenerateVideo: () => void; onOpenExternalUrl: (url: string) => void; onRestartServer: () => void; @@ -285,6 +286,7 @@ export function VideoStudioTab({ const mp4EncoderMissing = missingDependencies.some( (dep) => dep === "imageio" || dep === "imageio-ffmpeg", ); + const gpuBundleRestartRequired = gpuBundleJob?.phase === "done" && gpuBundleJob.requiresRestart; // Tokenizer / text-encoder packages individual pipelines need lazily — // tiktoken for LTX-Video, sentencepiece for Wan / HunyuanVideo / CogVideoX // / Mochi, plus the protobuf + ftfy support libs. We list them out as a @@ -343,7 +345,7 @@ export function VideoStudioTab({ variants: family.variants.filter((variant) => { if (variant.availableLocally) return true; if (variant.hasLocalData) return true; - const downloadState = activeVideoDownloads[variant.repo]; + const downloadState = videoDownloadStatusForVariant(activeVideoDownloads, variant); return downloadState?.state === "downloading" || downloadState?.state === "completed"; }), })) @@ -426,7 +428,7 @@ export function VideoStudioTab({ && !(mlxVideoStatus.missingDependencies ?? []).includes("mlx-video"); const downloadState = useMemo( - () => (selectedVideoVariant ? activeVideoDownloads[selectedVideoVariant.repo] : undefined), + () => (selectedVideoVariant ? videoDownloadStatusForVariant(activeVideoDownloads, selectedVideoVariant) : undefined), [activeVideoDownloads, selectedVideoVariant], ); const isDownloading = downloadState?.state === "downloading"; @@ -457,6 +459,8 @@ export function VideoStudioTab({ ? "Choose a video model first." : !isDownloaded ? `${selectedVideoVariant.name} is not installed locally yet.` + : gpuBundleRestartRequired + ? "Restart the backend to activate the newly installed GPU runtime before generating." : !selectedVideoRuntimeStatus.realGenerationAvailable ? (selectedVideoRuntimeStatus.message || "Video runtime is not ready.") : !hasPrompt @@ -490,6 +494,9 @@ export function VideoStudioTab({ deviceMemoryGb: selectedVideoRuntimeStatus.deviceMemoryGb, baseModelFootprintGb: selectedVideoVariant?.sizeGb, runtimeFootprintGb: selectedVideoVariant?.runtimeFootprintGb, + runtimeFootprintMpsGb: selectedVideoVariant?.runtimeFootprintMpsGb, + runtimeFootprintCudaGb: selectedVideoVariant?.runtimeFootprintCudaGb, + runtimeFootprintCpuGb: selectedVideoVariant?.runtimeFootprintCpuGb, }), [ videoWidth, @@ -499,6 +506,9 @@ export function VideoStudioTab({ selectedVideoRuntimeStatus.deviceMemoryGb, selectedVideoVariant?.sizeGb, selectedVideoVariant?.runtimeFootprintGb, + selectedVideoVariant?.runtimeFootprintMpsGb, + selectedVideoVariant?.runtimeFootprintCudaGb, + selectedVideoVariant?.runtimeFootprintCpuGb, ], ); @@ -599,6 +609,9 @@ export function VideoStudioTab({ {videoRuntimeStatus.realGenerationAvailable ? "Real engine ready" : "Fallback active"} + {gpuBundleRestartRequired ? ( + Restart required + ) : null} Engine: {videoRuntimeStatus.activeEngine} {/* Prefer the actual-loaded device; fall back to the predicted * expectedDevice computed via nvidia-smi + find_spec (no torch @@ -733,36 +746,31 @@ export function VideoStudioTab({
) : null} - {!videoRuntimeStatus.realGenerationAvailable ? ( + {gpuBundleRestartRequired ? ( + <> +
+

+ GPU runtime installed to{" "} + {gpuBundleJob.targetDir ?? "extras"}. The running backend + still has its old import cache — click Restart Backend to activate the + new runtime, then video generation will use it. +

+
+ +
+
+ + + ) : !videoRuntimeStatus.realGenerationAvailable ? ( <>
- {/* Same post-install-awaiting-restart branch Image Studio - * uses. After a successful GPU bundle install, the - * running backend still can't see the new torch in - * extras (PYTHONPATH is snapshotted at spawn). Nudge - * the user toward Restart Backend instead of asking - * them to install again. */} - {gpuBundleJob?.phase === "done" && gpuBundleJob.requiresRestart ? ( - <> -

- GPU runtime installed to{" "} - {gpuBundleJob.targetDir ?? "extras"}. The running backend - still has its old import cache — click Restart Backend to activate the - new runtime, then video generation will use your GPU. -

-
- -
- - ) : ( - <>

Video generation needs the GPU runtime bundle (torch + diffusers + tokenizers, ~2.5 GB). Install it once — it writes to a persistent user-local directory so @@ -781,8 +789,6 @@ export function VideoStudioTab({ {busyAction === "Restarting server..." ? "Restarting..." : "Restart Backend"}

- - )}
@@ -800,7 +806,7 @@ export function VideoStudioTab({ > {studioFamilies.flatMap((family) => family.variants.map((variant) => { - const downloadState = activeVideoDownloads[variant.repo]; + const downloadState = videoDownloadStatusForVariant(activeVideoDownloads, variant); const isDownloadingVariant = downloadState?.state === "downloading"; const suffix = variant.availableLocally ? " (installed)" @@ -846,7 +852,9 @@ export function VideoStudioTab({ ) : isDownloading ? ( {downloadProgressLabel(downloadState)} ) : ( - Not downloaded + + {selectedVideoVariant.hasLocalData ? "Incomplete" : "Not downloaded"} + )} {selectedVideoLoaded ? In Memory : null} {videoRuntimeLoadedDifferentModel && loadedVideoVariant ? ( @@ -855,6 +863,12 @@ export function VideoStudioTab({
) : null} + {selectedVideoVariant?.localStatusReason && !isDownloaded && !isDownloading ? ( +

+ {selectedVideoVariant.localStatusReason} +

+ ) : null} +