Skip to content
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
48 changes: 36 additions & 12 deletions src/transformers/generation/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2408,6 +2408,7 @@ def _dola_decoding(
if lm_head is None:
raise ValueError("DoLa is not supported for models that don't have output embeddings.")

finish_position = None # the position where the sequence is finished (for multi-gpu)
while self._has_unfinished_sequences(this_peer_finished, synced_gpus, device=input_ids.device):
# prepare model inputs
model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs)
Expand All @@ -2429,7 +2430,7 @@ def _dola_decoding(
)

if synced_gpus and this_peer_finished:
continue # don't waste resources running the code we don't need
finish_position = finish_position or input_ids.shape[-1]

next_token_logits = _dola_select_contrast(
candidate_premature_layers, candidate_premature_logits, final_logits
Expand Down Expand Up @@ -2485,6 +2486,9 @@ def _dola_decoding(
if streamer is not None:
streamer.end()

if finish_position is not None:
input_ids = input_ids[:, :finish_position]

if return_dict_in_generate:
return GenerateDecoderOnlyOutput(
sequences=input_ids,
Expand Down Expand Up @@ -2571,7 +2575,7 @@ def _contrastive_search(
model_kwargs = self._get_initial_cache_position(input_ids, model_kwargs)

this_peer_finished = False

finish_position = None # the position where the sequence is finished (for multi-gpu)
while self._has_unfinished_sequences(this_peer_finished, synced_gpus, device=input_ids.device):
# if the first step in the loop, encode all the prefix and obtain: (1) past_key_values;
# (2) last_hidden_states; (3) logit_for_next_step; (4) update model kwargs for the next step
Expand Down Expand Up @@ -2820,7 +2824,7 @@ def _contrastive_search(
# contrastive_search main logic end

if synced_gpus and this_peer_finished:
continue # don't waste resources running the code we don't need
finish_position = finish_position or input_ids.shape[-1]

# finished sentences should have their next token be a padding token
if has_eos_stopping_criteria:
Expand All @@ -2843,6 +2847,9 @@ def _contrastive_search(
if streamer is not None:
streamer.end()

if finish_position is not None:
input_ids = input_ids[:, :finish_position]

if return_dict_in_generate:
# Contrastive search works by forward looking at the next token, so we need to exclude it from
# `past_key_values` to be consistent with the other decoding methods
Expand Down Expand Up @@ -2968,6 +2975,7 @@ def _sample(
unfinished_sequences = torch.ones(batch_size, dtype=torch.long, device=input_ids.device)
model_kwargs = self._get_initial_cache_position(input_ids, model_kwargs)

finish_position = None # the position where the sequence is finished (for multi-gpu)
while self._has_unfinished_sequences(
this_peer_finished, synced_gpus, device=input_ids.device, cur_len=cur_len, max_length=max_length
):
Expand All @@ -2982,7 +2990,7 @@ def _sample(
outputs = self(**model_inputs, return_dict=True)

if synced_gpus and this_peer_finished:
continue # don't waste resources running the code we don't need
finish_position = finish_position or input_ids.shape[-1]

# Clone is needed to avoid keeping a hanging ref to outputs.logits which may be very large for first iteration
# (the clone itself is always small)
Expand Down Expand Up @@ -3046,6 +3054,9 @@ def _sample(
if streamer is not None:
streamer.end()

if finish_position is not None:
input_ids = input_ids[:, :finish_position]

if return_dict_in_generate:
if self.config.is_encoder_decoder:
return GenerateEncoderDecoderOutput(
Expand Down Expand Up @@ -3197,7 +3208,7 @@ def _beam_search(
this_peer_finished = False

decoder_prompt_len = input_ids.shape[-1] # record the prompt length of decoder

finish_position = None # the position where the sequence is finished (for multi-gpu)
while self._has_unfinished_sequences(this_peer_finished, synced_gpus, device=input_ids.device):
model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs)

Expand Down Expand Up @@ -3238,8 +3249,7 @@ def _beam_search(
outputs = self(**model_inputs, return_dict=True)

if synced_gpus and this_peer_finished:
cur_len = cur_len + 1
continue # don't waste resources running the code we don't need
finish_position = finish_position or input_ids.shape[-1]

# Clone is needed to avoid keeping a hanging ref to outputs.logits which may be very large for first iteration
# (the clone itself is always small)
Expand Down Expand Up @@ -3352,6 +3362,9 @@ def _beam_search(
decoder_prompt_len=decoder_prompt_len,
)

if finish_position is not None:
input_ids = input_ids[:, :finish_position]

if return_dict_in_generate:
if not output_scores:
sequence_outputs["sequence_scores"] = None
Expand Down Expand Up @@ -3476,6 +3489,7 @@ def _group_beam_search(
this_peer_finished = False

decoder_prompt_len = input_ids.shape[-1] # record the prompt length of decoder
finish_position = None # the position where the sequence is finished (for multi-gpu)
while self._has_unfinished_sequences(this_peer_finished, synced_gpus, device=input_ids.device):
# predicted tokens in cur_len step
current_tokens = torch.zeros(batch_size * num_beams, dtype=input_ids.dtype, device=device)
Expand All @@ -3493,8 +3507,7 @@ def _group_beam_search(
outputs = self(**model_inputs, return_dict=True)

if synced_gpus and this_peer_finished:
cur_len = cur_len + 1
continue # don't waste resources running the code we don't need
finish_position = finish_position or input_ids.shape[-1]

if output_scores:
processed_score = torch.zeros_like(outputs.logits[:, -1, :])
Expand Down Expand Up @@ -3639,6 +3652,9 @@ def _group_beam_search(
decoder_prompt_len=decoder_prompt_len,
)

if finish_position is not None:
input_ids = input_ids[:, :finish_position]

if return_dict_in_generate:
if not output_scores:
sequence_outputs["sequence_scores"] = None
Expand Down Expand Up @@ -3763,6 +3779,7 @@ def _constrained_beam_search(
this_peer_finished = False

decoder_prompt_len = input_ids.shape[-1] # record the prompt length of decoder
finish_position = None # the position where the sequence is finished (for multi-gpu)
while self._has_unfinished_sequences(this_peer_finished, synced_gpus, device=input_ids.device):
model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs)

Expand All @@ -3773,8 +3790,7 @@ def _constrained_beam_search(
outputs = self(**model_inputs, return_dict=True)

if synced_gpus and this_peer_finished:
cur_len = cur_len + 1
continue # don't waste resources running the code we don't need
finish_position = finish_position or input_ids.shape[-1]

# Clone is needed to avoid keeping a hanging ref to outputs.logits which may be very large for first iteration
# (the clone itself is always small)
Expand Down Expand Up @@ -3879,6 +3895,9 @@ def _constrained_beam_search(
decoder_prompt_len=decoder_prompt_len,
)

if finish_position is not None:
input_ids = input_ids[:, :finish_position]

if return_dict_in_generate:
if not output_scores:
sequence_outputs["sequence_scores"] = None
Expand Down Expand Up @@ -4000,6 +4019,7 @@ def _assisted_decoding(
start_from_empty_dynamic_cache = True

this_peer_finished = False
finish_position = None # the position where the sequence is finished (for multi-gpu)
while self._has_unfinished_sequences(this_peer_finished, synced_gpus, device=input_ids.device):
cur_len = input_ids.shape[-1]

Expand Down Expand Up @@ -4100,7 +4120,7 @@ def _assisted_decoding(
candidate_generator.update_candidate_strategy(input_ids, new_logits, n_matches)

if synced_gpus and this_peer_finished:
continue # don't waste resources running the code we don't need
finish_position = finish_position or input_ids.shape[-1]

# Store scores, attentions and hidden_states when required
# Assistant: modified to append one tuple element per token, as in the other generation methods.
Expand Down Expand Up @@ -4167,6 +4187,10 @@ def _assisted_decoding(
candidate_generator.assistant_model.generation_config.num_assistant_tokens = (
candidate_generator.num_assistant_tokens
)

if finish_position is not None:
input_ids = input_ids[:, :finish_position]

if return_dict_in_generate:
if self.config.is_encoder_decoder:
return GenerateEncoderDecoderOutput(
Expand Down