Skip to content

MPS CI Support#1278

Open
huseyincavusbi wants to merge 16 commits intoTransformerLensOrg:devfrom
huseyincavusbi:feat/mps-ci-support
Open

MPS CI Support#1278
huseyincavusbi wants to merge 16 commits intoTransformerLensOrg:devfrom
huseyincavusbi:feat/mps-ci-support

Conversation

@huseyincavusbi
Copy link
Copy Markdown
Contributor

@huseyincavusbi huseyincavusbi commented May 2, 2026

Hi @jlarson4,

This PR implements MPS (Metal Performance Shaders) CI Runner Support as proposed in #1264.

The goal is to provide automated testing for the Apple Silicon research community while working within the limits of GitHub's Mac runners.

Key Changes:

  • New Test Suite: Added tests/mps/test_mps_basic.py with 11 smoke tests covering device detection, core tensor ops on Metal, and HookedTransformer forward passes/caching with small models (TinyStories-1M).
  • CI Automation: Introduced the mps-checks job in .github/workflows/checks.yml. It uses macos-latest and runs only on PRs/pushes to main.
  • Memory Management:
    • Updated tests/conftest.py to proactively clear the MPS cache after every test using torch.mps.empty_cache().
    • Configured the CI to ignore memory-intensive modules (e.g., model_bridge) to ensure stability.
  • Opt-in Mechanism: Respects TRANSFORMERLENS_ALLOW_MPS=1 to ensure safe defaults for Mac users.

Type of change

  • New feature (non-breaking change which adds functionality)

Checklist:

  • I have commented my code, particularly in hard-to-understand areas
  • My changes generate no new warnings
  • I have added tests that prove my fix is effective or that my feature works
  • New and existing unit tests pass locally with my changes
  • I have not rewritten tests relating to key interfaces which would affect backward compatibility

brendanlong and others added 10 commits April 20, 2026 14:50
* Fix type of HookedTransformerConfig.device

This is typed as `Optional[str]` but sometimes returns `torch.device`.
Updated the code to just return the `str` instead of wrapping with a
device.

I'm not confident that every function which takes a device will
always be passed a string, so I didn't change functions like
warn_if_mps.

Found while working on TransformerLensOrg#1219

* more cleanup

* 3.0 CI Bugs (TransformerLensOrg#1261)

* Fixing `utils` imports

* skip gated notebooks on PR from forks

* Updating notebooks

* Ensure LLaMA only runs when HF_TOKEN is available

---------

Co-authored-by: jlarson4 <jonahalarson@comcast.net>
Copilot AI review requested due to automatic review settings May 2, 2026 16:03
Copy link
Copy Markdown
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

Adds Apple Silicon MPS coverage to CI by introducing an MPS-specific test suite and a macOS GitHub Actions job, alongside device-selection tweaks to make MPS opt-in by default.

Changes:

  • Added a new tests/mps smoke-test suite that validates basic tensor ops and a small HookedTransformer run on MPS.
  • Added an mps-checks GitHub Actions job on macos-latest to run unit/integration tests plus the new MPS smoke tests on PRs to main and pushes to main.
  • Updated device utilities and configs to better support MPS opt-in behavior, plus proactive torch.mps.empty_cache() cleanup in pytest fixtures.

Reviewed changes

Copilot reviewed 9 out of 9 changed files in this pull request and generated 4 comments.

Show a summary per file
File Description
transformer_lens/utilities/devices.py Adjusts device selection behavior/signatures to support MPS opt-in and updated warning typing.
transformer_lens/train.py Updates training config typing and default device assignment.
transformer_lens/config/HookedTransformerConfig.py Uses get_device() directly when defaulting cfg.device.
tests/unit/utilities/test_devices.py Updates device utility unit tests for the new get_device() return type.
tests/mps/test_mps_basic.py Adds MPS-only smoke tests covering device detection, core ops, and small-model forward/cache paths.
tests/mps/init.py Declares the MPS test package.
tests/conftest.py Adds MPS cache clearing after tests/classes/session to reduce CI OOM risk.
pyproject.toml Registers a no_mps pytest marker.
.github/workflows/checks.yml Adds the mps-checks CI job that runs on macOS and executes the MPS tests.

💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

Comment thread transformer_lens/utilities/devices.py Outdated
Comment on lines +56 to +67
def get_device() -> str:
"""Get the best available device, with MPS safety checks.

MPS is only auto-selected when the environment variable
``TRANSFORMERLENS_ALLOW_MPS=1`` is set **and** the installed PyTorch
version meets or exceeds ``_MPS_MIN_SAFE_TORCH_VERSION``.

Returns:
torch.device: The best available device (cuda, mps, or cpu)
str: The best available device name (cuda, mps, or cpu)
"""
if torch.cuda.is_available():
return torch.device("cuda")
return "cuda"
Comment on lines 59 to 62
MPS is only auto-selected when the environment variable
``TRANSFORMERLENS_ALLOW_MPS=1`` is set **and** the installed PyTorch
version meets or exceeds ``_MPS_MIN_SAFE_TORCH_VERSION``.

Comment thread transformer_lens/train.py
Comment on lines 33 to 54
weight_decay (float, *optional*): Weight decay to use for training
optimizer_name (str): The name of the optimizer to use
device (str, *optional*): Device to use for training
warmup_steps (int, *optional*): Number of warmup steps to use for training
save_every (int, *optional*): After how many batches should a checkpoint be saved
save_dir, (str, *optional*): Where to save checkpoints
wandb (bool): Whether to use Weights and Biases for logging
wandb_project (str, *optional*): Name of the Weights and Biases project to use
print_every (int, *optional*): Print the loss every n steps
max_steps (int, *optional*): Terminate the epoch after this many steps. Used for debugging.
"""

num_epochs: int
batch_size: int
lr: float = 1e-3
seed: int = 0
momentum: float = 0.0
max_grad_norm: Optional[float] = None
weight_decay: Optional[float] = None
optimizer_name: str = "Adam"
device: Optional[str] = None
device: Optional[Union[str, torch.device]] = None
warmup_steps: int = 0
Comment thread tests/mps/test_mps_basic.py Outdated
Comment on lines +85 to +89
assert device in ("cpu", "mps"), f"Unexpected device: {device}"
if original == "": # env var was not set originally
assert (
device == "cpu"
), "Without TRANSFORMERLENS_ALLOW_MPS=1, get_device() should return 'cpu' not 'mps'"
@huseyincavusbi
Copy link
Copy Markdown
Contributor Author

Hi @jlarson4, I've updated the PR to address the automated feedback:

  • API Stability: Reverted get_device() to return torch.device objects.
  • Type Checks: Updated type hints across model classes to resolve mypy failures.
  • CI Trigger: Strictly restricted mps-checks to the main branch

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants