diff --git a/src/transformers/models/gpt_oss/modeling_gpt_oss.py b/src/transformers/models/gpt_oss/modeling_gpt_oss.py index 18f31ea90379..2109621f0206 100644 --- a/src/transformers/models/gpt_oss/modeling_gpt_oss.py +++ b/src/transformers/models/gpt_oss/modeling_gpt_oss.py @@ -242,7 +242,7 @@ def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1): return q_embed, k_embed -def eager_attention_forward( +def eager_attention_forward_with_sink( module: nn.Module, query: torch.Tensor, key: torch.Tensor, @@ -324,7 +324,7 @@ def forward( key_states, value_states = past_key_values.update(key_states, value_states, self.layer_idx) attention_interface: Callable = ALL_ATTENTION_FUNCTIONS.get_interface( - self.config._attn_implementation, eager_attention_forward + self.config._attn_implementation, eager_attention_forward_with_sink ) attn_output, attn_weights = attention_interface( diff --git a/src/transformers/models/gpt_oss/modular_gpt_oss.py b/src/transformers/models/gpt_oss/modular_gpt_oss.py index 687e8864efeb..6142afd4a936 100644 --- a/src/transformers/models/gpt_oss/modular_gpt_oss.py +++ b/src/transformers/models/gpt_oss/modular_gpt_oss.py @@ -179,7 +179,7 @@ def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1): return q_embed, k_embed -def eager_attention_forward( +def eager_attention_forward_with_sink( module: nn.Module, query: torch.Tensor, key: torch.Tensor, @@ -249,7 +249,7 @@ def forward( key_states, value_states = past_key_values.update(key_states, value_states, self.layer_idx) attention_interface: Callable = ALL_ATTENTION_FUNCTIONS.get_interface( - self.config._attn_implementation, eager_attention_forward + self.config._attn_implementation, eager_attention_forward_with_sink ) attn_output, attn_weights = attention_interface(