From 09320fc1f41a3ad2ee21e49c5913ca53449f78dd Mon Sep 17 00:00:00 2001 From: mtthw13 Date: Tue, 17 Feb 2026 14:11:31 +0800 Subject: [PATCH] Refactor GPT-Neo to use @capture_outputs and @can_return_tuple decorators --- .../models/gpt_neo/modeling_gpt_neo.py | 159 +++++------------- tests/models/gpt_neo/test_modeling_gpt_neo.py | 2 +- 2 files changed, 42 insertions(+), 119 deletions(-) diff --git a/src/transformers/models/gpt_neo/modeling_gpt_neo.py b/src/transformers/models/gpt_neo/modeling_gpt_neo.py index e80b6d1c208a..d32c81280aed 100755 --- a/src/transformers/models/gpt_neo/modeling_gpt_neo.py +++ b/src/transformers/models/gpt_neo/modeling_gpt_neo.py @@ -26,8 +26,6 @@ from ...modeling_layers import GradientCheckpointingLayer from ...modeling_outputs import ( BaseModelOutputWithPast, - BaseModelOutputWithPastAndCrossAttentions, - CausalLMOutputWithCrossAttentions, CausalLMOutputWithPast, QuestionAnsweringModelOutput, SequenceClassifierOutputWithPast, @@ -36,8 +34,11 @@ from ...modeling_utils import PreTrainedModel from ...utils import ( auto_docstring, + can_return_tuple, logging, ) +from ...utils.generic import merge_with_config_defaults +from ...utils.output_capturing import capture_outputs from .configuration_gpt_neo import GPTNeoConfig @@ -135,7 +136,6 @@ def forward( attention_mask=None, layer_past=None, use_cache=False, - output_attentions=False, cache_position=None, ): query = self.q_proj(hidden_states) @@ -180,7 +180,6 @@ def forward( attention_mask=None, layer_past=None, use_cache=False, - output_attentions=False, cache_position=None, ): bsz, _, _ = hidden_states.size() @@ -282,7 +281,6 @@ def forward( layer_past=None, attention_mask=None, use_cache=False, - output_attentions=False, cache_position=None, ): return self.attention( @@ -290,7 +288,6 @@ def forward( attention_mask=attention_mask, layer_past=layer_past, use_cache=use_cache, - output_attentions=output_attentions, cache_position=cache_position, ) @@ -328,17 +325,15 @@ def forward( layer_past=None, attention_mask=None, use_cache=False, - output_attentions=False, cache_position=None, ): residual = hidden_states hidden_states = self.ln_1(hidden_states) - attn_output, attn_weights = self.attn( + attn_output, _ = self.attn( hidden_states, layer_past=layer_past, attention_mask=attention_mask, use_cache=use_cache, - output_attentions=output_attentions, cache_position=cache_position, ) @@ -351,7 +346,7 @@ def forward( # residual connection hidden_states = residual + feed_forward_hidden_states - return hidden_states, attn_weights + return hidden_states @auto_docstring @@ -363,6 +358,10 @@ class GPTNeoPreTrainedModel(PreTrainedModel): _skip_keys_device_placement = "past_key_values" _supports_flash_attn = True _can_compile_fullgraph = False # TODO: needs a hybrid cache + _can_record_outputs = { + "hidden_states": GPTNeoBlock, + "attentions": GPTNeoAttention, + } def _init_weights(self, module): super()._init_weights(module) @@ -398,6 +397,8 @@ def get_input_embeddings(self): def set_input_embeddings(self, new_embeddings): self.wte = new_embeddings + @merge_with_config_defaults + @capture_outputs @auto_docstring def forward( self, @@ -408,12 +409,9 @@ def forward( position_ids: torch.Tensor | None = None, inputs_embeds: torch.Tensor | None = None, use_cache: bool | None = None, - output_attentions: bool | None = None, - output_hidden_states: bool | None = None, - return_dict: bool | None = None, cache_position: torch.LongTensor | None = None, **kwargs, - ) -> tuple[torch.Tensor] | BaseModelOutputWithPastAndCrossAttentions: + ) -> BaseModelOutputWithPast: r""" input_ids (`torch.LongTensor` of shape `(batch_size, input_ids_length)`): `input_ids_length` = `sequence_length` if `past_key_values` is `None` else @@ -428,23 +426,9 @@ def forward( [What are input IDs?](../glossary#input-ids) """ - output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions - output_hidden_states = ( - output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states - ) - use_cache = use_cache if use_cache is not None else self.config.use_cache - return_dict = return_dict if return_dict is not None else self.config.use_return_dict - if (input_ids is None) ^ (inputs_embeds is not None): raise ValueError("You must specify exactly one of input_ids or inputs_embeds") - if self.gradient_checkpointing and self.training: - if use_cache: - logger.warning_once( - "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." - ) - use_cache = False - if inputs_embeds is None: inputs_embeds = self.wte(input_ids) @@ -479,42 +463,21 @@ def forward( hidden_states = self.drop(hidden_states) output_shape = (-1, seq_length, hidden_states.size(-1)) - all_self_attentions = () if output_attentions else None - all_hidden_states = () if output_hidden_states else None - for i, block in enumerate(self.h): - if output_hidden_states: - all_hidden_states = all_hidden_states + (hidden_states,) - - outputs = block( + for block in self.h: + hidden_states = block( hidden_states, layer_past=past_key_values, attention_mask=causal_mask, use_cache=use_cache, - output_attentions=output_attentions, cache_position=cache_position, ) - hidden_states = outputs[0] - if output_attentions: - all_self_attentions = all_self_attentions + (outputs[1],) - hidden_states = self.ln_f(hidden_states) - hidden_states = hidden_states.view(output_shape) - # Add last hidden state - if output_hidden_states: - all_hidden_states = all_hidden_states + (hidden_states,) - - if not return_dict: - return tuple( - v for v in [hidden_states, past_key_values, all_hidden_states, all_self_attentions] if v is not None - ) return BaseModelOutputWithPast( last_hidden_state=hidden_states, past_key_values=past_key_values, - hidden_states=all_hidden_states, - attentions=all_self_attentions, ) @@ -535,6 +498,7 @@ def __init__(self, config): # Initialize weights and apply final processing self.post_init() + @can_return_tuple @auto_docstring def forward( self, @@ -546,13 +510,10 @@ def forward( inputs_embeds: torch.Tensor | None = None, labels: torch.Tensor | None = None, use_cache: bool | None = None, - output_attentions: bool | None = None, - output_hidden_states: bool | None = None, - return_dict: bool | None = None, cache_position: torch.LongTensor | None = None, logits_to_keep: int | torch.Tensor = 0, **kwargs, - ) -> tuple[torch.Tensor] | CausalLMOutputWithCrossAttentions: + ) -> CausalLMOutputWithPast: r""" input_ids (`torch.LongTensor` of shape `(batch_size, input_ids_length)`): `input_ids_length` = `sequence_length` if `past_key_values` is `None` else @@ -571,9 +532,7 @@ def forward( `labels = input_ids` Indices are selected in `[-100, 0, ..., config.vocab_size]` All labels set to `-100` are ignored (masked), the loss is only computed for labels in `[0, ..., config.vocab_size]` """ - return_dict = return_dict if return_dict is not None else self.config.use_return_dict - - transformer_outputs = self.transformer( + outputs: BaseModelOutputWithPast = self.transformer( input_ids, past_key_values=past_key_values, attention_mask=attention_mask, @@ -581,14 +540,11 @@ def forward( position_ids=position_ids, inputs_embeds=inputs_embeds, use_cache=use_cache, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, cache_position=cache_position, **kwargs, ) - hidden_states = transformer_outputs[0] + hidden_states = outputs.last_hidden_state # Only compute necessary logits, and do not upcast them to float if we are not computing the loss slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep logits = self.lm_head(hidden_states[:, slice_indices, :]) @@ -597,16 +553,12 @@ def forward( if labels is not None: loss = self.loss_function(logits=logits, labels=labels, vocab_size=self.config.vocab_size, **kwargs) - if not return_dict: - output = (logits,) + transformer_outputs[1:] - return ((loss,) + output) if loss is not None else output - return CausalLMOutputWithPast( loss=loss, logits=logits, - past_key_values=transformer_outputs.past_key_values, - hidden_states=transformer_outputs.hidden_states, - attentions=transformer_outputs.attentions, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, ) @@ -634,6 +586,7 @@ def __init__(self, config): # Initialize weights and apply final processing self.post_init() + @can_return_tuple @auto_docstring def forward( self, @@ -645,11 +598,8 @@ def forward( inputs_embeds: torch.Tensor | None = None, labels: torch.Tensor | None = None, use_cache: bool | None = None, - output_attentions: bool | None = None, - output_hidden_states: bool | None = None, - return_dict: bool | None = None, **kwargs, - ) -> tuple[torch.Tensor] | SequenceClassifierOutputWithPast: + ) -> SequenceClassifierOutputWithPast: r""" input_ids (`torch.LongTensor` of shape `(batch_size, input_ids_length)`): `input_ids_length` = `sequence_length` if `past_key_values` is `None` else @@ -668,9 +618,7 @@ def forward( config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If `config.num_labels > 1` a classification loss is computed (Cross-Entropy). """ - return_dict = return_dict if return_dict is not None else self.config.use_return_dict - - transformer_outputs = self.transformer( + outputs: BaseModelOutputWithPast = self.transformer( input_ids, past_key_values=past_key_values, attention_mask=attention_mask, @@ -678,11 +626,9 @@ def forward( position_ids=position_ids, inputs_embeds=inputs_embeds, use_cache=use_cache, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, + **kwargs, ) - hidden_states = transformer_outputs[0] + hidden_states = outputs.last_hidden_state logits = self.score(hidden_states) if input_ids is not None: @@ -730,16 +676,13 @@ def forward( elif self.config.problem_type == "multi_label_classification": loss_fct = BCEWithLogitsLoss() loss = loss_fct(pooled_logits, labels) - if not return_dict: - output = (pooled_logits,) + transformer_outputs[1:] - return ((loss,) + output) if loss is not None else output return SequenceClassifierOutputWithPast( loss=loss, logits=pooled_logits, - past_key_values=transformer_outputs.past_key_values, - hidden_states=transformer_outputs.hidden_states, - attentions=transformer_outputs.attentions, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, ) @@ -756,6 +699,7 @@ def __init__(self, config): # Initialize weights and apply final processing self.post_init() + @can_return_tuple @auto_docstring def forward( self, @@ -767,11 +711,8 @@ def forward( inputs_embeds: torch.FloatTensor | None = None, labels: torch.LongTensor | None = None, use_cache: bool | None = None, - output_attentions: bool | None = None, - output_hidden_states: bool | None = None, - return_dict: bool | None = None, **kwargs, - ) -> tuple | TokenClassifierOutput: + ) -> TokenClassifierOutput: r""" input_ids (`torch.LongTensor` of shape `(batch_size, input_ids_length)`): `input_ids_length` = `sequence_length` if `past_key_values` is `None` else @@ -790,9 +731,7 @@ def forward( config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If `config.num_labels > 1` a classification loss is computed (Cross-Entropy). """ - return_dict = return_dict if return_dict is not None else self.config.use_return_dict - - transformer_outputs = self.transformer( + outputs: BaseModelOutputWithPast = self.transformer( input_ids, past_key_values=past_key_values, attention_mask=attention_mask, @@ -800,12 +739,10 @@ def forward( position_ids=position_ids, inputs_embeds=inputs_embeds, use_cache=use_cache, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, + **kwargs, ) - hidden_states = transformer_outputs[0] + hidden_states = outputs.last_hidden_state hidden_states = self.dropout(hidden_states) logits = self.classifier(hidden_states) @@ -815,15 +752,11 @@ def forward( loss_fct = CrossEntropyLoss() loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1)) - if not return_dict: - output = (logits,) + transformer_outputs[2:] - return ((loss,) + output) if loss is not None else output - return TokenClassifierOutput( loss=loss, logits=logits, - hidden_states=transformer_outputs.hidden_states, - attentions=transformer_outputs.attentions, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, ) @@ -838,6 +771,7 @@ def __init__(self, config): # Initialize weights and apply final processing self.post_init() + @can_return_tuple @auto_docstring def forward( self, @@ -848,11 +782,8 @@ def forward( inputs_embeds: torch.FloatTensor | None = None, start_positions: torch.LongTensor | None = None, end_positions: torch.LongTensor | None = None, - output_attentions: bool | None = None, - output_hidden_states: bool | None = None, - return_dict: bool | None = None, **kwargs, - ) -> tuple | QuestionAnsweringModelOutput: + ) -> QuestionAnsweringModelOutput: r""" input_ids (`torch.LongTensor` of shape `(batch_size, input_ids_length)`): `input_ids_length` = `sequence_length` if `past_key_values` is `None` else @@ -867,20 +798,16 @@ def forward( [What are input IDs?](../glossary#input-ids) """ - return_dict = return_dict if return_dict is not None else self.config.use_return_dict - - outputs = self.transformer( + outputs: BaseModelOutputWithPast = self.transformer( input_ids, attention_mask=attention_mask, token_type_ids=token_type_ids, position_ids=position_ids, inputs_embeds=inputs_embeds, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, + **kwargs, ) - sequence_output = outputs[0] + sequence_output = outputs.last_hidden_state logits = self.qa_outputs(sequence_output) start_logits, end_logits = logits.split(1, dim=-1) @@ -904,10 +831,6 @@ def forward( end_loss = loss_fct(end_logits, end_positions) total_loss = (start_loss + end_loss) / 2 - if not return_dict: - output = (start_logits, end_logits) + outputs[2:] - return ((total_loss,) + output) if total_loss is not None else output - return QuestionAnsweringModelOutput( loss=total_loss, start_logits=start_logits, diff --git a/tests/models/gpt_neo/test_modeling_gpt_neo.py b/tests/models/gpt_neo/test_modeling_gpt_neo.py index 3a45226c5d3b..8873e8aca1c1 100644 --- a/tests/models/gpt_neo/test_modeling_gpt_neo.py +++ b/tests/models/gpt_neo/test_modeling_gpt_neo.py @@ -458,7 +458,7 @@ def test_local_attn_probs(self): attention_mask = attention_mask[:, None, None, :] attention_mask = (1.0 - attention_mask) * -10000.0 - attn_probs = layer(hidden_states, attention_mask=attention_mask, output_attentions=True)[-1] + attn_probs = layer(hidden_states, attention_mask=attention_mask)[-1] # the last 2 tokens are masked, and should have 0 attn_probs self.assertTrue(torch.all(attn_probs[:, :, -mask_tokens:, -mask_tokens:] == 0))