Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
134 changes: 63 additions & 71 deletions src/transformers/generation/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -379,9 +379,10 @@ def prepare_inputs_for_generation(
# If we have cache: let's slice `input_ids` through `cache_position`, to keep only the unprocessed tokens
# Exception 1: when passing input_embeds, input_ids may be missing entries
# Exception 2: some generation methods do special slicing of input_ids, so we don't need to do it here
# Exception 3: with synced GPUs cache_position may go out of bounds, but we only want dummy token in that case
if past_key_values is not None:
model_inputs["past_key_values"] = past_key_values
if inputs_embeds is not None: # Exception 1
if inputs_embeds is not None or cache_position[-1] >= input_ids.shape[1]: # Exception 1 or Exception 3
input_ids = input_ids[:, -cache_position.shape[0] :]
elif input_ids.shape[1] != cache_position.shape[0]: # Default case (the "else", a no op, is Exception 2)
input_ids = input_ids[:, cache_position]
Expand Down Expand Up @@ -2609,8 +2610,14 @@ def _dola_decoding(
outputs.hidden_states[candidate_premature_layer][:, -1, :]
).to(final_logits.device)

# synced_gpus: don't waste resources running the code we don't need; kwargs must be updated before skipping
model_kwargs = self._update_model_kwargs_for_generation(
outputs,
model_kwargs,
is_encoder_decoder=self.config.is_encoder_decoder,
)
if synced_gpus and this_peer_finished:
continue # don't waste resources running the code we don't need
continue

next_token_logits = _dola_select_contrast(
candidate_premature_layers, candidate_premature_logits, final_logits
Expand Down Expand Up @@ -2652,11 +2659,6 @@ def _dola_decoding(
input_ids = torch.cat([input_ids, next_tokens[:, None]], dim=-1)
if streamer is not None:
streamer.put(next_tokens.cpu())
model_kwargs = self._update_model_kwargs_for_generation(
outputs,
model_kwargs,
is_encoder_decoder=self.config.is_encoder_decoder,
)

# stop when each sentence is finished
unfinished_sequences = unfinished_sequences & ~stopping_criteria(input_ids, scores)
Expand Down Expand Up @@ -3016,8 +3018,14 @@ def _contrastive_search(
)
# contrastive_search main logic end

# synced_gpus: don't waste resources running the code we don't need; kwargs must be updated before skipping
model_kwargs = self._update_model_kwargs_for_generation(
outputs,
model_kwargs,
is_encoder_decoder=self.config.is_encoder_decoder,
)
if synced_gpus and this_peer_finished:
continue # don't waste resources running the code we don't need
continue

# finished sentences should have their next token be a padding token
if has_eos_stopping_criteria:
Expand All @@ -3027,11 +3035,6 @@ def _contrastive_search(
input_ids = torch.cat([input_ids, next_tokens[:, None]], dim=-1)
if streamer is not None:
streamer.put(next_tokens.cpu())
model_kwargs = self._update_model_kwargs_for_generation(
outputs,
model_kwargs,
is_encoder_decoder=self.config.is_encoder_decoder,
)

# stop when each sentence is finished
unfinished_sequences = unfinished_sequences & ~stopping_criteria(input_ids, scores)
Expand Down Expand Up @@ -3168,8 +3171,14 @@ def _sample(
# forward pass to get next token
outputs = self(**model_inputs, return_dict=True)

# synced_gpus: don't waste resources running the code we don't need; kwargs must be updated before skipping
model_kwargs = self._update_model_kwargs_for_generation(
outputs,
model_kwargs,
is_encoder_decoder=self.config.is_encoder_decoder,
)
if synced_gpus and this_peer_finished:
continue # don't waste resources running the code we don't need
continue

# Clone is needed to avoid keeping a hanging ref to outputs.logits which may be very large for first iteration
# (the clone itself is always small)
Expand Down Expand Up @@ -3214,11 +3223,6 @@ def _sample(
input_ids = torch.cat([input_ids, next_tokens[:, None]], dim=-1)
if streamer is not None:
streamer.put(next_tokens.cpu())
model_kwargs = self._update_model_kwargs_for_generation(
outputs,
model_kwargs,
is_encoder_decoder=self.config.is_encoder_decoder,
)

unfinished_sequences = unfinished_sequences & ~stopping_criteria(input_ids, scores)
this_peer_finished = unfinished_sequences.max() == 0
Expand Down Expand Up @@ -3415,9 +3419,15 @@ def _beam_search(
else: # Unchanged original behavior
outputs = self(**model_inputs, return_dict=True)

# synced_gpus: don't waste resources running the code we don't need; kwargs must be updated before skipping
model_kwargs = self._update_model_kwargs_for_generation(
outputs,
model_kwargs,
is_encoder_decoder=self.config.is_encoder_decoder,
)
if synced_gpus and this_peer_finished:
cur_len = cur_len + 1
continue # don't waste resources running the code we don't need
continue

# Clone is needed to avoid keeping a hanging ref to outputs.logits which may be very large for first iteration
# (the clone itself is always small)
Expand Down Expand Up @@ -3491,12 +3501,6 @@ def _beam_search(

input_ids = torch.cat([input_ids[beam_idx, :], beam_next_tokens.unsqueeze(-1)], dim=-1)

model_kwargs = self._update_model_kwargs_for_generation(
outputs,
model_kwargs,
is_encoder_decoder=self.config.is_encoder_decoder,
)

# This is needed to properly delete outputs.logits which may be very large for first iteration
# Otherwise a reference to outputs is kept which keeps the logits alive in the next iteration
# IMPORTANT: Note that this should appear BEFORE the call to _reorder_cache() to save the maximum memory
Expand Down Expand Up @@ -3670,9 +3674,15 @@ def _group_beam_search(

outputs = self(**model_inputs, return_dict=True)

# synced_gpus: don't waste resources running the code we don't need; kwargs must be updated before skipping
model_kwargs = self._update_model_kwargs_for_generation(
outputs,
model_kwargs,
is_encoder_decoder=self.config.is_encoder_decoder,
)
if synced_gpus and this_peer_finished:
cur_len = cur_len + 1
continue # don't waste resources running the code we don't need
continue

if output_scores:
processed_score = torch.zeros_like(outputs.logits[:, -1, :])
Expand Down Expand Up @@ -3782,12 +3792,6 @@ def _group_beam_search(

input_ids = torch.cat([input_ids, current_tokens.unsqueeze(-1)], dim=-1)

model_kwargs = self._update_model_kwargs_for_generation(
outputs,
model_kwargs,
is_encoder_decoder=self.config.is_encoder_decoder,
)

# This is needed to properly delete outputs.logits which may be very large for first iteration
# Otherwise a reference to outputs is kept which keeps the logits alive in the next iteration
# IMPORTANT: Note that this should appear BEFORE the call to _reorder_cache() to save the maximum memory
Expand Down Expand Up @@ -3948,9 +3952,15 @@ def _constrained_beam_search(

outputs = self(**model_inputs, return_dict=True)

# synced_gpus: don't waste resources running the code we don't need; kwargs must be updated before skipping
model_kwargs = self._update_model_kwargs_for_generation(
outputs,
model_kwargs,
is_encoder_decoder=self.config.is_encoder_decoder,
)
if synced_gpus and this_peer_finished:
cur_len = cur_len + 1
continue # don't waste resources running the code we don't need
continue

# Clone is needed to avoid keeping a hanging ref to outputs.logits which may be very large for first iteration
# (the clone itself is always small)
Expand Down Expand Up @@ -4018,11 +4028,6 @@ def _constrained_beam_search(
beam_idx = beam_outputs["next_beam_indices"]

input_ids = torch.cat([input_ids[beam_idx, :], beam_next_tokens.unsqueeze(-1)], dim=-1)
model_kwargs = self._update_model_kwargs_for_generation(
outputs,
model_kwargs,
is_encoder_decoder=self.config.is_encoder_decoder,
)

# This is needed to properly delete outputs.logits which may be very large for first iteration
# Otherwise a reference to outputs is kept which keeps the logits alive in the next iteration
Expand Down Expand Up @@ -4162,17 +4167,8 @@ def _assisted_decoding(
unfinished_sequences = torch.ones(batch_size, dtype=torch.long, device=input_ids.device)
model_kwargs = self._get_initial_cache_position(input_ids, model_kwargs)

# This is needed if return_dict_in_generate is True
start_from_empty_dynamic_cache = False
past_key_values = model_kwargs.get("past_key_values", None)
if isinstance(past_key_values, DynamicCache) or (
isinstance(past_key_values, EncoderDecoderCache)
and isinstance(past_key_values.self_attention_cache, DynamicCache)
):
if past_key_values.get_seq_length() == 0:
start_from_empty_dynamic_cache = True

Comment on lines -4165 to -4174
Copy link
Copy Markdown
Contributor Author

@gante gante Oct 11, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Simplifies logic in assisted generation: see the new is_first_iteration variable and its uses :)

this_peer_finished = False
is_first_iteration = True # to preserve the same API in the output as other generation methods
while self._has_unfinished_sequences(this_peer_finished, synced_gpus, device=input_ids.device):
cur_len = input_ids.shape[-1]

Expand Down Expand Up @@ -4271,63 +4267,59 @@ def _assisted_decoding(
# 5. Update the candidate generation strategy if needed
candidate_generator.update_candidate_strategy(input_ids, new_logits, n_matches)

# synced_gpus: don't waste resources running the code we don't need; kwargs must be updated before skipping
model_kwargs = self._update_model_kwargs_for_generation(
outputs,
model_kwargs,
is_encoder_decoder=self.config.is_encoder_decoder,
num_new_tokens=n_matches + 1,
)
if synced_gpus and this_peer_finished:
continue # don't waste resources running the code we don't need
continue

# Store scores, attentions and hidden_states when required
# Assistant: modified to append one tuple element per token, as in the other generation methods.
if return_dict_in_generate:
newly_added_length = n_matches + 1
if output_scores:
scores += tuple(new_logits[:, i, :] for i in range(n_matches + 1))
scores += tuple(new_logits[:, i, :] for i in range(newly_added_length))
if output_logits:
raw_logits += (next_token_logits,)

if "past_key_values" not in model_kwargs or start_from_empty_dynamic_cache:
added_len = new_cur_len
# set it to false for other iterations
start_from_empty_dynamic_cache = False
else:
added_len = n_matches + 1
raw_logits += tuple(next_token_logits[:, i, :] for i in range(newly_added_length))

newly_added_length = new_cur_len if is_first_iteration else newly_added_length
if output_attentions:
if self.config.is_encoder_decoder:
cross_attentions = _split_model_outputs(
cross_attentions, outputs.cross_attentions, cur_len, added_len
cross_attentions, outputs.cross_attentions, cur_len, newly_added_length
)
decoder_attentions = _split_model_outputs(
decoder_attentions,
outputs.decoder_attentions,
cur_len,
added_len,
newly_added_length,
is_decoder_attention=True,
)
else:
decoder_attentions = _split_model_outputs(
decoder_attentions,
outputs.attentions,
cur_len,
added_len,
newly_added_length,
is_decoder_attention=True,
)
if output_hidden_states:
if self.config.is_encoder_decoder:
decoder_hidden_states = _split_model_outputs(
decoder_hidden_states, outputs.decoder_hidden_states, cur_len, added_len
decoder_hidden_states, outputs.decoder_hidden_states, cur_len, newly_added_length
)
else:
decoder_hidden_states = _split_model_outputs(
decoder_hidden_states, outputs.hidden_states, cur_len, added_len
decoder_hidden_states, outputs.hidden_states, cur_len, newly_added_length
)

model_kwargs = self._update_model_kwargs_for_generation(
outputs,
model_kwargs,
is_encoder_decoder=self.config.is_encoder_decoder,
num_new_tokens=n_matches + 1,
)

unfinished_sequences = unfinished_sequences & ~stopping_criteria(input_ids, scores)
this_peer_finished = unfinished_sequences.max() == 0
is_first_iteration = False

if streamer is not None:
streamer.end()
Expand Down