From c42439ed5a325bed468367648894d884c103ce00 Mon Sep 17 00:00:00 2001 From: casinca <47400729+casinca@users.noreply.github.com> Date: Tue, 31 Mar 2026 14:34:24 +0200 Subject: [PATCH] refactor(gpt-oss): rename `eager_attention_forward` to `eager_attention_forward_with_sink` --- src/transformers/models/gpt_oss/modeling_gpt_oss.py | 4 ++-- src/transformers/models/gpt_oss/modular_gpt_oss.py | 4 ++-- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/src/transformers/models/gpt_oss/modeling_gpt_oss.py b/src/transformers/models/gpt_oss/modeling_gpt_oss.py index 18f31ea90379..2109621f0206 100644 --- a/src/transformers/models/gpt_oss/modeling_gpt_oss.py +++ b/src/transformers/models/gpt_oss/modeling_gpt_oss.py @@ -242,7 +242,7 @@ def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1): return q_embed, k_embed -def eager_attention_forward( +def eager_attention_forward_with_sink( module: nn.Module, query: torch.Tensor, key: torch.Tensor, @@ -324,7 +324,7 @@ def forward( key_states, value_states = past_key_values.update(key_states, value_states, self.layer_idx) attention_interface: Callable = ALL_ATTENTION_FUNCTIONS.get_interface( - self.config._attn_implementation, eager_attention_forward + self.config._attn_implementation, eager_attention_forward_with_sink ) attn_output, attn_weights = attention_interface( diff --git a/src/transformers/models/gpt_oss/modular_gpt_oss.py b/src/transformers/models/gpt_oss/modular_gpt_oss.py index 687e8864efeb..6142afd4a936 100644 --- a/src/transformers/models/gpt_oss/modular_gpt_oss.py +++ b/src/transformers/models/gpt_oss/modular_gpt_oss.py @@ -179,7 +179,7 @@ def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1): return q_embed, k_embed -def eager_attention_forward( +def eager_attention_forward_with_sink( module: nn.Module, query: torch.Tensor, key: torch.Tensor, @@ -249,7 +249,7 @@ def forward( key_states, value_states = past_key_values.update(key_states, value_states, self.layer_idx) attention_interface: Callable = ALL_ATTENTION_FUNCTIONS.get_interface( - self.config._attn_implementation, eager_attention_forward + self.config._attn_implementation, eager_attention_forward_with_sink ) attn_output, attn_weights = attention_interface(