diff --git a/.github/workflows/checks.yml b/.github/workflows/checks.yml index 3a2e1389c..262be82d5 100644 --- a/.github/workflows/checks.yml +++ b/.github/workflows/checks.yml @@ -103,6 +103,60 @@ jobs: - name: Build check run: uv build + mps-checks: + name: MPS Checks + runs-on: macos-latest + # Only run on PRs merging to main or pushes directly to main + if: > + (github.event_name == 'pull_request' && github.base_ref == 'main') || + (github.event_name == 'push' && github.ref == 'refs/heads/main') + steps: + - uses: actions/checkout@v4 + - name: Install uv + uses: astral-sh/setup-uv@v6 + with: + python-version: "3.11" + activate-environment: true + enable-cache: true + - name: MPS Cache Models + uses: actions/cache@v3 + with: + path: | + ~/.cache/huggingface/hub/models--roneneldan--TinyStories-1M* + key: ${{ runner.os }}-huggingface-models-mps-v1 + - name: Install dependencies + run: | + uv lock --check + uv sync + - name: MPS Availability Check + run: | + uv run python -c " + import torch + print(f'PyTorch: {torch.__version__}') + print(f'MPS available: {torch.backends.mps.is_available()}') + print(f'MPS built: {torch.backends.mps.is_built()}') + assert torch.backends.mps.is_available(), 'MPS not available on this runner!' + " + - name: MPS Unit Tests + run: > + uv run pytest tests/unit -v + --ignore=tests/unit/model_bridge/ + - name: MPS Integration Tests + run: > + uv run pytest tests/integration -v + --ignore=tests/integration/model_bridge/ + --ignore=tests/integration/test_prepend_bos.py + --ignore=tests/integration/test_generation_compatibility.py + --ignore=tests/integration/test_grouped_query_attention.py + --ignore=tests/integration/test_match_huggingface.py + --ignore=tests/integration/test_fold_layer_integration.py + --ignore=tests/integration/test_centralized_weight_processing.py + --ignore=tests/integration/test_create_hooked_encoder.py + - name: MPS Smoke Tests + run: uv run pytest tests/mps -v + env: + TRANSFORMERLENS_ALLOW_MPS: "1" + format-check: name: Format Check runs-on: ubuntu-latest diff --git a/pyproject.toml b/pyproject.toml index e20e94ce6..da70ba811 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -90,7 +90,10 @@ "ignore:distutils Version classes are deprecated:DeprecationWarning", "ignore:pkg_resources is deprecated as an API:DeprecationWarning", ] - markers=["slow: marks tests as slow (deselect with '-m \"not slow\"')"] + markers=[ + "slow: marks tests as slow (deselect with '-m \"not slow\"')", + "no_mps: marks test as incompatible with MPS device (deselect with '-m \"not no_mps\"')", + ] pythonpath=["."] testpaths=["tests", "transformer_lens"] # Only test these directories diff --git a/tests/conftest.py b/tests/conftest.py index 4bb009c8e..3229823af 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -14,9 +14,11 @@ def cleanup_memory(): """Automatically clean up memory after each test.""" yield - # Clear torch cache + # Clear torch cache for all accelerators if torch.cuda.is_available(): torch.cuda.empty_cache() + if torch.backends.mps.is_available(): + torch.mps.empty_cache() # Force garbage collection for cleanup gc.collect() @@ -28,6 +30,8 @@ def cleanup_class_memory(): # More aggressive cleanup after test classes if torch.cuda.is_available(): torch.cuda.empty_cache() + if torch.backends.mps.is_available(): + torch.mps.empty_cache() gc.collect() @@ -50,6 +54,8 @@ def pytest_sessionfinish(session, exitstatus): """Clean up at the end of test session.""" if torch.cuda.is_available(): torch.cuda.empty_cache() + if torch.backends.mps.is_available(): + torch.mps.empty_cache() gc.collect() diff --git a/tests/mps/__init__.py b/tests/mps/__init__.py new file mode 100644 index 000000000..d319c86cd --- /dev/null +++ b/tests/mps/__init__.py @@ -0,0 +1 @@ +# MPS (Apple Silicon) test package diff --git a/tests/mps/test_mps_basic.py b/tests/mps/test_mps_basic.py new file mode 100644 index 000000000..efe12aada --- /dev/null +++ b/tests/mps/test_mps_basic.py @@ -0,0 +1,246 @@ +"""Apple Silicon MPS smoke tests for TransformerLens. + +Design principles: +- All tests skip automatically on non-MPS runners (Linux, Windows, CPU-only Macs) +- Only float32 is used (bfloat16 is unsupported on MPS) +- Only small models are loaded (roneneldan/TinyStories-1M, ~50MB) +- torch.mps.empty_cache() + gc.collect() between tests to stay within memory budget +- TRANSFORMERLENS_ALLOW_MPS=1 must be set for get_device() to return "mps" + +CI: These tests are run via the `mps-checks` job in .github/workflows/checks.yml +which sets TRANSFORMERLENS_ALLOW_MPS=1 and runs on macos-latest. +""" + +import gc +import os +import warnings + +import pytest +import torch + +# Skip the entire module on non-MPS runners (Linux CI, CPU-only Macs) +pytestmark = pytest.mark.skipif( + not torch.backends.mps.is_available(), + reason="MPS not available on this runner — skipping Apple Silicon tests", +) + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + +SMALL_MODEL = "roneneldan/TinyStories-1M" # ~50MB, safe for 1GB runner budget + + +def _load_tiny_model(device: str = "mps"): + """Load TinyStories-1M on the given device with float32 (bfloat16 unsupported on MPS).""" + from transformer_lens import HookedTransformer + + return HookedTransformer.from_pretrained(SMALL_MODEL, device=device, dtype=torch.float32) + + +def _cleanup(model=None): + """Free GPU memory between tests.""" + if model is not None: + del model + torch.mps.empty_cache() + gc.collect() + + +# --------------------------------------------------------------------------- +# 1. Device detection (no model load — instant) +# --------------------------------------------------------------------------- + + +def test_mps_device_available(): + """Sanity check: MPS backend is present and built on this runner.""" + assert torch.backends.mps.is_available(), "MPS not available" + assert torch.backends.mps.is_built(), "MPS not built into this PyTorch" + + +def test_mps_get_device_returns_mps_with_env_var(): + """get_device() auto-selects MPS when TRANSFORMERLENS_ALLOW_MPS=1 is set.""" + from transformer_lens.utilities.devices import get_device + + original = os.environ.get("TRANSFORMERLENS_ALLOW_MPS", "") + try: + os.environ["TRANSFORMERLENS_ALLOW_MPS"] = "1" + device = get_device() + assert isinstance(device, torch.device) + assert device.type == "mps", f"Expected 'mps', got '{device.type}'" + finally: + if original: + os.environ["TRANSFORMERLENS_ALLOW_MPS"] = original + else: + os.environ.pop("TRANSFORMERLENS_ALLOW_MPS", None) + + +def test_mps_get_device_falls_back_to_cpu_without_env_var(): + """get_device() falls back to CPU when TRANSFORMERLENS_ALLOW_MPS is unset (safety default).""" + from transformer_lens.utilities.devices import get_device + + original = os.environ.get("TRANSFORMERLENS_ALLOW_MPS", "") + try: + os.environ.pop("TRANSFORMERLENS_ALLOW_MPS", None) + device = get_device() + # On a Mac with no CUDA, should return cpu (safe default) + assert isinstance(device, torch.device) + assert ( + device.type == "cpu" + ), f"Without TRANSFORMERLENS_ALLOW_MPS=1, get_device() should return 'cpu' not '{device.type}'" + finally: + if original: + os.environ["TRANSFORMERLENS_ALLOW_MPS"] = original + + +def test_mps_warn_if_mps_emits_warning_without_env_var(): + """warn_if_mps() emits a UserWarning when MPS is used without the env var.""" + import transformer_lens.utilities.devices as devices_module + from transformer_lens.utilities import warn_if_mps + + original = os.environ.get("TRANSFORMERLENS_ALLOW_MPS", "") + original_warned = devices_module._mps_warned + try: + os.environ.pop("TRANSFORMERLENS_ALLOW_MPS", None) + devices_module._mps_warned = False # reset so warning fires + with warnings.catch_warnings(record=True) as w: + warnings.simplefilter("always") + warn_if_mps("mps") + assert any( + "MPS backend" in str(warning.message) for warning in w + ), "Expected MPS warning but got: " + str([str(x.message) for x in w]) + finally: + if original: + os.environ["TRANSFORMERLENS_ALLOW_MPS"] = original + devices_module._mps_warned = original_warned + + +# --------------------------------------------------------------------------- +# 2. Raw tensor operations on Metal (no model load) +# --------------------------------------------------------------------------- + + +def test_mps_tensor_basic_operations(): + """Basic tensor arithmetic runs on the Metal GPU without errors.""" + x = torch.randn(16, 32, device="mps", dtype=torch.float32) + y = torch.randn(16, 32, device="mps", dtype=torch.float32) + + z = x + y + assert z.device.type == "mps" + + w = torch.matmul(x, y.T) + assert w.device.type == "mps" + assert w.shape == (16, 16) + + # Verify result comes back to CPU correctly + z_cpu = z.cpu() + assert z_cpu.device.type == "cpu" + + _cleanup() + + +def test_mps_softmax_and_layernorm(): + """Softmax and LayerNorm — core transformer ops — work on MPS.""" + x = torch.randn(4, 16, 64, device="mps", dtype=torch.float32) + + softmax_out = torch.nn.functional.softmax(x, dim=-1) + assert softmax_out.device.type == "mps" + assert torch.allclose(softmax_out.sum(dim=-1), torch.ones(4, 16, device="mps"), atol=1e-5) + + ln = torch.nn.LayerNorm(64).to("mps") + ln_out = ln(x) + assert ln_out.device.type == "mps" + + _cleanup() + + +# --------------------------------------------------------------------------- +# 3. Model loading and forward pass on Metal +# --------------------------------------------------------------------------- + + +def test_mps_model_forward_pass(): + """TinyStories-1M loads and runs a forward pass on the Metal GPU.""" + model = _load_tiny_model(device="mps") + + tokens = model.to_tokens("Once upon a time") + assert tokens.device.type == "mps", f"Tokens should be on MPS, got {tokens.device}" + + logits = model(tokens) + assert logits.device.type == "mps", f"Logits should be on MPS, got {logits.device}" + assert logits.shape[-1] == model.cfg.d_vocab + assert not torch.isnan(logits).any(), "NaN values in logits — possible MPS compute error" + + _cleanup(model) + + +def test_mps_run_with_cache(): + """run_with_cache() returns cache tensors on the Metal GPU.""" + model = _load_tiny_model(device="mps") + tokens = model.to_tokens("The quick brown fox") + + logits, cache = model.run_with_cache(tokens) + + assert logits.device.type == "mps" + + # Check a representative set of cache keys + hook_q = cache["blocks.0.attn.hook_q"] + assert hook_q.device.type == "mps", f"Cache tensor not on MPS: {hook_q.device}" + assert not torch.isnan(hook_q).any(), "NaN in attention query cache" + + _cleanup(model) + + +def test_mps_activation_hook_fires_on_metal(): + """run_with_hooks() fires hooks and hook tensors are on the Metal GPU.""" + model = _load_tiny_model(device="mps") + tokens = model.to_tokens("Apple Silicon rocks") + + hook_devices = [] + hook_shapes = [] + + def capture_hook(value, hook): + hook_devices.append(value.device.type) + hook_shapes.append(value.shape) + return value + + model.run_with_hooks( + tokens, + fwd_hooks=[ + ("blocks.0.attn.hook_q", capture_hook), + ("blocks.0.mlp.hook_post", capture_hook), + ], + ) + + assert len(hook_devices) == 2, f"Expected 2 hooks to fire, got {len(hook_devices)}" + for device in hook_devices: + assert device == "mps", f"Hook tensor not on MPS: {device}" + + _cleanup(model) + + +def test_mps_float32_inference(): + """Explicit float32 model loads and infers correctly on MPS.""" + model = _load_tiny_model(device="mps") + + # Verify all parameters are float32 + for name, param in model.named_parameters(): + assert param.dtype == torch.float32, f"Parameter {name} has wrong dtype: {param.dtype}" + + tokens = model.to_tokens("Testing float32 on Metal") + logits = model(tokens) + assert logits.dtype == torch.float32 + + _cleanup(model) + + +def test_mps_loss_computation(): + """Loss computation (return_type='loss') works on MPS.""" + model = _load_tiny_model(device="mps") + + loss = model("Once upon a time in a land", return_type="loss") + assert isinstance(loss, torch.Tensor) + assert loss.device.type == "mps" + assert not torch.isnan(loss), f"NaN loss — possible MPS compute error: {loss}" + assert loss.item() > 0, "Loss should be positive" + + _cleanup(model) diff --git a/tests/unit/utilities/test_devices.py b/tests/unit/utilities/test_devices.py index 6ea3b4095..ef8bfb227 100644 --- a/tests/unit/utilities/test_devices.py +++ b/tests/unit/utilities/test_devices.py @@ -43,7 +43,8 @@ def test_get_device_cuda_available(): with patch("torch.cuda.is_available", return_value=True): with patch("torch.backends.mps.is_available", return_value=False): device = get_device() - assert device == torch.device("cuda") + assert isinstance(device, torch.device) + assert device.type == "cuda" @patch.dict("os.environ", {"TRANSFORMERLENS_ALLOW_MPS": "1"}) @@ -54,7 +55,8 @@ def test_get_device_mps_available(): with patch("torch.backends.mps.is_built", return_value=True): with patch("torch.__version__", "2.0.0"): device = get_device() - assert device == torch.device("mps") + assert isinstance(device, torch.device) + assert device.type == "mps" def test_get_device_mps_pytorch_1x(): @@ -64,7 +66,8 @@ def test_get_device_mps_pytorch_1x(): with patch("torch.backends.mps.is_built", return_value=True): with patch("torch.__version__", "1.13.0"): device = get_device() - assert device == torch.device("cpu") + assert isinstance(device, torch.device) + assert device.type == "cpu" def test_get_device_cpu_fallback(): @@ -72,7 +75,8 @@ def test_get_device_cpu_fallback(): with patch("torch.cuda.is_available", return_value=False): with patch("torch.backends.mps.is_available", return_value=False): device = get_device() - assert device == torch.device("cpu") + assert isinstance(device, torch.device) + assert device.type == "cpu" def test_model_with_cfg_protocol(): @@ -178,7 +182,8 @@ def test_get_device_returns_cpu_when_mps_available(mock_built, mock_avail, mock_ os.environ.pop("TRANSFORMERLENS_ALLOW_MPS", None) result = get_device() - assert result == torch.device("cpu") + assert isinstance(result, torch.device) + assert result.type == "cpu" @patch.dict("os.environ", {"TRANSFORMERLENS_ALLOW_MPS": "1"}) @@ -188,7 +193,8 @@ def test_get_device_returns_cpu_when_mps_available(mock_built, mock_avail, mock_ def test_get_device_returns_mps_when_env_var_set(mock_built, mock_avail, mock_cuda): """get_device() should return MPS when TRANSFORMERLENS_ALLOW_MPS=1 is set.""" result = get_device() - assert result == torch.device("mps") + assert isinstance(result, torch.device) + assert result.type == "mps" @patch.dict("os.environ", {}, clear=False) diff --git a/transformer_lens/HookedAudioEncoder.py b/transformer_lens/HookedAudioEncoder.py index c76f9c7b7..95494ad23 100644 --- a/transformer_lens/HookedAudioEncoder.py +++ b/transformer_lens/HookedAudioEncoder.py @@ -356,7 +356,7 @@ def from_pretrained( checkpoint_index: Optional[int] = None, checkpoint_value: Optional[int] = None, hf_model: Optional[Any] = None, - device: Optional[str] = None, + device: Optional[Union[str, torch.device]] = None, move_to_device: bool = True, dtype: torch.dtype = torch.float32, **from_pretrained_kwargs: Any, diff --git a/transformer_lens/HookedEncoder.py b/transformer_lens/HookedEncoder.py index 4c239f3d8..e6c29ce18 100644 --- a/transformer_lens/HookedEncoder.py +++ b/transformer_lens/HookedEncoder.py @@ -377,7 +377,7 @@ def from_pretrained( checkpoint_index: Optional[int] = None, checkpoint_value: Optional[int] = None, hf_model: Optional[Any] = None, - device: Optional[str] = None, + device: Optional[Union[str, torch.device]] = None, tokenizer: Optional[Any] = None, move_to_device: bool = True, dtype: torch.dtype = torch.float32, diff --git a/transformer_lens/HookedEncoderDecoder.py b/transformer_lens/HookedEncoderDecoder.py index e683d2f91..f3e4a9402 100644 --- a/transformer_lens/HookedEncoderDecoder.py +++ b/transformer_lens/HookedEncoderDecoder.py @@ -544,7 +544,7 @@ def from_pretrained( checkpoint_index: Optional[int] = None, checkpoint_value: Optional[int] = None, hf_model: Optional[Any] = None, - device: Optional[str] = None, + device: Optional[Union[str, torch.device]] = None, tokenizer: Optional[Any] = None, move_to_device: bool = True, dtype: Optional[torch.dtype] = torch.float32, diff --git a/transformer_lens/config/HookedTransformerConfig.py b/transformer_lens/config/HookedTransformerConfig.py index e8450df1f..5d8062dd6 100644 --- a/transformer_lens/config/HookedTransformerConfig.py +++ b/transformer_lens/config/HookedTransformerConfig.py @@ -334,7 +334,7 @@ def __post_init__(self): self.n_params += self.n_layers * mlp_params_per_layer if self.device is None: - self.device = str(get_device()) + self.device = get_device() else: from transformer_lens.utilities import warn_if_mps diff --git a/transformer_lens/config/TransformerLensConfig.py b/transformer_lens/config/TransformerLensConfig.py index fb1f5f045..df7fd352d 100644 --- a/transformer_lens/config/TransformerLensConfig.py +++ b/transformer_lens/config/TransformerLensConfig.py @@ -59,7 +59,7 @@ class TransformerLensConfig: d_vocab: int = -1 # Device configuration - device: Optional[str] = None + device: Optional[Union[str, torch.device]] = None # Attention configuration use_attn_result: bool = False diff --git a/transformer_lens/lit/model.py b/transformer_lens/lit/model.py index b66e3ba72..42e27a1f6 100644 --- a/transformer_lens/lit/model.py +++ b/transformer_lens/lit/model.py @@ -34,7 +34,7 @@ import logging from dataclasses import dataclass -from typing import TYPE_CHECKING, Any, Dict, Iterable, Iterator, List, Optional +from typing import TYPE_CHECKING, Any, Dict, Iterable, Iterator, List, Optional, Union import torch @@ -86,7 +86,7 @@ class HookedTransformerLITConfig: output_all_layers: bool = DEFAULTS.OUTPUT_ALL_LAYERS embedding_layers: Optional[List[int]] = None prepend_bos: bool = DEFAULTS.PREPEND_BOS - device: Optional[str] = None + device: Optional[Union[str, torch.device]] = None def _ensure_lit_available(): diff --git a/transformer_lens/train.py b/transformer_lens/train.py index 24acfd0be..db6f5d769 100644 --- a/transformer_lens/train.py +++ b/transformer_lens/train.py @@ -5,7 +5,7 @@ """ from dataclasses import dataclass -from typing import Optional +from typing import Optional, Union import torch import torch.optim as optim @@ -32,7 +32,7 @@ class HookedTransformerTrainConfig: max_grad_norm (float, *optional*): Maximum gradient norm to use for weight_decay (float, *optional*): Weight decay to use for training optimizer_name (str): The name of the optimizer to use - device (str, *optional*): Device to use for training + device (str or torch.device, *optional*): Device to use for training warmup_steps (int, *optional*): Number of warmup steps to use for training save_every (int, *optional*): After how many batches should a checkpoint be saved save_dir, (str, *optional*): Where to save checkpoints @@ -50,7 +50,7 @@ class HookedTransformerTrainConfig: max_grad_norm: Optional[float] = None weight_decay: Optional[float] = None optimizer_name: str = "Adam" - device: Optional[str] = None + device: Optional[Union[str, torch.device]] = None warmup_steps: int = 0 save_every: Optional[int] = None save_dir: Optional[str] = None @@ -89,7 +89,7 @@ def train( wandb.init(project=config.wandb_project_name, config=vars(config)) if config.device is None: - config.device = str(utils.get_device()) + config.device = utils.get_device() optimizer: Optimizer if config.optimizer_name in ["Adam", "AdamW"]: diff --git a/transformer_lens/utilities/devices.py b/transformer_lens/utilities/devices.py index 470646a41..6b23e1821 100644 --- a/transformer_lens/utilities/devices.py +++ b/transformer_lens/utilities/devices.py @@ -58,7 +58,7 @@ def get_device() -> torch.device: MPS is only auto-selected when the environment variable ``TRANSFORMERLENS_ALLOW_MPS=1`` is set **and** the installed PyTorch - version meets or exceeds ``_MPS_MIN_SAFE_TORCH_VERSION``. + version is 2.0 or higher. Returns: torch.device: The best available device (cuda, mps, or cpu) @@ -82,7 +82,7 @@ def get_device() -> torch.device: return torch.device("cpu") -def warn_if_mps(device): +def warn_if_mps(device: Union[str, torch.device]) -> None: """Emit a one-time warning if device is MPS and TRANSFORMERLENS_ALLOW_MPS is not set. Automatically suppressed when the installed PyTorch version meets or exceeds