Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
45 changes: 45 additions & 0 deletions nemo_automodel/_transformers/kernel_patches.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,6 +136,51 @@ def _patch_liger_kernel(model):
raise RuntimeError("Failed to patch model")


def _patch_legacy_flash_attn_flag():
"""Bridge the legacy ``_supports_flash_attn_2`` class flag to v5.5's
``_supports_flash_attn``.

transformers v5.5 renamed the FA2-support attribute from
``_supports_flash_attn_2`` to ``_supports_flash_attn`` and switched the
dispatch check at ``_flash_attn_can_dispatch`` to the new name only.
Remote-code models pinned against <=v5.3 (e.g. microsoft/Phi-4-multimodal-instruct
sets ``_supports_flash_attn_2 = True`` in its modeling file) are not aware
of the rename, so their FA2 support is invisible to v5.5 and
``attn_implementation="flash_attention_2"`` raises ``ValueError``.

Install a property on ``PreTrainedModel._supports_flash_attn`` that falls
back to the legacy flag when a subclass has not set the new one. Subclasses
that set ``_supports_flash_attn = True`` directly still shadow the property
via normal MRO lookup, so native models are unaffected.
"""
import transformers.modeling_utils as mu

base = mu.PreTrainedModel
if getattr(base, "_nemo_fa2_flag_bridged", False):
return

# Capture the base-class default (``False`` on v5.5) so the fallback
# preserves original behavior when no flag is set anywhere.
_base_default = base.__dict__.get("_supports_flash_attn", False)

def _supports_flash_attn_fget(self):
cls = type(self)
for klass in cls.__mro__:
# Stop at the base — the property lives here; anything below is
# just the captured default.
if klass is base:
break
d = klass.__dict__
if "_supports_flash_attn" in d:
return d["_supports_flash_attn"]
if d.get("_supports_flash_attn_2") is True:
return True
return _base_default

base._supports_flash_attn = property(_supports_flash_attn_fget)
base._nemo_fa2_flag_bridged = True # type: ignore[attr-defined]


def _get_next_fallback_attn(attn_implementation: str) -> str:
"""
Get the next attention implementation in the priority list, in reverse order.
Expand Down
4 changes: 4 additions & 0 deletions nemo_automodel/_transformers/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -236,6 +236,10 @@ def _patched_post_init(self):
_patch_phi4mm_processor()
_patch_peft_prepare_inputs()

from nemo_automodel._transformers.kernel_patches import _patch_legacy_flash_attn_flag

_patch_legacy_flash_attn_flag()


def _patch_phi4mm_processor():
"""Patch AutoProcessor.from_pretrained to fall back to the remote
Expand Down
19 changes: 11 additions & 8 deletions nemo_automodel/components/models/qwen3_5_moe/cp_linear_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,21 +97,24 @@ def _forward_no_cp(
):
"""HF GatedDeltaNet forward with FSDP-safe fp32 gate computation.

Copied from transformers==5.3.0 Qwen3_5GatedDeltaNet.forward
with gate computation replaced by self._compute_gate(a).
Mirrors transformers==5.5 Qwen3_5GatedDeltaNet.forward (uses the
per-layer cache API: ``has_previous_state(layer_idx)``,
``cache_params.layers[layer_idx].{conv,recurrent}_states``, and the
``update_{conv,recurrent}_state`` methods) with the gate computation
replaced by ``self._compute_gate(a)``.
"""
from transformers.models.qwen3_5.modeling_qwen3_5 import apply_mask_to_padding_states

hidden_states = apply_mask_to_padding_states(hidden_states, attention_mask)
batch_size, seq_len, _ = hidden_states.shape

use_precomputed_states = (
cache_params is not None and cache_params.has_previous_state and seq_len == 1 and cache_position is not None
cache_params is not None and cache_params.has_previous_state(self.layer_idx) and seq_len == 1
)

if cache_params is not None:
conv_state = cache_params.conv_states[self.layer_idx]
recurrent_state = cache_params.recurrent_states[self.layer_idx]
if use_precomputed_states:
conv_state = cache_params.layers[self.layer_idx].conv_states
recurrent_state = cache_params.layers[self.layer_idx].recurrent_states

mixed_qkv = self.in_proj_qkv(hidden_states)
mixed_qkv = mixed_qkv.transpose(1, 2)
Expand All @@ -133,7 +136,7 @@ def _forward_no_cp(
else:
if cache_params is not None:
conv_state = F.pad(mixed_qkv, (self.conv_kernel_size - mixed_qkv.shape[-1], 0))
cache_params.conv_states[self.layer_idx] = conv_state
cache_params.update_conv_state(conv_state, self.layer_idx)
if self.causal_conv1d_fn is not None:
mixed_qkv = self.causal_conv1d_fn(
x=mixed_qkv,
Expand Down Expand Up @@ -183,7 +186,7 @@ def _forward_no_cp(
)

if cache_params is not None:
cache_params.recurrent_states[self.layer_idx] = last_recurrent_state
cache_params.update_recurrent_state(last_recurrent_state, self.layer_idx)

core_attn_out = core_attn_out.reshape(-1, self.head_v_dim)
z = z.reshape(-1, self.head_v_dim)
Expand Down
113 changes: 113 additions & 0 deletions tests/unit_tests/_transformers/test_auto_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -182,6 +182,119 @@ def test_get_next_fallback_attn_edge_cases(self):
assert _get_next_fallback_attn("0") == "eager"


class TestPatchLegacyFlashAttnFlag:
"""Bridge the legacy ``_supports_flash_attn_2`` flag to v5.5's ``_supports_flash_attn``.

transformers v5.5 renamed the FA2-support attribute and switched the dispatch
check to the new name only. Remote-code models pinned against <=v5.3 still set
the legacy flag; the patch installs a fallback property so they dispatch to FA2.
"""

def test_installs_property_on_base(self):
import transformers.modeling_utils as mu

from nemo_automodel._transformers.kernel_patches import _patch_legacy_flash_attn_flag

_patch_legacy_flash_attn_flag()
assert isinstance(mu.PreTrainedModel.__dict__["_supports_flash_attn"], property)

def test_is_idempotent(self):
import transformers.modeling_utils as mu

from nemo_automodel._transformers.kernel_patches import _patch_legacy_flash_attn_flag

_patch_legacy_flash_attn_flag()
prop1 = mu.PreTrainedModel.__dict__["_supports_flash_attn"]
_patch_legacy_flash_attn_flag()
prop2 = mu.PreTrainedModel.__dict__["_supports_flash_attn"]
assert prop1 is prop2

def test_legacy_flag_bridged_to_true(self):
"""Subclass with only ``_supports_flash_attn_2 = True`` resolves to True."""
import transformers.modeling_utils as mu

from nemo_automodel._transformers.kernel_patches import _patch_legacy_flash_attn_flag

_patch_legacy_flash_attn_flag()

class _Legacy(mu.PreTrainedModel):
_supports_flash_attn_2 = True

assert _Legacy.__new__(_Legacy)._supports_flash_attn is True

def test_explicit_new_flag_true_wins(self):
"""Subclass that sets ``_supports_flash_attn = True`` directly shadows the property."""
import transformers.modeling_utils as mu

from nemo_automodel._transformers.kernel_patches import _patch_legacy_flash_attn_flag

_patch_legacy_flash_attn_flag()

class _Native(mu.PreTrainedModel):
_supports_flash_attn = True

assert _Native.__new__(_Native)._supports_flash_attn is True

def test_explicit_new_flag_false_wins_over_legacy_true(self):
"""Explicit ``_supports_flash_attn = False`` shadows a legacy True."""
import transformers.modeling_utils as mu

from nemo_automodel._transformers.kernel_patches import _patch_legacy_flash_attn_flag

_patch_legacy_flash_attn_flag()

class _Native(mu.PreTrainedModel):
_supports_flash_attn = False
_supports_flash_attn_2 = True

assert _Native.__new__(_Native)._supports_flash_attn is False

def test_neither_flag_falls_back_to_base_default(self):
"""Subclass with neither flag falls back to the captured base default (False)."""
import transformers.modeling_utils as mu

from nemo_automodel._transformers.kernel_patches import _patch_legacy_flash_attn_flag

_patch_legacy_flash_attn_flag()

class _Bare(mu.PreTrainedModel):
pass

assert _Bare.__new__(_Bare)._supports_flash_attn is False

def test_legacy_flag_false_does_not_bridge(self):
"""Only ``_supports_flash_attn_2 is True`` bridges; False passes through."""
import transformers.modeling_utils as mu

from nemo_automodel._transformers.kernel_patches import _patch_legacy_flash_attn_flag

_patch_legacy_flash_attn_flag()

class _LegacyFalse(mu.PreTrainedModel):
_supports_flash_attn_2 = False

assert _LegacyFalse.__new__(_LegacyFalse)._supports_flash_attn is False

def test_nearest_subclass_wins_in_mro(self):
"""In multi-level inheritance, the nearest ``_supports_flash_attn`` in MRO wins."""
import transformers.modeling_utils as mu

from nemo_automodel._transformers.kernel_patches import _patch_legacy_flash_attn_flag

_patch_legacy_flash_attn_flag()

class _Ancestor(mu.PreTrainedModel):
_supports_flash_attn_2 = True

class _Mid(_Ancestor):
_supports_flash_attn = False

class _Leaf(_Mid):
pass

assert _Leaf.__new__(_Leaf)._supports_flash_attn is False


class DummyModel(torch.nn.Module):
"""A tiny nn.Module that behaves enough like a HF/BERT style model."""

Expand Down
62 changes: 62 additions & 0 deletions tests/unit_tests/models/qwen3_5_moe/test_cp_linear_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -302,6 +302,68 @@ def test_cp_mesh_gt_1_calls_forward_with_cp(self, module, device):
mock_cp_fwd.assert_called_once()


class TestForwardNoCpV55CacheAPI:
"""_forward_no_cp must use the transformers v5.5 per-layer cache API.

v5.5 renamed ``has_previous_state`` to a method taking ``layer_idx``, moved
states under ``cache.layers[layer_idx]``, and exposes ``update_conv_state`` /
``update_recurrent_state`` methods instead of direct dict assignment. A plain
``DynamicCache`` (no pre-existing state, as in training) has no top-level
``conv_states`` attribute — the pre-v5.5 pattern raised ``AttributeError``.
"""

def test_training_cache_no_previous_state_runs(self, module, text_config, device):
"""Training-style forward with a fresh DynamicCache (no previous state) must not raise."""
from transformers import DynamicCache

B, S, D = 1, 8, module.hidden_size
hidden = torch.randn(B, S, D, device=device)
out = module._forward_no_cp(hidden, cache_params=DynamicCache(config=text_config))
assert out.shape == (B, S, D)

def test_no_cache_path_still_works(self, module, device):
"""When cache_params is None, _forward_no_cp runs the pure compute path."""
B, S, D = 1, 8, module.hidden_size
hidden = torch.randn(B, S, D, device=device)
out = module._forward_no_cp(hidden, cache_params=None)
assert out.shape == (B, S, D)

def test_updates_conv_state_via_method(self, module, text_config, device):
"""Prefill writes the conv state via ``update_conv_state(state, layer_idx)``."""
from transformers import DynamicCache

B, S, D = 1, 8, module.hidden_size
hidden = torch.randn(B, S, D, device=device)
cache = DynamicCache(config=text_config)
with (
patch.object(cache, "update_conv_state", wraps=cache.update_conv_state) as mock_update_conv,
patch.object(cache, "update_recurrent_state", wraps=cache.update_recurrent_state) as mock_update_rec,
Comment on lines +338 to +340
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit: this assertion is correct but hard to follow. A simpler equivalent:

Suggested change
with (
patch.object(cache, "update_conv_state", wraps=cache.update_conv_state) as mock_update_conv,
patch.object(cache, "update_recurrent_state", wraps=cache.update_recurrent_state) as mock_update_rec,
assert mock_update_conv.call_args.args[1] == module.layer_idx

The call site always passes layer_idx as the second positional arg, so indexing args[1] directly is sufficient and easier to read.

):
module._forward_no_cp(hidden, cache_params=cache)
mock_update_conv.assert_called_once()
# Written at the layer_idx owned by the module.
args, _ = mock_update_conv.call_args
assert args[1] == module.layer_idx
mock_update_rec.assert_called_once()

def test_has_previous_state_called_as_method_with_layer_idx(self, module, text_config, device):
"""v5.5 ``has_previous_state`` is a method that takes ``layer_idx``."""
from transformers import DynamicCache

B, S, D = 1, 8, module.hidden_size
hidden = torch.randn(B, S, D, device=device)
cache = DynamicCache(config=text_config)
with patch.object(cache, "has_previous_state", wraps=cache.has_previous_state) as mock_hps:
module._forward_no_cp(hidden, cache_params=cache)
mock_hps.assert_called()
# At least one call must pass the module's layer_idx.
layer_idx_seen = any(
(call.args and call.args[0] == module.layer_idx) or call.kwargs.get("layer_idx") == module.layer_idx
for call in mock_hps.call_args_list
)
assert layer_idx_seen


# ============================================================================
# _conv1d_with_cp
# ============================================================================
Expand Down
Loading