From 4886ee333e3ca3f15d090bfedae0458a75562c1e Mon Sep 17 00:00:00 2001 From: ojh31 Date: Mon, 12 Aug 2024 16:17:08 -0700 Subject: [PATCH 1/3] Tracked finish_position in _sample() --- src/transformers/generation/utils.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/src/transformers/generation/utils.py b/src/transformers/generation/utils.py index 998288bd38df..ef3b5f08d14b 100644 --- a/src/transformers/generation/utils.py +++ b/src/transformers/generation/utils.py @@ -2968,6 +2968,7 @@ def _sample( unfinished_sequences = torch.ones(batch_size, dtype=torch.long, device=input_ids.device) model_kwargs = self._get_initial_cache_position(input_ids, model_kwargs) + finish_position = None # the position where the sequence is finished (for multi-gpu) while self._has_unfinished_sequences( this_peer_finished, synced_gpus, device=input_ids.device, cur_len=cur_len, max_length=max_length ): @@ -2982,7 +2983,7 @@ def _sample( outputs = self(**model_inputs, return_dict=True) if synced_gpus and this_peer_finished: - continue # don't waste resources running the code we don't need + finish_position = finish_position or input_ids.shape[-1] # Clone is needed to avoid keeping a hanging ref to outputs.logits which may be very large for first iteration # (the clone itself is always small) @@ -3046,6 +3047,9 @@ def _sample( if streamer is not None: streamer.end() + if finish_position is not None: + input_ids = input_ids[:, : finish_position] + if return_dict_in_generate: if self.config.is_encoder_decoder: return GenerateEncoderDecoderOutput( From aaa07db9f4bad47f1103ccdbd0a9f92aff78521f Mon Sep 17 00:00:00 2001 From: ojh31 Date: Mon, 12 Aug 2024 16:32:25 -0700 Subject: [PATCH 2/3] Updated other methods using this_peer_finished --- src/transformers/generation/utils.py | 42 ++++++++++++++++++++-------- 1 file changed, 31 insertions(+), 11 deletions(-) diff --git a/src/transformers/generation/utils.py b/src/transformers/generation/utils.py index ef3b5f08d14b..9e9cb5bf43e3 100644 --- a/src/transformers/generation/utils.py +++ b/src/transformers/generation/utils.py @@ -2408,6 +2408,7 @@ def _dola_decoding( if lm_head is None: raise ValueError("DoLa is not supported for models that don't have output embeddings.") + finish_position = None # the position where the sequence is finished (for multi-gpu) while self._has_unfinished_sequences(this_peer_finished, synced_gpus, device=input_ids.device): # prepare model inputs model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs) @@ -2429,7 +2430,7 @@ def _dola_decoding( ) if synced_gpus and this_peer_finished: - continue # don't waste resources running the code we don't need + finish_position = finish_position or input_ids.shape[-1] next_token_logits = _dola_select_contrast( candidate_premature_layers, candidate_premature_logits, final_logits @@ -2485,6 +2486,9 @@ def _dola_decoding( if streamer is not None: streamer.end() + if finish_position is not None: + input_ids = input_ids[:, : finish_position] + if return_dict_in_generate: return GenerateDecoderOnlyOutput( sequences=input_ids, @@ -2571,7 +2575,7 @@ def _contrastive_search( model_kwargs = self._get_initial_cache_position(input_ids, model_kwargs) this_peer_finished = False - + finish_position = None # the position where the sequence is finished (for multi-gpu) while self._has_unfinished_sequences(this_peer_finished, synced_gpus, device=input_ids.device): # if the first step in the loop, encode all the prefix and obtain: (1) past_key_values; # (2) last_hidden_states; (3) logit_for_next_step; (4) update model kwargs for the next step @@ -2820,7 +2824,7 @@ def _contrastive_search( # contrastive_search main logic end if synced_gpus and this_peer_finished: - continue # don't waste resources running the code we don't need + finish_position = finish_position or input_ids.shape[-1] # finished sentences should have their next token be a padding token if has_eos_stopping_criteria: @@ -2843,6 +2847,9 @@ def _contrastive_search( if streamer is not None: streamer.end() + if finish_position is not None: + input_ids = input_ids[:, : finish_position] + if return_dict_in_generate: # Contrastive search works by forward looking at the next token, so we need to exclude it from # `past_key_values` to be consistent with the other decoding methods @@ -3201,7 +3208,7 @@ def _beam_search( this_peer_finished = False decoder_prompt_len = input_ids.shape[-1] # record the prompt length of decoder - + finish_position = None # the position where the sequence is finished (for multi-gpu) while self._has_unfinished_sequences(this_peer_finished, synced_gpus, device=input_ids.device): model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs) @@ -3242,8 +3249,7 @@ def _beam_search( outputs = self(**model_inputs, return_dict=True) if synced_gpus and this_peer_finished: - cur_len = cur_len + 1 - continue # don't waste resources running the code we don't need + finish_position = finish_position or input_ids.shape[-1] # Clone is needed to avoid keeping a hanging ref to outputs.logits which may be very large for first iteration # (the clone itself is always small) @@ -3356,6 +3362,9 @@ def _beam_search( decoder_prompt_len=decoder_prompt_len, ) + if finish_position is not None: + input_ids = input_ids[:, : finish_position] + if return_dict_in_generate: if not output_scores: sequence_outputs["sequence_scores"] = None @@ -3480,6 +3489,7 @@ def _group_beam_search( this_peer_finished = False decoder_prompt_len = input_ids.shape[-1] # record the prompt length of decoder + finish_position = None # the position where the sequence is finished (for multi-gpu) while self._has_unfinished_sequences(this_peer_finished, synced_gpus, device=input_ids.device): # predicted tokens in cur_len step current_tokens = torch.zeros(batch_size * num_beams, dtype=input_ids.dtype, device=device) @@ -3497,8 +3507,7 @@ def _group_beam_search( outputs = self(**model_inputs, return_dict=True) if synced_gpus and this_peer_finished: - cur_len = cur_len + 1 - continue # don't waste resources running the code we don't need + finish_position = finish_position or input_ids.shape[-1] if output_scores: processed_score = torch.zeros_like(outputs.logits[:, -1, :]) @@ -3643,6 +3652,9 @@ def _group_beam_search( decoder_prompt_len=decoder_prompt_len, ) + if finish_position is not None: + input_ids = input_ids[:, : finish_position] + if return_dict_in_generate: if not output_scores: sequence_outputs["sequence_scores"] = None @@ -3767,6 +3779,7 @@ def _constrained_beam_search( this_peer_finished = False decoder_prompt_len = input_ids.shape[-1] # record the prompt length of decoder + finish_position = None # the position where the sequence is finished (for multi-gpu) while self._has_unfinished_sequences(this_peer_finished, synced_gpus, device=input_ids.device): model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs) @@ -3777,8 +3790,7 @@ def _constrained_beam_search( outputs = self(**model_inputs, return_dict=True) if synced_gpus and this_peer_finished: - cur_len = cur_len + 1 - continue # don't waste resources running the code we don't need + finish_position = finish_position or input_ids.shape[-1] # Clone is needed to avoid keeping a hanging ref to outputs.logits which may be very large for first iteration # (the clone itself is always small) @@ -3883,6 +3895,9 @@ def _constrained_beam_search( decoder_prompt_len=decoder_prompt_len, ) + if finish_position is not None: + input_ids = input_ids[:, : finish_position] + if return_dict_in_generate: if not output_scores: sequence_outputs["sequence_scores"] = None @@ -4004,6 +4019,7 @@ def _assisted_decoding( start_from_empty_dynamic_cache = True this_peer_finished = False + finish_position = None # the position where the sequence is finished (for multi-gpu) while self._has_unfinished_sequences(this_peer_finished, synced_gpus, device=input_ids.device): cur_len = input_ids.shape[-1] @@ -4104,7 +4120,7 @@ def _assisted_decoding( candidate_generator.update_candidate_strategy(input_ids, new_logits, n_matches) if synced_gpus and this_peer_finished: - continue # don't waste resources running the code we don't need + finish_position = finish_position or input_ids.shape[-1] # Store scores, attentions and hidden_states when required # Assistant: modified to append one tuple element per token, as in the other generation methods. @@ -4171,6 +4187,10 @@ def _assisted_decoding( candidate_generator.assistant_model.generation_config.num_assistant_tokens = ( candidate_generator.num_assistant_tokens ) + + if finish_position is not None: + input_ids = input_ids[:, : finish_position] + if return_dict_in_generate: if self.config.is_encoder_decoder: return GenerateEncoderDecoderOutput( From 112c1005f862dde31cb6742f6e7b53fdd8ef2d2c Mon Sep 17 00:00:00 2001 From: ojh31 Date: Mon, 12 Aug 2024 16:52:15 -0700 Subject: [PATCH 3/3] Fixed some whitespace issues --- src/transformers/generation/utils.py | 32 ++++++++++++++-------------- 1 file changed, 16 insertions(+), 16 deletions(-) diff --git a/src/transformers/generation/utils.py b/src/transformers/generation/utils.py index 9e9cb5bf43e3..4aa9be732c94 100644 --- a/src/transformers/generation/utils.py +++ b/src/transformers/generation/utils.py @@ -2408,7 +2408,7 @@ def _dola_decoding( if lm_head is None: raise ValueError("DoLa is not supported for models that don't have output embeddings.") - finish_position = None # the position where the sequence is finished (for multi-gpu) + finish_position = None # the position where the sequence is finished (for multi-gpu) while self._has_unfinished_sequences(this_peer_finished, synced_gpus, device=input_ids.device): # prepare model inputs model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs) @@ -2487,7 +2487,7 @@ def _dola_decoding( streamer.end() if finish_position is not None: - input_ids = input_ids[:, : finish_position] + input_ids = input_ids[:, :finish_position] if return_dict_in_generate: return GenerateDecoderOnlyOutput( @@ -2575,7 +2575,7 @@ def _contrastive_search( model_kwargs = self._get_initial_cache_position(input_ids, model_kwargs) this_peer_finished = False - finish_position = None # the position where the sequence is finished (for multi-gpu) + finish_position = None # the position where the sequence is finished (for multi-gpu) while self._has_unfinished_sequences(this_peer_finished, synced_gpus, device=input_ids.device): # if the first step in the loop, encode all the prefix and obtain: (1) past_key_values; # (2) last_hidden_states; (3) logit_for_next_step; (4) update model kwargs for the next step @@ -2848,7 +2848,7 @@ def _contrastive_search( streamer.end() if finish_position is not None: - input_ids = input_ids[:, : finish_position] + input_ids = input_ids[:, :finish_position] if return_dict_in_generate: # Contrastive search works by forward looking at the next token, so we need to exclude it from @@ -2975,7 +2975,7 @@ def _sample( unfinished_sequences = torch.ones(batch_size, dtype=torch.long, device=input_ids.device) model_kwargs = self._get_initial_cache_position(input_ids, model_kwargs) - finish_position = None # the position where the sequence is finished (for multi-gpu) + finish_position = None # the position where the sequence is finished (for multi-gpu) while self._has_unfinished_sequences( this_peer_finished, synced_gpus, device=input_ids.device, cur_len=cur_len, max_length=max_length ): @@ -3055,7 +3055,7 @@ def _sample( streamer.end() if finish_position is not None: - input_ids = input_ids[:, : finish_position] + input_ids = input_ids[:, :finish_position] if return_dict_in_generate: if self.config.is_encoder_decoder: @@ -3208,7 +3208,7 @@ def _beam_search( this_peer_finished = False decoder_prompt_len = input_ids.shape[-1] # record the prompt length of decoder - finish_position = None # the position where the sequence is finished (for multi-gpu) + finish_position = None # the position where the sequence is finished (for multi-gpu) while self._has_unfinished_sequences(this_peer_finished, synced_gpus, device=input_ids.device): model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs) @@ -3363,7 +3363,7 @@ def _beam_search( ) if finish_position is not None: - input_ids = input_ids[:, : finish_position] + input_ids = input_ids[:, :finish_position] if return_dict_in_generate: if not output_scores: @@ -3489,7 +3489,7 @@ def _group_beam_search( this_peer_finished = False decoder_prompt_len = input_ids.shape[-1] # record the prompt length of decoder - finish_position = None # the position where the sequence is finished (for multi-gpu) + finish_position = None # the position where the sequence is finished (for multi-gpu) while self._has_unfinished_sequences(this_peer_finished, synced_gpus, device=input_ids.device): # predicted tokens in cur_len step current_tokens = torch.zeros(batch_size * num_beams, dtype=input_ids.dtype, device=device) @@ -3653,8 +3653,8 @@ def _group_beam_search( ) if finish_position is not None: - input_ids = input_ids[:, : finish_position] - + input_ids = input_ids[:, :finish_position] + if return_dict_in_generate: if not output_scores: sequence_outputs["sequence_scores"] = None @@ -3779,7 +3779,7 @@ def _constrained_beam_search( this_peer_finished = False decoder_prompt_len = input_ids.shape[-1] # record the prompt length of decoder - finish_position = None # the position where the sequence is finished (for multi-gpu) + finish_position = None # the position where the sequence is finished (for multi-gpu) while self._has_unfinished_sequences(this_peer_finished, synced_gpus, device=input_ids.device): model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs) @@ -3896,7 +3896,7 @@ def _constrained_beam_search( ) if finish_position is not None: - input_ids = input_ids[:, : finish_position] + input_ids = input_ids[:, :finish_position] if return_dict_in_generate: if not output_scores: @@ -4019,7 +4019,7 @@ def _assisted_decoding( start_from_empty_dynamic_cache = True this_peer_finished = False - finish_position = None # the position where the sequence is finished (for multi-gpu) + finish_position = None # the position where the sequence is finished (for multi-gpu) while self._has_unfinished_sequences(this_peer_finished, synced_gpus, device=input_ids.device): cur_len = input_ids.shape[-1] @@ -4187,9 +4187,9 @@ def _assisted_decoding( candidate_generator.assistant_model.generation_config.num_assistant_tokens = ( candidate_generator.num_assistant_tokens ) - + if finish_position is not None: - input_ids = input_ids[:, : finish_position] + input_ids = input_ids[:, :finish_position] if return_dict_in_generate: if self.config.is_encoder_decoder: