diff --git a/src/transformers/models/olmo_hybrid/modeling_olmo_hybrid.py b/src/transformers/models/olmo_hybrid/modeling_olmo_hybrid.py index 8ecb44916a57..563680286b64 100644 --- a/src/transformers/models/olmo_hybrid/modeling_olmo_hybrid.py +++ b/src/transformers/models/olmo_hybrid/modeling_olmo_hybrid.py @@ -229,8 +229,8 @@ def forward( hidden_states = hidden_states.transpose(1, 2) - if use_precomputed: - # Single token update (decoding mode) + if use_precomputed and seq_len == 1: + # Single-token decode: rolling-window update against the cached context. x_with_state = torch.cat([cache, hidden_states], dim=-1) out = F.conv1d( x_with_state, @@ -241,10 +241,17 @@ def forward( ) conv_state = x_with_state[:, :, 1:] else: - # Multi-token forward (prefill mode) + # Multi-token forward (prefill, or chunked-tokens decode when the cache has prior state). + if use_precomputed: + # Cached chunked-tokens decode: prepend the cached conv context so the causal conv + # sees the correct left-context rather than zero-padding. Dropped from the output + # at the end of this branch. + hidden_states = torch.cat([cache, hidden_states], dim=-1) out = F.conv1d(hidden_states, self.weight, self.bias, padding=self.conv_kernel_size - 1, groups=dim) - out = out[:, :, :seq_len] + out = out[:, :, : hidden_states.shape[-1]] conv_state = F.pad(hidden_states, (self.conv_kernel_size - 1 - hidden_states.shape[-1], 0)) + if use_precomputed: + out = out[:, :, -seq_len:] out = self.act_fn(out) @@ -723,7 +730,10 @@ def forward( batch_size, seq_len, _ = hidden_states.shape use_cache = cache_params is not None - use_precomputed = use_cache and cache_params.has_previous_state() and seq_len == 1 + # Reads "we have cached conv/recurrent state to continue from". Single-token vs multi-token + # branching lives inside `ShortConvolution` and in the recurrent-vs-chunk kernel dispatch + # below, each of which gates on `seq_len == 1` locally. + use_precomputed = use_cache and cache_params.has_previous_state() conv_state_q = cache_params.conv_states_q[self.layer_idx] if cache_params else None conv_state_k = cache_params.conv_states_k[self.layer_idx] if cache_params else None @@ -764,7 +774,7 @@ def forward( g = -self.A_log.float().exp() * F.softplus(self.a_proj(hidden_states).float() + self.dt_bias) - if use_precomputed: + if use_precomputed and seq_len == 1: output, new_recurrent_state = self.recurrent_gated_delta_rule( q, k, @@ -782,7 +792,7 @@ def forward( v, g=g, beta=beta, - initial_state=recurrent_state, + initial_state=recurrent_state if use_precomputed else None, output_final_state=use_cache, use_qk_l2norm_in_kernel=True, ) diff --git a/src/transformers/models/olmo_hybrid/modular_olmo_hybrid.py b/src/transformers/models/olmo_hybrid/modular_olmo_hybrid.py index 089f29309007..7e40d6d61f5d 100644 --- a/src/transformers/models/olmo_hybrid/modular_olmo_hybrid.py +++ b/src/transformers/models/olmo_hybrid/modular_olmo_hybrid.py @@ -330,8 +330,8 @@ def forward( hidden_states = hidden_states.transpose(1, 2) - if use_precomputed: - # Single token update (decoding mode) + if use_precomputed and seq_len == 1: + # Single-token decode: rolling-window update against the cached context. x_with_state = torch.cat([cache, hidden_states], dim=-1) out = F.conv1d( x_with_state, @@ -342,10 +342,17 @@ def forward( ) conv_state = x_with_state[:, :, 1:] else: - # Multi-token forward (prefill mode) + # Multi-token forward (prefill, or chunked-tokens decode when the cache has prior state). + if use_precomputed: + # Cached chunked-tokens decode: prepend the cached conv context so the causal conv + # sees the correct left-context rather than zero-padding. Dropped from the output + # at the end of this branch. + hidden_states = torch.cat([cache, hidden_states], dim=-1) out = F.conv1d(hidden_states, self.weight, self.bias, padding=self.conv_kernel_size - 1, groups=dim) - out = out[:, :, :seq_len] + out = out[:, :, : hidden_states.shape[-1]] conv_state = F.pad(hidden_states, (self.conv_kernel_size - 1 - hidden_states.shape[-1], 0)) + if use_precomputed: + out = out[:, :, -seq_len:] out = self.act_fn(out) @@ -541,7 +548,10 @@ def forward( batch_size, seq_len, _ = hidden_states.shape use_cache = cache_params is not None - use_precomputed = use_cache and cache_params.has_previous_state() and seq_len == 1 + # Reads "we have cached conv/recurrent state to continue from". Single-token vs multi-token + # branching lives inside `ShortConvolution` and in the recurrent-vs-chunk kernel dispatch + # below, each of which gates on `seq_len == 1` locally. + use_precomputed = use_cache and cache_params.has_previous_state() conv_state_q = cache_params.conv_states_q[self.layer_idx] if cache_params else None conv_state_k = cache_params.conv_states_k[self.layer_idx] if cache_params else None @@ -582,7 +592,7 @@ def forward( g = -self.A_log.float().exp() * F.softplus(self.a_proj(hidden_states).float() + self.dt_bias) - if use_precomputed: + if use_precomputed and seq_len == 1: output, new_recurrent_state = self.recurrent_gated_delta_rule( q, k, @@ -600,7 +610,7 @@ def forward( v, g=g, beta=beta, - initial_state=recurrent_state, + initial_state=recurrent_state if use_precomputed else None, output_final_state=use_cache, use_qk_l2norm_in_kernel=True, ) diff --git a/src/transformers/models/qwen3_5/modeling_qwen3_5.py b/src/transformers/models/qwen3_5/modeling_qwen3_5.py index 2c4eba9597dc..904d08a5570f 100644 --- a/src/transformers/models/qwen3_5/modeling_qwen3_5.py +++ b/src/transformers/models/qwen3_5/modeling_qwen3_5.py @@ -430,9 +430,11 @@ def forward( # Set up dimensions for reshapes later batch_size, seq_len, _ = hidden_states.shape - use_precomputed_states = ( - cache_params is not None and cache_params.has_previous_state(self.layer_idx) and seq_len == 1 - ) + # We have cached `conv_state` / `recurrent_state` to continue from. The two cached modes + # (single-token decode and chunk-tokens continuation) share the state read here; they only + # diverge in how the conv input is assembled and which kernel consumes the states below, + # which we gate locally on `seq_len`. + use_precomputed_states = cache_params is not None and cache_params.has_previous_state(self.layer_idx) # getting projected states from cache if it exists if use_precomputed_states: @@ -448,9 +450,8 @@ def forward( b = self.in_proj_b(hidden_states) a = self.in_proj_a(hidden_states) - if use_precomputed_states: - # 2. Convolution sequence transformation - # NOTE: the conv state is updated in `causal_conv1d_update` + if use_precomputed_states and seq_len == 1: + # Single-token cached decode: the fused per-step kernel updates the conv state in-place. mixed_qkv = self.causal_conv1d_update( mixed_qkv, conv_state, @@ -459,9 +460,15 @@ def forward( self.activation, ) else: + # Multi-token forward (prefill, or chunked-tokens decode when the cache has prior state). + if use_precomputed_states: + # Cached chunked-tokens decode: prepend the cached conv context so the causal conv + # sees the correct left-context rather than zero-padding. Dropped from the output + # at the end of this branch. + mixed_qkv = torch.cat([conv_state, mixed_qkv], dim=-1) if cache_params is not None: - conv_state = F.pad(mixed_qkv, (self.conv_kernel_size - mixed_qkv.shape[-1], 0)) - conv_state = cache_params.update_conv_state(conv_state, self.layer_idx) + new_conv_state = F.pad(mixed_qkv, (self.conv_kernel_size - mixed_qkv.shape[-1], 0)) + cache_params.update_conv_state(new_conv_state, self.layer_idx) if self.causal_conv1d_fn is not None: mixed_qkv = self.causal_conv1d_fn( x=mixed_qkv, @@ -471,7 +478,9 @@ def forward( seq_idx=None, ) else: - mixed_qkv = F.silu(self.conv1d(mixed_qkv)[:, :, :seq_len]) + mixed_qkv = F.silu(self.conv1d(mixed_qkv)[:, :, : mixed_qkv.shape[-1]]) + if use_precomputed_states: + mixed_qkv = mixed_qkv[:, :, -seq_len:] mixed_qkv = mixed_qkv.transpose(1, 2) query, key, value = torch.split( @@ -495,26 +504,26 @@ def forward( query = query.repeat_interleave(self.num_v_heads // self.num_k_heads, dim=2) key = key.repeat_interleave(self.num_v_heads // self.num_k_heads, dim=2) - if not use_precomputed_states: - core_attn_out, last_recurrent_state = self.chunk_gated_delta_rule( + if use_precomputed_states and seq_len == 1: + core_attn_out, last_recurrent_state = self.recurrent_gated_delta_rule( query, key, value, g=g, beta=beta, - initial_state=None, + initial_state=recurrent_state, output_final_state=cache_params is not None, use_qk_l2norm_in_kernel=True, ) else: - core_attn_out, last_recurrent_state = self.recurrent_gated_delta_rule( + core_attn_out, last_recurrent_state = self.chunk_gated_delta_rule( query, key, value, g=g, beta=beta, - initial_state=recurrent_state, + initial_state=recurrent_state if use_precomputed_states else None, output_final_state=cache_params is not None, use_qk_l2norm_in_kernel=True, ) diff --git a/src/transformers/models/qwen3_5/modular_qwen3_5.py b/src/transformers/models/qwen3_5/modular_qwen3_5.py index f159901bec15..710b63a28dba 100644 --- a/src/transformers/models/qwen3_5/modular_qwen3_5.py +++ b/src/transformers/models/qwen3_5/modular_qwen3_5.py @@ -215,9 +215,11 @@ def forward( # Set up dimensions for reshapes later batch_size, seq_len, _ = hidden_states.shape - use_precomputed_states = ( - cache_params is not None and cache_params.has_previous_state(self.layer_idx) and seq_len == 1 - ) + # We have cached `conv_state` / `recurrent_state` to continue from. The two cached modes + # (single-token decode and chunk-tokens continuation) share the state read here; they only + # diverge in how the conv input is assembled and which kernel consumes the states below, + # which we gate locally on `seq_len`. + use_precomputed_states = cache_params is not None and cache_params.has_previous_state(self.layer_idx) # getting projected states from cache if it exists if use_precomputed_states: @@ -233,9 +235,8 @@ def forward( b = self.in_proj_b(hidden_states) a = self.in_proj_a(hidden_states) - if use_precomputed_states: - # 2. Convolution sequence transformation - # NOTE: the conv state is updated in `causal_conv1d_update` + if use_precomputed_states and seq_len == 1: + # Single-token cached decode: the fused per-step kernel updates the conv state in-place. mixed_qkv = self.causal_conv1d_update( mixed_qkv, conv_state, @@ -244,9 +245,15 @@ def forward( self.activation, ) else: + # Multi-token forward (prefill, or chunked-tokens decode when the cache has prior state). + if use_precomputed_states: + # Cached chunked-tokens decode: prepend the cached conv context so the causal conv + # sees the correct left-context rather than zero-padding. Dropped from the output + # at the end of this branch. + mixed_qkv = torch.cat([conv_state, mixed_qkv], dim=-1) if cache_params is not None: - conv_state = F.pad(mixed_qkv, (self.conv_kernel_size - mixed_qkv.shape[-1], 0)) - conv_state = cache_params.update_conv_state(conv_state, self.layer_idx) + new_conv_state = F.pad(mixed_qkv, (self.conv_kernel_size - mixed_qkv.shape[-1], 0)) + cache_params.update_conv_state(new_conv_state, self.layer_idx) if self.causal_conv1d_fn is not None: mixed_qkv = self.causal_conv1d_fn( x=mixed_qkv, @@ -256,7 +263,9 @@ def forward( seq_idx=None, ) else: - mixed_qkv = F.silu(self.conv1d(mixed_qkv)[:, :, :seq_len]) + mixed_qkv = F.silu(self.conv1d(mixed_qkv)[:, :, : mixed_qkv.shape[-1]]) + if use_precomputed_states: + mixed_qkv = mixed_qkv[:, :, -seq_len:] mixed_qkv = mixed_qkv.transpose(1, 2) query, key, value = torch.split( @@ -280,26 +289,26 @@ def forward( query = query.repeat_interleave(self.num_v_heads // self.num_k_heads, dim=2) key = key.repeat_interleave(self.num_v_heads // self.num_k_heads, dim=2) - if not use_precomputed_states: - core_attn_out, last_recurrent_state = self.chunk_gated_delta_rule( + if use_precomputed_states and seq_len == 1: + core_attn_out, last_recurrent_state = self.recurrent_gated_delta_rule( query, key, value, g=g, beta=beta, - initial_state=None, + initial_state=recurrent_state, output_final_state=cache_params is not None, use_qk_l2norm_in_kernel=True, ) else: - core_attn_out, last_recurrent_state = self.recurrent_gated_delta_rule( + core_attn_out, last_recurrent_state = self.chunk_gated_delta_rule( query, key, value, g=g, beta=beta, - initial_state=recurrent_state, + initial_state=recurrent_state if use_precomputed_states else None, output_final_state=cache_params is not None, use_qk_l2norm_in_kernel=True, ) diff --git a/src/transformers/models/qwen3_5_moe/modeling_qwen3_5_moe.py b/src/transformers/models/qwen3_5_moe/modeling_qwen3_5_moe.py index 0b2a6a06aa85..bd8a49c969ca 100644 --- a/src/transformers/models/qwen3_5_moe/modeling_qwen3_5_moe.py +++ b/src/transformers/models/qwen3_5_moe/modeling_qwen3_5_moe.py @@ -431,9 +431,11 @@ def forward( # Set up dimensions for reshapes later batch_size, seq_len, _ = hidden_states.shape - use_precomputed_states = ( - cache_params is not None and cache_params.has_previous_state(self.layer_idx) and seq_len == 1 - ) + # We have cached `conv_state` / `recurrent_state` to continue from. The two cached modes + # (single-token decode and chunk-tokens continuation) share the state read here; they only + # diverge in how the conv input is assembled and which kernel consumes the states below, + # which we gate locally on `seq_len`. + use_precomputed_states = cache_params is not None and cache_params.has_previous_state(self.layer_idx) # getting projected states from cache if it exists if use_precomputed_states: @@ -449,9 +451,8 @@ def forward( b = self.in_proj_b(hidden_states) a = self.in_proj_a(hidden_states) - if use_precomputed_states: - # 2. Convolution sequence transformation - # NOTE: the conv state is updated in `causal_conv1d_update` + if use_precomputed_states and seq_len == 1: + # Single-token cached decode: the fused per-step kernel updates the conv state in-place. mixed_qkv = self.causal_conv1d_update( mixed_qkv, conv_state, @@ -460,9 +461,15 @@ def forward( self.activation, ) else: + # Multi-token forward (prefill, or chunked-tokens decode when the cache has prior state). + if use_precomputed_states: + # Cached chunked-tokens decode: prepend the cached conv context so the causal conv + # sees the correct left-context rather than zero-padding. Dropped from the output + # at the end of this branch. + mixed_qkv = torch.cat([conv_state, mixed_qkv], dim=-1) if cache_params is not None: - conv_state = F.pad(mixed_qkv, (self.conv_kernel_size - mixed_qkv.shape[-1], 0)) - conv_state = cache_params.update_conv_state(conv_state, self.layer_idx) + new_conv_state = F.pad(mixed_qkv, (self.conv_kernel_size - mixed_qkv.shape[-1], 0)) + cache_params.update_conv_state(new_conv_state, self.layer_idx) if self.causal_conv1d_fn is not None: mixed_qkv = self.causal_conv1d_fn( x=mixed_qkv, @@ -472,7 +479,9 @@ def forward( seq_idx=None, ) else: - mixed_qkv = F.silu(self.conv1d(mixed_qkv)[:, :, :seq_len]) + mixed_qkv = F.silu(self.conv1d(mixed_qkv)[:, :, : mixed_qkv.shape[-1]]) + if use_precomputed_states: + mixed_qkv = mixed_qkv[:, :, -seq_len:] mixed_qkv = mixed_qkv.transpose(1, 2) query, key, value = torch.split( @@ -496,26 +505,26 @@ def forward( query = query.repeat_interleave(self.num_v_heads // self.num_k_heads, dim=2) key = key.repeat_interleave(self.num_v_heads // self.num_k_heads, dim=2) - if not use_precomputed_states: - core_attn_out, last_recurrent_state = self.chunk_gated_delta_rule( + if use_precomputed_states and seq_len == 1: + core_attn_out, last_recurrent_state = self.recurrent_gated_delta_rule( query, key, value, g=g, beta=beta, - initial_state=None, + initial_state=recurrent_state, output_final_state=cache_params is not None, use_qk_l2norm_in_kernel=True, ) else: - core_attn_out, last_recurrent_state = self.recurrent_gated_delta_rule( + core_attn_out, last_recurrent_state = self.chunk_gated_delta_rule( query, key, value, g=g, beta=beta, - initial_state=recurrent_state, + initial_state=recurrent_state if use_precomputed_states else None, output_final_state=cache_params is not None, use_qk_l2norm_in_kernel=True, ) diff --git a/src/transformers/models/qwen3_next/modeling_qwen3_next.py b/src/transformers/models/qwen3_next/modeling_qwen3_next.py index cd152e3d3e59..24f4d8a47b29 100644 --- a/src/transformers/models/qwen3_next/modeling_qwen3_next.py +++ b/src/transformers/models/qwen3_next/modeling_qwen3_next.py @@ -600,9 +600,11 @@ def forward( # Set up dimensions for reshapes later batch_size, seq_len, _ = hidden_states.shape - use_precomputed_states = ( - cache_params is not None and cache_params.has_previous_state(self.layer_idx) and seq_len == 1 - ) + # We have cached `conv_state` / `recurrent_state` to continue from. The two cached modes + # (single-token decode and chunk-tokens continuation) share the state read here; they only + # diverge in how the conv input is assembled and which kernel consumes the states below, + # which we gate locally on `seq_len`. + use_precomputed_states = cache_params is not None and cache_params.has_previous_state(self.layer_idx) # getting projected states from cache if it exists if use_precomputed_states: @@ -617,9 +619,8 @@ def forward( mixed_qkv = torch.cat((query, key, value), dim=-1) mixed_qkv = mixed_qkv.transpose(1, 2) - if use_precomputed_states: - # 2. Convolution sequence transformation - # NOTE: the conv state is updated in `causal_conv1d_update` + if use_precomputed_states and seq_len == 1: + # Single-token cached decode: the fused per-step kernel updates the conv state in-place. mixed_qkv = self.causal_conv1d_update( mixed_qkv, conv_state, @@ -628,9 +629,15 @@ def forward( self.activation, ) else: + # Multi-token forward (prefill, or chunked-tokens decode when the cache has prior state). + if use_precomputed_states: + # Cached chunked-tokens decode: prepend the cached conv context so the causal conv + # sees the correct left-context rather than zero-padding. Dropped from the output + # at the end of this branch. + mixed_qkv = torch.cat([conv_state, mixed_qkv], dim=-1) if cache_params is not None: - conv_state = F.pad(mixed_qkv, (self.conv_kernel_size - mixed_qkv.shape[-1], 0)) - conv_state = cache_params.update_conv_state(conv_state, self.layer_idx) + new_conv_state = F.pad(mixed_qkv, (self.conv_kernel_size - mixed_qkv.shape[-1], 0)) + cache_params.update_conv_state(new_conv_state, self.layer_idx) if self.causal_conv1d_fn is not None: mixed_qkv = self.causal_conv1d_fn( x=mixed_qkv, @@ -640,7 +647,9 @@ def forward( seq_idx=None, ) else: - mixed_qkv = F.silu(self.conv1d(mixed_qkv)[:, :, :seq_len]) + mixed_qkv = F.silu(self.conv1d(mixed_qkv)[:, :, : mixed_qkv.shape[-1]]) + if use_precomputed_states: + mixed_qkv = mixed_qkv[:, :, -seq_len:] mixed_qkv = mixed_qkv.transpose(1, 2) query, key, value = torch.split( @@ -663,25 +672,25 @@ def forward( query = query.repeat_interleave(self.num_v_heads // self.num_k_heads, dim=2) key = key.repeat_interleave(self.num_v_heads // self.num_k_heads, dim=2) - if not use_precomputed_states: - core_attn_out, last_recurrent_state = self.chunk_gated_delta_rule( + if use_precomputed_states and seq_len == 1: + core_attn_out, last_recurrent_state = self.recurrent_gated_delta_rule( query, key, value, g=g, beta=beta, - initial_state=None, + initial_state=recurrent_state, output_final_state=cache_params is not None, use_qk_l2norm_in_kernel=True, ) else: - core_attn_out, last_recurrent_state = self.recurrent_gated_delta_rule( + core_attn_out, last_recurrent_state = self.chunk_gated_delta_rule( query, key, value, g=g, beta=beta, - initial_state=recurrent_state, + initial_state=recurrent_state if use_precomputed_states else None, output_final_state=cache_params is not None, use_qk_l2norm_in_kernel=True, ) diff --git a/src/transformers/models/qwen3_next/modular_qwen3_next.py b/src/transformers/models/qwen3_next/modular_qwen3_next.py index 392d8325fe22..ef55cbdda3f2 100644 --- a/src/transformers/models/qwen3_next/modular_qwen3_next.py +++ b/src/transformers/models/qwen3_next/modular_qwen3_next.py @@ -439,9 +439,11 @@ def forward( # Set up dimensions for reshapes later batch_size, seq_len, _ = hidden_states.shape - use_precomputed_states = ( - cache_params is not None and cache_params.has_previous_state(self.layer_idx) and seq_len == 1 - ) + # We have cached `conv_state` / `recurrent_state` to continue from. The two cached modes + # (single-token decode and chunk-tokens continuation) share the state read here; they only + # diverge in how the conv input is assembled and which kernel consumes the states below, + # which we gate locally on `seq_len`. + use_precomputed_states = cache_params is not None and cache_params.has_previous_state(self.layer_idx) # getting projected states from cache if it exists if use_precomputed_states: @@ -456,9 +458,8 @@ def forward( mixed_qkv = torch.cat((query, key, value), dim=-1) mixed_qkv = mixed_qkv.transpose(1, 2) - if use_precomputed_states: - # 2. Convolution sequence transformation - # NOTE: the conv state is updated in `causal_conv1d_update` + if use_precomputed_states and seq_len == 1: + # Single-token cached decode: the fused per-step kernel updates the conv state in-place. mixed_qkv = self.causal_conv1d_update( mixed_qkv, conv_state, @@ -467,9 +468,15 @@ def forward( self.activation, ) else: + # Multi-token forward (prefill, or chunked-tokens decode when the cache has prior state). + if use_precomputed_states: + # Cached chunked-tokens decode: prepend the cached conv context so the causal conv + # sees the correct left-context rather than zero-padding. Dropped from the output + # at the end of this branch. + mixed_qkv = torch.cat([conv_state, mixed_qkv], dim=-1) if cache_params is not None: - conv_state = F.pad(mixed_qkv, (self.conv_kernel_size - mixed_qkv.shape[-1], 0)) - conv_state = cache_params.update_conv_state(conv_state, self.layer_idx) + new_conv_state = F.pad(mixed_qkv, (self.conv_kernel_size - mixed_qkv.shape[-1], 0)) + cache_params.update_conv_state(new_conv_state, self.layer_idx) if self.causal_conv1d_fn is not None: mixed_qkv = self.causal_conv1d_fn( x=mixed_qkv, @@ -479,7 +486,9 @@ def forward( seq_idx=None, ) else: - mixed_qkv = F.silu(self.conv1d(mixed_qkv)[:, :, :seq_len]) + mixed_qkv = F.silu(self.conv1d(mixed_qkv)[:, :, : mixed_qkv.shape[-1]]) + if use_precomputed_states: + mixed_qkv = mixed_qkv[:, :, -seq_len:] mixed_qkv = mixed_qkv.transpose(1, 2) query, key, value = torch.split( @@ -502,25 +511,25 @@ def forward( query = query.repeat_interleave(self.num_v_heads // self.num_k_heads, dim=2) key = key.repeat_interleave(self.num_v_heads // self.num_k_heads, dim=2) - if not use_precomputed_states: - core_attn_out, last_recurrent_state = self.chunk_gated_delta_rule( + if use_precomputed_states and seq_len == 1: + core_attn_out, last_recurrent_state = self.recurrent_gated_delta_rule( query, key, value, g=g, beta=beta, - initial_state=None, + initial_state=recurrent_state, output_final_state=cache_params is not None, use_qk_l2norm_in_kernel=True, ) else: - core_attn_out, last_recurrent_state = self.recurrent_gated_delta_rule( + core_attn_out, last_recurrent_state = self.chunk_gated_delta_rule( query, key, value, g=g, beta=beta, - initial_state=recurrent_state, + initial_state=recurrent_state if use_precomputed_states else None, output_final_state=cache_params is not None, use_qk_l2norm_in_kernel=True, ) diff --git a/tests/models/olmo_hybrid/test_modeling_olmo_hybrid.py b/tests/models/olmo_hybrid/test_modeling_olmo_hybrid.py index 37b2750a632d..56622d65407b 100644 --- a/tests/models/olmo_hybrid/test_modeling_olmo_hybrid.py +++ b/tests/models/olmo_hybrid/test_modeling_olmo_hybrid.py @@ -26,6 +26,7 @@ ) from ...causal_lm_tester import CausalLMModelTest, CausalLMModelTester +from ...test_modeling_common import ids_tensor if is_torch_available(): @@ -68,6 +69,40 @@ class OlmoHybridModelTest(CausalLMModelTest, unittest.TestCase): def test_tp_generation_quantized(self): pass + def test_linear_attention_multi_token_cached_forward_matches_single_token(self): + """ + OLMo-Hybrid's GatedDeltaNet layers must produce the same output for a token regardless of + whether it's fed as a single-token cached forward or as the first token of a multi-token chunk + after the cache has been populated (chunked-prefill continuation / speculative verification). + A causal LM's logits at position `i` cannot depend on tokens at positions > `i`, even across + separate forward calls with a shared cache. + """ + config, _ = self.model_tester.prepare_config_and_inputs_for_common() + config._attn_implementation = "eager" + model = OlmoHybridModel._from_config(config) + model.to(torch_device) + model.eval() + + prefill_len = 8 + prompt = ids_tensor((1, prefill_len), config.vocab_size).to(torch_device) + next_token = ids_tensor((1, 1), config.vocab_size).to(torch_device) + + cache_single = OlmoHybridDynamicCache(config=config) + with torch.no_grad(): + model(input_ids=prompt, past_key_values=cache_single, use_cache=True) + single_out = model(input_ids=next_token, past_key_values=cache_single, use_cache=True) + ref_first = single_out.last_hidden_state[:, 0, :] + + distractors = ids_tensor((1, 7), config.vocab_size).to(torch_device) + multi_input = torch.cat([next_token, distractors], dim=1) + cache_multi = OlmoHybridDynamicCache(config=config) + with torch.no_grad(): + model(input_ids=prompt, past_key_values=cache_multi, use_cache=True) + multi_out = model(input_ids=multi_input, past_key_values=cache_multi, use_cache=True) + under_test_first = multi_out.last_hidden_state[:, 0, :] + + torch.testing.assert_close(under_test_first, ref_first, rtol=1e-4, atol=1e-4) + # === Cache helper methods (same pattern as Qwen3Next) === def _check_past_key_values_for_generate(self, batch_size, past_key_values, seq_length, config): """OlmoHybrid has a special Cache as it alternates with gated deltanet layers""" diff --git a/tests/models/qwen3_5/test_modeling_qwen3_5.py b/tests/models/qwen3_5/test_modeling_qwen3_5.py index 7725d2891a33..668a4e513970 100644 --- a/tests/models/qwen3_5/test_modeling_qwen3_5.py +++ b/tests/models/qwen3_5/test_modeling_qwen3_5.py @@ -38,6 +38,7 @@ import torch from transformers import ( + DynamicCache, Qwen3_5Config, Qwen3_5ForCausalLM, Qwen3_5ForConditionalGeneration, @@ -142,6 +143,46 @@ def test_multi_gpu_data_parallel_forward(self): def test_reverse_loading_mapping(self, check_keys_were_modified=True): pass + def test_linear_attention_multi_token_cached_forward_matches_single_token(self): + """ + Qwen3.5's gated-delta-net layers must produce the same output for a token regardless of + whether it's fed as a single-token cached forward or as the first token of a multi-token + chunk after the cache has been populated (chunked prefill continuation / speculative + verification pattern). A causal LM can never have its logits at position `i` depend on + tokens at positions > `i`, even across separate forward calls with a shared cache. + """ + config, _ = self.model_tester.prepare_config_and_inputs_for_common() + config._attn_implementation = "eager" + # GatedDeltaNet's fused norm-gate kernel only supports silu/swish/sigmoid; the shared + # tester's default `gelu` would raise before we get to exercise the cache path. + config.hidden_act = "silu" + model = Qwen3_5TextModel._from_config(config) + model.to(torch_device) + model.eval() + + prefill_len = 8 + prompt = ids_tensor((1, prefill_len), config.vocab_size).to(torch_device) + next_token = ids_tensor((1, 1), config.vocab_size).to(torch_device) + + # Reference: prefill, then forward the next token alone with the populated cache. + cache_single = DynamicCache(config=config) + with torch.no_grad(): + model(input_ids=prompt, past_key_values=cache_single, use_cache=True) + single_out = model(input_ids=next_token, past_key_values=cache_single, use_cache=True) + ref_first = single_out.last_hidden_state[:, 0, :] + + # Under test: prefill, then forward [next_token, *distractors] in one call. The first + # position must match the single-token forward exactly (causal attention property). + distractors = ids_tensor((1, 7), config.vocab_size).to(torch_device) + multi_input = torch.cat([next_token, distractors], dim=1) + cache_multi = DynamicCache(config=config) + with torch.no_grad(): + model(input_ids=prompt, past_key_values=cache_multi, use_cache=True) + multi_out = model(input_ids=multi_input, past_key_values=cache_multi, use_cache=True) + under_test_first = multi_out.last_hidden_state[:, 0, :] + + torch.testing.assert_close(under_test_first, ref_first, rtol=1e-4, atol=1e-4) + class Qwen3_5VisionText2TextModelTester: def __init__( diff --git a/tests/models/qwen3_next/test_modeling_qwen3_next.py b/tests/models/qwen3_next/test_modeling_qwen3_next.py index 4cb53fb6c695..5f65bf6e54c9 100644 --- a/tests/models/qwen3_next/test_modeling_qwen3_next.py +++ b/tests/models/qwen3_next/test_modeling_qwen3_next.py @@ -25,6 +25,7 @@ import torch from transformers import ( + DynamicCache, Qwen3NextModel, ) @@ -32,6 +33,7 @@ from ...test_modeling_common import ( TEST_EAGER_MATCHES_SDPA_INFERENCE_PARAMETERIZATION, _test_eager_matches_sdpa_inference, + ids_tensor, ) @@ -121,6 +123,46 @@ def test_attention_outputs(self): self.assertEqual(len(self_attentions), sum(layer == "full_attention" for layer in config.layer_types)) self.assertListEqual(list(self_attentions[0].shape[-3:]), [config.num_attention_heads, seq_len, seq_len]) + def test_linear_attention_multi_token_cached_forward_matches_single_token(self): + """ + Qwen3-Next's gated-delta-net layers must produce the same output for a token regardless of + whether it's fed as a single-token cached forward or as the first token of a multi-token chunk + after the cache has been populated (chunked-prefill continuation / speculative verification). + A causal LM's logits at position `i` cannot depend on tokens at positions > `i`, even across + separate forward calls with a shared cache. + """ + config, _ = self.model_tester.prepare_config_and_inputs_for_common() + config._attn_implementation = "eager" + # GatedDeltaNet's fused norm-gate kernel only supports silu/swish/sigmoid; the shared tester + # default `gelu` would raise before exercising the cache path. + config.hidden_act = "silu" + model = Qwen3NextModel._from_config(config) + model.to(torch_device) + model.eval() + + prefill_len = 8 + prompt = ids_tensor((1, prefill_len), config.vocab_size).to(torch_device) + next_token = ids_tensor((1, 1), config.vocab_size).to(torch_device) + + # Reference: prefill, then forward the next token alone with the populated cache. + cache_single = DynamicCache(config=config) + with torch.no_grad(): + model(input_ids=prompt, past_key_values=cache_single, use_cache=True) + single_out = model(input_ids=next_token, past_key_values=cache_single, use_cache=True) + ref_first = single_out.last_hidden_state[:, 0, :] + + # Under test: prefill, then forward [next_token, *distractors] in one call. The first + # position must match the single-token forward exactly (causal attention). + distractors = ids_tensor((1, 7), config.vocab_size).to(torch_device) + multi_input = torch.cat([next_token, distractors], dim=1) + cache_multi = DynamicCache(config=config) + with torch.no_grad(): + model(input_ids=prompt, past_key_values=cache_multi, use_cache=True) + multi_out = model(input_ids=multi_input, past_key_values=cache_multi, use_cache=True) + under_test_first = multi_out.last_hidden_state[:, 0, :] + + torch.testing.assert_close(under_test_first, ref_first, rtol=1e-4, atol=1e-4) + @parameterized.expand(TEST_EAGER_MATCHES_SDPA_INFERENCE_PARAMETERIZATION) def test_eager_matches_sdpa_inference( self,