diff --git a/src/transformers/models/gpt_neo/modeling_gpt_neo.py b/src/transformers/models/gpt_neo/modeling_gpt_neo.py index 5bf0bc584d85..d25f5034bdcd 100755 --- a/src/transformers/models/gpt_neo/modeling_gpt_neo.py +++ b/src/transformers/models/gpt_neo/modeling_gpt_neo.py @@ -26,18 +26,21 @@ from ...modeling_layers import GradientCheckpointingLayer from ...modeling_outputs import ( BaseModelOutputWithPast, - BaseModelOutputWithPastAndCrossAttentions, - CausalLMOutputWithCrossAttentions, CausalLMOutputWithPast, QuestionAnsweringModelOutput, SequenceClassifierOutputWithPast, TokenClassifierOutput, ) from ...modeling_utils import PreTrainedModel +from ...processing_utils import Unpack from ...utils import ( + TransformersKwargs, auto_docstring, + can_return_tuple, logging, ) +from ...utils.generic import merge_with_config_defaults +from ...utils.output_capturing import capture_outputs from .configuration_gpt_neo import GPTNeoConfig @@ -135,8 +138,7 @@ def forward( attention_mask=None, layer_past=None, use_cache=False, - output_attentions=False, - **kwargs, + cache_position=None, ): query = self.q_proj(hidden_states) key = self.k_proj(hidden_states) @@ -147,7 +149,8 @@ def forward( value = self._split_heads(value, self.num_heads, self.head_dim) if layer_past is not None: - key, value = layer_past.update(key, value, self.layer_id) + cache_kwargs = {"cache_position": cache_position} + key, value = layer_past.update(key, value, self.layer_id, cache_kwargs) attn_output, attn_weights = self._attn(query, key, value, attention_mask) @@ -179,8 +182,7 @@ def forward( attention_mask=None, layer_past=None, use_cache=False, - output_attentions=False, - **kwargs, + cache_position=None, ): bsz, _, _ = hidden_states.size() @@ -193,7 +195,8 @@ def forward( value = self._split_heads(value, self.num_heads, self.head_dim) if layer_past is not None: - key, value = layer_past.update(key, value, self.layer_id) + cache_kwargs = {"cache_position": cache_position} + key, value = layer_past.update(key, value, self.layer_id, cache_kwargs) query_length = query.shape[2] tgt_len = key.shape[2] @@ -280,15 +283,14 @@ def forward( layer_past=None, attention_mask=None, use_cache=False, - output_attentions=False, - **kwargs, + cache_position=None, ): return self.attention( hidden_states, attention_mask=attention_mask, layer_past=layer_past, use_cache=use_cache, - output_attentions=output_attentions, + cache_position=cache_position, ) @@ -325,17 +327,16 @@ def forward( layer_past=None, attention_mask=None, use_cache=False, - output_attentions=False, - **kwargs, + cache_position=None, ): residual = hidden_states hidden_states = self.ln_1(hidden_states) - attn_output, attn_weights = self.attn( + attn_output, _ = self.attn( hidden_states, layer_past=layer_past, attention_mask=attention_mask, use_cache=use_cache, - output_attentions=output_attentions, + cache_position=cache_position, ) # residual connection @@ -347,7 +348,7 @@ def forward( # residual connection hidden_states = residual + feed_forward_hidden_states - return hidden_states, attn_weights + return hidden_states @auto_docstring @@ -359,6 +360,10 @@ class GPTNeoPreTrainedModel(PreTrainedModel): _skip_keys_device_placement = "past_key_values" _supports_flash_attn = True _can_compile_fullgraph = False # TODO: needs a hybrid cache + _can_record_outputs = { + "hidden_states": GPTNeoBlock, + "attentions": GPTNeoAttention, + } def _init_weights(self, module): super()._init_weights(module) @@ -394,6 +399,8 @@ def get_input_embeddings(self): def set_input_embeddings(self, new_embeddings): self.wte = new_embeddings + @merge_with_config_defaults + @capture_outputs @auto_docstring def forward( self, @@ -404,11 +411,9 @@ def forward( position_ids: torch.Tensor | None = None, inputs_embeds: torch.Tensor | None = None, use_cache: bool | None = None, - output_attentions: bool | None = None, - output_hidden_states: bool | None = None, - return_dict: bool | None = None, - **kwargs, - ) -> tuple[torch.Tensor] | BaseModelOutputWithPastAndCrossAttentions: + cache_position: torch.LongTensor | None = None, + **kwargs: Unpack[TransformersKwargs], + ) -> BaseModelOutputWithPast: r""" input_ids (`torch.LongTensor` of shape `(batch_size, input_ids_length)`): `input_ids_length` = `sequence_length` if `past_key_values` is `None` else @@ -423,13 +428,6 @@ def forward( [What are input IDs?](../glossary#input-ids) """ - output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions - output_hidden_states = ( - output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states - ) - use_cache = use_cache if use_cache is not None else self.config.use_cache - return_dict = return_dict if return_dict is not None else self.config.return_dict - if (input_ids is None) ^ (inputs_embeds is not None): raise ValueError("You must specify exactly one of input_ids or inputs_embeds") @@ -446,15 +444,19 @@ def forward( if use_cache and past_key_values is None: past_key_values = DynamicCache(config=self.config) - if position_ids is None: + seq_length = inputs_embeds.shape[1] + if cache_position is None: past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 - position_ids = torch.arange(inputs_embeds.shape[1], device=inputs_embeds.device) + past_seen_tokens - position_ids = position_ids.unsqueeze(0) + cache_position = torch.arange(past_seen_tokens, past_seen_tokens + seq_length, device=inputs_embeds.device) + + if position_ids is None: + position_ids = cache_position.unsqueeze(0) causal_mask = create_causal_mask( config=self.config, inputs_embeds=inputs_embeds, attention_mask=attention_mask, + cache_position=cache_position, past_key_values=past_key_values, position_ids=position_ids, ) @@ -462,7 +464,6 @@ def forward( position_embeds = self.wpe(position_ids) hidden_states = inputs_embeds + position_embeds - seq_length = inputs_embeds.shape[1] if token_type_ids is not None: token_type_ids = token_type_ids.view(-1, seq_length) token_type_embeds = self.wte(token_type_ids) @@ -471,41 +472,21 @@ def forward( hidden_states = self.drop(hidden_states) output_shape = (-1, seq_length, hidden_states.size(-1)) - all_self_attentions = () if output_attentions else None - all_hidden_states = () if output_hidden_states else None - for i, block in enumerate(self.h): - if output_hidden_states: - all_hidden_states = all_hidden_states + (hidden_states,) - - outputs = block( + for block in self.h: + hidden_states = block( hidden_states, layer_past=past_key_values, attention_mask=causal_mask, use_cache=use_cache, - output_attentions=output_attentions, + cache_position=cache_position, ) - hidden_states = outputs[0] - if output_attentions: - all_self_attentions = all_self_attentions + (outputs[1],) - hidden_states = self.ln_f(hidden_states) - hidden_states = hidden_states.view(output_shape) - # Add last hidden state - if output_hidden_states: - all_hidden_states = all_hidden_states + (hidden_states,) - - if not return_dict: - return tuple( - v for v in [hidden_states, past_key_values, all_hidden_states, all_self_attentions] if v is not None - ) return BaseModelOutputWithPast( last_hidden_state=hidden_states, past_key_values=past_key_values, - hidden_states=all_hidden_states, - attentions=all_self_attentions, ) @@ -526,6 +507,7 @@ def __init__(self, config): # Initialize weights and apply final processing self.post_init() + @can_return_tuple @auto_docstring def forward( self, @@ -537,12 +519,10 @@ def forward( inputs_embeds: torch.Tensor | None = None, labels: torch.Tensor | None = None, use_cache: bool | None = None, - output_attentions: bool | None = None, - output_hidden_states: bool | None = None, - return_dict: bool | None = None, + cache_position: torch.LongTensor | None = None, logits_to_keep: int | torch.Tensor = 0, - **kwargs, - ) -> tuple[torch.Tensor] | CausalLMOutputWithCrossAttentions: + **kwargs: Unpack[TransformersKwargs], + ) -> CausalLMOutputWithPast: r""" input_ids (`torch.LongTensor` of shape `(batch_size, input_ids_length)`): `input_ids_length` = `sequence_length` if `past_key_values` is `None` else @@ -561,9 +541,7 @@ def forward( `labels = input_ids` Indices are selected in `[-100, 0, ..., config.vocab_size]` All labels set to `-100` are ignored (masked), the loss is only computed for labels in `[0, ..., config.vocab_size]` """ - return_dict = return_dict if return_dict is not None else self.config.return_dict - - transformer_outputs = self.transformer( + transformer_outputs: BaseModelOutputWithPast = self.transformer( input_ids, past_key_values=past_key_values, attention_mask=attention_mask, @@ -571,13 +549,11 @@ def forward( position_ids=position_ids, inputs_embeds=inputs_embeds, use_cache=use_cache, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, + cache_position=cache_position, **kwargs, ) - hidden_states = transformer_outputs[0] + hidden_states = transformer_outputs.last_hidden_state # Only compute necessary logits, and do not upcast them to float if we are not computing the loss slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep logits = self.lm_head(hidden_states[:, slice_indices, :]) @@ -586,10 +562,6 @@ def forward( if labels is not None: loss = self.loss_function(logits=logits, labels=labels, vocab_size=self.config.vocab_size, **kwargs) - if not return_dict: - output = (logits,) + transformer_outputs[1:] - return ((loss,) + output) if loss is not None else output - return CausalLMOutputWithPast( loss=loss, logits=logits, @@ -623,6 +595,7 @@ def __init__(self, config): # Initialize weights and apply final processing self.post_init() + @can_return_tuple @auto_docstring def forward( self, @@ -634,11 +607,8 @@ def forward( inputs_embeds: torch.Tensor | None = None, labels: torch.Tensor | None = None, use_cache: bool | None = None, - output_attentions: bool | None = None, - output_hidden_states: bool | None = None, - return_dict: bool | None = None, - **kwargs, - ) -> tuple[torch.Tensor] | SequenceClassifierOutputWithPast: + **kwargs: Unpack[TransformersKwargs], + ) -> SequenceClassifierOutputWithPast: r""" input_ids (`torch.LongTensor` of shape `(batch_size, input_ids_length)`): `input_ids_length` = `sequence_length` if `past_key_values` is `None` else @@ -657,9 +627,7 @@ def forward( config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If `config.num_labels > 1` a classification loss is computed (Cross-Entropy). """ - return_dict = return_dict if return_dict is not None else self.config.return_dict - - transformer_outputs = self.transformer( + transformer_outputs: BaseModelOutputWithPast = self.transformer( input_ids, past_key_values=past_key_values, attention_mask=attention_mask, @@ -667,11 +635,9 @@ def forward( position_ids=position_ids, inputs_embeds=inputs_embeds, use_cache=use_cache, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, + **kwargs, ) - hidden_states = transformer_outputs[0] + hidden_states = transformer_outputs.last_hidden_state logits = self.score(hidden_states) if input_ids is not None: @@ -719,9 +685,6 @@ def forward( elif self.config.problem_type == "multi_label_classification": loss_fct = BCEWithLogitsLoss() loss = loss_fct(pooled_logits, labels) - if not return_dict: - output = (pooled_logits,) + transformer_outputs[1:] - return ((loss,) + output) if loss is not None else output return SequenceClassifierOutputWithPast( loss=loss, @@ -745,6 +708,7 @@ def __init__(self, config): # Initialize weights and apply final processing self.post_init() + @can_return_tuple @auto_docstring def forward( self, @@ -756,11 +720,8 @@ def forward( inputs_embeds: torch.FloatTensor | None = None, labels: torch.LongTensor | None = None, use_cache: bool | None = None, - output_attentions: bool | None = None, - output_hidden_states: bool | None = None, - return_dict: bool | None = None, - **kwargs, - ) -> tuple | TokenClassifierOutput: + **kwargs: Unpack[TransformersKwargs], + ) -> TokenClassifierOutput: r""" input_ids (`torch.LongTensor` of shape `(batch_size, input_ids_length)`): `input_ids_length` = `sequence_length` if `past_key_values` is `None` else @@ -779,9 +740,7 @@ def forward( config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If `config.num_labels > 1` a classification loss is computed (Cross-Entropy). """ - return_dict = return_dict if return_dict is not None else self.config.return_dict - - transformer_outputs = self.transformer( + transformer_outputs: BaseModelOutputWithPast = self.transformer( input_ids, past_key_values=past_key_values, attention_mask=attention_mask, @@ -789,12 +748,10 @@ def forward( position_ids=position_ids, inputs_embeds=inputs_embeds, use_cache=use_cache, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, + **kwargs, ) - hidden_states = transformer_outputs[0] + hidden_states = transformer_outputs.last_hidden_state hidden_states = self.dropout(hidden_states) logits = self.classifier(hidden_states) @@ -804,10 +761,6 @@ def forward( loss_fct = CrossEntropyLoss() loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1)) - if not return_dict: - output = (logits,) + transformer_outputs[2:] - return ((loss,) + output) if loss is not None else output - return TokenClassifierOutput( loss=loss, logits=logits, @@ -827,6 +780,7 @@ def __init__(self, config): # Initialize weights and apply final processing self.post_init() + @can_return_tuple @auto_docstring def forward( self, @@ -837,11 +791,8 @@ def forward( inputs_embeds: torch.FloatTensor | None = None, start_positions: torch.LongTensor | None = None, end_positions: torch.LongTensor | None = None, - output_attentions: bool | None = None, - output_hidden_states: bool | None = None, - return_dict: bool | None = None, - **kwargs, - ) -> tuple | QuestionAnsweringModelOutput: + **kwargs: Unpack[TransformersKwargs], + ) -> QuestionAnsweringModelOutput: r""" input_ids (`torch.LongTensor` of shape `(batch_size, input_ids_length)`): `input_ids_length` = `sequence_length` if `past_key_values` is `None` else @@ -856,20 +807,16 @@ def forward( [What are input IDs?](../glossary#input-ids) """ - return_dict = return_dict if return_dict is not None else self.config.return_dict - - outputs = self.transformer( + outputs: BaseModelOutputWithPast = self.transformer( input_ids, attention_mask=attention_mask, token_type_ids=token_type_ids, position_ids=position_ids, inputs_embeds=inputs_embeds, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, + **kwargs, ) - sequence_output = outputs[0] + sequence_output = outputs.last_hidden_state logits = self.qa_outputs(sequence_output) start_logits, end_logits = logits.split(1, dim=-1) @@ -893,10 +840,6 @@ def forward( end_loss = loss_fct(end_logits, end_positions) total_loss = (start_loss + end_loss) / 2 - if not return_dict: - output = (start_logits, end_logits) + outputs[2:] - return ((total_loss,) + output) if total_loss is not None else output - return QuestionAnsweringModelOutput( loss=total_loss, start_logits=start_logits, diff --git a/tests/models/gpt_neo/test_modeling_gpt_neo.py b/tests/models/gpt_neo/test_modeling_gpt_neo.py index 1adb5bf5083a..3924432f5837 100644 --- a/tests/models/gpt_neo/test_modeling_gpt_neo.py +++ b/tests/models/gpt_neo/test_modeling_gpt_neo.py @@ -457,7 +457,7 @@ def test_local_attn_probs(self): attention_mask = attention_mask[:, None, None, :] attention_mask = (1.0 - attention_mask) * -10000.0 - attn_probs = layer(hidden_states, attention_mask=attention_mask, output_attentions=True)[-1] + attn_probs = layer(hidden_states, attention_mask=attention_mask)[-1] # the last 2 tokens are masked, and should have 0 attn_probs self.assertTrue(torch.all(attn_probs[:, :, -mask_tokens:, -mask_tokens:] == 0))