Skip to content

[Qwen3.5] Fix GDN linear attention multi-token cached forward#45513

Merged
vasqu merged 9 commits intohuggingface:mainfrom
kashif:fix-qwen35-linear-attn-multi-token-cached
Apr 27, 2026
Merged

[Qwen3.5] Fix GDN linear attention multi-token cached forward#45513
vasqu merged 9 commits intohuggingface:mainfrom
kashif:fix-qwen35-linear-attn-multi-token-cached

Conversation

@kashif
Copy link
Copy Markdown
Contributor

@kashif kashif commented Apr 19, 2026

What does this PR do?

The gated-delta-net (GDN) forward only used the cached recurrent state when seq_len == 1. For any multi-token forward with a populated cache (e.g. chunked prefill continuation or speculative-decoding verification), it fell through to chunk_gated_delta_rule(initial_state=None), silently restarting the linear layers from zero and ignoring the prefill state.

This breaks the causal-LM invariant that the logits at position i must not depend on whether later tokens are batched into the same call — position 0 of a 16-token verify forward ended up differing from the corresponding single-token cached decode, collapsing to high-frequency context tokens and destroying speculative-decoding correctness.

Add a use_cached_chunk path that, when has_previous_state is true and seq_len > 1:

  • reads the cached conv_state / recurrent_state,
  • prepends the conv context onto the chunk input so the causal conv sees the correct left-context,
  • drops the prepended context from the output,
  • passes the cached recurrent_state as initial_state to chunk_gated_delta_rule.

The same fix propagates to qwen3_5_moe via the modular system and is applied to qwen3_next and olmo_hybrid.

Add a unit test that compares the first-position output of a multi-token cached forward against the single-token cached forward on the same token and cache. Without this fix the mismatch is 100%.

Fixes # (issue)

Code Agent Policy

The Transformers repo is currently being overwhelmed by a large number of PRs and issue comments written by
code agents. We are currently bottlenecked by our ability to review and respond to them. As a result,
we ask that new users do not submit pure code agent PRs at this time.
You may use code agents in drafting or to help you diagnose issues. We'd also ask autonomous "OpenClaw"-like agents
not to open any PRs or issues for the moment.

PRs that appear to be fully agent-written will probably be closed without review, and we may block users who do this
repeatedly or maliciously.

This is a rapidly-evolving situation that's causing significant shockwaves in the open-source community. As a result,
this policy is likely to be updated regularly in the near future. For more information, please read CONTRIBUTING.md.

  • I confirm that this is not a pure code agent PR.

Before submitting

  • This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case).
  • Did you read the contributor guideline,
    Pull Request section?
  • Was this discussed/approved via a Github issue or the forum? Please add a link
    to it if that's the case.
  • Did you make sure to update the documentation with your changes? Here are the
    documentation guidelines, and
    here are tips on formatting docstrings.
  • Did you write any new necessary tests?

Who can review?

Anyone in the community is free to review the PR once the tests have passed. Feel free to tag
members/contributors who may be interested in your PR.

The gated-delta-net forward only used the cached recurrent state when
`seq_len == 1`. For any multi-token forward with a populated cache (e.g.
chunked prefill continuation or speculative-decoding verification), it
fell through to `chunk_gated_delta_rule(initial_state=None)`, silently
restarting the linear layers from zero and ignoring the prefill state.

This breaks the causal-LM invariant that the logits at position `i` must
not depend on whether later tokens are batched into the same call —
position 0 of a 16-token verify forward ended up differing from the
corresponding single-token cached decode, collapsing to high-frequency
context tokens and destroying speculative-decoding correctness.

Add a `use_cached_chunk` path that, when `has_previous_state` is true
and `seq_len > 1`:
- reads the cached `conv_state` / `recurrent_state`,
- prepends the conv context onto the chunk input so the causal conv sees
  the correct left-context,
- drops the prepended context from the output,
- passes the cached `recurrent_state` as `initial_state` to
  `chunk_gated_delta_rule`.

The same fix propagates to `qwen3_5_moe` via the modular system.

Add a unit test that compares the first-position output of a
multi-token cached forward against the single-token cached forward on
the same token and cache. Without this fix the mismatch is 100%.
@kashif kashif requested a review from Cyrilvallez April 19, 2026 09:46
@HuggingFaceDocBuilderDev
Copy link
Copy Markdown

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

@Cyrilvallez
Copy link
Copy Markdown
Member

cc @vasqu here to double-check as well: I had raised similar concerns when I was refactoring the mamba caches, but I'm less familiar than you with all the mamba kernels and how the updates should exactly look like. But some mambas use the seq_len == 1 part and other don't, and as I said before idk if it was really intended or not.

Fix looks legit though, but would appreciate second opiniuon!

@kashif
Copy link
Copy Markdown
Contributor Author

kashif commented Apr 23, 2026

Thanks for the thorough check! The seq_len == 1 pattern isn't intentional — I believe it was copied forward from mamba → qwen3_next → qwen3_5 without anyone exercising the "populated cache + seq_len > 1" path (plain .generate() always has seq_len == 1 in decode, and training doesn't use the cache). I grepped transformers and the same has_previous_state(...) and seq_len == 1 guard appears in 9 mamba-family files: bamba, falcon_h1, granitemoehybrid, jamba, olmo_hybrid, qwen3_next, qwen3_5, qwen3_5_moe, zamba. Every one of them drops the cached recurrent state in the chunk path (passes initial_state=None) and zero-pads the conv state instead of prepending the cached context.

The canonical reference is flash-linear-attention. I grepped fla/layers/ and every linear-attention implementation (gated_deltanet, delta_net, gla, simple_gla, hgrn, hgrn2, rwkv6, rwkv7, lightnet, gated_deltaproduct — 10 I checked) follows the same canonical pattern:

recurrent_state = last_state['recurrent_state'] if last_state is not None else None
mode = 'fused_recurrent' if q_len <= 64 else self.mode   # perf heuristic, not correctness
o, recurrent_state = chunk_or_fused_recurrent_kernel(..., initial_state=recurrent_state, ...)

No seq_len gating on correctness. The q_len <= 64 cutoff in FLA is a kernel-selection heuristic (fused-recurrent is faster for short seqs), not a correctness gate — both kernels accept and propagate initial_state.

This fix brings qwen3_5 in line with FLA. I've constrained the change to qwen3_5 + qwen3_5_moe (via modular) so it's reviewable; happy to open follow-up PRs for the other mamba-family files if you'd like them audited together.

The new unit test (test_linear_attention_multi_token_cached_forward_matches_single_token) is the minimal invariant the old code violated — a causal LM's logits at position i must not depend on whether later tokens are batched into the same call, across any combination of seq_len and cache state. 100% element mismatch on main, passes after the fix.

Copy link
Copy Markdown
Contributor

@vasqu vasqu left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Happy to have that, it is a valid usecase and just a limitation from our side tbh (which you lift here). I feel like something similar could be done for Mamba2 as well but maybe out of scope for here

My main comments is more about conventions and maybe propagating to qwen3 next and olmo hybrid if possible

# here, so we fall into the chunked kernel and must carry forward both the cached conv
# context and the cached recurrent state; otherwise the first-position outputs silently
# diverge from the equivalent single-token decode.
use_cached_chunk = cache_params is not None and cache_params.has_previous_state(self.layer_idx) and seq_len > 1
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why not remove the seq_len > 1 then? I agree that this shouldn't be a restriction as it will be used as initial state and the forward essentially stays the same. Same/Similar for qwen3 next then

We have something similar to Mamba2, unsure about Mamba1. Would be a nice to have there as well but not directly needed in this PR. Reminds me of these lines that were removed

if cache_params is not None and cache_params.has_previous_state:
previous_states = cache_params.ssm_states[self.layer_idx][:, None, ...].to(device=states.device)
else:
previous_states = torch.zeros_like(states[:, :1])

Comment on lines +253 to +256
if use_cached_chunk:
# Prepend the cached conv context so the causal conv sees the correct left-context
# when continuing a prior forward rather than starting from zero-padding.
mixed_qkv = torch.cat([conv_state, mixed_qkv], dim=-1)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok I see now, we have to accomodate for the seq_len > 1 case. I'd still keep the terms precomputed_state and something along single_token vs chunk_tokens to have clearer boundaries

@kashif
Copy link
Copy Markdown
Contributor Author

kashif commented Apr 23, 2026

thanks @vasqu let me simplify and update

@kashif
Copy link
Copy Markdown
Contributor Author

kashif commented Apr 23, 2026

should i do the qwen3_next and olmo_hybrid fixes in a followup PR or this one @vasqu ?

use_precomputed_states = (
cache_params is not None and cache_params.has_previous_state(self.layer_idx) and seq_len == 1
)
use_cached_chunk = cache_params is not None and cache_params.has_previous_state(self.layer_idx)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
use_cached_chunk = cache_params is not None and cache_params.has_previous_state(self.layer_idx)

Ah sorry I meant it a bit differently, imo both are precomputed states that are differing in their amount of tokens so I'd like to split the logic explicitly with focusing on seq_len where needed. So

  1. Only use precomputed states (and remove the seq len condition)
  2. Where needed check for the initial seq len to determine if we have a single decode step or chunks

@vasqu
Copy link
Copy Markdown
Contributor

vasqu commented Apr 23, 2026

@kashif would be nice if we could do those models here as well - they are all based on GDN

@kashif
Copy link
Copy Markdown
Contributor Author

kashif commented Apr 23, 2026

yup! fixing all GDN ones

…hunk locally

Replace the two `use_precomputed_states` / `use_cached_chunk` variables with a single
`use_precomputed_states = cache_params is not None and cache_params.has_previous_state(...)`
that just signals "we have cached conv/recurrent state to continue from". The split between
the single-token (fused per-step) and chunk-tokens (chunk kernel + cached conv context) modes
is now expressed locally via `seq_len == 1` checks at the three places where it actually
matters — kernel dispatch, conv-context prepend, and prepend-drop slice — as requested in
review.

Behavior is unchanged; this is pure restructuring for clarity. `modeling_qwen3_5.py` and
`modeling_qwen3_5_moe.py` regenerated via `check_modular_conversion.py --fix_and_overwrite`.
@kashif kashif force-pushed the fix-qwen35-linear-attn-multi-token-cached branch from 0f14464 to e395f5a Compare April 23, 2026 15:45
…xt and olmo_hybrid

qwen3_next's GatedDeltaNet had the same bug as qwen3_5: for a multi-token forward after the
cache was populated (chunked-prefill continuation or speculative-decoding verification), the
chunk kernel was called with `initial_state=None` and the conv state was zero-padded, silently
dropping the cached state. Applies the same pattern: unified `use_precomputed_states` flag
(no seq_len condition), with single-token vs chunk-tokens routing gated locally on
`seq_len == 1`.

olmo_hybrid has the same kind of bug in its custom `OlmoHybridShortConvolution` — when the
caller had cached state but fed more than one token (`use_precomputed=False` under the old
`seq_len==1` gate), the conv was zero-padded instead of using the cached context. Fix the
caller to drop the `seq_len==1` gate from `use_precomputed`, make the ShortConvolution branch
on `seq_len==1` locally, and route recurrent vs chunk kernel dispatch in the caller likewise.
olmo_hybrid's chunk kernel path already passed `initial_state=recurrent_state` correctly; no
change needed there beyond the `use_precomputed` flag semantics.

Add the same causal-LM invariance test to both suites
(`test_linear_attention_multi_token_cached_forward_matches_single_token`). 100% element
mismatch on the old code, passes after the fix. RUN_SLOW integration suites for both models
pass end-to-end.

`modeling_qwen3_next.py` and `modeling_olmo_hybrid.py` regenerated via
`check_modular_conversion.py --fix_and_overwrite`.
@kashif kashif changed the title [Qwen3.5] Fix Qwen3.5 linear attention multi-token cached forward [Qwen3.5] Fix GDN linear attention multi-token cached forward Apr 23, 2026
@kashif
Copy link
Copy Markdown
Contributor Author

kashif commented Apr 27, 2026

@vasqu if you can have another look whenever you are free

Copy link
Copy Markdown
Contributor

@vasqu vasqu left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, just some last nits left, then we are good to merge 🤗

Comment thread src/transformers/models/olmo_hybrid/modular_olmo_hybrid.py
Comment thread src/transformers/models/qwen3_5/modular_qwen3_5.py Outdated
…bel prepend block

Restore the historical `# Multi-token forward (prefill mode)` comment on the chunk-mode
else-branch in olmo_hybrid (and the equivalent qwen3_5 / qwen3_next paths) and adjust the
wording so the two intents — fresh prefill vs. cached chunked-tokens decode — are visible
on the same line. Tag the conv-context prepend block with a "dropped at the end of this
branch" hint so a reader knows the mirror operation exists.

Comment-only; behavior unchanged. Generated files re-emitted via
`check_modular_conversion.py --fix_and_overwrite`.
@github-actions
Copy link
Copy Markdown
Contributor

[For maintainers] Suggested jobs to run (before merge)

run-slow: olmo_hybrid, qwen3_5, qwen3_5_moe, qwen3_next

@vasqu
Copy link
Copy Markdown
Contributor

vasqu commented Apr 27, 2026

run-slow: olmo_hybrid, qwen3_5, qwen3_5_moe, qwen3_next

@github-actions
Copy link
Copy Markdown
Contributor

Workflow Run ⚙️

This comment contains run-slow, running the specified jobs:

models: ["models/olmo_hybrid", "models/qwen3_5", "models/qwen3_5_moe", "models/qwen3_next"]
quantizations: []

@vasqu
Copy link
Copy Markdown
Contributor

vasqu commented Apr 27, 2026

Just sanity checking with run-slow, but should be good will merge afterwards

@github-actions
Copy link
Copy Markdown
Contributor

CI Results

Workflow Run ⚙️

Commit Info

Context Commit Description
RUN 37b7ffb3 workflow commit (merge commit)
PR 40f471b9 branch commit (from PR)
main e651c68e base commit (on main)

✅ No failing test specific to this PR 🎉 👏 !

@vasqu vasqu added this pull request to the merge queue Apr 27, 2026
@vasqu
Copy link
Copy Markdown
Contributor

vasqu commented Apr 27, 2026

@kashif thanks for the PR 🤗 merging now!

Merged via the queue into huggingface:main with commit f53ca05 Apr 27, 2026
22 checks passed
ArthurZucker pushed a commit that referenced this pull request Apr 28, 2026
* Fix Qwen3.5 linear attention multi-token cached forward

The gated-delta-net forward only used the cached recurrent state when
`seq_len == 1`. For any multi-token forward with a populated cache (e.g.
chunked prefill continuation or speculative-decoding verification), it
fell through to `chunk_gated_delta_rule(initial_state=None)`, silently
restarting the linear layers from zero and ignoring the prefill state.

This breaks the causal-LM invariant that the logits at position `i` must
not depend on whether later tokens are batched into the same call —
position 0 of a 16-token verify forward ended up differing from the
corresponding single-token cached decode, collapsing to high-frequency
context tokens and destroying speculative-decoding correctness.

Add a `use_cached_chunk` path that, when `has_previous_state` is true
and `seq_len > 1`:
- reads the cached `conv_state` / `recurrent_state`,
- prepends the conv context onto the chunk input so the causal conv sees
  the correct left-context,
- drops the prepended context from the output,
- passes the cached `recurrent_state` as `initial_state` to
  `chunk_gated_delta_rule`.

The same fix propagates to `qwen3_5_moe` via the modular system.

Add a unit test that compares the first-position output of a
multi-token cached forward against the single-token cached forward on
the same token and cache. Without this fix the mismatch is 100%.

* Review feedback: unify cached-forward state flag, gate single-token/chunk locally

Replace the two `use_precomputed_states` / `use_cached_chunk` variables with a single
`use_precomputed_states = cache_params is not None and cache_params.has_previous_state(...)`
that just signals "we have cached conv/recurrent state to continue from". The split between
the single-token (fused per-step) and chunk-tokens (chunk kernel + cached conv context) modes
is now expressed locally via `seq_len == 1` checks at the three places where it actually
matters — kernel dispatch, conv-context prepend, and prepend-drop slice — as requested in
review.

Behavior is unchanged; this is pure restructuring for clarity. `modeling_qwen3_5.py` and
`modeling_qwen3_5_moe.py` regenerated via `check_modular_conversion.py --fix_and_overwrite`.

* Propagate linear-attention multi-token cached-forward fix to qwen3_next and olmo_hybrid

qwen3_next's GatedDeltaNet had the same bug as qwen3_5: for a multi-token forward after the
cache was populated (chunked-prefill continuation or speculative-decoding verification), the
chunk kernel was called with `initial_state=None` and the conv state was zero-padded, silently
dropping the cached state. Applies the same pattern: unified `use_precomputed_states` flag
(no seq_len condition), with single-token vs chunk-tokens routing gated locally on
`seq_len == 1`.

olmo_hybrid has the same kind of bug in its custom `OlmoHybridShortConvolution` — when the
caller had cached state but fed more than one token (`use_precomputed=False` under the old
`seq_len==1` gate), the conv was zero-padded instead of using the cached context. Fix the
caller to drop the `seq_len==1` gate from `use_precomputed`, make the ShortConvolution branch
on `seq_len==1` locally, and route recurrent vs chunk kernel dispatch in the caller likewise.
olmo_hybrid's chunk kernel path already passed `initial_state=recurrent_state` correctly; no
change needed there beyond the `use_precomputed` flag semantics.

Add the same causal-LM invariance test to both suites
(`test_linear_attention_multi_token_cached_forward_matches_single_token`). 100% element
mismatch on the old code, passes after the fix. RUN_SLOW integration suites for both models
pass end-to-end.

`modeling_qwen3_next.py` and `modeling_olmo_hybrid.py` regenerated via
`check_modular_conversion.py --fix_and_overwrite`.

* Review feedback: keep "prefill mode / multi-token decode" comment, label prepend block

Restore the historical `# Multi-token forward (prefill mode)` comment on the chunk-mode
else-branch in olmo_hybrid (and the equivalent qwen3_5 / qwen3_next paths) and adjust the
wording so the two intents — fresh prefill vs. cached chunked-tokens decode — are visible
on the same line. Tag the conv-context prepend block with a "dropped at the end of this
branch" hint so a reader knows the mirror operation exists.

Comment-only; behavior unchanged. Generated files re-emitted via
`check_modular_conversion.py --fix_and_overwrite`.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants