From 7d4c48133a606dee4e7703466783df907ee19ec0 Mon Sep 17 00:00:00 2001 From: Joao Gante Date: Fri, 11 Oct 2024 15:22:56 +0000 Subject: [PATCH 1/5] sync gpus --- src/transformers/generation/utils.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/src/transformers/generation/utils.py b/src/transformers/generation/utils.py index 2225b033aa0a..a92a1660204c 100644 --- a/src/transformers/generation/utils.py +++ b/src/transformers/generation/utils.py @@ -379,10 +379,13 @@ def prepare_inputs_for_generation( # If we have cache: let's slice `input_ids` through `cache_position`, to keep only the unprocessed tokens # Exception 1: when passing input_embeds, input_ids may be missing entries # Exception 2: some generation methods do special slicing of input_ids, so we don't need to do it here + # Exception 3: with synced GPUs cache_position may go out of bounds, but we only want dummy token in that case if past_key_values is not None: model_inputs["past_key_values"] = past_key_values if inputs_embeds is not None: # Exception 1 input_ids = input_ids[:, -cache_position.shape[0] :] + elif cache_position[-1] >= input_ids.shape[1]: # Exception 3 + input_ids = input_ids[:, -cache_position.shape[0] :] elif input_ids.shape[1] != cache_position.shape[0]: # Default case (the "else", a no op, is Exception 2) input_ids = input_ids[:, cache_position] From 0ee8da43d161643867f299c8f671aba243634437 Mon Sep 17 00:00:00 2001 From: Joao Gante Date: Fri, 11 Oct 2024 15:29:46 +0000 Subject: [PATCH 2/5] sync gpus --- src/transformers/generation/utils.py | 13 +++++++------ 1 file changed, 7 insertions(+), 6 deletions(-) diff --git a/src/transformers/generation/utils.py b/src/transformers/generation/utils.py index a92a1660204c..dddbef7d6c4b 100644 --- a/src/transformers/generation/utils.py +++ b/src/transformers/generation/utils.py @@ -3171,8 +3171,14 @@ def _sample( # forward pass to get next token outputs = self(**model_inputs, return_dict=True) + # synced_gpus: don't waste resources running the code we don't need; kwargs must be updated before skipping + model_kwargs = self._update_model_kwargs_for_generation( + outputs, + model_kwargs, + is_encoder_decoder=self.config.is_encoder_decoder, + ) if synced_gpus and this_peer_finished: - continue # don't waste resources running the code we don't need + continue # Clone is needed to avoid keeping a hanging ref to outputs.logits which may be very large for first iteration # (the clone itself is always small) @@ -3217,11 +3223,6 @@ def _sample( input_ids = torch.cat([input_ids, next_tokens[:, None]], dim=-1) if streamer is not None: streamer.put(next_tokens.cpu()) - model_kwargs = self._update_model_kwargs_for_generation( - outputs, - model_kwargs, - is_encoder_decoder=self.config.is_encoder_decoder, - ) unfinished_sequences = unfinished_sequences & ~stopping_criteria(input_ids, scores) this_peer_finished = unfinished_sequences.max() == 0 From 287911fcc299428a9bc1b31cde799f45d0af4c70 Mon Sep 17 00:00:00 2001 From: Joao Gante Date: Fri, 11 Oct 2024 15:37:16 +0000 Subject: [PATCH 3/5] fix other decoding methods --- src/transformers/generation/utils.py | 85 ++++++++++++++-------------- 1 file changed, 44 insertions(+), 41 deletions(-) diff --git a/src/transformers/generation/utils.py b/src/transformers/generation/utils.py index dddbef7d6c4b..905a6a62dfe6 100644 --- a/src/transformers/generation/utils.py +++ b/src/transformers/generation/utils.py @@ -384,7 +384,7 @@ def prepare_inputs_for_generation( model_inputs["past_key_values"] = past_key_values if inputs_embeds is not None: # Exception 1 input_ids = input_ids[:, -cache_position.shape[0] :] - elif cache_position[-1] >= input_ids.shape[1]: # Exception 3 + elif cache_position[-1] >= input_ids.shape[1]: # Exception 3 input_ids = input_ids[:, -cache_position.shape[0] :] elif input_ids.shape[1] != cache_position.shape[0]: # Default case (the "else", a no op, is Exception 2) input_ids = input_ids[:, cache_position] @@ -2612,8 +2612,14 @@ def _dola_decoding( outputs.hidden_states[candidate_premature_layer][:, -1, :] ).to(final_logits.device) + # synced_gpus: don't waste resources running the code we don't need; kwargs must be updated before skipping + model_kwargs = self._update_model_kwargs_for_generation( + outputs, + model_kwargs, + is_encoder_decoder=self.config.is_encoder_decoder, + ) if synced_gpus and this_peer_finished: - continue # don't waste resources running the code we don't need + continue next_token_logits = _dola_select_contrast( candidate_premature_layers, candidate_premature_logits, final_logits @@ -2655,11 +2661,6 @@ def _dola_decoding( input_ids = torch.cat([input_ids, next_tokens[:, None]], dim=-1) if streamer is not None: streamer.put(next_tokens.cpu()) - model_kwargs = self._update_model_kwargs_for_generation( - outputs, - model_kwargs, - is_encoder_decoder=self.config.is_encoder_decoder, - ) # stop when each sentence is finished unfinished_sequences = unfinished_sequences & ~stopping_criteria(input_ids, scores) @@ -3019,8 +3020,14 @@ def _contrastive_search( ) # contrastive_search main logic end + # synced_gpus: don't waste resources running the code we don't need; kwargs must be updated before skipping + model_kwargs = self._update_model_kwargs_for_generation( + outputs, + model_kwargs, + is_encoder_decoder=self.config.is_encoder_decoder, + ) if synced_gpus and this_peer_finished: - continue # don't waste resources running the code we don't need + continue # finished sentences should have their next token be a padding token if has_eos_stopping_criteria: @@ -3030,11 +3037,6 @@ def _contrastive_search( input_ids = torch.cat([input_ids, next_tokens[:, None]], dim=-1) if streamer is not None: streamer.put(next_tokens.cpu()) - model_kwargs = self._update_model_kwargs_for_generation( - outputs, - model_kwargs, - is_encoder_decoder=self.config.is_encoder_decoder, - ) # stop when each sentence is finished unfinished_sequences = unfinished_sequences & ~stopping_criteria(input_ids, scores) @@ -3419,9 +3421,15 @@ def _beam_search( else: # Unchanged original behavior outputs = self(**model_inputs, return_dict=True) + # synced_gpus: don't waste resources running the code we don't need; kwargs must be updated before skipping + model_kwargs = self._update_model_kwargs_for_generation( + outputs, + model_kwargs, + is_encoder_decoder=self.config.is_encoder_decoder, + ) if synced_gpus and this_peer_finished: cur_len = cur_len + 1 - continue # don't waste resources running the code we don't need + continue # Clone is needed to avoid keeping a hanging ref to outputs.logits which may be very large for first iteration # (the clone itself is always small) @@ -3495,12 +3503,6 @@ def _beam_search( input_ids = torch.cat([input_ids[beam_idx, :], beam_next_tokens.unsqueeze(-1)], dim=-1) - model_kwargs = self._update_model_kwargs_for_generation( - outputs, - model_kwargs, - is_encoder_decoder=self.config.is_encoder_decoder, - ) - # This is needed to properly delete outputs.logits which may be very large for first iteration # Otherwise a reference to outputs is kept which keeps the logits alive in the next iteration # IMPORTANT: Note that this should appear BEFORE the call to _reorder_cache() to save the maximum memory @@ -3674,9 +3676,15 @@ def _group_beam_search( outputs = self(**model_inputs, return_dict=True) + # synced_gpus: don't waste resources running the code we don't need; kwargs must be updated before skipping + model_kwargs = self._update_model_kwargs_for_generation( + outputs, + model_kwargs, + is_encoder_decoder=self.config.is_encoder_decoder, + ) if synced_gpus and this_peer_finished: cur_len = cur_len + 1 - continue # don't waste resources running the code we don't need + continue if output_scores: processed_score = torch.zeros_like(outputs.logits[:, -1, :]) @@ -3786,12 +3794,6 @@ def _group_beam_search( input_ids = torch.cat([input_ids, current_tokens.unsqueeze(-1)], dim=-1) - model_kwargs = self._update_model_kwargs_for_generation( - outputs, - model_kwargs, - is_encoder_decoder=self.config.is_encoder_decoder, - ) - # This is needed to properly delete outputs.logits which may be very large for first iteration # Otherwise a reference to outputs is kept which keeps the logits alive in the next iteration # IMPORTANT: Note that this should appear BEFORE the call to _reorder_cache() to save the maximum memory @@ -3952,9 +3954,15 @@ def _constrained_beam_search( outputs = self(**model_inputs, return_dict=True) + # synced_gpus: don't waste resources running the code we don't need; kwargs must be updated before skipping + model_kwargs = self._update_model_kwargs_for_generation( + outputs, + model_kwargs, + is_encoder_decoder=self.config.is_encoder_decoder, + ) if synced_gpus and this_peer_finished: cur_len = cur_len + 1 - continue # don't waste resources running the code we don't need + continue # Clone is needed to avoid keeping a hanging ref to outputs.logits which may be very large for first iteration # (the clone itself is always small) @@ -4022,11 +4030,6 @@ def _constrained_beam_search( beam_idx = beam_outputs["next_beam_indices"] input_ids = torch.cat([input_ids[beam_idx, :], beam_next_tokens.unsqueeze(-1)], dim=-1) - model_kwargs = self._update_model_kwargs_for_generation( - outputs, - model_kwargs, - is_encoder_decoder=self.config.is_encoder_decoder, - ) # This is needed to properly delete outputs.logits which may be very large for first iteration # Otherwise a reference to outputs is kept which keeps the logits alive in the next iteration @@ -4275,8 +4278,15 @@ def _assisted_decoding( # 5. Update the candidate generation strategy if needed candidate_generator.update_candidate_strategy(input_ids, new_logits, n_matches) + # synced_gpus: don't waste resources running the code we don't need; kwargs must be updated before skipping + model_kwargs = self._update_model_kwargs_for_generation( + outputs, + model_kwargs, + is_encoder_decoder=self.config.is_encoder_decoder, + num_new_tokens=n_matches + 1, + ) if synced_gpus and this_peer_finished: - continue # don't waste resources running the code we don't need + continue # Store scores, attentions and hidden_states when required # Assistant: modified to append one tuple element per token, as in the other generation methods. @@ -4323,13 +4333,6 @@ def _assisted_decoding( decoder_hidden_states, outputs.hidden_states, cur_len, added_len ) - model_kwargs = self._update_model_kwargs_for_generation( - outputs, - model_kwargs, - is_encoder_decoder=self.config.is_encoder_decoder, - num_new_tokens=n_matches + 1, - ) - unfinished_sequences = unfinished_sequences & ~stopping_criteria(input_ids, scores) this_peer_finished = unfinished_sequences.max() == 0 From 53d8e10f71846f8a3f6bf82180f4e42732f42c8d Mon Sep 17 00:00:00 2001 From: Joao Gante Date: Fri, 11 Oct 2024 15:48:59 +0000 Subject: [PATCH 4/5] nit --- src/transformers/generation/utils.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/src/transformers/generation/utils.py b/src/transformers/generation/utils.py index 905a6a62dfe6..d1419af0896c 100644 --- a/src/transformers/generation/utils.py +++ b/src/transformers/generation/utils.py @@ -382,9 +382,7 @@ def prepare_inputs_for_generation( # Exception 3: with synced GPUs cache_position may go out of bounds, but we only want dummy token in that case if past_key_values is not None: model_inputs["past_key_values"] = past_key_values - if inputs_embeds is not None: # Exception 1 - input_ids = input_ids[:, -cache_position.shape[0] :] - elif cache_position[-1] >= input_ids.shape[1]: # Exception 3 + if inputs_embeds is not None or cache_position[-1] >= input_ids.shape[1]: # Exception 1 or Exception 3 input_ids = input_ids[:, -cache_position.shape[0] :] elif input_ids.shape[1] != cache_position.shape[0]: # Default case (the "else", a no op, is Exception 2) input_ids = input_ids[:, cache_position] From 262c9719aea8886b58923325ead7058b860e0f9c Mon Sep 17 00:00:00 2001 From: Joao Gante Date: Fri, 11 Oct 2024 16:29:41 +0000 Subject: [PATCH 5/5] fix assisted gen (consistent return api) --- src/transformers/generation/utils.py | 35 +++++++++------------------- 1 file changed, 11 insertions(+), 24 deletions(-) diff --git a/src/transformers/generation/utils.py b/src/transformers/generation/utils.py index d1419af0896c..68b8b598ec09 100644 --- a/src/transformers/generation/utils.py +++ b/src/transformers/generation/utils.py @@ -4167,17 +4167,8 @@ def _assisted_decoding( unfinished_sequences = torch.ones(batch_size, dtype=torch.long, device=input_ids.device) model_kwargs = self._get_initial_cache_position(input_ids, model_kwargs) - # This is needed if return_dict_in_generate is True - start_from_empty_dynamic_cache = False - past_key_values = model_kwargs.get("past_key_values", None) - if isinstance(past_key_values, DynamicCache) or ( - isinstance(past_key_values, EncoderDecoderCache) - and isinstance(past_key_values.self_attention_cache, DynamicCache) - ): - if past_key_values.get_seq_length() == 0: - start_from_empty_dynamic_cache = True - this_peer_finished = False + is_first_iteration = True # to preserve the same API in the output as other generation methods while self._has_unfinished_sequences(this_peer_finished, synced_gpus, device=input_ids.device): cur_len = input_ids.shape[-1] @@ -4289,28 +4280,23 @@ def _assisted_decoding( # Store scores, attentions and hidden_states when required # Assistant: modified to append one tuple element per token, as in the other generation methods. if return_dict_in_generate: + newly_added_length = n_matches + 1 if output_scores: - scores += tuple(new_logits[:, i, :] for i in range(n_matches + 1)) + scores += tuple(new_logits[:, i, :] for i in range(newly_added_length)) if output_logits: - raw_logits += (next_token_logits,) - - if "past_key_values" not in model_kwargs or start_from_empty_dynamic_cache: - added_len = new_cur_len - # set it to false for other iterations - start_from_empty_dynamic_cache = False - else: - added_len = n_matches + 1 + raw_logits += tuple(next_token_logits[:, i, :] for i in range(newly_added_length)) + newly_added_length = new_cur_len if is_first_iteration else newly_added_length if output_attentions: if self.config.is_encoder_decoder: cross_attentions = _split_model_outputs( - cross_attentions, outputs.cross_attentions, cur_len, added_len + cross_attentions, outputs.cross_attentions, cur_len, newly_added_length ) decoder_attentions = _split_model_outputs( decoder_attentions, outputs.decoder_attentions, cur_len, - added_len, + newly_added_length, is_decoder_attention=True, ) else: @@ -4318,21 +4304,22 @@ def _assisted_decoding( decoder_attentions, outputs.attentions, cur_len, - added_len, + newly_added_length, is_decoder_attention=True, ) if output_hidden_states: if self.config.is_encoder_decoder: decoder_hidden_states = _split_model_outputs( - decoder_hidden_states, outputs.decoder_hidden_states, cur_len, added_len + decoder_hidden_states, outputs.decoder_hidden_states, cur_len, newly_added_length ) else: decoder_hidden_states = _split_model_outputs( - decoder_hidden_states, outputs.hidden_states, cur_len, added_len + decoder_hidden_states, outputs.hidden_states, cur_len, newly_added_length ) unfinished_sequences = unfinished_sequences & ~stopping_criteria(input_ids, scores) this_peer_finished = unfinished_sequences.max() == 0 + is_first_iteration = False if streamer is not None: streamer.end()