From 49a614caf8b48467b39ff7ae99a93603d1a4a772 Mon Sep 17 00:00:00 2001 From: Kashif Rasul Date: Sun, 19 Apr 2026 09:34:04 +0000 Subject: [PATCH 1/4] Fix Qwen3.5 linear attention multi-token cached forward MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit The gated-delta-net forward only used the cached recurrent state when `seq_len == 1`. For any multi-token forward with a populated cache (e.g. chunked prefill continuation or speculative-decoding verification), it fell through to `chunk_gated_delta_rule(initial_state=None)`, silently restarting the linear layers from zero and ignoring the prefill state. This breaks the causal-LM invariant that the logits at position `i` must not depend on whether later tokens are batched into the same call — position 0 of a 16-token verify forward ended up differing from the corresponding single-token cached decode, collapsing to high-frequency context tokens and destroying speculative-decoding correctness. Add a `use_cached_chunk` path that, when `has_previous_state` is true and `seq_len > 1`: - reads the cached `conv_state` / `recurrent_state`, - prepends the conv context onto the chunk input so the causal conv sees the correct left-context, - drops the prepended context from the output, - passes the cached `recurrent_state` as `initial_state` to `chunk_gated_delta_rule`. The same fix propagates to `qwen3_5_moe` via the modular system. Add a unit test that compares the first-position output of a multi-token cached forward against the single-token cached forward on the same token and cache. Without this fix the mismatch is 100%. --- .../models/qwen3_5/modeling_qwen3_5.py | 23 ++++++++--- .../models/qwen3_5/modular_qwen3_5.py | 23 ++++++++--- .../qwen3_5_moe/modeling_qwen3_5_moe.py | 23 ++++++++--- tests/models/qwen3_5/test_modeling_qwen3_5.py | 41 +++++++++++++++++++ 4 files changed, 95 insertions(+), 15 deletions(-) diff --git a/src/transformers/models/qwen3_5/modeling_qwen3_5.py b/src/transformers/models/qwen3_5/modeling_qwen3_5.py index 2c4eba9597dc..3c68fe17a270 100644 --- a/src/transformers/models/qwen3_5/modeling_qwen3_5.py +++ b/src/transformers/models/qwen3_5/modeling_qwen3_5.py @@ -433,9 +433,15 @@ def forward( use_precomputed_states = ( cache_params is not None and cache_params.has_previous_state(self.layer_idx) and seq_len == 1 ) + # Multi-token forward after a prior forward has populated the cache (e.g. chunked prefill + # continuation or speculative verification). The single-token fused kernel can't be used + # here, so we fall into the chunked kernel and must carry forward both the cached conv + # context and the cached recurrent state; otherwise the first-position outputs silently + # diverge from the equivalent single-token decode. + use_cached_chunk = cache_params is not None and cache_params.has_previous_state(self.layer_idx) and seq_len > 1 # getting projected states from cache if it exists - if use_precomputed_states: + if use_precomputed_states or use_cached_chunk: conv_state = cache_params.layers[self.layer_idx].conv_states recurrent_state = cache_params.layers[self.layer_idx].recurrent_states @@ -459,9 +465,13 @@ def forward( self.activation, ) else: + if use_cached_chunk: + # Prepend the cached conv context so the causal conv sees the correct left-context + # when continuing a prior forward rather than starting from zero-padding. + mixed_qkv = torch.cat([conv_state, mixed_qkv], dim=-1) if cache_params is not None: - conv_state = F.pad(mixed_qkv, (self.conv_kernel_size - mixed_qkv.shape[-1], 0)) - conv_state = cache_params.update_conv_state(conv_state, self.layer_idx) + new_conv_state = F.pad(mixed_qkv, (self.conv_kernel_size - mixed_qkv.shape[-1], 0)) + cache_params.update_conv_state(new_conv_state, self.layer_idx) if self.causal_conv1d_fn is not None: mixed_qkv = self.causal_conv1d_fn( x=mixed_qkv, @@ -471,7 +481,10 @@ def forward( seq_idx=None, ) else: - mixed_qkv = F.silu(self.conv1d(mixed_qkv)[:, :, :seq_len]) + mixed_qkv = F.silu(self.conv1d(mixed_qkv)[:, :, : mixed_qkv.shape[-1]]) + if use_cached_chunk: + # Drop the prepended context; only this chunk's outputs remain. + mixed_qkv = mixed_qkv[:, :, -seq_len:] mixed_qkv = mixed_qkv.transpose(1, 2) query, key, value = torch.split( @@ -502,7 +515,7 @@ def forward( value, g=g, beta=beta, - initial_state=None, + initial_state=recurrent_state if use_cached_chunk else None, output_final_state=cache_params is not None, use_qk_l2norm_in_kernel=True, ) diff --git a/src/transformers/models/qwen3_5/modular_qwen3_5.py b/src/transformers/models/qwen3_5/modular_qwen3_5.py index f159901bec15..c8071d565a46 100644 --- a/src/transformers/models/qwen3_5/modular_qwen3_5.py +++ b/src/transformers/models/qwen3_5/modular_qwen3_5.py @@ -218,9 +218,15 @@ def forward( use_precomputed_states = ( cache_params is not None and cache_params.has_previous_state(self.layer_idx) and seq_len == 1 ) + # Multi-token forward after a prior forward has populated the cache (e.g. chunked prefill + # continuation or speculative verification). The single-token fused kernel can't be used + # here, so we fall into the chunked kernel and must carry forward both the cached conv + # context and the cached recurrent state; otherwise the first-position outputs silently + # diverge from the equivalent single-token decode. + use_cached_chunk = cache_params is not None and cache_params.has_previous_state(self.layer_idx) and seq_len > 1 # getting projected states from cache if it exists - if use_precomputed_states: + if use_precomputed_states or use_cached_chunk: conv_state = cache_params.layers[self.layer_idx].conv_states recurrent_state = cache_params.layers[self.layer_idx].recurrent_states @@ -244,9 +250,13 @@ def forward( self.activation, ) else: + if use_cached_chunk: + # Prepend the cached conv context so the causal conv sees the correct left-context + # when continuing a prior forward rather than starting from zero-padding. + mixed_qkv = torch.cat([conv_state, mixed_qkv], dim=-1) if cache_params is not None: - conv_state = F.pad(mixed_qkv, (self.conv_kernel_size - mixed_qkv.shape[-1], 0)) - conv_state = cache_params.update_conv_state(conv_state, self.layer_idx) + new_conv_state = F.pad(mixed_qkv, (self.conv_kernel_size - mixed_qkv.shape[-1], 0)) + cache_params.update_conv_state(new_conv_state, self.layer_idx) if self.causal_conv1d_fn is not None: mixed_qkv = self.causal_conv1d_fn( x=mixed_qkv, @@ -256,7 +266,10 @@ def forward( seq_idx=None, ) else: - mixed_qkv = F.silu(self.conv1d(mixed_qkv)[:, :, :seq_len]) + mixed_qkv = F.silu(self.conv1d(mixed_qkv)[:, :, : mixed_qkv.shape[-1]]) + if use_cached_chunk: + # Drop the prepended context; only this chunk's outputs remain. + mixed_qkv = mixed_qkv[:, :, -seq_len:] mixed_qkv = mixed_qkv.transpose(1, 2) query, key, value = torch.split( @@ -287,7 +300,7 @@ def forward( value, g=g, beta=beta, - initial_state=None, + initial_state=recurrent_state if use_cached_chunk else None, output_final_state=cache_params is not None, use_qk_l2norm_in_kernel=True, ) diff --git a/src/transformers/models/qwen3_5_moe/modeling_qwen3_5_moe.py b/src/transformers/models/qwen3_5_moe/modeling_qwen3_5_moe.py index 0b2a6a06aa85..8979712bb373 100644 --- a/src/transformers/models/qwen3_5_moe/modeling_qwen3_5_moe.py +++ b/src/transformers/models/qwen3_5_moe/modeling_qwen3_5_moe.py @@ -434,9 +434,15 @@ def forward( use_precomputed_states = ( cache_params is not None and cache_params.has_previous_state(self.layer_idx) and seq_len == 1 ) + # Multi-token forward after a prior forward has populated the cache (e.g. chunked prefill + # continuation or speculative verification). The single-token fused kernel can't be used + # here, so we fall into the chunked kernel and must carry forward both the cached conv + # context and the cached recurrent state; otherwise the first-position outputs silently + # diverge from the equivalent single-token decode. + use_cached_chunk = cache_params is not None and cache_params.has_previous_state(self.layer_idx) and seq_len > 1 # getting projected states from cache if it exists - if use_precomputed_states: + if use_precomputed_states or use_cached_chunk: conv_state = cache_params.layers[self.layer_idx].conv_states recurrent_state = cache_params.layers[self.layer_idx].recurrent_states @@ -460,9 +466,13 @@ def forward( self.activation, ) else: + if use_cached_chunk: + # Prepend the cached conv context so the causal conv sees the correct left-context + # when continuing a prior forward rather than starting from zero-padding. + mixed_qkv = torch.cat([conv_state, mixed_qkv], dim=-1) if cache_params is not None: - conv_state = F.pad(mixed_qkv, (self.conv_kernel_size - mixed_qkv.shape[-1], 0)) - conv_state = cache_params.update_conv_state(conv_state, self.layer_idx) + new_conv_state = F.pad(mixed_qkv, (self.conv_kernel_size - mixed_qkv.shape[-1], 0)) + cache_params.update_conv_state(new_conv_state, self.layer_idx) if self.causal_conv1d_fn is not None: mixed_qkv = self.causal_conv1d_fn( x=mixed_qkv, @@ -472,7 +482,10 @@ def forward( seq_idx=None, ) else: - mixed_qkv = F.silu(self.conv1d(mixed_qkv)[:, :, :seq_len]) + mixed_qkv = F.silu(self.conv1d(mixed_qkv)[:, :, : mixed_qkv.shape[-1]]) + if use_cached_chunk: + # Drop the prepended context; only this chunk's outputs remain. + mixed_qkv = mixed_qkv[:, :, -seq_len:] mixed_qkv = mixed_qkv.transpose(1, 2) query, key, value = torch.split( @@ -503,7 +516,7 @@ def forward( value, g=g, beta=beta, - initial_state=None, + initial_state=recurrent_state if use_cached_chunk else None, output_final_state=cache_params is not None, use_qk_l2norm_in_kernel=True, ) diff --git a/tests/models/qwen3_5/test_modeling_qwen3_5.py b/tests/models/qwen3_5/test_modeling_qwen3_5.py index 7725d2891a33..668a4e513970 100644 --- a/tests/models/qwen3_5/test_modeling_qwen3_5.py +++ b/tests/models/qwen3_5/test_modeling_qwen3_5.py @@ -38,6 +38,7 @@ import torch from transformers import ( + DynamicCache, Qwen3_5Config, Qwen3_5ForCausalLM, Qwen3_5ForConditionalGeneration, @@ -142,6 +143,46 @@ def test_multi_gpu_data_parallel_forward(self): def test_reverse_loading_mapping(self, check_keys_were_modified=True): pass + def test_linear_attention_multi_token_cached_forward_matches_single_token(self): + """ + Qwen3.5's gated-delta-net layers must produce the same output for a token regardless of + whether it's fed as a single-token cached forward or as the first token of a multi-token + chunk after the cache has been populated (chunked prefill continuation / speculative + verification pattern). A causal LM can never have its logits at position `i` depend on + tokens at positions > `i`, even across separate forward calls with a shared cache. + """ + config, _ = self.model_tester.prepare_config_and_inputs_for_common() + config._attn_implementation = "eager" + # GatedDeltaNet's fused norm-gate kernel only supports silu/swish/sigmoid; the shared + # tester's default `gelu` would raise before we get to exercise the cache path. + config.hidden_act = "silu" + model = Qwen3_5TextModel._from_config(config) + model.to(torch_device) + model.eval() + + prefill_len = 8 + prompt = ids_tensor((1, prefill_len), config.vocab_size).to(torch_device) + next_token = ids_tensor((1, 1), config.vocab_size).to(torch_device) + + # Reference: prefill, then forward the next token alone with the populated cache. + cache_single = DynamicCache(config=config) + with torch.no_grad(): + model(input_ids=prompt, past_key_values=cache_single, use_cache=True) + single_out = model(input_ids=next_token, past_key_values=cache_single, use_cache=True) + ref_first = single_out.last_hidden_state[:, 0, :] + + # Under test: prefill, then forward [next_token, *distractors] in one call. The first + # position must match the single-token forward exactly (causal attention property). + distractors = ids_tensor((1, 7), config.vocab_size).to(torch_device) + multi_input = torch.cat([next_token, distractors], dim=1) + cache_multi = DynamicCache(config=config) + with torch.no_grad(): + model(input_ids=prompt, past_key_values=cache_multi, use_cache=True) + multi_out = model(input_ids=multi_input, past_key_values=cache_multi, use_cache=True) + under_test_first = multi_out.last_hidden_state[:, 0, :] + + torch.testing.assert_close(under_test_first, ref_first, rtol=1e-4, atol=1e-4) + class Qwen3_5VisionText2TextModelTester: def __init__( From e395f5ac488b61db67bc2eba4c83ef5f7c069f82 Mon Sep 17 00:00:00 2001 From: Kashif Rasul Date: Thu, 23 Apr 2026 13:30:45 +0000 Subject: [PATCH 2/4] Review feedback: unify cached-forward state flag, gate single-token/chunk locally MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Replace the two `use_precomputed_states` / `use_cached_chunk` variables with a single `use_precomputed_states = cache_params is not None and cache_params.has_previous_state(...)` that just signals "we have cached conv/recurrent state to continue from". The split between the single-token (fused per-step) and chunk-tokens (chunk kernel + cached conv context) modes is now expressed locally via `seq_len == 1` checks at the three places where it actually matters — kernel dispatch, conv-context prepend, and prepend-drop slice — as requested in review. Behavior is unchanged; this is pure restructuring for clarity. `modeling_qwen3_5.py` and `modeling_qwen3_5_moe.py` regenerated via `check_modular_conversion.py --fix_and_overwrite`. --- .../models/qwen3_5/modeling_qwen3_5.py | 40 +++++++++---------- .../models/qwen3_5/modular_qwen3_5.py | 40 +++++++++---------- .../qwen3_5_moe/modeling_qwen3_5_moe.py | 40 +++++++++---------- 3 files changed, 54 insertions(+), 66 deletions(-) diff --git a/src/transformers/models/qwen3_5/modeling_qwen3_5.py b/src/transformers/models/qwen3_5/modeling_qwen3_5.py index 3c68fe17a270..d001073751da 100644 --- a/src/transformers/models/qwen3_5/modeling_qwen3_5.py +++ b/src/transformers/models/qwen3_5/modeling_qwen3_5.py @@ -430,18 +430,14 @@ def forward( # Set up dimensions for reshapes later batch_size, seq_len, _ = hidden_states.shape - use_precomputed_states = ( - cache_params is not None and cache_params.has_previous_state(self.layer_idx) and seq_len == 1 - ) - # Multi-token forward after a prior forward has populated the cache (e.g. chunked prefill - # continuation or speculative verification). The single-token fused kernel can't be used - # here, so we fall into the chunked kernel and must carry forward both the cached conv - # context and the cached recurrent state; otherwise the first-position outputs silently - # diverge from the equivalent single-token decode. - use_cached_chunk = cache_params is not None and cache_params.has_previous_state(self.layer_idx) and seq_len > 1 + # We have cached `conv_state` / `recurrent_state` to continue from. The two cached modes + # (single-token decode and chunk-tokens continuation) share the state read here; they only + # diverge in how the conv input is assembled and which kernel consumes the states below, + # which we gate locally on `seq_len`. + use_precomputed_states = cache_params is not None and cache_params.has_previous_state(self.layer_idx) # getting projected states from cache if it exists - if use_precomputed_states or use_cached_chunk: + if use_precomputed_states: conv_state = cache_params.layers[self.layer_idx].conv_states recurrent_state = cache_params.layers[self.layer_idx].recurrent_states @@ -454,9 +450,8 @@ def forward( b = self.in_proj_b(hidden_states) a = self.in_proj_a(hidden_states) - if use_precomputed_states: - # 2. Convolution sequence transformation - # NOTE: the conv state is updated in `causal_conv1d_update` + if use_precomputed_states and seq_len == 1: + # Single-token cached decode: the fused per-step kernel updates the conv state in-place. mixed_qkv = self.causal_conv1d_update( mixed_qkv, conv_state, @@ -465,9 +460,10 @@ def forward( self.activation, ) else: - if use_cached_chunk: - # Prepend the cached conv context so the causal conv sees the correct left-context - # when continuing a prior forward rather than starting from zero-padding. + if use_precomputed_states: + # Multi-token cached continuation (`seq_len > 1`, cache already populated). Prepend + # the cached conv context so the causal conv sees the correct left-context rather + # than zero-padding. mixed_qkv = torch.cat([conv_state, mixed_qkv], dim=-1) if cache_params is not None: new_conv_state = F.pad(mixed_qkv, (self.conv_kernel_size - mixed_qkv.shape[-1], 0)) @@ -482,7 +478,7 @@ def forward( ) else: mixed_qkv = F.silu(self.conv1d(mixed_qkv)[:, :, : mixed_qkv.shape[-1]]) - if use_cached_chunk: + if use_precomputed_states: # Drop the prepended context; only this chunk's outputs remain. mixed_qkv = mixed_qkv[:, :, -seq_len:] @@ -508,26 +504,26 @@ def forward( query = query.repeat_interleave(self.num_v_heads // self.num_k_heads, dim=2) key = key.repeat_interleave(self.num_v_heads // self.num_k_heads, dim=2) - if not use_precomputed_states: - core_attn_out, last_recurrent_state = self.chunk_gated_delta_rule( + if use_precomputed_states and seq_len == 1: + core_attn_out, last_recurrent_state = self.recurrent_gated_delta_rule( query, key, value, g=g, beta=beta, - initial_state=recurrent_state if use_cached_chunk else None, + initial_state=recurrent_state, output_final_state=cache_params is not None, use_qk_l2norm_in_kernel=True, ) else: - core_attn_out, last_recurrent_state = self.recurrent_gated_delta_rule( + core_attn_out, last_recurrent_state = self.chunk_gated_delta_rule( query, key, value, g=g, beta=beta, - initial_state=recurrent_state, + initial_state=recurrent_state if use_precomputed_states else None, output_final_state=cache_params is not None, use_qk_l2norm_in_kernel=True, ) diff --git a/src/transformers/models/qwen3_5/modular_qwen3_5.py b/src/transformers/models/qwen3_5/modular_qwen3_5.py index c8071d565a46..94ce3fc55d6a 100644 --- a/src/transformers/models/qwen3_5/modular_qwen3_5.py +++ b/src/transformers/models/qwen3_5/modular_qwen3_5.py @@ -215,18 +215,14 @@ def forward( # Set up dimensions for reshapes later batch_size, seq_len, _ = hidden_states.shape - use_precomputed_states = ( - cache_params is not None and cache_params.has_previous_state(self.layer_idx) and seq_len == 1 - ) - # Multi-token forward after a prior forward has populated the cache (e.g. chunked prefill - # continuation or speculative verification). The single-token fused kernel can't be used - # here, so we fall into the chunked kernel and must carry forward both the cached conv - # context and the cached recurrent state; otherwise the first-position outputs silently - # diverge from the equivalent single-token decode. - use_cached_chunk = cache_params is not None and cache_params.has_previous_state(self.layer_idx) and seq_len > 1 + # We have cached `conv_state` / `recurrent_state` to continue from. The two cached modes + # (single-token decode and chunk-tokens continuation) share the state read here; they only + # diverge in how the conv input is assembled and which kernel consumes the states below, + # which we gate locally on `seq_len`. + use_precomputed_states = cache_params is not None and cache_params.has_previous_state(self.layer_idx) # getting projected states from cache if it exists - if use_precomputed_states or use_cached_chunk: + if use_precomputed_states: conv_state = cache_params.layers[self.layer_idx].conv_states recurrent_state = cache_params.layers[self.layer_idx].recurrent_states @@ -239,9 +235,8 @@ def forward( b = self.in_proj_b(hidden_states) a = self.in_proj_a(hidden_states) - if use_precomputed_states: - # 2. Convolution sequence transformation - # NOTE: the conv state is updated in `causal_conv1d_update` + if use_precomputed_states and seq_len == 1: + # Single-token cached decode: the fused per-step kernel updates the conv state in-place. mixed_qkv = self.causal_conv1d_update( mixed_qkv, conv_state, @@ -250,9 +245,10 @@ def forward( self.activation, ) else: - if use_cached_chunk: - # Prepend the cached conv context so the causal conv sees the correct left-context - # when continuing a prior forward rather than starting from zero-padding. + if use_precomputed_states: + # Multi-token cached continuation (`seq_len > 1`, cache already populated). Prepend + # the cached conv context so the causal conv sees the correct left-context rather + # than zero-padding. mixed_qkv = torch.cat([conv_state, mixed_qkv], dim=-1) if cache_params is not None: new_conv_state = F.pad(mixed_qkv, (self.conv_kernel_size - mixed_qkv.shape[-1], 0)) @@ -267,7 +263,7 @@ def forward( ) else: mixed_qkv = F.silu(self.conv1d(mixed_qkv)[:, :, : mixed_qkv.shape[-1]]) - if use_cached_chunk: + if use_precomputed_states: # Drop the prepended context; only this chunk's outputs remain. mixed_qkv = mixed_qkv[:, :, -seq_len:] @@ -293,26 +289,26 @@ def forward( query = query.repeat_interleave(self.num_v_heads // self.num_k_heads, dim=2) key = key.repeat_interleave(self.num_v_heads // self.num_k_heads, dim=2) - if not use_precomputed_states: - core_attn_out, last_recurrent_state = self.chunk_gated_delta_rule( + if use_precomputed_states and seq_len == 1: + core_attn_out, last_recurrent_state = self.recurrent_gated_delta_rule( query, key, value, g=g, beta=beta, - initial_state=recurrent_state if use_cached_chunk else None, + initial_state=recurrent_state, output_final_state=cache_params is not None, use_qk_l2norm_in_kernel=True, ) else: - core_attn_out, last_recurrent_state = self.recurrent_gated_delta_rule( + core_attn_out, last_recurrent_state = self.chunk_gated_delta_rule( query, key, value, g=g, beta=beta, - initial_state=recurrent_state, + initial_state=recurrent_state if use_precomputed_states else None, output_final_state=cache_params is not None, use_qk_l2norm_in_kernel=True, ) diff --git a/src/transformers/models/qwen3_5_moe/modeling_qwen3_5_moe.py b/src/transformers/models/qwen3_5_moe/modeling_qwen3_5_moe.py index 8979712bb373..452bf85c003a 100644 --- a/src/transformers/models/qwen3_5_moe/modeling_qwen3_5_moe.py +++ b/src/transformers/models/qwen3_5_moe/modeling_qwen3_5_moe.py @@ -431,18 +431,14 @@ def forward( # Set up dimensions for reshapes later batch_size, seq_len, _ = hidden_states.shape - use_precomputed_states = ( - cache_params is not None and cache_params.has_previous_state(self.layer_idx) and seq_len == 1 - ) - # Multi-token forward after a prior forward has populated the cache (e.g. chunked prefill - # continuation or speculative verification). The single-token fused kernel can't be used - # here, so we fall into the chunked kernel and must carry forward both the cached conv - # context and the cached recurrent state; otherwise the first-position outputs silently - # diverge from the equivalent single-token decode. - use_cached_chunk = cache_params is not None and cache_params.has_previous_state(self.layer_idx) and seq_len > 1 + # We have cached `conv_state` / `recurrent_state` to continue from. The two cached modes + # (single-token decode and chunk-tokens continuation) share the state read here; they only + # diverge in how the conv input is assembled and which kernel consumes the states below, + # which we gate locally on `seq_len`. + use_precomputed_states = cache_params is not None and cache_params.has_previous_state(self.layer_idx) # getting projected states from cache if it exists - if use_precomputed_states or use_cached_chunk: + if use_precomputed_states: conv_state = cache_params.layers[self.layer_idx].conv_states recurrent_state = cache_params.layers[self.layer_idx].recurrent_states @@ -455,9 +451,8 @@ def forward( b = self.in_proj_b(hidden_states) a = self.in_proj_a(hidden_states) - if use_precomputed_states: - # 2. Convolution sequence transformation - # NOTE: the conv state is updated in `causal_conv1d_update` + if use_precomputed_states and seq_len == 1: + # Single-token cached decode: the fused per-step kernel updates the conv state in-place. mixed_qkv = self.causal_conv1d_update( mixed_qkv, conv_state, @@ -466,9 +461,10 @@ def forward( self.activation, ) else: - if use_cached_chunk: - # Prepend the cached conv context so the causal conv sees the correct left-context - # when continuing a prior forward rather than starting from zero-padding. + if use_precomputed_states: + # Multi-token cached continuation (`seq_len > 1`, cache already populated). Prepend + # the cached conv context so the causal conv sees the correct left-context rather + # than zero-padding. mixed_qkv = torch.cat([conv_state, mixed_qkv], dim=-1) if cache_params is not None: new_conv_state = F.pad(mixed_qkv, (self.conv_kernel_size - mixed_qkv.shape[-1], 0)) @@ -483,7 +479,7 @@ def forward( ) else: mixed_qkv = F.silu(self.conv1d(mixed_qkv)[:, :, : mixed_qkv.shape[-1]]) - if use_cached_chunk: + if use_precomputed_states: # Drop the prepended context; only this chunk's outputs remain. mixed_qkv = mixed_qkv[:, :, -seq_len:] @@ -509,26 +505,26 @@ def forward( query = query.repeat_interleave(self.num_v_heads // self.num_k_heads, dim=2) key = key.repeat_interleave(self.num_v_heads // self.num_k_heads, dim=2) - if not use_precomputed_states: - core_attn_out, last_recurrent_state = self.chunk_gated_delta_rule( + if use_precomputed_states and seq_len == 1: + core_attn_out, last_recurrent_state = self.recurrent_gated_delta_rule( query, key, value, g=g, beta=beta, - initial_state=recurrent_state if use_cached_chunk else None, + initial_state=recurrent_state, output_final_state=cache_params is not None, use_qk_l2norm_in_kernel=True, ) else: - core_attn_out, last_recurrent_state = self.recurrent_gated_delta_rule( + core_attn_out, last_recurrent_state = self.chunk_gated_delta_rule( query, key, value, g=g, beta=beta, - initial_state=recurrent_state, + initial_state=recurrent_state if use_precomputed_states else None, output_final_state=cache_params is not None, use_qk_l2norm_in_kernel=True, ) From 29a76b360b26cd1bc39aa61e5fe9242cf5190b5f Mon Sep 17 00:00:00 2001 From: Kashif Rasul Date: Thu, 23 Apr 2026 16:33:29 +0000 Subject: [PATCH 3/4] Propagate linear-attention multi-token cached-forward fix to qwen3_next and olmo_hybrid MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit qwen3_next's GatedDeltaNet had the same bug as qwen3_5: for a multi-token forward after the cache was populated (chunked-prefill continuation or speculative-decoding verification), the chunk kernel was called with `initial_state=None` and the conv state was zero-padded, silently dropping the cached state. Applies the same pattern: unified `use_precomputed_states` flag (no seq_len condition), with single-token vs chunk-tokens routing gated locally on `seq_len == 1`. olmo_hybrid has the same kind of bug in its custom `OlmoHybridShortConvolution` — when the caller had cached state but fed more than one token (`use_precomputed=False` under the old `seq_len==1` gate), the conv was zero-padded instead of using the cached context. Fix the caller to drop the `seq_len==1` gate from `use_precomputed`, make the ShortConvolution branch on `seq_len==1` locally, and route recurrent vs chunk kernel dispatch in the caller likewise. olmo_hybrid's chunk kernel path already passed `initial_state=recurrent_state` correctly; no change needed there beyond the `use_precomputed` flag semantics. Add the same causal-LM invariance test to both suites (`test_linear_attention_multi_token_cached_forward_matches_single_token`). 100% element mismatch on the old code, passes after the fix. RUN_SLOW integration suites for both models pass end-to-end. `modeling_qwen3_next.py` and `modeling_olmo_hybrid.py` regenerated via `check_modular_conversion.py --fix_and_overwrite`. --- .../olmo_hybrid/modeling_olmo_hybrid.py | 23 ++++++---- .../models/olmo_hybrid/modular_olmo_hybrid.py | 23 ++++++---- .../models/qwen3_next/modeling_qwen3_next.py | 37 +++++++++------- .../models/qwen3_next/modular_qwen3_next.py | 37 +++++++++------- .../olmo_hybrid/test_modeling_olmo_hybrid.py | 35 ++++++++++++++++ .../qwen3_next/test_modeling_qwen3_next.py | 42 +++++++++++++++++++ 6 files changed, 155 insertions(+), 42 deletions(-) diff --git a/src/transformers/models/olmo_hybrid/modeling_olmo_hybrid.py b/src/transformers/models/olmo_hybrid/modeling_olmo_hybrid.py index 8ecb44916a57..75c23f07543d 100644 --- a/src/transformers/models/olmo_hybrid/modeling_olmo_hybrid.py +++ b/src/transformers/models/olmo_hybrid/modeling_olmo_hybrid.py @@ -229,8 +229,8 @@ def forward( hidden_states = hidden_states.transpose(1, 2) - if use_precomputed: - # Single token update (decoding mode) + if use_precomputed and seq_len == 1: + # Single-token decode: rolling-window update against the cached context. x_with_state = torch.cat([cache, hidden_states], dim=-1) out = F.conv1d( x_with_state, @@ -241,10 +241,16 @@ def forward( ) conv_state = x_with_state[:, :, 1:] else: - # Multi-token forward (prefill mode) + if use_precomputed: + # Multi-token cached continuation (`seq_len > 1`, cache already populated): prepend + # the cached conv context so the causal conv sees the correct left-context rather + # than zero-padding. Drop the prepended region from the output below. + hidden_states = torch.cat([cache, hidden_states], dim=-1) out = F.conv1d(hidden_states, self.weight, self.bias, padding=self.conv_kernel_size - 1, groups=dim) - out = out[:, :, :seq_len] + out = out[:, :, : hidden_states.shape[-1]] conv_state = F.pad(hidden_states, (self.conv_kernel_size - 1 - hidden_states.shape[-1], 0)) + if use_precomputed: + out = out[:, :, -seq_len:] out = self.act_fn(out) @@ -723,7 +729,10 @@ def forward( batch_size, seq_len, _ = hidden_states.shape use_cache = cache_params is not None - use_precomputed = use_cache and cache_params.has_previous_state() and seq_len == 1 + # Reads "we have cached conv/recurrent state to continue from". Single-token vs multi-token + # branching lives inside `ShortConvolution` and in the recurrent-vs-chunk kernel dispatch + # below, each of which gates on `seq_len == 1` locally. + use_precomputed = use_cache and cache_params.has_previous_state() conv_state_q = cache_params.conv_states_q[self.layer_idx] if cache_params else None conv_state_k = cache_params.conv_states_k[self.layer_idx] if cache_params else None @@ -764,7 +773,7 @@ def forward( g = -self.A_log.float().exp() * F.softplus(self.a_proj(hidden_states).float() + self.dt_bias) - if use_precomputed: + if use_precomputed and seq_len == 1: output, new_recurrent_state = self.recurrent_gated_delta_rule( q, k, @@ -782,7 +791,7 @@ def forward( v, g=g, beta=beta, - initial_state=recurrent_state, + initial_state=recurrent_state if use_precomputed else None, output_final_state=use_cache, use_qk_l2norm_in_kernel=True, ) diff --git a/src/transformers/models/olmo_hybrid/modular_olmo_hybrid.py b/src/transformers/models/olmo_hybrid/modular_olmo_hybrid.py index 089f29309007..ccefed107d4d 100644 --- a/src/transformers/models/olmo_hybrid/modular_olmo_hybrid.py +++ b/src/transformers/models/olmo_hybrid/modular_olmo_hybrid.py @@ -330,8 +330,8 @@ def forward( hidden_states = hidden_states.transpose(1, 2) - if use_precomputed: - # Single token update (decoding mode) + if use_precomputed and seq_len == 1: + # Single-token decode: rolling-window update against the cached context. x_with_state = torch.cat([cache, hidden_states], dim=-1) out = F.conv1d( x_with_state, @@ -342,10 +342,16 @@ def forward( ) conv_state = x_with_state[:, :, 1:] else: - # Multi-token forward (prefill mode) + if use_precomputed: + # Multi-token cached continuation (`seq_len > 1`, cache already populated): prepend + # the cached conv context so the causal conv sees the correct left-context rather + # than zero-padding. Drop the prepended region from the output below. + hidden_states = torch.cat([cache, hidden_states], dim=-1) out = F.conv1d(hidden_states, self.weight, self.bias, padding=self.conv_kernel_size - 1, groups=dim) - out = out[:, :, :seq_len] + out = out[:, :, : hidden_states.shape[-1]] conv_state = F.pad(hidden_states, (self.conv_kernel_size - 1 - hidden_states.shape[-1], 0)) + if use_precomputed: + out = out[:, :, -seq_len:] out = self.act_fn(out) @@ -541,7 +547,10 @@ def forward( batch_size, seq_len, _ = hidden_states.shape use_cache = cache_params is not None - use_precomputed = use_cache and cache_params.has_previous_state() and seq_len == 1 + # Reads "we have cached conv/recurrent state to continue from". Single-token vs multi-token + # branching lives inside `ShortConvolution` and in the recurrent-vs-chunk kernel dispatch + # below, each of which gates on `seq_len == 1` locally. + use_precomputed = use_cache and cache_params.has_previous_state() conv_state_q = cache_params.conv_states_q[self.layer_idx] if cache_params else None conv_state_k = cache_params.conv_states_k[self.layer_idx] if cache_params else None @@ -582,7 +591,7 @@ def forward( g = -self.A_log.float().exp() * F.softplus(self.a_proj(hidden_states).float() + self.dt_bias) - if use_precomputed: + if use_precomputed and seq_len == 1: output, new_recurrent_state = self.recurrent_gated_delta_rule( q, k, @@ -600,7 +609,7 @@ def forward( v, g=g, beta=beta, - initial_state=recurrent_state, + initial_state=recurrent_state if use_precomputed else None, output_final_state=use_cache, use_qk_l2norm_in_kernel=True, ) diff --git a/src/transformers/models/qwen3_next/modeling_qwen3_next.py b/src/transformers/models/qwen3_next/modeling_qwen3_next.py index cd152e3d3e59..05f1359001a0 100644 --- a/src/transformers/models/qwen3_next/modeling_qwen3_next.py +++ b/src/transformers/models/qwen3_next/modeling_qwen3_next.py @@ -600,9 +600,11 @@ def forward( # Set up dimensions for reshapes later batch_size, seq_len, _ = hidden_states.shape - use_precomputed_states = ( - cache_params is not None and cache_params.has_previous_state(self.layer_idx) and seq_len == 1 - ) + # We have cached `conv_state` / `recurrent_state` to continue from. The two cached modes + # (single-token decode and chunk-tokens continuation) share the state read here; they only + # diverge in how the conv input is assembled and which kernel consumes the states below, + # which we gate locally on `seq_len`. + use_precomputed_states = cache_params is not None and cache_params.has_previous_state(self.layer_idx) # getting projected states from cache if it exists if use_precomputed_states: @@ -617,9 +619,8 @@ def forward( mixed_qkv = torch.cat((query, key, value), dim=-1) mixed_qkv = mixed_qkv.transpose(1, 2) - if use_precomputed_states: - # 2. Convolution sequence transformation - # NOTE: the conv state is updated in `causal_conv1d_update` + if use_precomputed_states and seq_len == 1: + # Single-token cached decode: the fused per-step kernel updates the conv state in-place. mixed_qkv = self.causal_conv1d_update( mixed_qkv, conv_state, @@ -628,9 +629,14 @@ def forward( self.activation, ) else: + if use_precomputed_states: + # Multi-token cached continuation (`seq_len > 1`, cache already populated). Prepend + # the cached conv context so the causal conv sees the correct left-context rather + # than zero-padding. + mixed_qkv = torch.cat([conv_state, mixed_qkv], dim=-1) if cache_params is not None: - conv_state = F.pad(mixed_qkv, (self.conv_kernel_size - mixed_qkv.shape[-1], 0)) - conv_state = cache_params.update_conv_state(conv_state, self.layer_idx) + new_conv_state = F.pad(mixed_qkv, (self.conv_kernel_size - mixed_qkv.shape[-1], 0)) + cache_params.update_conv_state(new_conv_state, self.layer_idx) if self.causal_conv1d_fn is not None: mixed_qkv = self.causal_conv1d_fn( x=mixed_qkv, @@ -640,7 +646,10 @@ def forward( seq_idx=None, ) else: - mixed_qkv = F.silu(self.conv1d(mixed_qkv)[:, :, :seq_len]) + mixed_qkv = F.silu(self.conv1d(mixed_qkv)[:, :, : mixed_qkv.shape[-1]]) + if use_precomputed_states: + # Drop the prepended context; only this chunk's outputs remain. + mixed_qkv = mixed_qkv[:, :, -seq_len:] mixed_qkv = mixed_qkv.transpose(1, 2) query, key, value = torch.split( @@ -663,25 +672,25 @@ def forward( query = query.repeat_interleave(self.num_v_heads // self.num_k_heads, dim=2) key = key.repeat_interleave(self.num_v_heads // self.num_k_heads, dim=2) - if not use_precomputed_states: - core_attn_out, last_recurrent_state = self.chunk_gated_delta_rule( + if use_precomputed_states and seq_len == 1: + core_attn_out, last_recurrent_state = self.recurrent_gated_delta_rule( query, key, value, g=g, beta=beta, - initial_state=None, + initial_state=recurrent_state, output_final_state=cache_params is not None, use_qk_l2norm_in_kernel=True, ) else: - core_attn_out, last_recurrent_state = self.recurrent_gated_delta_rule( + core_attn_out, last_recurrent_state = self.chunk_gated_delta_rule( query, key, value, g=g, beta=beta, - initial_state=recurrent_state, + initial_state=recurrent_state if use_precomputed_states else None, output_final_state=cache_params is not None, use_qk_l2norm_in_kernel=True, ) diff --git a/src/transformers/models/qwen3_next/modular_qwen3_next.py b/src/transformers/models/qwen3_next/modular_qwen3_next.py index 392d8325fe22..38686b926e08 100644 --- a/src/transformers/models/qwen3_next/modular_qwen3_next.py +++ b/src/transformers/models/qwen3_next/modular_qwen3_next.py @@ -439,9 +439,11 @@ def forward( # Set up dimensions for reshapes later batch_size, seq_len, _ = hidden_states.shape - use_precomputed_states = ( - cache_params is not None and cache_params.has_previous_state(self.layer_idx) and seq_len == 1 - ) + # We have cached `conv_state` / `recurrent_state` to continue from. The two cached modes + # (single-token decode and chunk-tokens continuation) share the state read here; they only + # diverge in how the conv input is assembled and which kernel consumes the states below, + # which we gate locally on `seq_len`. + use_precomputed_states = cache_params is not None and cache_params.has_previous_state(self.layer_idx) # getting projected states from cache if it exists if use_precomputed_states: @@ -456,9 +458,8 @@ def forward( mixed_qkv = torch.cat((query, key, value), dim=-1) mixed_qkv = mixed_qkv.transpose(1, 2) - if use_precomputed_states: - # 2. Convolution sequence transformation - # NOTE: the conv state is updated in `causal_conv1d_update` + if use_precomputed_states and seq_len == 1: + # Single-token cached decode: the fused per-step kernel updates the conv state in-place. mixed_qkv = self.causal_conv1d_update( mixed_qkv, conv_state, @@ -467,9 +468,14 @@ def forward( self.activation, ) else: + if use_precomputed_states: + # Multi-token cached continuation (`seq_len > 1`, cache already populated). Prepend + # the cached conv context so the causal conv sees the correct left-context rather + # than zero-padding. + mixed_qkv = torch.cat([conv_state, mixed_qkv], dim=-1) if cache_params is not None: - conv_state = F.pad(mixed_qkv, (self.conv_kernel_size - mixed_qkv.shape[-1], 0)) - conv_state = cache_params.update_conv_state(conv_state, self.layer_idx) + new_conv_state = F.pad(mixed_qkv, (self.conv_kernel_size - mixed_qkv.shape[-1], 0)) + cache_params.update_conv_state(new_conv_state, self.layer_idx) if self.causal_conv1d_fn is not None: mixed_qkv = self.causal_conv1d_fn( x=mixed_qkv, @@ -479,7 +485,10 @@ def forward( seq_idx=None, ) else: - mixed_qkv = F.silu(self.conv1d(mixed_qkv)[:, :, :seq_len]) + mixed_qkv = F.silu(self.conv1d(mixed_qkv)[:, :, : mixed_qkv.shape[-1]]) + if use_precomputed_states: + # Drop the prepended context; only this chunk's outputs remain. + mixed_qkv = mixed_qkv[:, :, -seq_len:] mixed_qkv = mixed_qkv.transpose(1, 2) query, key, value = torch.split( @@ -502,25 +511,25 @@ def forward( query = query.repeat_interleave(self.num_v_heads // self.num_k_heads, dim=2) key = key.repeat_interleave(self.num_v_heads // self.num_k_heads, dim=2) - if not use_precomputed_states: - core_attn_out, last_recurrent_state = self.chunk_gated_delta_rule( + if use_precomputed_states and seq_len == 1: + core_attn_out, last_recurrent_state = self.recurrent_gated_delta_rule( query, key, value, g=g, beta=beta, - initial_state=None, + initial_state=recurrent_state, output_final_state=cache_params is not None, use_qk_l2norm_in_kernel=True, ) else: - core_attn_out, last_recurrent_state = self.recurrent_gated_delta_rule( + core_attn_out, last_recurrent_state = self.chunk_gated_delta_rule( query, key, value, g=g, beta=beta, - initial_state=recurrent_state, + initial_state=recurrent_state if use_precomputed_states else None, output_final_state=cache_params is not None, use_qk_l2norm_in_kernel=True, ) diff --git a/tests/models/olmo_hybrid/test_modeling_olmo_hybrid.py b/tests/models/olmo_hybrid/test_modeling_olmo_hybrid.py index 37b2750a632d..56622d65407b 100644 --- a/tests/models/olmo_hybrid/test_modeling_olmo_hybrid.py +++ b/tests/models/olmo_hybrid/test_modeling_olmo_hybrid.py @@ -26,6 +26,7 @@ ) from ...causal_lm_tester import CausalLMModelTest, CausalLMModelTester +from ...test_modeling_common import ids_tensor if is_torch_available(): @@ -68,6 +69,40 @@ class OlmoHybridModelTest(CausalLMModelTest, unittest.TestCase): def test_tp_generation_quantized(self): pass + def test_linear_attention_multi_token_cached_forward_matches_single_token(self): + """ + OLMo-Hybrid's GatedDeltaNet layers must produce the same output for a token regardless of + whether it's fed as a single-token cached forward or as the first token of a multi-token chunk + after the cache has been populated (chunked-prefill continuation / speculative verification). + A causal LM's logits at position `i` cannot depend on tokens at positions > `i`, even across + separate forward calls with a shared cache. + """ + config, _ = self.model_tester.prepare_config_and_inputs_for_common() + config._attn_implementation = "eager" + model = OlmoHybridModel._from_config(config) + model.to(torch_device) + model.eval() + + prefill_len = 8 + prompt = ids_tensor((1, prefill_len), config.vocab_size).to(torch_device) + next_token = ids_tensor((1, 1), config.vocab_size).to(torch_device) + + cache_single = OlmoHybridDynamicCache(config=config) + with torch.no_grad(): + model(input_ids=prompt, past_key_values=cache_single, use_cache=True) + single_out = model(input_ids=next_token, past_key_values=cache_single, use_cache=True) + ref_first = single_out.last_hidden_state[:, 0, :] + + distractors = ids_tensor((1, 7), config.vocab_size).to(torch_device) + multi_input = torch.cat([next_token, distractors], dim=1) + cache_multi = OlmoHybridDynamicCache(config=config) + with torch.no_grad(): + model(input_ids=prompt, past_key_values=cache_multi, use_cache=True) + multi_out = model(input_ids=multi_input, past_key_values=cache_multi, use_cache=True) + under_test_first = multi_out.last_hidden_state[:, 0, :] + + torch.testing.assert_close(under_test_first, ref_first, rtol=1e-4, atol=1e-4) + # === Cache helper methods (same pattern as Qwen3Next) === def _check_past_key_values_for_generate(self, batch_size, past_key_values, seq_length, config): """OlmoHybrid has a special Cache as it alternates with gated deltanet layers""" diff --git a/tests/models/qwen3_next/test_modeling_qwen3_next.py b/tests/models/qwen3_next/test_modeling_qwen3_next.py index 4cb53fb6c695..5f65bf6e54c9 100644 --- a/tests/models/qwen3_next/test_modeling_qwen3_next.py +++ b/tests/models/qwen3_next/test_modeling_qwen3_next.py @@ -25,6 +25,7 @@ import torch from transformers import ( + DynamicCache, Qwen3NextModel, ) @@ -32,6 +33,7 @@ from ...test_modeling_common import ( TEST_EAGER_MATCHES_SDPA_INFERENCE_PARAMETERIZATION, _test_eager_matches_sdpa_inference, + ids_tensor, ) @@ -121,6 +123,46 @@ def test_attention_outputs(self): self.assertEqual(len(self_attentions), sum(layer == "full_attention" for layer in config.layer_types)) self.assertListEqual(list(self_attentions[0].shape[-3:]), [config.num_attention_heads, seq_len, seq_len]) + def test_linear_attention_multi_token_cached_forward_matches_single_token(self): + """ + Qwen3-Next's gated-delta-net layers must produce the same output for a token regardless of + whether it's fed as a single-token cached forward or as the first token of a multi-token chunk + after the cache has been populated (chunked-prefill continuation / speculative verification). + A causal LM's logits at position `i` cannot depend on tokens at positions > `i`, even across + separate forward calls with a shared cache. + """ + config, _ = self.model_tester.prepare_config_and_inputs_for_common() + config._attn_implementation = "eager" + # GatedDeltaNet's fused norm-gate kernel only supports silu/swish/sigmoid; the shared tester + # default `gelu` would raise before exercising the cache path. + config.hidden_act = "silu" + model = Qwen3NextModel._from_config(config) + model.to(torch_device) + model.eval() + + prefill_len = 8 + prompt = ids_tensor((1, prefill_len), config.vocab_size).to(torch_device) + next_token = ids_tensor((1, 1), config.vocab_size).to(torch_device) + + # Reference: prefill, then forward the next token alone with the populated cache. + cache_single = DynamicCache(config=config) + with torch.no_grad(): + model(input_ids=prompt, past_key_values=cache_single, use_cache=True) + single_out = model(input_ids=next_token, past_key_values=cache_single, use_cache=True) + ref_first = single_out.last_hidden_state[:, 0, :] + + # Under test: prefill, then forward [next_token, *distractors] in one call. The first + # position must match the single-token forward exactly (causal attention). + distractors = ids_tensor((1, 7), config.vocab_size).to(torch_device) + multi_input = torch.cat([next_token, distractors], dim=1) + cache_multi = DynamicCache(config=config) + with torch.no_grad(): + model(input_ids=prompt, past_key_values=cache_multi, use_cache=True) + multi_out = model(input_ids=multi_input, past_key_values=cache_multi, use_cache=True) + under_test_first = multi_out.last_hidden_state[:, 0, :] + + torch.testing.assert_close(under_test_first, ref_first, rtol=1e-4, atol=1e-4) + @parameterized.expand(TEST_EAGER_MATCHES_SDPA_INFERENCE_PARAMETERIZATION) def test_eager_matches_sdpa_inference( self, From 40f471b930b1dd40a78a59a855231663b351f951 Mon Sep 17 00:00:00 2001 From: Kashif Rasul Date: Mon, 27 Apr 2026 13:02:35 +0000 Subject: [PATCH 4/4] Review feedback: keep "prefill mode / multi-token decode" comment, label prepend block MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Restore the historical `# Multi-token forward (prefill mode)` comment on the chunk-mode else-branch in olmo_hybrid (and the equivalent qwen3_5 / qwen3_next paths) and adjust the wording so the two intents — fresh prefill vs. cached chunked-tokens decode — are visible on the same line. Tag the conv-context prepend block with a "dropped at the end of this branch" hint so a reader knows the mirror operation exists. Comment-only; behavior unchanged. Generated files re-emitted via `check_modular_conversion.py --fix_and_overwrite`. --- .../models/olmo_hybrid/modeling_olmo_hybrid.py | 7 ++++--- .../models/olmo_hybrid/modular_olmo_hybrid.py | 7 ++++--- src/transformers/models/qwen3_5/modeling_qwen3_5.py | 8 ++++---- src/transformers/models/qwen3_5/modular_qwen3_5.py | 8 ++++---- .../models/qwen3_5_moe/modeling_qwen3_5_moe.py | 8 ++++---- src/transformers/models/qwen3_next/modeling_qwen3_next.py | 8 ++++---- src/transformers/models/qwen3_next/modular_qwen3_next.py | 8 ++++---- 7 files changed, 28 insertions(+), 26 deletions(-) diff --git a/src/transformers/models/olmo_hybrid/modeling_olmo_hybrid.py b/src/transformers/models/olmo_hybrid/modeling_olmo_hybrid.py index 75c23f07543d..563680286b64 100644 --- a/src/transformers/models/olmo_hybrid/modeling_olmo_hybrid.py +++ b/src/transformers/models/olmo_hybrid/modeling_olmo_hybrid.py @@ -241,10 +241,11 @@ def forward( ) conv_state = x_with_state[:, :, 1:] else: + # Multi-token forward (prefill, or chunked-tokens decode when the cache has prior state). if use_precomputed: - # Multi-token cached continuation (`seq_len > 1`, cache already populated): prepend - # the cached conv context so the causal conv sees the correct left-context rather - # than zero-padding. Drop the prepended region from the output below. + # Cached chunked-tokens decode: prepend the cached conv context so the causal conv + # sees the correct left-context rather than zero-padding. Dropped from the output + # at the end of this branch. hidden_states = torch.cat([cache, hidden_states], dim=-1) out = F.conv1d(hidden_states, self.weight, self.bias, padding=self.conv_kernel_size - 1, groups=dim) out = out[:, :, : hidden_states.shape[-1]] diff --git a/src/transformers/models/olmo_hybrid/modular_olmo_hybrid.py b/src/transformers/models/olmo_hybrid/modular_olmo_hybrid.py index ccefed107d4d..7e40d6d61f5d 100644 --- a/src/transformers/models/olmo_hybrid/modular_olmo_hybrid.py +++ b/src/transformers/models/olmo_hybrid/modular_olmo_hybrid.py @@ -342,10 +342,11 @@ def forward( ) conv_state = x_with_state[:, :, 1:] else: + # Multi-token forward (prefill, or chunked-tokens decode when the cache has prior state). if use_precomputed: - # Multi-token cached continuation (`seq_len > 1`, cache already populated): prepend - # the cached conv context so the causal conv sees the correct left-context rather - # than zero-padding. Drop the prepended region from the output below. + # Cached chunked-tokens decode: prepend the cached conv context so the causal conv + # sees the correct left-context rather than zero-padding. Dropped from the output + # at the end of this branch. hidden_states = torch.cat([cache, hidden_states], dim=-1) out = F.conv1d(hidden_states, self.weight, self.bias, padding=self.conv_kernel_size - 1, groups=dim) out = out[:, :, : hidden_states.shape[-1]] diff --git a/src/transformers/models/qwen3_5/modeling_qwen3_5.py b/src/transformers/models/qwen3_5/modeling_qwen3_5.py index d001073751da..904d08a5570f 100644 --- a/src/transformers/models/qwen3_5/modeling_qwen3_5.py +++ b/src/transformers/models/qwen3_5/modeling_qwen3_5.py @@ -460,10 +460,11 @@ def forward( self.activation, ) else: + # Multi-token forward (prefill, or chunked-tokens decode when the cache has prior state). if use_precomputed_states: - # Multi-token cached continuation (`seq_len > 1`, cache already populated). Prepend - # the cached conv context so the causal conv sees the correct left-context rather - # than zero-padding. + # Cached chunked-tokens decode: prepend the cached conv context so the causal conv + # sees the correct left-context rather than zero-padding. Dropped from the output + # at the end of this branch. mixed_qkv = torch.cat([conv_state, mixed_qkv], dim=-1) if cache_params is not None: new_conv_state = F.pad(mixed_qkv, (self.conv_kernel_size - mixed_qkv.shape[-1], 0)) @@ -479,7 +480,6 @@ def forward( else: mixed_qkv = F.silu(self.conv1d(mixed_qkv)[:, :, : mixed_qkv.shape[-1]]) if use_precomputed_states: - # Drop the prepended context; only this chunk's outputs remain. mixed_qkv = mixed_qkv[:, :, -seq_len:] mixed_qkv = mixed_qkv.transpose(1, 2) diff --git a/src/transformers/models/qwen3_5/modular_qwen3_5.py b/src/transformers/models/qwen3_5/modular_qwen3_5.py index 94ce3fc55d6a..710b63a28dba 100644 --- a/src/transformers/models/qwen3_5/modular_qwen3_5.py +++ b/src/transformers/models/qwen3_5/modular_qwen3_5.py @@ -245,10 +245,11 @@ def forward( self.activation, ) else: + # Multi-token forward (prefill, or chunked-tokens decode when the cache has prior state). if use_precomputed_states: - # Multi-token cached continuation (`seq_len > 1`, cache already populated). Prepend - # the cached conv context so the causal conv sees the correct left-context rather - # than zero-padding. + # Cached chunked-tokens decode: prepend the cached conv context so the causal conv + # sees the correct left-context rather than zero-padding. Dropped from the output + # at the end of this branch. mixed_qkv = torch.cat([conv_state, mixed_qkv], dim=-1) if cache_params is not None: new_conv_state = F.pad(mixed_qkv, (self.conv_kernel_size - mixed_qkv.shape[-1], 0)) @@ -264,7 +265,6 @@ def forward( else: mixed_qkv = F.silu(self.conv1d(mixed_qkv)[:, :, : mixed_qkv.shape[-1]]) if use_precomputed_states: - # Drop the prepended context; only this chunk's outputs remain. mixed_qkv = mixed_qkv[:, :, -seq_len:] mixed_qkv = mixed_qkv.transpose(1, 2) diff --git a/src/transformers/models/qwen3_5_moe/modeling_qwen3_5_moe.py b/src/transformers/models/qwen3_5_moe/modeling_qwen3_5_moe.py index 452bf85c003a..bd8a49c969ca 100644 --- a/src/transformers/models/qwen3_5_moe/modeling_qwen3_5_moe.py +++ b/src/transformers/models/qwen3_5_moe/modeling_qwen3_5_moe.py @@ -461,10 +461,11 @@ def forward( self.activation, ) else: + # Multi-token forward (prefill, or chunked-tokens decode when the cache has prior state). if use_precomputed_states: - # Multi-token cached continuation (`seq_len > 1`, cache already populated). Prepend - # the cached conv context so the causal conv sees the correct left-context rather - # than zero-padding. + # Cached chunked-tokens decode: prepend the cached conv context so the causal conv + # sees the correct left-context rather than zero-padding. Dropped from the output + # at the end of this branch. mixed_qkv = torch.cat([conv_state, mixed_qkv], dim=-1) if cache_params is not None: new_conv_state = F.pad(mixed_qkv, (self.conv_kernel_size - mixed_qkv.shape[-1], 0)) @@ -480,7 +481,6 @@ def forward( else: mixed_qkv = F.silu(self.conv1d(mixed_qkv)[:, :, : mixed_qkv.shape[-1]]) if use_precomputed_states: - # Drop the prepended context; only this chunk's outputs remain. mixed_qkv = mixed_qkv[:, :, -seq_len:] mixed_qkv = mixed_qkv.transpose(1, 2) diff --git a/src/transformers/models/qwen3_next/modeling_qwen3_next.py b/src/transformers/models/qwen3_next/modeling_qwen3_next.py index 05f1359001a0..24f4d8a47b29 100644 --- a/src/transformers/models/qwen3_next/modeling_qwen3_next.py +++ b/src/transformers/models/qwen3_next/modeling_qwen3_next.py @@ -629,10 +629,11 @@ def forward( self.activation, ) else: + # Multi-token forward (prefill, or chunked-tokens decode when the cache has prior state). if use_precomputed_states: - # Multi-token cached continuation (`seq_len > 1`, cache already populated). Prepend - # the cached conv context so the causal conv sees the correct left-context rather - # than zero-padding. + # Cached chunked-tokens decode: prepend the cached conv context so the causal conv + # sees the correct left-context rather than zero-padding. Dropped from the output + # at the end of this branch. mixed_qkv = torch.cat([conv_state, mixed_qkv], dim=-1) if cache_params is not None: new_conv_state = F.pad(mixed_qkv, (self.conv_kernel_size - mixed_qkv.shape[-1], 0)) @@ -648,7 +649,6 @@ def forward( else: mixed_qkv = F.silu(self.conv1d(mixed_qkv)[:, :, : mixed_qkv.shape[-1]]) if use_precomputed_states: - # Drop the prepended context; only this chunk's outputs remain. mixed_qkv = mixed_qkv[:, :, -seq_len:] mixed_qkv = mixed_qkv.transpose(1, 2) diff --git a/src/transformers/models/qwen3_next/modular_qwen3_next.py b/src/transformers/models/qwen3_next/modular_qwen3_next.py index 38686b926e08..ef55cbdda3f2 100644 --- a/src/transformers/models/qwen3_next/modular_qwen3_next.py +++ b/src/transformers/models/qwen3_next/modular_qwen3_next.py @@ -468,10 +468,11 @@ def forward( self.activation, ) else: + # Multi-token forward (prefill, or chunked-tokens decode when the cache has prior state). if use_precomputed_states: - # Multi-token cached continuation (`seq_len > 1`, cache already populated). Prepend - # the cached conv context so the causal conv sees the correct left-context rather - # than zero-padding. + # Cached chunked-tokens decode: prepend the cached conv context so the causal conv + # sees the correct left-context rather than zero-padding. Dropped from the output + # at the end of this branch. mixed_qkv = torch.cat([conv_state, mixed_qkv], dim=-1) if cache_params is not None: new_conv_state = F.pad(mixed_qkv, (self.conv_kernel_size - mixed_qkv.shape[-1], 0)) @@ -487,7 +488,6 @@ def forward( else: mixed_qkv = F.silu(self.conv1d(mixed_qkv)[:, :, : mixed_qkv.shape[-1]]) if use_precomputed_states: - # Drop the prepended context; only this chunk's outputs remain. mixed_qkv = mixed_qkv[:, :, -seq_len:] mixed_qkv = mixed_qkv.transpose(1, 2)