-
Notifications
You must be signed in to change notification settings - Fork 33.1k
Generate: move llama prepare_inputs_for_generation to GenerationMixin
#33677
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -31,6 +31,7 @@ | |
| EncoderDecoderCache, | ||
| OffloadedCache, | ||
| QuantizedCacheConfig, | ||
| StaticCache, | ||
| ) | ||
| from ..configuration_utils import PretrainedConfig | ||
| from ..integrations.deepspeed import is_deepspeed_zero3_enabled | ||
|
|
@@ -342,10 +343,99 @@ class GenerationMixin: | |
| To learn more about decoding strategies refer to the [text generation strategies guide](../generation_strategies). | ||
| """ | ||
|
|
||
| def prepare_inputs_for_generation(self, *args, **kwargs): | ||
| raise NotImplementedError( | ||
| "A model class needs to define a `prepare_inputs_for_generation` method in order to use `.generate()`." | ||
| def prepare_inputs_for_generation( | ||
| self, | ||
| input_ids: torch.LongTensor, | ||
| past_key_values: Optional[Cache] = None, | ||
| attention_mask: Optional[torch.LongTensor] = None, | ||
| inputs_embeds: Optional[torch.FloatTensor] = None, | ||
| cache_position: Optional[torch.LongTensor] = None, | ||
| position_ids: Optional[torch.LongTensor] = None, | ||
| use_cache: bool = True, | ||
| num_logits_to_keep: Optional[int] = None, | ||
| **kwargs, | ||
| ): | ||
| """ | ||
| Prepare the model inputs for generation. In includes operations like computing the 4D attention mask or | ||
| slicing inputs given the existing cache. | ||
|
|
||
| See the documentation in the used model for the arguments (different models might have different requirements | ||
| for e.g. `past_key_values`). Should work as is for most LLMs. | ||
| """ | ||
| # If we have cache: let's slice `input_ids` through `cache_position`, to keep only the unprocessed tokens | ||
| # Exception 1: when passing input_embeds, input_ids may be missing entries | ||
| # Exception 2: some generation methods do special slicing of input_ids, so we don't need to do it here | ||
| if past_key_values is not None: | ||
| if inputs_embeds is not None: # Exception 1 | ||
| input_ids = input_ids[:, -cache_position.shape[0] :] | ||
| elif input_ids.shape[1] != cache_position.shape[0]: # Default case (the "else", a no op, is Exception 2) | ||
| input_ids = input_ids[:, cache_position] | ||
|
|
||
| if attention_mask is not None and position_ids is None: | ||
| # create position_ids on the fly for batch generation | ||
| position_ids = attention_mask.long().cumsum(-1) - 1 | ||
| position_ids.masked_fill_(attention_mask == 0, 1) | ||
| if past_key_values: | ||
| position_ids = position_ids[:, -input_ids.shape[1] :] | ||
|
|
||
| # This `clone` call is needed to avoid recapturing cuda graphs with `torch.compile`'s | ||
| # `mode="reduce-overhead`, as otherwise the input `position_ids` would have various stride during the | ||
| # decoding. Here, simply using `.contiguous()` is not sufficient as in the batch size = 1 case, | ||
| # `position_ids` is already contiguous but with varying stride which retriggers a capture. | ||
| position_ids = position_ids.clone(memory_format=torch.contiguous_format) | ||
|
|
||
| # if `inputs_embeds` are passed, we only want to use them in the 1st generation step | ||
| if inputs_embeds is not None and cache_position[0] == 0: | ||
| model_inputs = {"inputs_embeds": inputs_embeds, "input_ids": None} | ||
| else: | ||
| # The clone here is for the same reason as for `position_ids`. | ||
| model_inputs = {"input_ids": input_ids.clone(memory_format=torch.contiguous_format), "inputs_embeds": None} | ||
|
|
||
| if isinstance(past_key_values, StaticCache) and attention_mask.ndim == 2: | ||
| if model_inputs["inputs_embeds"] is not None: | ||
| batch_size, sequence_length, _ = model_inputs["inputs_embeds"].shape | ||
| device = model_inputs["inputs_embeds"].device | ||
| else: | ||
| batch_size, sequence_length = model_inputs["input_ids"].shape | ||
| device = model_inputs["input_ids"].device | ||
|
|
||
| # Create the causal mask with fixed shape in advance, to reduce recompilations. If the function to create | ||
| # the 4D causal mask exists, it should be present in the base model (XXXModel class). | ||
| base_model = getattr(self, self.base_model_prefix) | ||
| causal_mask_creation_function = getattr( | ||
| base_model, "_prepare_4d_causal_attention_mask_with_cache_position", None | ||
| ) | ||
| if causal_mask_creation_function is None: | ||
| logger.warning_once( | ||
| f"{self.__class__.__name__} has no `_prepare_4d_causal_attention_mask_with_cache_position` method " | ||
| "defined in its base modeling class. Compiled forward passes will be sub-optimal. If you're " | ||
| "writing code, see Llama for an example implementation. If you're a user, please report this " | ||
| "issue on GitHub." | ||
| ) | ||
|
Comment on lines
+402
to
+414
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Can handle the case where I've added this extra logic to throw the warning in case something goes wrong when moving the function :) It will also be useful for other models in the future, since not all of them have this function. |
||
| else: | ||
| attention_mask = causal_mask_creation_function( | ||
| attention_mask, | ||
| sequence_length=sequence_length, | ||
| target_length=past_key_values.get_max_length(), | ||
| dtype=self.get_output_embeddings().weight.dtype, | ||
|
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. uses |
||
| device=device, | ||
| cache_position=cache_position, | ||
| batch_size=batch_size, | ||
| ) | ||
|
|
||
| if num_logits_to_keep is not None: | ||
| model_inputs["num_logits_to_keep"] = num_logits_to_keep | ||
|
|
||
| model_inputs.update( | ||
| { | ||
| "position_ids": position_ids, | ||
| "cache_position": cache_position, | ||
| "past_key_values": past_key_values, | ||
| "use_cache": use_cache, | ||
| "attention_mask": attention_mask, | ||
| } | ||
| ) | ||
| return model_inputs | ||
|
|
||
| def _prepare_model_inputs( | ||
| self, | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
new docstring