Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
112 changes: 111 additions & 1 deletion tests/unit/utilities/test_devices.py
Original file line number Diff line number Diff line change
Expand Up @@ -158,12 +158,14 @@ def test_move_to_and_update_config_print_details_false():

@pytest.fixture(autouse=True)
def reset_mps_warned():
"""Reset the _mps_warned flag before each test."""
"""Reset the _mps_warned and _mps_broken_torch_warned flags before each test."""
import transformer_lens.utilities.devices as devices_module

devices_module._mps_warned = False
devices_module._mps_broken_torch_warned = False
yield
devices_module._mps_warned = False
devices_module._mps_broken_torch_warned = False


@patch.dict("os.environ", {}, clear=False)
Expand Down Expand Up @@ -291,3 +293,111 @@ def test_warn_if_mps_active_when_torch_version_below_safe():
assert len(w) == 1
finally:
devices_module._MPS_MIN_SAFE_TORCH_VERSION = original


# --- Known-broken-torch-on-MPS warning tests (issue #1062, torch 2.8.0) ---


@patch.dict("os.environ", {}, clear=False)
def test_warn_if_mps_warns_about_broken_torch_version():
"""When torch is in _MPS_BROKEN_TORCH_VERSIONS, warn_if_mps emits the broken-version warning."""
import os

import transformer_lens.utilities.devices as devices_module

os.environ.pop("TRANSFORMERLENS_ALLOW_MPS", None)
with patch(
"transformer_lens.utilities.devices._torch_version_tuple",
return_value=(2, 8),
):
with warnings.catch_warnings(record=True) as w:
warnings.simplefilter("always")
warn_if_mps("mps")
messages = [str(warning.message) for warning in w]
assert any(
"known MPS bug that produces silently incorrect results" in m for m in messages
), f"Expected broken-torch warning in {messages}"
assert any("issues/1062" in m for m in messages)


@patch.dict("os.environ", {"TRANSFORMERLENS_ALLOW_MPS": "1"})
def test_warn_if_mps_broken_torch_warning_fires_even_when_opted_in():
"""The broken-torch warning must fire even with TRANSFORMERLENS_ALLOW_MPS=1,
because the bug produces silently wrong output regardless of opt-in."""
with patch(
"transformer_lens.utilities.devices._torch_version_tuple",
return_value=(2, 8),
):
with warnings.catch_warnings(record=True) as w:
warnings.simplefilter("always")
warn_if_mps("mps")
messages = [str(warning.message) for warning in w]
assert any("known MPS bug" in m for m in messages)


@patch.dict("os.environ", {}, clear=False)
def test_warn_if_mps_no_broken_warning_on_safe_torch_version():
"""Non-broken torch versions should not emit the broken-torch warning."""
import os

os.environ.pop("TRANSFORMERLENS_ALLOW_MPS", None)
for version in [(2, 7), (2, 9), (3, 0)]:
with patch(
"transformer_lens.utilities.devices._torch_version_tuple",
return_value=version,
):
# Reset the broken-warn flag for each iteration
import transformer_lens.utilities.devices as devices_module

devices_module._mps_broken_torch_warned = False
with warnings.catch_warnings(record=True) as w:
warnings.simplefilter("always")
warn_if_mps("mps")
messages = [str(warning.message) for warning in w]
assert not any(
"known MPS bug" in m for m in messages
), f"Unexpected broken-torch warning on torch {version}: {messages}"


@patch.dict("os.environ", {}, clear=False)
def test_warn_if_mps_broken_warning_fires_only_once():
"""The broken-torch warning should only fire once per process."""
import os

os.environ.pop("TRANSFORMERLENS_ALLOW_MPS", None)
with patch(
"transformer_lens.utilities.devices._torch_version_tuple",
return_value=(2, 8),
):
with warnings.catch_warnings(record=True) as w:
warnings.simplefilter("always")
warn_if_mps("mps")
warn_if_mps("mps")
warn_if_mps(torch.device("mps"))
broken_warnings = [warning for warning in w if "known MPS bug" in str(warning.message)]
assert len(broken_warnings) == 1


def test_torch_mps_has_known_broken_bug_for_2_8():
"""_torch_mps_has_known_broken_bug should return True for torch 2.8."""
from transformer_lens.utilities.devices import _torch_mps_has_known_broken_bug

with patch(
"transformer_lens.utilities.devices._torch_version_tuple",
return_value=(2, 8),
):
assert _torch_mps_has_known_broken_bug() is True


def test_torch_mps_has_known_broken_bug_false_for_other_versions():
"""_torch_mps_has_known_broken_bug should return False for non-broken torch versions."""
from transformer_lens.utilities.devices import _torch_mps_has_known_broken_bug

for version in [(2, 7), (2, 9), (3, 0)]:
with patch(
"transformer_lens.utilities.devices._torch_version_tuple",
return_value=version,
):
assert (
_torch_mps_has_known_broken_bug() is False
), f"torch {version} incorrectly flagged as broken"
77 changes: 58 additions & 19 deletions transformer_lens/utilities/devices.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,12 +24,30 @@
# Bump this when a PyTorch release ships verified MPS fixes.
_MPS_MIN_SAFE_TORCH_VERSION: tuple[int, ...] | None = None

# torch 2.8.0 on MPS has an upstream bug where torch.nn.functional.linear
# produces incorrect results for non-contiguous tensors. This silently
# corrupts generate() output and attention computations. Fixed in 2.9.0.
# See: https://github.com/pytorch/pytorch/issues/161640
# See: https://github.com/TransformerLensOrg/TransformerLens/issues/1062
_MPS_BROKEN_TORCH_VERSIONS: tuple[tuple[int, ...], ...] = ((2, 8),)

_mps_broken_torch_warned = False


def _torch_version_tuple() -> tuple[int, ...]:
"""Parse torch.__version__ into a comparable tuple, ignoring pre-release suffixes."""
return tuple(int(x) for x in torch.__version__.split("+")[0].split(".")[:2])


def _torch_mps_has_known_broken_bug() -> bool:
"""True if the installed torch version has a known-broken MPS path.

Distinct from the generic MPS-may-be-unreliable warning: these are specific,
upstream-fixed bugs where output is silently wrong regardless of opt-in.
"""
return _torch_version_tuple() in _MPS_BROKEN_TORCH_VERSIONS


# ---------------------------------------------------------------------------
# Device helpers
# ---------------------------------------------------------------------------
Expand Down Expand Up @@ -69,28 +87,49 @@ def warn_if_mps(device):

Automatically suppressed when the installed PyTorch version meets or exceeds
_MPS_MIN_SAFE_TORCH_VERSION (currently unset — no version is considered safe yet).

Also emits a separate, stronger warning for known-broken torch versions on MPS
(see _MPS_BROKEN_TORCH_VERSIONS). This warning fires even when the user has
opted in via TRANSFORMERLENS_ALLOW_MPS=1, because the affected operations
produce silently wrong outputs regardless of opt-in.
"""
global _mps_warned
if _mps_warned:
return
global _mps_warned, _mps_broken_torch_warned
if isinstance(device, torch.device):
device = device.type
if isinstance(device, str) and device == "mps":
if (
_MPS_MIN_SAFE_TORCH_VERSION is not None
and _torch_version_tuple() >= _MPS_MIN_SAFE_TORCH_VERSION
):
return
if os.environ.get("TRANSFORMERLENS_ALLOW_MPS", "") != "1":
_mps_warned = True
warnings.warn(
"MPS backend may produce silently incorrect results (PyTorch "
f"{torch.__version__}). "
"Set TRANSFORMERLENS_ALLOW_MPS=1 to suppress this warning. "
"See: https://github.com/TransformerLensOrg/TransformerLens/issues/1178",
UserWarning,
stacklevel=2,
)
if not (isinstance(device, str) and device == "mps"):
return

# Known-broken torch versions always warn (can't be opted-out of).
if _torch_mps_has_known_broken_bug() and not _mps_broken_torch_warned:
_mps_broken_torch_warned = True
warnings.warn(
f"PyTorch {torch.__version__} has a known MPS bug that produces "
"silently incorrect results (torch.nn.functional.linear on "
"non-contiguous tensors). This corrupts generate() output and "
"attention computations. Upgrade to torch >= 2.9.0. "
"See: https://github.com/TransformerLensOrg/TransformerLens/issues/1062 "
"and https://github.com/pytorch/pytorch/issues/161640",
UserWarning,
stacklevel=2,
)

if _mps_warned:
return
if (
_MPS_MIN_SAFE_TORCH_VERSION is not None
and _torch_version_tuple() >= _MPS_MIN_SAFE_TORCH_VERSION
):
return
if os.environ.get("TRANSFORMERLENS_ALLOW_MPS", "") != "1":
_mps_warned = True
warnings.warn(
"MPS backend may produce silently incorrect results (PyTorch "
f"{torch.__version__}). "
"Set TRANSFORMERLENS_ALLOW_MPS=1 to suppress this warning. "
"See: https://github.com/TransformerLensOrg/TransformerLens/issues/1178",
UserWarning,
stacklevel=2,
)


# ---------------------------------------------------------------------------
Expand Down
Loading