Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 5 additions & 1 deletion backend_service/helpers/images.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,10 @@
_parse_iso_datetime,
)
from backend_service.helpers.discovery import _candidate_model_dirs, _path_size_bytes
from backend_service.helpers.platform_filter import (
filter_mlx_only_families,
is_apple_silicon,
)
from backend_service.image_runtime import validate_local_diffusers_snapshot


Expand Down Expand Up @@ -196,7 +200,7 @@ def _image_model_payloads(library: list[dict[str, Any]]) -> list[dict[str, Any]]
"variants": variants,
}
)
return families
return filter_mlx_only_families(families, on_apple_silicon=is_apple_silicon())


def _find_image_variant(model_id: str) -> dict[str, Any] | None:
Expand Down
84 changes: 84 additions & 0 deletions backend_service/helpers/platform_filter.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,84 @@
"""Platform-aware filtering for the image + video model catalogs.

Some catalog variants only run on Apple Silicon: ``mflux`` (image) routes
through ``mflux``/``mlx-lm`` and ``prince-canuma/LTX-2-*`` (video) routes
through ``mlx-video``. Both of those Python packages depend on ``mlx``,
which has no Linux or Windows wheels. Surfacing those variants in the
Image Studio / Video Studio dropdowns on the wrong OS lets users pick
something that cannot run, so this module strips them server-side
before the payload reaches the frontend.

The detection is conservative: a variant is treated as MLX-only iff it
declares so explicitly via ``mlxOnly`` or it carries one of the runtime
labels we know is Apple-only. New runtime labels need to be added here
when they ship — falsely keeping an entry visible is a regression we'd
catch at smoke test, falsely hiding one isn't.
"""

from __future__ import annotations

import platform
from typing import Any


_MLX_ONLY_RUNTIME_MARKERS: tuple[str, ...] = (
"mflux (MLX native)",
"mlx-video (MLX native)",
)

_MLX_ONLY_ENGINES: frozenset[str] = frozenset({"mflux", "mlx-video"})


def is_apple_silicon(system: str | None = None, machine: str | None = None) -> bool:
"""True iff the host is Darwin running on arm64.

Both arguments are exposed for tests so the platform check can be
pinned without monkeypatching ``platform`` itself. They default to
the live host values.
"""
sys_name = system if system is not None else platform.system()
arch = machine if machine is not None else platform.machine()
return sys_name == "Darwin" and arch == "arm64"


def is_mlx_only_variant(variant: dict[str, Any]) -> bool:
"""True iff the variant cannot run outside Apple Silicon."""
if variant.get("mlxOnly") is True:
return True
engine = str(variant.get("engine") or "").strip().lower()
if engine in _MLX_ONLY_ENGINES:
return True
runtime = str(variant.get("runtime") or "")
return any(marker in runtime for marker in _MLX_ONLY_RUNTIME_MARKERS)


def filter_mlx_only_families(
families: list[dict[str, Any]],
*,
on_apple_silicon: bool,
) -> list[dict[str, Any]]:
"""Strip MLX-only variants from a catalog payload on non-Apple hosts.

On Apple Silicon every variant is preserved untouched. On every other
OS the MLX-only variants are dropped from each family's ``variants``
list, and any family whose entire variant set is MLX-only is dropped
from the result so the UI doesn't render an empty card.

Returns a new list — the input is not mutated.
"""
if on_apple_silicon:
return families

filtered: list[dict[str, Any]] = []
for family in families:
variants = [
variant
for variant in family.get("variants", [])
if not is_mlx_only_variant(variant)
]
if not variants:
continue
new_family = dict(family)
new_family["variants"] = variants
filtered.append(new_family)
return filtered
6 changes: 5 additions & 1 deletion backend_service/helpers/video.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,10 @@
from backend_service.helpers.formatting import _bytes_to_gb
from backend_service.helpers.huggingface import _format_release_label, _hf_repo_snapshot_dir
from backend_service.helpers.images import _image_repo_live_metadata, _snapshot_on_disk_bytes
from backend_service.helpers.platform_filter import (
filter_mlx_only_families,
is_apple_silicon,
)
from backend_service.image_runtime import validate_local_diffusers_snapshot


Expand Down Expand Up @@ -113,7 +117,7 @@ def _video_model_payloads(library: list[dict[str, Any]]) -> list[dict[str, Any]]
payload = dict(family)
payload["variants"] = variants
families.append(payload)
return families
return filter_mlx_only_families(families, on_apple_silicon=is_apple_silicon())


def _find_video_variant(model_id: str) -> dict[str, Any] | None:
Expand Down
144 changes: 144 additions & 0 deletions tests/test_platform_filter.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,144 @@
"""Tests for the MLX-only catalog filter.

Validates that ``filter_mlx_only_families`` strips Apple-only variants on
non-Apple hosts and leaves them visible on Apple Silicon. The detector
covers explicit ``mlxOnly`` flags, ``engine`` markers, and the runtime
strings used by the live catalog.
"""

from __future__ import annotations

import unittest

from backend_service.helpers.platform_filter import (
filter_mlx_only_families,
is_apple_silicon,
is_mlx_only_variant,
)


def _flux_dev_gguf() -> dict[str, object]:
return {
"id": "black-forest-labs/FLUX.1-dev-gguf-q8",
"name": "FLUX.1 Dev · GGUF Q8_0",
"engine": None,
"runtime": "Stub diffusion pipeline",
"styleTags": ["general", "detailed", "gguf"],
}


def _flux_dev_mflux() -> dict[str, object]:
return {
"id": "black-forest-labs/FLUX.1-dev-mflux",
"name": "FLUX.1 Dev · mflux (MLX)",
"engine": "mflux",
"runtime": "mflux (MLX native)",
"styleTags": ["general", "detailed", "apple-silicon"],
}


def _ltx2_distilled_mlx() -> dict[str, object]:
return {
"id": "prince-canuma/LTX-2-distilled",
"name": "LTX-2 · distilled (MLX)",
"runtime": "mlx-video (MLX native)",
"styleTags": ["general", "fast", "motion", "mlx"],
}


def _wan_diffusers() -> dict[str, object]:
return {
"id": "Wan-AI/Wan2.1-T2V-1.3B-Diffusers",
"name": "Wan 2.1 T2V 1.3B",
"runtime": "diffusers (MPS / CUDA)",
"styleTags": ["general", "motion"],
}


class IsAppleSiliconTests(unittest.TestCase):
def test_darwin_arm64_is_apple_silicon(self) -> None:
self.assertTrue(is_apple_silicon(system="Darwin", machine="arm64"))

def test_darwin_x86_64_is_not_apple_silicon(self) -> None:
self.assertFalse(is_apple_silicon(system="Darwin", machine="x86_64"))

def test_windows_is_not_apple_silicon(self) -> None:
self.assertFalse(is_apple_silicon(system="Windows", machine="AMD64"))

def test_linux_is_not_apple_silicon(self) -> None:
self.assertFalse(is_apple_silicon(system="Linux", machine="x86_64"))


class IsMlxOnlyVariantTests(unittest.TestCase):
def test_mflux_engine_marker(self) -> None:
self.assertTrue(is_mlx_only_variant(_flux_dev_mflux()))

def test_mlx_video_runtime_marker(self) -> None:
self.assertTrue(is_mlx_only_variant(_ltx2_distilled_mlx()))

def test_explicit_mlx_only_flag(self) -> None:
variant = {"id": "x", "name": "x", "mlxOnly": True}
self.assertTrue(is_mlx_only_variant(variant))

def test_diffusers_runtime_is_not_mlx_only(self) -> None:
self.assertFalse(is_mlx_only_variant(_wan_diffusers()))

def test_gguf_variant_is_not_mlx_only(self) -> None:
self.assertFalse(is_mlx_only_variant(_flux_dev_gguf()))

def test_engine_field_case_insensitive(self) -> None:
variant = {"id": "x", "engine": "MFlux"}
self.assertTrue(is_mlx_only_variant(variant))


class FilterMlxOnlyFamiliesTests(unittest.TestCase):
def setUp(self) -> None:
self.flux_family = {
"id": "flux-dev",
"name": "FLUX.1 Dev",
"variants": [_flux_dev_gguf(), _flux_dev_mflux()],
}
self.ltx_only_family = {
"id": "ltx-2",
"name": "LTX-2 (MLX)",
"variants": [_ltx2_distilled_mlx()],
}
self.wan_family = {
"id": "wan-2-1",
"name": "Wan 2.1",
"variants": [_wan_diffusers()],
}

def test_apple_silicon_passes_everything_through(self) -> None:
families = [self.flux_family, self.ltx_only_family, self.wan_family]
result = filter_mlx_only_families(families, on_apple_silicon=True)
self.assertEqual(len(result), 3)
self.assertEqual([f["id"] for f in result], ["flux-dev", "ltx-2", "wan-2-1"])

def test_non_apple_drops_mlx_variants(self) -> None:
families = [self.flux_family]
result = filter_mlx_only_families(families, on_apple_silicon=False)
self.assertEqual(len(result), 1)
ids = [v["id"] for v in result[0]["variants"]]
self.assertEqual(ids, ["black-forest-labs/FLUX.1-dev-gguf-q8"])

def test_non_apple_drops_mlx_only_families(self) -> None:
"""A family whose only variant is MLX-only disappears entirely."""
families = [self.flux_family, self.ltx_only_family, self.wan_family]
result = filter_mlx_only_families(families, on_apple_silicon=False)
ids = [f["id"] for f in result]
self.assertEqual(ids, ["flux-dev", "wan-2-1"])

def test_does_not_mutate_input(self) -> None:
families = [self.flux_family]
original_variant_count = len(families[0]["variants"])
_ = filter_mlx_only_families(families, on_apple_silicon=False)
self.assertEqual(len(families[0]["variants"]), original_variant_count)

def test_empty_input_returns_empty(self) -> None:
self.assertEqual(filter_mlx_only_families([], on_apple_silicon=True), [])
self.assertEqual(filter_mlx_only_families([], on_apple_silicon=False), [])


if __name__ == "__main__":
unittest.main()
Loading