diff --git a/tests/unit/utilities/test_devices.py b/tests/unit/utilities/test_devices.py index 5e1af5632..6ea3b4095 100644 --- a/tests/unit/utilities/test_devices.py +++ b/tests/unit/utilities/test_devices.py @@ -158,12 +158,14 @@ def test_move_to_and_update_config_print_details_false(): @pytest.fixture(autouse=True) def reset_mps_warned(): - """Reset the _mps_warned flag before each test.""" + """Reset the _mps_warned and _mps_broken_torch_warned flags before each test.""" import transformer_lens.utilities.devices as devices_module devices_module._mps_warned = False + devices_module._mps_broken_torch_warned = False yield devices_module._mps_warned = False + devices_module._mps_broken_torch_warned = False @patch.dict("os.environ", {}, clear=False) @@ -291,3 +293,111 @@ def test_warn_if_mps_active_when_torch_version_below_safe(): assert len(w) == 1 finally: devices_module._MPS_MIN_SAFE_TORCH_VERSION = original + + +# --- Known-broken-torch-on-MPS warning tests (issue #1062, torch 2.8.0) --- + + +@patch.dict("os.environ", {}, clear=False) +def test_warn_if_mps_warns_about_broken_torch_version(): + """When torch is in _MPS_BROKEN_TORCH_VERSIONS, warn_if_mps emits the broken-version warning.""" + import os + + import transformer_lens.utilities.devices as devices_module + + os.environ.pop("TRANSFORMERLENS_ALLOW_MPS", None) + with patch( + "transformer_lens.utilities.devices._torch_version_tuple", + return_value=(2, 8), + ): + with warnings.catch_warnings(record=True) as w: + warnings.simplefilter("always") + warn_if_mps("mps") + messages = [str(warning.message) for warning in w] + assert any( + "known MPS bug that produces silently incorrect results" in m for m in messages + ), f"Expected broken-torch warning in {messages}" + assert any("issues/1062" in m for m in messages) + + +@patch.dict("os.environ", {"TRANSFORMERLENS_ALLOW_MPS": "1"}) +def test_warn_if_mps_broken_torch_warning_fires_even_when_opted_in(): + """The broken-torch warning must fire even with TRANSFORMERLENS_ALLOW_MPS=1, + because the bug produces silently wrong output regardless of opt-in.""" + with patch( + "transformer_lens.utilities.devices._torch_version_tuple", + return_value=(2, 8), + ): + with warnings.catch_warnings(record=True) as w: + warnings.simplefilter("always") + warn_if_mps("mps") + messages = [str(warning.message) for warning in w] + assert any("known MPS bug" in m for m in messages) + + +@patch.dict("os.environ", {}, clear=False) +def test_warn_if_mps_no_broken_warning_on_safe_torch_version(): + """Non-broken torch versions should not emit the broken-torch warning.""" + import os + + os.environ.pop("TRANSFORMERLENS_ALLOW_MPS", None) + for version in [(2, 7), (2, 9), (3, 0)]: + with patch( + "transformer_lens.utilities.devices._torch_version_tuple", + return_value=version, + ): + # Reset the broken-warn flag for each iteration + import transformer_lens.utilities.devices as devices_module + + devices_module._mps_broken_torch_warned = False + with warnings.catch_warnings(record=True) as w: + warnings.simplefilter("always") + warn_if_mps("mps") + messages = [str(warning.message) for warning in w] + assert not any( + "known MPS bug" in m for m in messages + ), f"Unexpected broken-torch warning on torch {version}: {messages}" + + +@patch.dict("os.environ", {}, clear=False) +def test_warn_if_mps_broken_warning_fires_only_once(): + """The broken-torch warning should only fire once per process.""" + import os + + os.environ.pop("TRANSFORMERLENS_ALLOW_MPS", None) + with patch( + "transformer_lens.utilities.devices._torch_version_tuple", + return_value=(2, 8), + ): + with warnings.catch_warnings(record=True) as w: + warnings.simplefilter("always") + warn_if_mps("mps") + warn_if_mps("mps") + warn_if_mps(torch.device("mps")) + broken_warnings = [warning for warning in w if "known MPS bug" in str(warning.message)] + assert len(broken_warnings) == 1 + + +def test_torch_mps_has_known_broken_bug_for_2_8(): + """_torch_mps_has_known_broken_bug should return True for torch 2.8.""" + from transformer_lens.utilities.devices import _torch_mps_has_known_broken_bug + + with patch( + "transformer_lens.utilities.devices._torch_version_tuple", + return_value=(2, 8), + ): + assert _torch_mps_has_known_broken_bug() is True + + +def test_torch_mps_has_known_broken_bug_false_for_other_versions(): + """_torch_mps_has_known_broken_bug should return False for non-broken torch versions.""" + from transformer_lens.utilities.devices import _torch_mps_has_known_broken_bug + + for version in [(2, 7), (2, 9), (3, 0)]: + with patch( + "transformer_lens.utilities.devices._torch_version_tuple", + return_value=version, + ): + assert ( + _torch_mps_has_known_broken_bug() is False + ), f"torch {version} incorrectly flagged as broken" diff --git a/transformer_lens/utilities/devices.py b/transformer_lens/utilities/devices.py index d1265f002..470646a41 100644 --- a/transformer_lens/utilities/devices.py +++ b/transformer_lens/utilities/devices.py @@ -24,12 +24,30 @@ # Bump this when a PyTorch release ships verified MPS fixes. _MPS_MIN_SAFE_TORCH_VERSION: tuple[int, ...] | None = None +# torch 2.8.0 on MPS has an upstream bug where torch.nn.functional.linear +# produces incorrect results for non-contiguous tensors. This silently +# corrupts generate() output and attention computations. Fixed in 2.9.0. +# See: https://github.com/pytorch/pytorch/issues/161640 +# See: https://github.com/TransformerLensOrg/TransformerLens/issues/1062 +_MPS_BROKEN_TORCH_VERSIONS: tuple[tuple[int, ...], ...] = ((2, 8),) + +_mps_broken_torch_warned = False + def _torch_version_tuple() -> tuple[int, ...]: """Parse torch.__version__ into a comparable tuple, ignoring pre-release suffixes.""" return tuple(int(x) for x in torch.__version__.split("+")[0].split(".")[:2]) +def _torch_mps_has_known_broken_bug() -> bool: + """True if the installed torch version has a known-broken MPS path. + + Distinct from the generic MPS-may-be-unreliable warning: these are specific, + upstream-fixed bugs where output is silently wrong regardless of opt-in. + """ + return _torch_version_tuple() in _MPS_BROKEN_TORCH_VERSIONS + + # --------------------------------------------------------------------------- # Device helpers # --------------------------------------------------------------------------- @@ -69,28 +87,49 @@ def warn_if_mps(device): Automatically suppressed when the installed PyTorch version meets or exceeds _MPS_MIN_SAFE_TORCH_VERSION (currently unset — no version is considered safe yet). + + Also emits a separate, stronger warning for known-broken torch versions on MPS + (see _MPS_BROKEN_TORCH_VERSIONS). This warning fires even when the user has + opted in via TRANSFORMERLENS_ALLOW_MPS=1, because the affected operations + produce silently wrong outputs regardless of opt-in. """ - global _mps_warned - if _mps_warned: - return + global _mps_warned, _mps_broken_torch_warned if isinstance(device, torch.device): device = device.type - if isinstance(device, str) and device == "mps": - if ( - _MPS_MIN_SAFE_TORCH_VERSION is not None - and _torch_version_tuple() >= _MPS_MIN_SAFE_TORCH_VERSION - ): - return - if os.environ.get("TRANSFORMERLENS_ALLOW_MPS", "") != "1": - _mps_warned = True - warnings.warn( - "MPS backend may produce silently incorrect results (PyTorch " - f"{torch.__version__}). " - "Set TRANSFORMERLENS_ALLOW_MPS=1 to suppress this warning. " - "See: https://github.com/TransformerLensOrg/TransformerLens/issues/1178", - UserWarning, - stacklevel=2, - ) + if not (isinstance(device, str) and device == "mps"): + return + + # Known-broken torch versions always warn (can't be opted-out of). + if _torch_mps_has_known_broken_bug() and not _mps_broken_torch_warned: + _mps_broken_torch_warned = True + warnings.warn( + f"PyTorch {torch.__version__} has a known MPS bug that produces " + "silently incorrect results (torch.nn.functional.linear on " + "non-contiguous tensors). This corrupts generate() output and " + "attention computations. Upgrade to torch >= 2.9.0. " + "See: https://github.com/TransformerLensOrg/TransformerLens/issues/1062 " + "and https://github.com/pytorch/pytorch/issues/161640", + UserWarning, + stacklevel=2, + ) + + if _mps_warned: + return + if ( + _MPS_MIN_SAFE_TORCH_VERSION is not None + and _torch_version_tuple() >= _MPS_MIN_SAFE_TORCH_VERSION + ): + return + if os.environ.get("TRANSFORMERLENS_ALLOW_MPS", "") != "1": + _mps_warned = True + warnings.warn( + "MPS backend may produce silently incorrect results (PyTorch " + f"{torch.__version__}). " + "Set TRANSFORMERLENS_ALLOW_MPS=1 to suppress this warning. " + "See: https://github.com/TransformerLensOrg/TransformerLens/issues/1178", + UserWarning, + stacklevel=2, + ) # ---------------------------------------------------------------------------