MPS CI Support#1278
Open
huseyincavusbi wants to merge 16 commits intoTransformerLensOrg:devfrom
Open
Conversation
* Fix type of HookedTransformerConfig.device This is typed as `Optional[str]` but sometimes returns `torch.device`. Updated the code to just return the `str` instead of wrapping with a device. I'm not confident that every function which takes a device will always be passed a string, so I didn't change functions like warn_if_mps. Found while working on TransformerLensOrg#1219 * more cleanup * 3.0 CI Bugs (TransformerLensOrg#1261) * Fixing `utils` imports * skip gated notebooks on PR from forks * Updating notebooks * Ensure LLaMA only runs when HF_TOKEN is available --------- Co-authored-by: jlarson4 <jonahalarson@comcast.net>
TransformerLens 3.1.0
Contributor
There was a problem hiding this comment.
Pull request overview
Adds Apple Silicon MPS coverage to CI by introducing an MPS-specific test suite and a macOS GitHub Actions job, alongside device-selection tweaks to make MPS opt-in by default.
Changes:
- Added a new
tests/mpssmoke-test suite that validates basic tensor ops and a smallHookedTransformerrun on MPS. - Added an
mps-checksGitHub Actions job onmacos-latestto run unit/integration tests plus the new MPS smoke tests on PRs tomainand pushes tomain. - Updated device utilities and configs to better support MPS opt-in behavior, plus proactive
torch.mps.empty_cache()cleanup in pytest fixtures.
Reviewed changes
Copilot reviewed 9 out of 9 changed files in this pull request and generated 4 comments.
Show a summary per file
| File | Description |
|---|---|
| transformer_lens/utilities/devices.py | Adjusts device selection behavior/signatures to support MPS opt-in and updated warning typing. |
| transformer_lens/train.py | Updates training config typing and default device assignment. |
| transformer_lens/config/HookedTransformerConfig.py | Uses get_device() directly when defaulting cfg.device. |
| tests/unit/utilities/test_devices.py | Updates device utility unit tests for the new get_device() return type. |
| tests/mps/test_mps_basic.py | Adds MPS-only smoke tests covering device detection, core ops, and small-model forward/cache paths. |
| tests/mps/init.py | Declares the MPS test package. |
| tests/conftest.py | Adds MPS cache clearing after tests/classes/session to reduce CI OOM risk. |
| pyproject.toml | Registers a no_mps pytest marker. |
| .github/workflows/checks.yml | Adds the mps-checks CI job that runs on macOS and executes the MPS tests. |
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
Comment on lines
+56
to
+67
| def get_device() -> str: | ||
| """Get the best available device, with MPS safety checks. | ||
|
|
||
| MPS is only auto-selected when the environment variable | ||
| ``TRANSFORMERLENS_ALLOW_MPS=1`` is set **and** the installed PyTorch | ||
| version meets or exceeds ``_MPS_MIN_SAFE_TORCH_VERSION``. | ||
|
|
||
| Returns: | ||
| torch.device: The best available device (cuda, mps, or cpu) | ||
| str: The best available device name (cuda, mps, or cpu) | ||
| """ | ||
| if torch.cuda.is_available(): | ||
| return torch.device("cuda") | ||
| return "cuda" |
Comment on lines
59
to
62
| MPS is only auto-selected when the environment variable | ||
| ``TRANSFORMERLENS_ALLOW_MPS=1`` is set **and** the installed PyTorch | ||
| version meets or exceeds ``_MPS_MIN_SAFE_TORCH_VERSION``. | ||
|
|
Comment on lines
33
to
54
| weight_decay (float, *optional*): Weight decay to use for training | ||
| optimizer_name (str): The name of the optimizer to use | ||
| device (str, *optional*): Device to use for training | ||
| warmup_steps (int, *optional*): Number of warmup steps to use for training | ||
| save_every (int, *optional*): After how many batches should a checkpoint be saved | ||
| save_dir, (str, *optional*): Where to save checkpoints | ||
| wandb (bool): Whether to use Weights and Biases for logging | ||
| wandb_project (str, *optional*): Name of the Weights and Biases project to use | ||
| print_every (int, *optional*): Print the loss every n steps | ||
| max_steps (int, *optional*): Terminate the epoch after this many steps. Used for debugging. | ||
| """ | ||
|
|
||
| num_epochs: int | ||
| batch_size: int | ||
| lr: float = 1e-3 | ||
| seed: int = 0 | ||
| momentum: float = 0.0 | ||
| max_grad_norm: Optional[float] = None | ||
| weight_decay: Optional[float] = None | ||
| optimizer_name: str = "Adam" | ||
| device: Optional[str] = None | ||
| device: Optional[Union[str, torch.device]] = None | ||
| warmup_steps: int = 0 |
Comment on lines
+85
to
+89
| assert device in ("cpu", "mps"), f"Unexpected device: {device}" | ||
| if original == "": # env var was not set originally | ||
| assert ( | ||
| device == "cpu" | ||
| ), "Without TRANSFORMERLENS_ALLOW_MPS=1, get_device() should return 'cpu' not 'mps'" |
Contributor
Author
|
Hi @jlarson4, I've updated the PR to address the automated feedback:
|
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Hi @jlarson4,
This PR implements MPS (Metal Performance Shaders) CI Runner Support as proposed in #1264.
The goal is to provide automated testing for the Apple Silicon research community while working within the limits of GitHub's Mac runners.
Key Changes:
tests/mps/test_mps_basic.pywith 11 smoke tests covering device detection, core tensor ops on Metal, andHookedTransformerforward passes/caching with small models (TinyStories-1M).mps-checksjob in.github/workflows/checks.yml. It usesmacos-latestand runs only on PRs/pushes to main.tests/conftest.pyto proactively clear the MPS cache after every test usingtorch.mps.empty_cache().model_bridge) to ensure stability.TRANSFORMERLENS_ALLOW_MPS=1to ensure safe defaults for Mac users.Type of change
Checklist: