diff --git a/nemo_automodel/_transformers/kernel_patches.py b/nemo_automodel/_transformers/kernel_patches.py index 3aadffc56e..6438ab3466 100644 --- a/nemo_automodel/_transformers/kernel_patches.py +++ b/nemo_automodel/_transformers/kernel_patches.py @@ -136,6 +136,51 @@ def _patch_liger_kernel(model): raise RuntimeError("Failed to patch model") +def _patch_legacy_flash_attn_flag(): + """Bridge the legacy ``_supports_flash_attn_2`` class flag to v5.5's + ``_supports_flash_attn``. + + transformers v5.5 renamed the FA2-support attribute from + ``_supports_flash_attn_2`` to ``_supports_flash_attn`` and switched the + dispatch check at ``_flash_attn_can_dispatch`` to the new name only. + Remote-code models pinned against <=v5.3 (e.g. microsoft/Phi-4-multimodal-instruct + sets ``_supports_flash_attn_2 = True`` in its modeling file) are not aware + of the rename, so their FA2 support is invisible to v5.5 and + ``attn_implementation="flash_attention_2"`` raises ``ValueError``. + + Install a property on ``PreTrainedModel._supports_flash_attn`` that falls + back to the legacy flag when a subclass has not set the new one. Subclasses + that set ``_supports_flash_attn = True`` directly still shadow the property + via normal MRO lookup, so native models are unaffected. + """ + import transformers.modeling_utils as mu + + base = mu.PreTrainedModel + if getattr(base, "_nemo_fa2_flag_bridged", False): + return + + # Capture the base-class default (``False`` on v5.5) so the fallback + # preserves original behavior when no flag is set anywhere. + _base_default = base.__dict__.get("_supports_flash_attn", False) + + def _supports_flash_attn_fget(self): + cls = type(self) + for klass in cls.__mro__: + # Stop at the base — the property lives here; anything below is + # just the captured default. + if klass is base: + break + d = klass.__dict__ + if "_supports_flash_attn" in d: + return d["_supports_flash_attn"] + if d.get("_supports_flash_attn_2") is True: + return True + return _base_default + + base._supports_flash_attn = property(_supports_flash_attn_fget) + base._nemo_fa2_flag_bridged = True # type: ignore[attr-defined] + + def _get_next_fallback_attn(attn_implementation: str) -> str: """ Get the next attention implementation in the priority list, in reverse order. diff --git a/nemo_automodel/_transformers/utils.py b/nemo_automodel/_transformers/utils.py index f0525d4ebc..06e303c1c6 100644 --- a/nemo_automodel/_transformers/utils.py +++ b/nemo_automodel/_transformers/utils.py @@ -236,6 +236,10 @@ def _patched_post_init(self): _patch_phi4mm_processor() _patch_peft_prepare_inputs() + from nemo_automodel._transformers.kernel_patches import _patch_legacy_flash_attn_flag + + _patch_legacy_flash_attn_flag() + def _patch_phi4mm_processor(): """Patch AutoProcessor.from_pretrained to fall back to the remote diff --git a/nemo_automodel/components/models/qwen3_5_moe/cp_linear_attn.py b/nemo_automodel/components/models/qwen3_5_moe/cp_linear_attn.py index 07b414e69f..63a86bb0fc 100644 --- a/nemo_automodel/components/models/qwen3_5_moe/cp_linear_attn.py +++ b/nemo_automodel/components/models/qwen3_5_moe/cp_linear_attn.py @@ -97,8 +97,11 @@ def _forward_no_cp( ): """HF GatedDeltaNet forward with FSDP-safe fp32 gate computation. - Copied from transformers==5.3.0 Qwen3_5GatedDeltaNet.forward - with gate computation replaced by self._compute_gate(a). + Mirrors transformers==5.5 Qwen3_5GatedDeltaNet.forward (uses the + per-layer cache API: ``has_previous_state(layer_idx)``, + ``cache_params.layers[layer_idx].{conv,recurrent}_states``, and the + ``update_{conv,recurrent}_state`` methods) with the gate computation + replaced by ``self._compute_gate(a)``. """ from transformers.models.qwen3_5.modeling_qwen3_5 import apply_mask_to_padding_states @@ -106,12 +109,12 @@ def _forward_no_cp( batch_size, seq_len, _ = hidden_states.shape use_precomputed_states = ( - cache_params is not None and cache_params.has_previous_state and seq_len == 1 and cache_position is not None + cache_params is not None and cache_params.has_previous_state(self.layer_idx) and seq_len == 1 ) - if cache_params is not None: - conv_state = cache_params.conv_states[self.layer_idx] - recurrent_state = cache_params.recurrent_states[self.layer_idx] + if use_precomputed_states: + conv_state = cache_params.layers[self.layer_idx].conv_states + recurrent_state = cache_params.layers[self.layer_idx].recurrent_states mixed_qkv = self.in_proj_qkv(hidden_states) mixed_qkv = mixed_qkv.transpose(1, 2) @@ -133,7 +136,7 @@ def _forward_no_cp( else: if cache_params is not None: conv_state = F.pad(mixed_qkv, (self.conv_kernel_size - mixed_qkv.shape[-1], 0)) - cache_params.conv_states[self.layer_idx] = conv_state + cache_params.update_conv_state(conv_state, self.layer_idx) if self.causal_conv1d_fn is not None: mixed_qkv = self.causal_conv1d_fn( x=mixed_qkv, @@ -183,7 +186,7 @@ def _forward_no_cp( ) if cache_params is not None: - cache_params.recurrent_states[self.layer_idx] = last_recurrent_state + cache_params.update_recurrent_state(last_recurrent_state, self.layer_idx) core_attn_out = core_attn_out.reshape(-1, self.head_v_dim) z = z.reshape(-1, self.head_v_dim) diff --git a/tests/unit_tests/_transformers/test_auto_model.py b/tests/unit_tests/_transformers/test_auto_model.py index 583b5202c7..c8d8296e2c 100644 --- a/tests/unit_tests/_transformers/test_auto_model.py +++ b/tests/unit_tests/_transformers/test_auto_model.py @@ -182,6 +182,119 @@ def test_get_next_fallback_attn_edge_cases(self): assert _get_next_fallback_attn("0") == "eager" +class TestPatchLegacyFlashAttnFlag: + """Bridge the legacy ``_supports_flash_attn_2`` flag to v5.5's ``_supports_flash_attn``. + + transformers v5.5 renamed the FA2-support attribute and switched the dispatch + check to the new name only. Remote-code models pinned against <=v5.3 still set + the legacy flag; the patch installs a fallback property so they dispatch to FA2. + """ + + def test_installs_property_on_base(self): + import transformers.modeling_utils as mu + + from nemo_automodel._transformers.kernel_patches import _patch_legacy_flash_attn_flag + + _patch_legacy_flash_attn_flag() + assert isinstance(mu.PreTrainedModel.__dict__["_supports_flash_attn"], property) + + def test_is_idempotent(self): + import transformers.modeling_utils as mu + + from nemo_automodel._transformers.kernel_patches import _patch_legacy_flash_attn_flag + + _patch_legacy_flash_attn_flag() + prop1 = mu.PreTrainedModel.__dict__["_supports_flash_attn"] + _patch_legacy_flash_attn_flag() + prop2 = mu.PreTrainedModel.__dict__["_supports_flash_attn"] + assert prop1 is prop2 + + def test_legacy_flag_bridged_to_true(self): + """Subclass with only ``_supports_flash_attn_2 = True`` resolves to True.""" + import transformers.modeling_utils as mu + + from nemo_automodel._transformers.kernel_patches import _patch_legacy_flash_attn_flag + + _patch_legacy_flash_attn_flag() + + class _Legacy(mu.PreTrainedModel): + _supports_flash_attn_2 = True + + assert _Legacy.__new__(_Legacy)._supports_flash_attn is True + + def test_explicit_new_flag_true_wins(self): + """Subclass that sets ``_supports_flash_attn = True`` directly shadows the property.""" + import transformers.modeling_utils as mu + + from nemo_automodel._transformers.kernel_patches import _patch_legacy_flash_attn_flag + + _patch_legacy_flash_attn_flag() + + class _Native(mu.PreTrainedModel): + _supports_flash_attn = True + + assert _Native.__new__(_Native)._supports_flash_attn is True + + def test_explicit_new_flag_false_wins_over_legacy_true(self): + """Explicit ``_supports_flash_attn = False`` shadows a legacy True.""" + import transformers.modeling_utils as mu + + from nemo_automodel._transformers.kernel_patches import _patch_legacy_flash_attn_flag + + _patch_legacy_flash_attn_flag() + + class _Native(mu.PreTrainedModel): + _supports_flash_attn = False + _supports_flash_attn_2 = True + + assert _Native.__new__(_Native)._supports_flash_attn is False + + def test_neither_flag_falls_back_to_base_default(self): + """Subclass with neither flag falls back to the captured base default (False).""" + import transformers.modeling_utils as mu + + from nemo_automodel._transformers.kernel_patches import _patch_legacy_flash_attn_flag + + _patch_legacy_flash_attn_flag() + + class _Bare(mu.PreTrainedModel): + pass + + assert _Bare.__new__(_Bare)._supports_flash_attn is False + + def test_legacy_flag_false_does_not_bridge(self): + """Only ``_supports_flash_attn_2 is True`` bridges; False passes through.""" + import transformers.modeling_utils as mu + + from nemo_automodel._transformers.kernel_patches import _patch_legacy_flash_attn_flag + + _patch_legacy_flash_attn_flag() + + class _LegacyFalse(mu.PreTrainedModel): + _supports_flash_attn_2 = False + + assert _LegacyFalse.__new__(_LegacyFalse)._supports_flash_attn is False + + def test_nearest_subclass_wins_in_mro(self): + """In multi-level inheritance, the nearest ``_supports_flash_attn`` in MRO wins.""" + import transformers.modeling_utils as mu + + from nemo_automodel._transformers.kernel_patches import _patch_legacy_flash_attn_flag + + _patch_legacy_flash_attn_flag() + + class _Ancestor(mu.PreTrainedModel): + _supports_flash_attn_2 = True + + class _Mid(_Ancestor): + _supports_flash_attn = False + + class _Leaf(_Mid): + pass + + assert _Leaf.__new__(_Leaf)._supports_flash_attn is False + + class DummyModel(torch.nn.Module): """A tiny nn.Module that behaves enough like a HF/BERT style model.""" diff --git a/tests/unit_tests/models/qwen3_5_moe/test_cp_linear_attn.py b/tests/unit_tests/models/qwen3_5_moe/test_cp_linear_attn.py index 748a65b988..8e1e3fdf8c 100644 --- a/tests/unit_tests/models/qwen3_5_moe/test_cp_linear_attn.py +++ b/tests/unit_tests/models/qwen3_5_moe/test_cp_linear_attn.py @@ -302,6 +302,68 @@ def test_cp_mesh_gt_1_calls_forward_with_cp(self, module, device): mock_cp_fwd.assert_called_once() +class TestForwardNoCpV55CacheAPI: + """_forward_no_cp must use the transformers v5.5 per-layer cache API. + + v5.5 renamed ``has_previous_state`` to a method taking ``layer_idx``, moved + states under ``cache.layers[layer_idx]``, and exposes ``update_conv_state`` / + ``update_recurrent_state`` methods instead of direct dict assignment. A plain + ``DynamicCache`` (no pre-existing state, as in training) has no top-level + ``conv_states`` attribute — the pre-v5.5 pattern raised ``AttributeError``. + """ + + def test_training_cache_no_previous_state_runs(self, module, text_config, device): + """Training-style forward with a fresh DynamicCache (no previous state) must not raise.""" + from transformers import DynamicCache + + B, S, D = 1, 8, module.hidden_size + hidden = torch.randn(B, S, D, device=device) + out = module._forward_no_cp(hidden, cache_params=DynamicCache(config=text_config)) + assert out.shape == (B, S, D) + + def test_no_cache_path_still_works(self, module, device): + """When cache_params is None, _forward_no_cp runs the pure compute path.""" + B, S, D = 1, 8, module.hidden_size + hidden = torch.randn(B, S, D, device=device) + out = module._forward_no_cp(hidden, cache_params=None) + assert out.shape == (B, S, D) + + def test_updates_conv_state_via_method(self, module, text_config, device): + """Prefill writes the conv state via ``update_conv_state(state, layer_idx)``.""" + from transformers import DynamicCache + + B, S, D = 1, 8, module.hidden_size + hidden = torch.randn(B, S, D, device=device) + cache = DynamicCache(config=text_config) + with ( + patch.object(cache, "update_conv_state", wraps=cache.update_conv_state) as mock_update_conv, + patch.object(cache, "update_recurrent_state", wraps=cache.update_recurrent_state) as mock_update_rec, + ): + module._forward_no_cp(hidden, cache_params=cache) + mock_update_conv.assert_called_once() + # Written at the layer_idx owned by the module. + args, _ = mock_update_conv.call_args + assert args[1] == module.layer_idx + mock_update_rec.assert_called_once() + + def test_has_previous_state_called_as_method_with_layer_idx(self, module, text_config, device): + """v5.5 ``has_previous_state`` is a method that takes ``layer_idx``.""" + from transformers import DynamicCache + + B, S, D = 1, 8, module.hidden_size + hidden = torch.randn(B, S, D, device=device) + cache = DynamicCache(config=text_config) + with patch.object(cache, "has_previous_state", wraps=cache.has_previous_state) as mock_hps: + module._forward_no_cp(hidden, cache_params=cache) + mock_hps.assert_called() + # At least one call must pass the module's layer_idx. + layer_idx_seen = any( + (call.args and call.args[0] == module.layer_idx) or call.kwargs.get("layer_idx") == module.layer_idx + for call in mock_hps.call_args_list + ) + assert layer_idx_seen + + # ============================================================================ # _conv1d_with_cp # ============================================================================