diff --git a/src/transformers/generation/utils.py b/src/transformers/generation/utils.py index 998288bd38df..4aa9be732c94 100644 --- a/src/transformers/generation/utils.py +++ b/src/transformers/generation/utils.py @@ -2408,6 +2408,7 @@ def _dola_decoding( if lm_head is None: raise ValueError("DoLa is not supported for models that don't have output embeddings.") + finish_position = None # the position where the sequence is finished (for multi-gpu) while self._has_unfinished_sequences(this_peer_finished, synced_gpus, device=input_ids.device): # prepare model inputs model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs) @@ -2429,7 +2430,7 @@ def _dola_decoding( ) if synced_gpus and this_peer_finished: - continue # don't waste resources running the code we don't need + finish_position = finish_position or input_ids.shape[-1] next_token_logits = _dola_select_contrast( candidate_premature_layers, candidate_premature_logits, final_logits @@ -2485,6 +2486,9 @@ def _dola_decoding( if streamer is not None: streamer.end() + if finish_position is not None: + input_ids = input_ids[:, :finish_position] + if return_dict_in_generate: return GenerateDecoderOnlyOutput( sequences=input_ids, @@ -2571,7 +2575,7 @@ def _contrastive_search( model_kwargs = self._get_initial_cache_position(input_ids, model_kwargs) this_peer_finished = False - + finish_position = None # the position where the sequence is finished (for multi-gpu) while self._has_unfinished_sequences(this_peer_finished, synced_gpus, device=input_ids.device): # if the first step in the loop, encode all the prefix and obtain: (1) past_key_values; # (2) last_hidden_states; (3) logit_for_next_step; (4) update model kwargs for the next step @@ -2820,7 +2824,7 @@ def _contrastive_search( # contrastive_search main logic end if synced_gpus and this_peer_finished: - continue # don't waste resources running the code we don't need + finish_position = finish_position or input_ids.shape[-1] # finished sentences should have their next token be a padding token if has_eos_stopping_criteria: @@ -2843,6 +2847,9 @@ def _contrastive_search( if streamer is not None: streamer.end() + if finish_position is not None: + input_ids = input_ids[:, :finish_position] + if return_dict_in_generate: # Contrastive search works by forward looking at the next token, so we need to exclude it from # `past_key_values` to be consistent with the other decoding methods @@ -2968,6 +2975,7 @@ def _sample( unfinished_sequences = torch.ones(batch_size, dtype=torch.long, device=input_ids.device) model_kwargs = self._get_initial_cache_position(input_ids, model_kwargs) + finish_position = None # the position where the sequence is finished (for multi-gpu) while self._has_unfinished_sequences( this_peer_finished, synced_gpus, device=input_ids.device, cur_len=cur_len, max_length=max_length ): @@ -2982,7 +2990,7 @@ def _sample( outputs = self(**model_inputs, return_dict=True) if synced_gpus and this_peer_finished: - continue # don't waste resources running the code we don't need + finish_position = finish_position or input_ids.shape[-1] # Clone is needed to avoid keeping a hanging ref to outputs.logits which may be very large for first iteration # (the clone itself is always small) @@ -3046,6 +3054,9 @@ def _sample( if streamer is not None: streamer.end() + if finish_position is not None: + input_ids = input_ids[:, :finish_position] + if return_dict_in_generate: if self.config.is_encoder_decoder: return GenerateEncoderDecoderOutput( @@ -3197,7 +3208,7 @@ def _beam_search( this_peer_finished = False decoder_prompt_len = input_ids.shape[-1] # record the prompt length of decoder - + finish_position = None # the position where the sequence is finished (for multi-gpu) while self._has_unfinished_sequences(this_peer_finished, synced_gpus, device=input_ids.device): model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs) @@ -3238,8 +3249,7 @@ def _beam_search( outputs = self(**model_inputs, return_dict=True) if synced_gpus and this_peer_finished: - cur_len = cur_len + 1 - continue # don't waste resources running the code we don't need + finish_position = finish_position or input_ids.shape[-1] # Clone is needed to avoid keeping a hanging ref to outputs.logits which may be very large for first iteration # (the clone itself is always small) @@ -3352,6 +3362,9 @@ def _beam_search( decoder_prompt_len=decoder_prompt_len, ) + if finish_position is not None: + input_ids = input_ids[:, :finish_position] + if return_dict_in_generate: if not output_scores: sequence_outputs["sequence_scores"] = None @@ -3476,6 +3489,7 @@ def _group_beam_search( this_peer_finished = False decoder_prompt_len = input_ids.shape[-1] # record the prompt length of decoder + finish_position = None # the position where the sequence is finished (for multi-gpu) while self._has_unfinished_sequences(this_peer_finished, synced_gpus, device=input_ids.device): # predicted tokens in cur_len step current_tokens = torch.zeros(batch_size * num_beams, dtype=input_ids.dtype, device=device) @@ -3493,8 +3507,7 @@ def _group_beam_search( outputs = self(**model_inputs, return_dict=True) if synced_gpus and this_peer_finished: - cur_len = cur_len + 1 - continue # don't waste resources running the code we don't need + finish_position = finish_position or input_ids.shape[-1] if output_scores: processed_score = torch.zeros_like(outputs.logits[:, -1, :]) @@ -3639,6 +3652,9 @@ def _group_beam_search( decoder_prompt_len=decoder_prompt_len, ) + if finish_position is not None: + input_ids = input_ids[:, :finish_position] + if return_dict_in_generate: if not output_scores: sequence_outputs["sequence_scores"] = None @@ -3763,6 +3779,7 @@ def _constrained_beam_search( this_peer_finished = False decoder_prompt_len = input_ids.shape[-1] # record the prompt length of decoder + finish_position = None # the position where the sequence is finished (for multi-gpu) while self._has_unfinished_sequences(this_peer_finished, synced_gpus, device=input_ids.device): model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs) @@ -3773,8 +3790,7 @@ def _constrained_beam_search( outputs = self(**model_inputs, return_dict=True) if synced_gpus and this_peer_finished: - cur_len = cur_len + 1 - continue # don't waste resources running the code we don't need + finish_position = finish_position or input_ids.shape[-1] # Clone is needed to avoid keeping a hanging ref to outputs.logits which may be very large for first iteration # (the clone itself is always small) @@ -3879,6 +3895,9 @@ def _constrained_beam_search( decoder_prompt_len=decoder_prompt_len, ) + if finish_position is not None: + input_ids = input_ids[:, :finish_position] + if return_dict_in_generate: if not output_scores: sequence_outputs["sequence_scores"] = None @@ -4000,6 +4019,7 @@ def _assisted_decoding( start_from_empty_dynamic_cache = True this_peer_finished = False + finish_position = None # the position where the sequence is finished (for multi-gpu) while self._has_unfinished_sequences(this_peer_finished, synced_gpus, device=input_ids.device): cur_len = input_ids.shape[-1] @@ -4100,7 +4120,7 @@ def _assisted_decoding( candidate_generator.update_candidate_strategy(input_ids, new_logits, n_matches) if synced_gpus and this_peer_finished: - continue # don't waste resources running the code we don't need + finish_position = finish_position or input_ids.shape[-1] # Store scores, attentions and hidden_states when required # Assistant: modified to append one tuple element per token, as in the other generation methods. @@ -4167,6 +4187,10 @@ def _assisted_decoding( candidate_generator.assistant_model.generation_config.num_assistant_tokens = ( candidate_generator.num_assistant_tokens ) + + if finish_position is not None: + input_ids = input_ids[:, :finish_position] + if return_dict_in_generate: if self.config.is_encoder_decoder: return GenerateEncoderDecoderOutput(