[Qwen3.5] Fix GDN linear attention multi-token cached forward#45513
[Qwen3.5] Fix GDN linear attention multi-token cached forward#45513vasqu merged 9 commits intohuggingface:mainfrom
Conversation
The gated-delta-net forward only used the cached recurrent state when `seq_len == 1`. For any multi-token forward with a populated cache (e.g. chunked prefill continuation or speculative-decoding verification), it fell through to `chunk_gated_delta_rule(initial_state=None)`, silently restarting the linear layers from zero and ignoring the prefill state. This breaks the causal-LM invariant that the logits at position `i` must not depend on whether later tokens are batched into the same call — position 0 of a 16-token verify forward ended up differing from the corresponding single-token cached decode, collapsing to high-frequency context tokens and destroying speculative-decoding correctness. Add a `use_cached_chunk` path that, when `has_previous_state` is true and `seq_len > 1`: - reads the cached `conv_state` / `recurrent_state`, - prepends the conv context onto the chunk input so the causal conv sees the correct left-context, - drops the prepended context from the output, - passes the cached `recurrent_state` as `initial_state` to `chunk_gated_delta_rule`. The same fix propagates to `qwen3_5_moe` via the modular system. Add a unit test that compares the first-position output of a multi-token cached forward against the single-token cached forward on the same token and cache. Without this fix the mismatch is 100%.
|
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update. |
|
cc @vasqu here to double-check as well: I had raised similar concerns when I was refactoring the mamba caches, but I'm less familiar than you with all the mamba kernels and how the updates should exactly look like. But some mambas use the Fix looks legit though, but would appreciate second opiniuon! |
|
Thanks for the thorough check! The The canonical reference is recurrent_state = last_state['recurrent_state'] if last_state is not None else None
mode = 'fused_recurrent' if q_len <= 64 else self.mode # perf heuristic, not correctness
o, recurrent_state = chunk_or_fused_recurrent_kernel(..., initial_state=recurrent_state, ...)No seq_len gating on correctness. The This fix brings qwen3_5 in line with FLA. I've constrained the change to qwen3_5 + qwen3_5_moe (via modular) so it's reviewable; happy to open follow-up PRs for the other mamba-family files if you'd like them audited together. The new unit test ( |
vasqu
left a comment
There was a problem hiding this comment.
Happy to have that, it is a valid usecase and just a limitation from our side tbh (which you lift here). I feel like something similar could be done for Mamba2 as well but maybe out of scope for here
My main comments is more about conventions and maybe propagating to qwen3 next and olmo hybrid if possible
| # here, so we fall into the chunked kernel and must carry forward both the cached conv | ||
| # context and the cached recurrent state; otherwise the first-position outputs silently | ||
| # diverge from the equivalent single-token decode. | ||
| use_cached_chunk = cache_params is not None and cache_params.has_previous_state(self.layer_idx) and seq_len > 1 |
There was a problem hiding this comment.
Why not remove the seq_len > 1 then? I agree that this shouldn't be a restriction as it will be used as initial state and the forward essentially stays the same. Same/Similar for qwen3 next then
We have something similar to Mamba2, unsure about Mamba1. Would be a nice to have there as well but not directly needed in this PR. Reminds me of these lines that were removed
transformers/src/transformers/models/mamba2/modeling_mamba2.py
Lines 642 to 645 in 83a6c5b
| if use_cached_chunk: | ||
| # Prepend the cached conv context so the causal conv sees the correct left-context | ||
| # when continuing a prior forward rather than starting from zero-padding. | ||
| mixed_qkv = torch.cat([conv_state, mixed_qkv], dim=-1) |
There was a problem hiding this comment.
Ok I see now, we have to accomodate for the seq_len > 1 case. I'd still keep the terms precomputed_state and something along single_token vs chunk_tokens to have clearer boundaries
|
thanks @vasqu let me simplify and update |
|
should i do the qwen3_next and olmo_hybrid fixes in a followup PR or this one @vasqu ? |
| use_precomputed_states = ( | ||
| cache_params is not None and cache_params.has_previous_state(self.layer_idx) and seq_len == 1 | ||
| ) | ||
| use_cached_chunk = cache_params is not None and cache_params.has_previous_state(self.layer_idx) |
There was a problem hiding this comment.
| use_cached_chunk = cache_params is not None and cache_params.has_previous_state(self.layer_idx) |
Ah sorry I meant it a bit differently, imo both are precomputed states that are differing in their amount of tokens so I'd like to split the logic explicitly with focusing on seq_len where needed. So
- Only use precomputed states (and remove the seq len condition)
- Where needed check for the initial seq len to determine if we have a single decode step or chunks
|
@kashif would be nice if we could do those models here as well - they are all based on GDN |
|
yup! fixing all GDN ones |
…hunk locally Replace the two `use_precomputed_states` / `use_cached_chunk` variables with a single `use_precomputed_states = cache_params is not None and cache_params.has_previous_state(...)` that just signals "we have cached conv/recurrent state to continue from". The split between the single-token (fused per-step) and chunk-tokens (chunk kernel + cached conv context) modes is now expressed locally via `seq_len == 1` checks at the three places where it actually matters — kernel dispatch, conv-context prepend, and prepend-drop slice — as requested in review. Behavior is unchanged; this is pure restructuring for clarity. `modeling_qwen3_5.py` and `modeling_qwen3_5_moe.py` regenerated via `check_modular_conversion.py --fix_and_overwrite`.
0f14464 to
e395f5a
Compare
…xt and olmo_hybrid qwen3_next's GatedDeltaNet had the same bug as qwen3_5: for a multi-token forward after the cache was populated (chunked-prefill continuation or speculative-decoding verification), the chunk kernel was called with `initial_state=None` and the conv state was zero-padded, silently dropping the cached state. Applies the same pattern: unified `use_precomputed_states` flag (no seq_len condition), with single-token vs chunk-tokens routing gated locally on `seq_len == 1`. olmo_hybrid has the same kind of bug in its custom `OlmoHybridShortConvolution` — when the caller had cached state but fed more than one token (`use_precomputed=False` under the old `seq_len==1` gate), the conv was zero-padded instead of using the cached context. Fix the caller to drop the `seq_len==1` gate from `use_precomputed`, make the ShortConvolution branch on `seq_len==1` locally, and route recurrent vs chunk kernel dispatch in the caller likewise. olmo_hybrid's chunk kernel path already passed `initial_state=recurrent_state` correctly; no change needed there beyond the `use_precomputed` flag semantics. Add the same causal-LM invariance test to both suites (`test_linear_attention_multi_token_cached_forward_matches_single_token`). 100% element mismatch on the old code, passes after the fix. RUN_SLOW integration suites for both models pass end-to-end. `modeling_qwen3_next.py` and `modeling_olmo_hybrid.py` regenerated via `check_modular_conversion.py --fix_and_overwrite`.
|
@vasqu if you can have another look whenever you are free |
vasqu
left a comment
There was a problem hiding this comment.
LGTM, just some last nits left, then we are good to merge 🤗
…bel prepend block Restore the historical `# Multi-token forward (prefill mode)` comment on the chunk-mode else-branch in olmo_hybrid (and the equivalent qwen3_5 / qwen3_next paths) and adjust the wording so the two intents — fresh prefill vs. cached chunked-tokens decode — are visible on the same line. Tag the conv-context prepend block with a "dropped at the end of this branch" hint so a reader knows the mirror operation exists. Comment-only; behavior unchanged. Generated files re-emitted via `check_modular_conversion.py --fix_and_overwrite`.
|
[For maintainers] Suggested jobs to run (before merge) run-slow: olmo_hybrid, qwen3_5, qwen3_5_moe, qwen3_next |
|
run-slow: olmo_hybrid, qwen3_5, qwen3_5_moe, qwen3_next |
|
This comment contains models: ["models/olmo_hybrid", "models/qwen3_5", "models/qwen3_5_moe", "models/qwen3_next"] |
|
Just sanity checking with run-slow, but should be good will merge afterwards |
|
@kashif thanks for the PR 🤗 merging now! |
* Fix Qwen3.5 linear attention multi-token cached forward The gated-delta-net forward only used the cached recurrent state when `seq_len == 1`. For any multi-token forward with a populated cache (e.g. chunked prefill continuation or speculative-decoding verification), it fell through to `chunk_gated_delta_rule(initial_state=None)`, silently restarting the linear layers from zero and ignoring the prefill state. This breaks the causal-LM invariant that the logits at position `i` must not depend on whether later tokens are batched into the same call — position 0 of a 16-token verify forward ended up differing from the corresponding single-token cached decode, collapsing to high-frequency context tokens and destroying speculative-decoding correctness. Add a `use_cached_chunk` path that, when `has_previous_state` is true and `seq_len > 1`: - reads the cached `conv_state` / `recurrent_state`, - prepends the conv context onto the chunk input so the causal conv sees the correct left-context, - drops the prepended context from the output, - passes the cached `recurrent_state` as `initial_state` to `chunk_gated_delta_rule`. The same fix propagates to `qwen3_5_moe` via the modular system. Add a unit test that compares the first-position output of a multi-token cached forward against the single-token cached forward on the same token and cache. Without this fix the mismatch is 100%. * Review feedback: unify cached-forward state flag, gate single-token/chunk locally Replace the two `use_precomputed_states` / `use_cached_chunk` variables with a single `use_precomputed_states = cache_params is not None and cache_params.has_previous_state(...)` that just signals "we have cached conv/recurrent state to continue from". The split between the single-token (fused per-step) and chunk-tokens (chunk kernel + cached conv context) modes is now expressed locally via `seq_len == 1` checks at the three places where it actually matters — kernel dispatch, conv-context prepend, and prepend-drop slice — as requested in review. Behavior is unchanged; this is pure restructuring for clarity. `modeling_qwen3_5.py` and `modeling_qwen3_5_moe.py` regenerated via `check_modular_conversion.py --fix_and_overwrite`. * Propagate linear-attention multi-token cached-forward fix to qwen3_next and olmo_hybrid qwen3_next's GatedDeltaNet had the same bug as qwen3_5: for a multi-token forward after the cache was populated (chunked-prefill continuation or speculative-decoding verification), the chunk kernel was called with `initial_state=None` and the conv state was zero-padded, silently dropping the cached state. Applies the same pattern: unified `use_precomputed_states` flag (no seq_len condition), with single-token vs chunk-tokens routing gated locally on `seq_len == 1`. olmo_hybrid has the same kind of bug in its custom `OlmoHybridShortConvolution` — when the caller had cached state but fed more than one token (`use_precomputed=False` under the old `seq_len==1` gate), the conv was zero-padded instead of using the cached context. Fix the caller to drop the `seq_len==1` gate from `use_precomputed`, make the ShortConvolution branch on `seq_len==1` locally, and route recurrent vs chunk kernel dispatch in the caller likewise. olmo_hybrid's chunk kernel path already passed `initial_state=recurrent_state` correctly; no change needed there beyond the `use_precomputed` flag semantics. Add the same causal-LM invariance test to both suites (`test_linear_attention_multi_token_cached_forward_matches_single_token`). 100% element mismatch on the old code, passes after the fix. RUN_SLOW integration suites for both models pass end-to-end. `modeling_qwen3_next.py` and `modeling_olmo_hybrid.py` regenerated via `check_modular_conversion.py --fix_and_overwrite`. * Review feedback: keep "prefill mode / multi-token decode" comment, label prepend block Restore the historical `# Multi-token forward (prefill mode)` comment on the chunk-mode else-branch in olmo_hybrid (and the equivalent qwen3_5 / qwen3_next paths) and adjust the wording so the two intents — fresh prefill vs. cached chunked-tokens decode — are visible on the same line. Tag the conv-context prepend block with a "dropped at the end of this branch" hint so a reader knows the mirror operation exists. Comment-only; behavior unchanged. Generated files re-emitted via `check_modular_conversion.py --fix_and_overwrite`.
What does this PR do?
The gated-delta-net (GDN) forward only used the cached recurrent state when
seq_len == 1. For any multi-token forward with a populated cache (e.g. chunked prefill continuation or speculative-decoding verification), it fell through tochunk_gated_delta_rule(initial_state=None), silently restarting the linear layers from zero and ignoring the prefill state.This breaks the causal-LM invariant that the logits at position
imust not depend on whether later tokens are batched into the same call — position 0 of a 16-token verify forward ended up differing from the corresponding single-token cached decode, collapsing to high-frequency context tokens and destroying speculative-decoding correctness.Add a
use_cached_chunkpath that, whenhas_previous_stateis true andseq_len > 1:conv_state/recurrent_state,recurrent_stateasinitial_statetochunk_gated_delta_rule.The same fix propagates to
qwen3_5_moevia the modular system and is applied toqwen3_nextandolmo_hybrid.Add a unit test that compares the first-position output of a multi-token cached forward against the single-token cached forward on the same token and cache. Without this fix the mismatch is 100%.
Fixes # (issue)
Code Agent Policy
The Transformers repo is currently being overwhelmed by a large number of PRs and issue comments written by
code agents. We are currently bottlenecked by our ability to review and respond to them. As a result,
we ask that new users do not submit pure code agent PRs at this time.
You may use code agents in drafting or to help you diagnose issues. We'd also ask autonomous "OpenClaw"-like agents
not to open any PRs or issues for the moment.
PRs that appear to be fully agent-written will probably be closed without review, and we may block users who do this
repeatedly or maliciously.
This is a rapidly-evolving situation that's causing significant shockwaves in the open-source community. As a result,
this policy is likely to be updated regularly in the near future. For more information, please read
CONTRIBUTING.md.Before submitting
Pull Request section?
to it if that's the case.
documentation guidelines, and
here are tips on formatting docstrings.
Who can review?
Anyone in the community is free to review the PR once the tests have passed. Feel free to tag
members/contributors who may be interested in your PR.