Skip to content
Merged
24 changes: 17 additions & 7 deletions src/transformers/models/olmo_hybrid/modeling_olmo_hybrid.py
Original file line number Diff line number Diff line change
Expand Up @@ -229,8 +229,8 @@ def forward(

hidden_states = hidden_states.transpose(1, 2)

if use_precomputed:
# Single token update (decoding mode)
if use_precomputed and seq_len == 1:
# Single-token decode: rolling-window update against the cached context.
x_with_state = torch.cat([cache, hidden_states], dim=-1)
out = F.conv1d(
x_with_state,
Expand All @@ -241,10 +241,17 @@ def forward(
)
conv_state = x_with_state[:, :, 1:]
else:
# Multi-token forward (prefill mode)
# Multi-token forward (prefill, or chunked-tokens decode when the cache has prior state).
if use_precomputed:
# Cached chunked-tokens decode: prepend the cached conv context so the causal conv
# sees the correct left-context rather than zero-padding. Dropped from the output
# at the end of this branch.
hidden_states = torch.cat([cache, hidden_states], dim=-1)
out = F.conv1d(hidden_states, self.weight, self.bias, padding=self.conv_kernel_size - 1, groups=dim)
out = out[:, :, :seq_len]
out = out[:, :, : hidden_states.shape[-1]]
conv_state = F.pad(hidden_states, (self.conv_kernel_size - 1 - hidden_states.shape[-1], 0))
if use_precomputed:
out = out[:, :, -seq_len:]

out = self.act_fn(out)

Expand Down Expand Up @@ -723,7 +730,10 @@ def forward(
batch_size, seq_len, _ = hidden_states.shape

use_cache = cache_params is not None
use_precomputed = use_cache and cache_params.has_previous_state() and seq_len == 1
# Reads "we have cached conv/recurrent state to continue from". Single-token vs multi-token
# branching lives inside `ShortConvolution` and in the recurrent-vs-chunk kernel dispatch
# below, each of which gates on `seq_len == 1` locally.
use_precomputed = use_cache and cache_params.has_previous_state()

conv_state_q = cache_params.conv_states_q[self.layer_idx] if cache_params else None
conv_state_k = cache_params.conv_states_k[self.layer_idx] if cache_params else None
Expand Down Expand Up @@ -764,7 +774,7 @@ def forward(

g = -self.A_log.float().exp() * F.softplus(self.a_proj(hidden_states).float() + self.dt_bias)

if use_precomputed:
if use_precomputed and seq_len == 1:
output, new_recurrent_state = self.recurrent_gated_delta_rule(
q,
k,
Expand All @@ -782,7 +792,7 @@ def forward(
v,
g=g,
beta=beta,
initial_state=recurrent_state,
initial_state=recurrent_state if use_precomputed else None,
output_final_state=use_cache,
use_qk_l2norm_in_kernel=True,
)
Expand Down
24 changes: 17 additions & 7 deletions src/transformers/models/olmo_hybrid/modular_olmo_hybrid.py
Original file line number Diff line number Diff line change
Expand Up @@ -330,8 +330,8 @@ def forward(

hidden_states = hidden_states.transpose(1, 2)

if use_precomputed:
# Single token update (decoding mode)
if use_precomputed and seq_len == 1:
# Single-token decode: rolling-window update against the cached context.
x_with_state = torch.cat([cache, hidden_states], dim=-1)
out = F.conv1d(
x_with_state,
Expand All @@ -342,10 +342,17 @@ def forward(
)
conv_state = x_with_state[:, :, 1:]
else:
# Multi-token forward (prefill mode)
Comment thread
kashif marked this conversation as resolved.
# Multi-token forward (prefill, or chunked-tokens decode when the cache has prior state).
if use_precomputed:
# Cached chunked-tokens decode: prepend the cached conv context so the causal conv
# sees the correct left-context rather than zero-padding. Dropped from the output
# at the end of this branch.
hidden_states = torch.cat([cache, hidden_states], dim=-1)
out = F.conv1d(hidden_states, self.weight, self.bias, padding=self.conv_kernel_size - 1, groups=dim)
out = out[:, :, :seq_len]
out = out[:, :, : hidden_states.shape[-1]]
conv_state = F.pad(hidden_states, (self.conv_kernel_size - 1 - hidden_states.shape[-1], 0))
if use_precomputed:
out = out[:, :, -seq_len:]

out = self.act_fn(out)

Expand Down Expand Up @@ -541,7 +548,10 @@ def forward(
batch_size, seq_len, _ = hidden_states.shape

use_cache = cache_params is not None
use_precomputed = use_cache and cache_params.has_previous_state() and seq_len == 1
# Reads "we have cached conv/recurrent state to continue from". Single-token vs multi-token
# branching lives inside `ShortConvolution` and in the recurrent-vs-chunk kernel dispatch
# below, each of which gates on `seq_len == 1` locally.
use_precomputed = use_cache and cache_params.has_previous_state()

conv_state_q = cache_params.conv_states_q[self.layer_idx] if cache_params else None
conv_state_k = cache_params.conv_states_k[self.layer_idx] if cache_params else None
Expand Down Expand Up @@ -582,7 +592,7 @@ def forward(

g = -self.A_log.float().exp() * F.softplus(self.a_proj(hidden_states).float() + self.dt_bias)

if use_precomputed:
if use_precomputed and seq_len == 1:
output, new_recurrent_state = self.recurrent_gated_delta_rule(
q,
k,
Expand All @@ -600,7 +610,7 @@ def forward(
v,
g=g,
beta=beta,
initial_state=recurrent_state,
initial_state=recurrent_state if use_precomputed else None,
output_final_state=use_cache,
use_qk_l2norm_in_kernel=True,
)
Expand Down
37 changes: 23 additions & 14 deletions src/transformers/models/qwen3_5/modeling_qwen3_5.py
Original file line number Diff line number Diff line change
Expand Up @@ -430,9 +430,11 @@ def forward(
# Set up dimensions for reshapes later
batch_size, seq_len, _ = hidden_states.shape

use_precomputed_states = (
cache_params is not None and cache_params.has_previous_state(self.layer_idx) and seq_len == 1
)
# We have cached `conv_state` / `recurrent_state` to continue from. The two cached modes
# (single-token decode and chunk-tokens continuation) share the state read here; they only
# diverge in how the conv input is assembled and which kernel consumes the states below,
# which we gate locally on `seq_len`.
use_precomputed_states = cache_params is not None and cache_params.has_previous_state(self.layer_idx)

# getting projected states from cache if it exists
if use_precomputed_states:
Expand All @@ -448,9 +450,8 @@ def forward(
b = self.in_proj_b(hidden_states)
a = self.in_proj_a(hidden_states)

if use_precomputed_states:
# 2. Convolution sequence transformation
# NOTE: the conv state is updated in `causal_conv1d_update`
if use_precomputed_states and seq_len == 1:
# Single-token cached decode: the fused per-step kernel updates the conv state in-place.
mixed_qkv = self.causal_conv1d_update(
mixed_qkv,
conv_state,
Expand All @@ -459,9 +460,15 @@ def forward(
self.activation,
)
else:
# Multi-token forward (prefill, or chunked-tokens decode when the cache has prior state).
if use_precomputed_states:
# Cached chunked-tokens decode: prepend the cached conv context so the causal conv
# sees the correct left-context rather than zero-padding. Dropped from the output
# at the end of this branch.
mixed_qkv = torch.cat([conv_state, mixed_qkv], dim=-1)
if cache_params is not None:
conv_state = F.pad(mixed_qkv, (self.conv_kernel_size - mixed_qkv.shape[-1], 0))
conv_state = cache_params.update_conv_state(conv_state, self.layer_idx)
new_conv_state = F.pad(mixed_qkv, (self.conv_kernel_size - mixed_qkv.shape[-1], 0))
cache_params.update_conv_state(new_conv_state, self.layer_idx)
if self.causal_conv1d_fn is not None:
mixed_qkv = self.causal_conv1d_fn(
x=mixed_qkv,
Expand All @@ -471,7 +478,9 @@ def forward(
seq_idx=None,
)
else:
mixed_qkv = F.silu(self.conv1d(mixed_qkv)[:, :, :seq_len])
mixed_qkv = F.silu(self.conv1d(mixed_qkv)[:, :, : mixed_qkv.shape[-1]])
if use_precomputed_states:
mixed_qkv = mixed_qkv[:, :, -seq_len:]

mixed_qkv = mixed_qkv.transpose(1, 2)
query, key, value = torch.split(
Expand All @@ -495,26 +504,26 @@ def forward(
query = query.repeat_interleave(self.num_v_heads // self.num_k_heads, dim=2)
key = key.repeat_interleave(self.num_v_heads // self.num_k_heads, dim=2)

if not use_precomputed_states:
core_attn_out, last_recurrent_state = self.chunk_gated_delta_rule(
if use_precomputed_states and seq_len == 1:
core_attn_out, last_recurrent_state = self.recurrent_gated_delta_rule(
query,
key,
value,
g=g,
beta=beta,
initial_state=None,
initial_state=recurrent_state,
output_final_state=cache_params is not None,
use_qk_l2norm_in_kernel=True,
)

else:
core_attn_out, last_recurrent_state = self.recurrent_gated_delta_rule(
core_attn_out, last_recurrent_state = self.chunk_gated_delta_rule(
query,
key,
value,
g=g,
beta=beta,
initial_state=recurrent_state,
initial_state=recurrent_state if use_precomputed_states else None,
output_final_state=cache_params is not None,
use_qk_l2norm_in_kernel=True,
)
Expand Down
37 changes: 23 additions & 14 deletions src/transformers/models/qwen3_5/modular_qwen3_5.py
Original file line number Diff line number Diff line change
Expand Up @@ -215,9 +215,11 @@ def forward(
# Set up dimensions for reshapes later
batch_size, seq_len, _ = hidden_states.shape

use_precomputed_states = (
cache_params is not None and cache_params.has_previous_state(self.layer_idx) and seq_len == 1
)
# We have cached `conv_state` / `recurrent_state` to continue from. The two cached modes
# (single-token decode and chunk-tokens continuation) share the state read here; they only
# diverge in how the conv input is assembled and which kernel consumes the states below,
# which we gate locally on `seq_len`.
use_precomputed_states = cache_params is not None and cache_params.has_previous_state(self.layer_idx)

# getting projected states from cache if it exists
if use_precomputed_states:
Expand All @@ -233,9 +235,8 @@ def forward(
b = self.in_proj_b(hidden_states)
a = self.in_proj_a(hidden_states)

if use_precomputed_states:
# 2. Convolution sequence transformation
# NOTE: the conv state is updated in `causal_conv1d_update`
if use_precomputed_states and seq_len == 1:
# Single-token cached decode: the fused per-step kernel updates the conv state in-place.
mixed_qkv = self.causal_conv1d_update(
mixed_qkv,
conv_state,
Expand All @@ -244,9 +245,15 @@ def forward(
self.activation,
)
else:
# Multi-token forward (prefill, or chunked-tokens decode when the cache has prior state).
if use_precomputed_states:
# Cached chunked-tokens decode: prepend the cached conv context so the causal conv
# sees the correct left-context rather than zero-padding. Dropped from the output
# at the end of this branch.
mixed_qkv = torch.cat([conv_state, mixed_qkv], dim=-1)
if cache_params is not None:
conv_state = F.pad(mixed_qkv, (self.conv_kernel_size - mixed_qkv.shape[-1], 0))
conv_state = cache_params.update_conv_state(conv_state, self.layer_idx)
new_conv_state = F.pad(mixed_qkv, (self.conv_kernel_size - mixed_qkv.shape[-1], 0))
cache_params.update_conv_state(new_conv_state, self.layer_idx)
if self.causal_conv1d_fn is not None:
mixed_qkv = self.causal_conv1d_fn(
x=mixed_qkv,
Expand All @@ -256,7 +263,9 @@ def forward(
seq_idx=None,
)
else:
mixed_qkv = F.silu(self.conv1d(mixed_qkv)[:, :, :seq_len])
mixed_qkv = F.silu(self.conv1d(mixed_qkv)[:, :, : mixed_qkv.shape[-1]])
if use_precomputed_states:
mixed_qkv = mixed_qkv[:, :, -seq_len:]

mixed_qkv = mixed_qkv.transpose(1, 2)
query, key, value = torch.split(
Expand All @@ -280,26 +289,26 @@ def forward(
query = query.repeat_interleave(self.num_v_heads // self.num_k_heads, dim=2)
key = key.repeat_interleave(self.num_v_heads // self.num_k_heads, dim=2)

if not use_precomputed_states:
core_attn_out, last_recurrent_state = self.chunk_gated_delta_rule(
if use_precomputed_states and seq_len == 1:
core_attn_out, last_recurrent_state = self.recurrent_gated_delta_rule(
query,
key,
value,
g=g,
beta=beta,
initial_state=None,
initial_state=recurrent_state,
output_final_state=cache_params is not None,
use_qk_l2norm_in_kernel=True,
)

else:
core_attn_out, last_recurrent_state = self.recurrent_gated_delta_rule(
core_attn_out, last_recurrent_state = self.chunk_gated_delta_rule(
query,
key,
value,
g=g,
beta=beta,
initial_state=recurrent_state,
initial_state=recurrent_state if use_precomputed_states else None,
output_final_state=cache_params is not None,
use_qk_l2norm_in_kernel=True,
)
Expand Down
37 changes: 23 additions & 14 deletions src/transformers/models/qwen3_5_moe/modeling_qwen3_5_moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -431,9 +431,11 @@ def forward(
# Set up dimensions for reshapes later
batch_size, seq_len, _ = hidden_states.shape

use_precomputed_states = (
cache_params is not None and cache_params.has_previous_state(self.layer_idx) and seq_len == 1
)
# We have cached `conv_state` / `recurrent_state` to continue from. The two cached modes
# (single-token decode and chunk-tokens continuation) share the state read here; they only
# diverge in how the conv input is assembled and which kernel consumes the states below,
# which we gate locally on `seq_len`.
use_precomputed_states = cache_params is not None and cache_params.has_previous_state(self.layer_idx)

# getting projected states from cache if it exists
if use_precomputed_states:
Expand All @@ -449,9 +451,8 @@ def forward(
b = self.in_proj_b(hidden_states)
a = self.in_proj_a(hidden_states)

if use_precomputed_states:
# 2. Convolution sequence transformation
# NOTE: the conv state is updated in `causal_conv1d_update`
if use_precomputed_states and seq_len == 1:
# Single-token cached decode: the fused per-step kernel updates the conv state in-place.
mixed_qkv = self.causal_conv1d_update(
mixed_qkv,
conv_state,
Expand All @@ -460,9 +461,15 @@ def forward(
self.activation,
)
else:
# Multi-token forward (prefill, or chunked-tokens decode when the cache has prior state).
if use_precomputed_states:
# Cached chunked-tokens decode: prepend the cached conv context so the causal conv
# sees the correct left-context rather than zero-padding. Dropped from the output
# at the end of this branch.
mixed_qkv = torch.cat([conv_state, mixed_qkv], dim=-1)
if cache_params is not None:
conv_state = F.pad(mixed_qkv, (self.conv_kernel_size - mixed_qkv.shape[-1], 0))
conv_state = cache_params.update_conv_state(conv_state, self.layer_idx)
new_conv_state = F.pad(mixed_qkv, (self.conv_kernel_size - mixed_qkv.shape[-1], 0))
cache_params.update_conv_state(new_conv_state, self.layer_idx)
if self.causal_conv1d_fn is not None:
mixed_qkv = self.causal_conv1d_fn(
x=mixed_qkv,
Expand All @@ -472,7 +479,9 @@ def forward(
seq_idx=None,
)
else:
mixed_qkv = F.silu(self.conv1d(mixed_qkv)[:, :, :seq_len])
mixed_qkv = F.silu(self.conv1d(mixed_qkv)[:, :, : mixed_qkv.shape[-1]])
if use_precomputed_states:
mixed_qkv = mixed_qkv[:, :, -seq_len:]

mixed_qkv = mixed_qkv.transpose(1, 2)
query, key, value = torch.split(
Expand All @@ -496,26 +505,26 @@ def forward(
query = query.repeat_interleave(self.num_v_heads // self.num_k_heads, dim=2)
key = key.repeat_interleave(self.num_v_heads // self.num_k_heads, dim=2)

if not use_precomputed_states:
core_attn_out, last_recurrent_state = self.chunk_gated_delta_rule(
if use_precomputed_states and seq_len == 1:
core_attn_out, last_recurrent_state = self.recurrent_gated_delta_rule(
query,
key,
value,
g=g,
beta=beta,
initial_state=None,
initial_state=recurrent_state,
output_final_state=cache_params is not None,
use_qk_l2norm_in_kernel=True,
)

else:
core_attn_out, last_recurrent_state = self.recurrent_gated_delta_rule(
core_attn_out, last_recurrent_state = self.chunk_gated_delta_rule(
query,
key,
value,
g=g,
beta=beta,
initial_state=recurrent_state,
initial_state=recurrent_state if use_precomputed_states else None,
output_final_state=cache_params is not None,
use_qk_l2norm_in_kernel=True,
)
Expand Down
Loading
Loading