From 13c07d0a7801ec19b26166b6a8a4e1543be171c0 Mon Sep 17 00:00:00 2001 From: Joao Gante Date: Sat, 17 Aug 2024 11:24:17 +0000 Subject: [PATCH 1/9] tmp --- .../generation/configuration_utils.py | 109 +++--- src/transformers/generation/utils.py | 329 +++++++++--------- 2 files changed, 230 insertions(+), 208 deletions(-) diff --git a/src/transformers/generation/configuration_utils.py b/src/transformers/generation/configuration_utils.py index aa5e77ac6817..bc464e3d6f2e 100644 --- a/src/transformers/generation/configuration_utils.py +++ b/src/transformers/generation/configuration_utils.py @@ -130,9 +130,29 @@ class GenerationConfig(PushToHubMixin): [this paper](https://arxiv.org/pdf/1610.02424.pdf) for more details. penalty_alpha (`float`, *optional*): The values balance the model confidence and the degeneration penalty in contrastive search decoding. + dola_layers (`str` or `List[int]`, *optional*): + The layers to use for DoLa decoding. If `None`, DoLa decoding is not used. If a string, it must + be one of "low" or "high", which means using the lower part or higher part of the model layers, respectively. + "low" means the first half of the layers up to the first 20 layers, and "high" means the last half of the + layers up to the last 20 layers. + If a list of integers, it must contain the indices of the layers to use for candidate premature layers in DoLa. + The 0-th layer is the word embedding layer of the model. Set to `'low'` to improve long-answer reasoning tasks, + `'high'` to improve short-answer tasks. Check the [documentation](https://github.com/huggingface/transformers/blob/main/docs/source/en/generation_strategies.md) + or [the paper](https://arxiv.org/abs/2309.03883) for more details. + + > Parameters that control the cache + use_cache (`bool`, *optional*, defaults to `True`): Whether or not the model should use the past last key/values attentions (if applicable to the model) to speed up decoding. + cache_implementation (`str`, *optional*, default to `None`): + Cache class that should be used when generating. + cache_config (`CacheConfig` or `dict`, *optional*, default to `None`): + Arguments used in the key-value cache class can be passed in `cache_config`. Can be passed as a `Dict` and + it will be converted to its repsective `CacheConfig` internally. + Otherwise can be passed as a `CacheConfig` class matching the indicated `cache_implementation`. + return_legacy_cache (`bool`, *optional*, default to `True`): + Whether to return the legacy or new format of the cache when `DynamicCache` is used by default. > Parameters for manipulation of the model output logits @@ -307,29 +327,6 @@ class GenerationConfig(PushToHubMixin): max_matching_ngram_size (`int`, *optional*, default to `None`): The maximum ngram size to be considered for matching in the prompt. Default to 2 if not provided. - > Generation parameters exclusive to [DoLa decoding](https://arxiv.org/abs/2309.03883) - - dola_layers (`str` or `List[int]`, *optional*): - The layers to use for DoLa decoding. If `None`, DoLa decoding is not used. If a string, it must - be one of "low" or "high", which means using the lower part or higher part of the model layers, respectively. - "low" means the first half of the layers up to the first 20 layers, and "high" means the last half of the - layers up to the last 20 layers. - If a list of integers, it must contain the indices of the layers to use for candidate premature layers in DoLa. - The 0-th layer is the word embedding layer of the model. Set to `'low'` to improve long-answer reasoning tasks, - `'high'` to improve short-answer tasks. Check the [documentation](https://github.com/huggingface/transformers/blob/main/docs/source/en/generation_strategies.md) - or [the paper](https://arxiv.org/abs/2309.03883) for more details. - - > Parameters specific to the caching mechanism: - - cache_implementation (`str`, *optional*, default to `None`): - Cache class that should be used when generating. - cache_config (`CacheConfig` or `dict`, *optional*, default to `None`): - Arguments used in the key-value cache class can be passed in `cache_config`. Can be passed as a `Dict` and - it will be converted to its repsective `CacheConfig` internally. - Otherwise can be passed as a `CacheConfig` class matching the indicated `cache_implementation`. - return_legacy_cache (`bool`, *optional*, default to `True`): - Whether to return the legacy or new format of the cache when `DynamicCache` is used by default. - > Wild card generation_kwargs: @@ -352,7 +349,19 @@ def __init__(self, **kwargs): self.num_beams = kwargs.pop("num_beams", 1) self.num_beam_groups = kwargs.pop("num_beam_groups", 1) self.penalty_alpha = kwargs.pop("penalty_alpha", None) + self.dola_layers = kwargs.pop("dola_layers", None) + + # Parameters that control the cache self.use_cache = kwargs.pop("use_cache", True) + self.cache_implementation = kwargs.pop("cache_implementation", None) + self.cache_config = kwargs.pop("cache_config", None) + if self.cache_implementation is not None and self.cache_implementation in NEEDS_CACHE_CONFIG: + cache_config_class = NEEDS_CACHE_CONFIG[self.cache_implementation] + if self.cache_config is None: + self.cache_config = cache_config_class() + elif isinstance(self.cache_config, dict): + self.cache_config = cache_config_class.from_dict(self.cache_config) + self.return_legacy_cache = kwargs.pop("return_legacy_cache", None) # Parameters for manipulation of the model output logits self.temperature = kwargs.pop("temperature", 1.0) @@ -411,20 +420,6 @@ def __init__(self, **kwargs): self.num_assistant_tokens = kwargs.pop("num_assistant_tokens", 5) self.num_assistant_tokens_schedule = kwargs.pop("num_assistant_tokens_schedule", "heuristic") - # DoLa generation - self.dola_layers = kwargs.pop("dola_layers", None) - - # Cache implementation - self.cache_implementation = kwargs.pop("cache_implementation", None) - self.cache_config = kwargs.pop("cache_config", None) - if self.cache_implementation is not None and self.cache_implementation in NEEDS_CACHE_CONFIG: - cache_config_class = NEEDS_CACHE_CONFIG[self.cache_implementation] - if self.cache_config is None: - self.cache_config = cache_config_class() - elif isinstance(self.cache_config, dict): - self.cache_config = cache_config_class.from_dict(self.cache_config) - self.return_legacy_cache = kwargs.pop("return_legacy_cache", True) - # Prompt lookup decoding self.prompt_lookup_num_tokens = kwargs.pop("prompt_lookup_num_tokens", None) self.max_matching_ngram_size = kwargs.pop("max_matching_ngram_size", None) @@ -544,8 +539,9 @@ def validate(self, is_init=False): raise ValueError(f"`max_new_tokens` must be greater than 0, but is {self.max_new_tokens}.") if self.pad_token_id is not None and self.pad_token_id < 0: warnings.warn( - f"`pad_token_id` should be positive but got {self.pad_token_id}. This will cause errors when batch generating, if there is padding. " - "Please set `pad_token_id` explicitly by `model.generation_config.pad_token_id=PAD_TOKEN_ID` to avoid errors in generation, and ensure your `input_ids` input does not have negative values." + f"`pad_token_id` should be positive but got {self.pad_token_id}. This will cause errors when batch " + "generating, if there is padding. Please set `pad_token_id` explicitly as " + "`model.generation_config.pad_token_id=PAD_TOKEN_ID` to avoid errors in generation" ) # Validation of attribute relations: @@ -675,6 +671,14 @@ def validate(self, is_init=False): group_error_prefix + "`diversity_penalty` should be greater than `0.0`, otherwise your groups will be identical." ) + # DoLa generation + if self.dola_layers is not None and (self.repetition_penalty is None or self.repetition_penalty < 1.2): + warnings.warn( + "`dola_layers` is set to trigger DoLa decoding, but `repetition_penalty` is set to a value of " + f"{self.repetition_penalty}, which could induce unwanted repetition. The recommended value for " + "DoLa decoding is `repetition_penalty>=1.2`.", + UserWarning, + ) # 4. check `num_return_sequences` if self.num_return_sequences != 1: @@ -690,7 +694,7 @@ def validate(self, is_init=False): f"({self.num_beams})." ) - # 5. check `cache_config` + # 5. check cache-related arguments if self.cache_config is not None: cache_class = NEEDS_CACHE_CONFIG.get(self.cache_implementation) if cache_class is None: @@ -702,6 +706,20 @@ def validate(self, is_init=False): if not isinstance(self.cache_config, cache_class): self.cache_config = cache_class.from_dict(self.cache_config) self.cache_config.validate() + if self.use_cache is False: + # In this case, all cache-related arguments should be unset. However, since `use_cache=False` is often used + # passed to `generate` directly to hot-fix cache issues, let's raise a warning instead of an error + # (otherwise a user might need to overwrite several parameters). + no_cache_warning = ( + "You have set `use_cache` to `False`, but {cache_arg} is set to {cache_arg_value}. {cache_arg} will " + "have no effect." + ) + for arg_name in ("cache_implementation", "cache_config", "return_legacy_cache"): + if getattr(self, arg_name) is not None: + warnings.warn( + no_cache_warning.format(cache_arg=arg_name, cache_arg_value=getattr(self, arg_name)), + UserWarning, + ) # 6. check watermarking arguments if self.watermarking_config is not None: @@ -727,17 +745,6 @@ def validate(self, is_init=False): "`generate()` (or a pipeline) directly." ) - # 6. if dola_layers is set, check if repetition_penalty is set to >= 1.2 - if self.dola_layers is not None and (self.repetition_penalty is None or self.repetition_penalty < 1.2): - dola_decoding_wrong_parameter_msg = ( - "`dola_layers` is set to trigger DoLa decoding, but `repetition_penalty` is set to a value of {repetition_penalty}, " - "which could induce unwanted repetition. The recommended value for DoLa decoding is `repetition_penalty>=1.2`." - ) - warnings.warn( - dola_decoding_wrong_parameter_msg.format(repetition_penalty=self.repetition_penalty), - UserWarning, - ) - def save_pretrained( self, save_directory: Union[str, os.PathLike], diff --git a/src/transformers/generation/utils.py b/src/transformers/generation/utils.py index 6ab88ad26bf3..af33086b6dda 100644 --- a/src/transformers/generation/utils.py +++ b/src/transformers/generation/utils.py @@ -136,27 +136,23 @@ class GenerateDecoderOnlyOutput(ModelOutput): sequences (`torch.LongTensor` of shape `(batch_size, sequence_length)`): The generated sequences. The second dimension (sequence_length) is either equal to `max_length` or shorter if all batches finished early due to the `eos_token_id`. - scores (`tuple(torch.FloatTensor)` *optional*, returned when `output_scores=True` is passed or when `config.output_scores=True`): + scores (`tuple(torch.FloatTensor)` *optional*, returned when `output_scores=True`): Processed prediction scores of the language modeling head (scores for each vocabulary token before SoftMax) at each generation step. Tuple of `torch.FloatTensor` with up to `max_new_tokens` elements (one element for each generated token), with each tensor of shape `(batch_size, config.vocab_size)`. - logits (`tuple(torch.FloatTensor)` *optional*, returned when `output_logits=True` is passed or when `config.output_logits=True`): + logits (`tuple(torch.FloatTensor)` *optional*, returned when `output_logits=True`): Unprocessed prediction scores of the language modeling head (scores for each vocabulary token before SoftMax) at each generation step. Tuple of `torch.FloatTensor` with up to `max_new_tokens` elements (one element for each generated token), with each tensor of shape `(batch_size, config.vocab_size)`. - attentions (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `output_attentions=True` is passed or `config.output_attentions=True`): + attentions (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `output_attentions=True`): Tuple (one element for each generated token) of tuples (one element for each layer of the decoder) of `torch.FloatTensor` of shape `(batch_size, num_heads, generated_length, sequence_length)`. - hidden_states (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): + hidden_states (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `output_hidden_states=True`): Tuple (one element for each generated token) of tuples (one element for each layer of the decoder) of `torch.FloatTensor` of shape `(batch_size, generated_length, hidden_size)`. - past_key_values (`tuple(tuple(torch.FloatTensor)))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): - NOTE: some models have a different `past_key_values` format, confirm with the model's documentation. - Usually a Tuple (one element for each layer of the decoder) of tuples (two elements, key tensor and value - tensor). The first Tuple is of length `config.n_layers`, with each tuple having 2 tensors of shape - `(batch_size, num_heads, sequence_length, embed_size_per_head)`) and optionally if - `config.is_encoder_decoder=True` 2 additional tensors of shape `(batch_size, num_heads, - encoder_sequence_length, embed_size_per_head)`. + past_key_values (`tuple(tuple(torch.FloatTensor)))`, *optional*, returned when `use_cache=True`): + Returns the model cache, used to speed up decoding. Different models have a different cache format, check + the model's documentation. Usually, a [`~cache_utils.Cache`] instance. """ sequences: torch.LongTensor = None @@ -176,36 +172,32 @@ class GenerateEncoderDecoderOutput(ModelOutput): sequences (`torch.LongTensor` of shape `(batch_size*num_return_sequences, sequence_length)`): The generated sequences. The second dimension (sequence_length) is either equal to `max_length` or shorter if all batches finished early due to the `eos_token_id`. - scores (`tuple(torch.FloatTensor)` *optional*, returned when `output_scores=True` is passed or when `config.output_scores=True`): + scores (`tuple(torch.FloatTensor)` *optional*, returned when `output_scores=True`): Processed prediction scores of the language modeling head (scores for each vocabulary token before SoftMax) at each generation step. Tuple of `torch.FloatTensor` with up to `max_new_tokens` elements (one element for each generated token), with each tensor of shape `(batch_size, config.vocab_size)`. - logits (`tuple(torch.FloatTensor)` *optional*, returned when `output_logits=True` is passed or when `config.output_logits=True`): + logits (`tuple(torch.FloatTensor)` *optional*, returned when `output_logits=True`): Unprocessed prediction scores of the language modeling head (scores for each vocabulary token before SoftMax) at each generation step. Tuple of `torch.FloatTensor` with up to `max_new_tokens` elements (one element for each generated token), with each tensor of shape `(batch_size, config.vocab_size)`. - encoder_attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or `config.output_attentions=True`): + encoder_attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True`): Tuple of `torch.FloatTensor` (one for each layer of the decoder) of shape `(batch_size, num_heads, sequence_length, sequence_length)`. - encoder_hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): + encoder_hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True`): Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`. - decoder_attentions (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `output_attentions=True` is passed or `config.output_attentions=True`): + decoder_attentions (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `output_attentions=True`): Tuple (one element for each generated token) of tuples (one element for each layer of the decoder) of `torch.FloatTensor` of shape `(batch_size, num_heads, generated_length, sequence_length)`. - cross_attentions (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `output_attentions=True` is passed or `config.output_attentions=True`): + cross_attentions (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `output_attentions=True`): Tuple (one element for each generated token) of tuples (one element for each layer of the decoder) of `torch.FloatTensor` of shape `(batch_size, num_heads, generated_length, sequence_length)`. - decoder_hidden_states (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): + decoder_hidden_states (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `output_hidden_states=True`): Tuple (one element for each generated token) of tuples (one element for each layer of the decoder) of `torch.FloatTensor` of shape `(batch_size, generated_length, hidden_size)`. past_key_values (`tuple(tuple(torch.FloatTensor)))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): - NOTE: some models have a different `past_key_values` format, confirm with the model's documentation. - Usually a Tuple (one element for each layer of the decoder) of tuples (two elements, key tensor and value - tensor). The first Tuple is of length `config.n_layers`, with each tuple having 2 tensors of shape - `(batch_size, num_heads, sequence_length, embed_size_per_head)`) and optionally if - `config.is_encoder_decoder=True` 2 additional tensors of shape `(batch_size, num_heads, - encoder_sequence_length, embed_size_per_head)`. + Returns the model cache, used to speed up decoding. Different models have a different cache format, check + the model's documentation. Usually, a [`~cache_utils.Cache`] instance. """ sequences: torch.LongTensor = None @@ -228,33 +220,29 @@ class GenerateBeamDecoderOnlyOutput(ModelOutput): sequences (`torch.LongTensor` of shape `(batch_size*num_return_sequences, sequence_length)`): The generated sequences. The second dimension (sequence_length) is either equal to `max_length` or shorter if all batches finished early due to the `eos_token_id`. - sequences_scores (`torch.FloatTensor` of shape `(batch_size*num_return_sequences)`, *optional*, returned when `output_scores=True` is passed or when `config.output_scores=True`): + sequences_scores (`torch.FloatTensor` of shape `(batch_size*num_return_sequences)`, *optional*, returned when `output_scores=True`): Final beam scores of the generated `sequences`. - scores (`tuple(torch.FloatTensor)` *optional*, returned when `output_scores=True` is passed or when `config.output_scores=True`): + scores (`tuple(torch.FloatTensor)` *optional*, returned when `output_scores=True`): Beam transition scores for each vocabulary token at each generation step. Beam transition scores consisting of log probabilities of tokens conditioned on log softmax of previously generated tokens in this beam. Tuple of `torch.FloatTensor` with up to `max_new_tokens` elements (one element for each generated token), with each tensor of shape `(batch_size*num_beams, config.vocab_size)`. - logits (`tuple(torch.FloatTensor)` *optional*, returned when `output_logits=True` is passed or when `config.output_logits=True`): + logits (`tuple(torch.FloatTensor)` *optional*, returned when `output_logits=True`): Unprocessed prediction scores of the language modeling head (scores for each vocabulary token before SoftMax) at each generation step. Tuple of `torch.FloatTensor` with up to `max_new_tokens` elements (one element for each generated token), with each tensor of shape `(batch_size, config.vocab_size)`. - beam_indices (`torch.LongTensor`, *optional*, returned when `output_scores=True` is passed or when `config.output_scores=True`): + beam_indices (`torch.LongTensor`, *optional*, returned when `output_scores=True`): Beam indices of generated token id at each generation step. `torch.LongTensor` of shape `(batch_size*num_return_sequences, sequence_length)`. - attentions (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `output_attentions=True` is passed or `config.output_attentions=True`): + attentions (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `output_attentions=True`): Tuple (one element for each generated token) of tuples (one element for each layer of the decoder) of `torch.FloatTensor` of shape `(batch_size*num_beams, num_heads, generated_length, sequence_length)`. - hidden_states (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): + hidden_states (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `output_hidden_states=True`): Tuple (one element for each generated token) of tuples (one element for each layer of the decoder) of `torch.FloatTensor` of shape `(batch_size*num_beams*num_return_sequences, generated_length, hidden_size)`. - past_key_values (`tuple(tuple(torch.FloatTensor)))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): - NOTE: some models have a different `past_key_values` format, confirm with the model's documentation. - Usually a Tuple (one element for each layer of the decoder) of tuples (two elements, key tensor and value - tensor). The first Tuple is of length `config.n_layers`, with each tuple having 2 tensors of shape - `(batch_size, num_heads, sequence_length, embed_size_per_head)`) and optionally if - `config.is_encoder_decoder=True` 2 additional tensors of shape `(batch_size, num_heads, - encoder_sequence_length, embed_size_per_head)`. + past_key_values (`tuple(tuple(torch.FloatTensor)))`, *optional*, returned when `use_cache=True`): + Returns the model cache, used to speed up decoding. Different models have a different cache format, check + the model's documentation. Usually, a [`~cache_utils.Cache`] instance. """ sequences: torch.LongTensor = None @@ -276,43 +264,39 @@ class GenerateBeamEncoderDecoderOutput(ModelOutput): sequences (`torch.LongTensor` of shape `(batch_size*num_return_sequences, sequence_length)`): The generated sequences. The second dimension (sequence_length) is either equal to `max_length` or shorter if all batches finished early due to the `eos_token_id`. - sequences_scores (`torch.FloatTensor` of shape `(batch_size*num_return_sequences)`, *optional*, returned when `output_scores=True` is passed or when `config.output_scores=True`): + sequences_scores (`torch.FloatTensor` of shape `(batch_size*num_return_sequences)`, *optional*, returned when `output_scores=True`): Final beam scores of the generated `sequences`. - scores (`tuple(torch.FloatTensor)` *optional*, returned when `output_scores=True` is passed or when `config.output_scores=True`): + scores (`tuple(torch.FloatTensor)` *optional*, returned when `output_scores=True`): Beam transition scores for each vocabulary token at each generation step. Beam transition scores consisting of log probabilities of tokens conditioned on log softmax of previously generated tokens in this beam. Tuple of `torch.FloatTensor` with up to `max_new_tokens` elements (one element for each generated token), with each tensor of shape `(batch_size*num_beams, config.vocab_size)`. - logits (`tuple(torch.FloatTensor)` *optional*, returned when `output_logits=True` is passed or when `config.output_logits=True`): + logits (`tuple(torch.FloatTensor)` *optional*, returned when `output_logits=True`): Unprocessed prediction scores of the language modeling head (scores for each vocabulary token before SoftMax) at each generation step. Tuple of `torch.FloatTensor` with up to `max_new_tokens` elements (one element for each generated token), with each tensor of shape `(batch_size, config.vocab_size)`. - beam_indices (`torch.LongTensor`, *optional*, returned when `output_scores=True` is passed or when `config.output_scores=True`): + beam_indices (`torch.LongTensor`, *optional*, returned when `output_scores=True`): Beam indices of generated token id at each generation step. `torch.LongTensor` of shape `(batch_size*num_return_sequences, sequence_length)`. - encoder_attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or `config.output_attentions=True`): + encoder_attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True`): Tuple of `torch.FloatTensor` (one for each layer of the decoder) of shape `(batch_size, num_heads, sequence_length, sequence_length)`. - encoder_hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): + encoder_hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True`): Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer) of shape `(batch_size*num_beams*num_return_sequences, sequence_length, hidden_size)`. - decoder_attentions (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `output_attentions=True` is passed or `config.output_attentions=True`): + decoder_attentions (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `output_attentions=True`): Tuple (one element for each generated token) of tuples (one element for each layer of the decoder) of `torch.FloatTensor` of shape `(batch_size*num_beams*num_return_sequences, num_heads, generated_length, sequence_length)`. - cross_attentions (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `output_attentions=True` is passed or `config.output_attentions=True`): + cross_attentions (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `output_attentions=True`): Tuple (one element for each generated token) of tuples (one element for each layer of the decoder) of `torch.FloatTensor` of shape `(batch_size, num_heads, generated_length, sequence_length)`. - decoder_hidden_states (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): + decoder_hidden_states (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `output_hidden_states=True`): Tuple (one element for each generated token) of tuples (one element for each layer of the decoder) of `torch.FloatTensor` of shape `(batch_size*num_beams*num_return_sequences, generated_length, hidden_size)`. - past_key_values (`tuple(tuple(torch.FloatTensor)))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): - NOTE: some models have a different `past_key_values` format, confirm with the model's documentation. - Usually a Tuple (one element for each layer of the decoder) of tuples (two elements, key tensor and value - tensor). The first Tuple is of length `config.n_layers`, with each tuple having 2 tensors of shape - `(batch_size, num_heads, sequence_length, embed_size_per_head)`) and optionally if - `config.is_encoder_decoder=True` 2 additional tensors of shape `(batch_size, num_heads, - encoder_sequence_length, embed_size_per_head)`. + past_key_values (`tuple(tuple(torch.FloatTensor)))`, *optional*, returned when `use_cache=True`): + Returns the model cache, used to speed up decoding. Different models have a different cache format, check + the model's documentation. Usually, a [`~cache_utils.Cache`] instance. """ sequences: torch.LongTensor = None @@ -328,6 +312,7 @@ class GenerateBeamEncoderDecoderOutput(ModelOutput): past_key_values: Optional[Tuple[Tuple[Tuple[torch.FloatTensor]]]] = None +# TODO (joao): remove the equivalent classes and typing shortcuts below in v5 # Equivalent classes (kept for retrocompatibility purposes) GreedySearchDecoderOnlyOutput = GenerateDecoderOnlyOutput ContrastiveSearchDecoderOnlyOutput = GenerateDecoderOnlyOutput @@ -1501,6 +1486,126 @@ def _supports_default_dynamic_cache(self) -> bool: """ return self._supports_cache_class and "jamba" not in self.__class__.__name__.lower() + def _prepare_cache_for_generation( + self, + generation_config: GenerationConfig, + model_kwargs: Dict, + assistant_model: PreTrainedModel, + batch_size: int, + device: torch.device, + ) -> bool: + """ + Prepares the cache for generation (if applicable), given `generate`'s paramaterization. If a cache is + instantiated, writes it to `model_kwargs`, under the name expected by the model. + """ + use_dynamic_cache_by_default = False + + if "mamba" in self.__class__.__name__.lower(): + cache_name = "cache_params" + else: + cache_name = "past_key_values" + + # Quick escape route 1: if the user specifies a cache, we don't need to do anything (other than check for + # conflicting `generate` arguments) + user_defined_cache = model_kwargs.get(cache_name) + if user_defined_cache is not None: + if is_torchdynamo_compiling(): + raise ValueError( + "Passing `past_key_values` is not supported when compiling `model.generate` with torch.compile -- you " + "may get incorrect outputs. Please compile `model.forward` only or use the `cache_implementation` " + "input argument." + ) + if generation_config.cache_implementation is not None: + raise ValueError( + f"Passing both `cache_implementation` (used to initialize certain caches) and `{cache_name}` (a " + "Cache object) is unsupported. Please use only one of the two." + ) + return use_dynamic_cache_by_default + + # Quick escape route 2: if the user specifies no cache is to be used. (conflicting arguments are handled in + # `generation_config.validate()`) + + # Otherwise we may need to prepare a cache, based on `generation_config.cache_implementation` + # TODO(joao): support static caches in assisted generation. assisted generation needs to roll back caches, + # which is only supported in dynamic caches atm + if ( + assistant_model is not None + and generation_config.cache_implementation is not None + and self._supports_default_dynamic_cache() + ): + logger.warning_once( + "An assistant model is provided, using a dynamic cache instead of a cache of type=" + f"'{generation_config.cache_implementation}'." + ) + generation_config.cache_implementation = None + + if generation_config.cache_implementation is not None: + if generation_config.cache_implementation in NEED_SETUP_CACHE_CLASSES_MAPPING: + if generation_config.cache_implementation == "static" and not self._supports_static_cache: + raise ValueError( + "This model does not support `cache_implementation='static'`. Please check the following " + "issue: https://github.com/huggingface/transformers/issues/28981" + ) + model_kwargs[cache_name] = self._get_cache( + cache_implementation=generation_config.cache_implementation, + max_batch_size=generation_config.num_beams * generation_config.num_return_sequences * batch_size, + max_cache_len=generation_config.max_length, + device=device, + model_kwargs=model_kwargs, + ) + elif generation_config.cache_implementation == "quantized": + if not self._supports_quantized_cache: + raise ValueError( + "This model does not support the quantized cache. If you want your model to support quantized " + "cache, please open an issue." + ) + + cache_config = ( + generation_config.cache_config + if generation_config.cache_config is not None + else QuantizedCacheConfig() + ) + cache_class = QUANT_BACKEND_CLASSES_MAPPING[cache_config.backend] + + if cache_config.backend == "quanto" and not is_quanto_available(): + raise ImportError( + "You need to install `quanto` in order to use KV cache quantization with quanto backend. " + "Please install it via with `pip install quanto`" + ) + elif cache_config.backend == "HQQ" and not is_hqq_available(): + raise ImportError( + "You need to install `HQQ` in order to use KV cache quantization with HQQ backend. " + "Please install it via with `pip install hqq`" + ) + + model_kwargs[cache_name] = cache_class(cache_config) + elif generation_config.cache_implementation == "offloaded": + model_kwargs[cache_name] = OffloadedCache() + # Use DynamicCache() instance by default. This will avoid back and forth from legacy format that + # keeps copying the cache thus using much more memory + elif generation_config.cache_implementation is None and self._supports_default_dynamic_cache(): + past = model_kwargs.get(cache_name, None) + requires_cross_attention_cache = ( + self.config.is_encoder_decoder or model_kwargs.get("encoder_outputs") is not None + ) + if past is None: + model_kwargs[cache_name] = ( + DynamicCache() + if not requires_cross_attention_cache + else EncoderDecoderCache(DynamicCache(), DynamicCache()) + ) + use_dynamic_cache_by_default = True + elif isinstance(past, tuple): + model_kwargs[cache_name] = ( + DynamicCache.from_legacy_cache(past) + if not requires_cross_attention_cache + else EncoderDecoderCache.from_legacy_cache(past) + ) + use_dynamic_cache_by_default = True + + return use_dynamic_cache_by_default + + def _prepare_special_tokens( self, generation_config: GenerationConfig, @@ -1776,104 +1881,14 @@ def generate( inputs_tensor=inputs_tensor, input_ids_length=input_ids_length, ) - - use_dynamic_cache_by_default = False - if "mamba" in self.__class__.__name__.lower(): - cache_name = "cache_params" - else: - cache_name = "past_key_values" - - # TODO(joao): support static caches in assisted generation. assisted generation needs to roll back caches, - # which is only supported in dynamic caches atm - if ( - assistant_model is not None - and generation_config.cache_implementation is not None - and self._supports_default_dynamic_cache() - ): - logger.warning_once( - "An assistant model is provided, using a dynamic cache instead of a cache of type=" - f"'{generation_config.cache_implementation}'." - ) - generation_config.cache_implementation = None - - if (model_kwargs.get(cache_name) is not None) and is_torchdynamo_compiling(): - raise ValueError( - "Passing `past_key_values` is not supported when compiling `model.generate` with torch.compile -- you " - "may get incorrect outputs. Please compile `model.forward` only or use the `cache_implementation` " - "input argument." - ) - if generation_config.cache_implementation is not None and (model_kwargs.get(cache_name) is not None): - raise ValueError( - f"Passing both `cache_implementation` (used to initialize certain caches) and `{cache_name}` (a " - "Cache object) is unsupported. Please use only one of the two." - ) - elif generation_config.cache_implementation is not None: - if generation_config.cache_implementation in NEED_SETUP_CACHE_CLASSES_MAPPING: - if generation_config.cache_implementation == "static" and not self._supports_static_cache: - raise ValueError( - "This model does not support `cache_implementation='static'`. Please check the following " - "issue: https://github.com/huggingface/transformers/issues/28981" - ) - model_kwargs[cache_name] = self._get_cache( - cache_implementation=generation_config.cache_implementation, - batch_size=generation_config.num_beams * generation_config.num_return_sequences * batch_size, - max_cache_len=generation_config.max_length, - device=device, - model_kwargs=model_kwargs, - ) - elif generation_config.cache_implementation == "quantized": - if not self._supports_quantized_cache: - raise ValueError( - "This model does not support the quantized cache. If you want your model to support quantized " - "cache, please open an issue." - ) - - cache_config = ( - generation_config.cache_config - if generation_config.cache_config is not None - else QuantizedCacheConfig() - ) - cache_class = QUANT_BACKEND_CLASSES_MAPPING[cache_config.backend] - - if cache_config.backend == "quanto" and not is_quanto_available(): - raise ImportError( - "You need to install `quanto` in order to use KV cache quantization with quanto backend. " - "Please install it via with `pip install quanto`" - ) - elif cache_config.backend == "HQQ" and not is_hqq_available(): - raise ImportError( - "You need to install `HQQ` in order to use KV cache quantization with HQQ backend. " - "Please install it via with `pip install hqq`" - ) - - model_kwargs[cache_name] = cache_class(cache_config) - elif generation_config.cache_implementation == "offloaded": - model_kwargs[cache_name] = OffloadedCache() - # Use DynamicCache() instance by default. This will avoid back and forth from legacy format that - # keeps copying the cache thus using much more memory - elif generation_config.cache_implementation is None and self._supports_default_dynamic_cache(): - past = model_kwargs.get(cache_name, None) - requires_cross_attention_cache = ( - self.config.is_encoder_decoder or model_kwargs.get("encoder_outputs") is not None - ) - if past is None: - model_kwargs[cache_name] = ( - DynamicCache() - if not requires_cross_attention_cache - else EncoderDecoderCache(DynamicCache(), DynamicCache()) - ) - use_dynamic_cache_by_default = True - elif isinstance(past, tuple): - model_kwargs[cache_name] = ( - DynamicCache.from_legacy_cache(past) - if not requires_cross_attention_cache - else EncoderDecoderCache.from_legacy_cache(past) - ) - use_dynamic_cache_by_default = True - self._validate_generated_length(generation_config, input_ids_length, has_default_max_length) - # 7. determine generation mode + # 7. Prepare the cache. + # - `model_kwargs` may be updated in place with the appropriate cache + # - `max_length`, prepared above, is used to determine the cache length + self._prepare_cache_for_generation(generation_config, model_kwargs, assistant_model, batch_size, device) + + # 8. determine generation mode generation_mode = generation_config.get_generation_mode(assistant_model) if streamer is not None and (generation_config.num_beams > 1): @@ -1892,7 +1907,7 @@ def generate( UserWarning, ) - # 8. prepare distribution pre_processing samplers + # 9. prepare logits processors and stopping criteria prepared_logits_processor = self._get_logits_processor( generation_config=generation_config, input_ids_seq_length=input_ids_length, @@ -1904,8 +1919,6 @@ def generate( negative_prompt_ids=negative_prompt_ids, negative_prompt_attention_mask=negative_prompt_attention_mask, ) - - # 9. prepare stopping criteria prepared_stopping_criteria = self._get_stopping_criteria( generation_config=generation_config, stopping_criteria=stopping_criteria, tokenizer=tokenizer, **kwargs ) @@ -2138,11 +2151,13 @@ def typeerror(): **model_kwargs, ) - # Convert to legacy cache if needed - if use_dynamic_cache_by_default and generation_config.return_legacy_cache: - if isinstance(result, ModelOutput) and hasattr(result, "past_key_values"): - if isinstance(result.past_key_values, (DynamicCache, EncoderDecoderCache)): - result.past_key_values = result.past_key_values.to_legacy_cache() + # Convert to legacy cache format if requested + if ( + generation_config.return_legacy_cache + and hasattr(result, "past_key_values") + and isinstance(result.past_key_values, (DynamicCache, EncoderDecoderCache)) + ): + result.past_key_values = result.past_key_values.to_legacy_cache() return result def _has_unfinished_sequences( From a6611e70b608c775a32a275b81cb04187ca89d02 Mon Sep 17 00:00:00 2001 From: Joao Gante Date: Sat, 17 Aug 2024 12:21:05 +0000 Subject: [PATCH 2/9] organize cache init --- src/transformers/generation/utils.py | 84 ++++++++++++++++------------ 1 file changed, 48 insertions(+), 36 deletions(-) diff --git a/src/transformers/generation/utils.py b/src/transformers/generation/utils.py index af33086b6dda..7732aa93e23d 100644 --- a/src/transformers/generation/utils.py +++ b/src/transformers/generation/utils.py @@ -1490,7 +1490,7 @@ def _prepare_cache_for_generation( self, generation_config: GenerationConfig, model_kwargs: Dict, - assistant_model: PreTrainedModel, + assistant_model: "PreTrainedModel", batch_size: int, device: torch.device, ) -> bool: @@ -1498,15 +1498,19 @@ def _prepare_cache_for_generation( Prepares the cache for generation (if applicable), given `generate`'s paramaterization. If a cache is instantiated, writes it to `model_kwargs`, under the name expected by the model. """ - use_dynamic_cache_by_default = False if "mamba" in self.__class__.__name__.lower(): cache_name = "cache_params" else: cache_name = "past_key_values" - # Quick escape route 1: if the user specifies a cache, we don't need to do anything (other than check for - # conflicting `generate` arguments) + requires_cross_attention_cache = ( + self.config.is_encoder_decoder or model_kwargs.get("encoder_outputs") is not None + ) + + # Quick escape route 1: if the user specifies a cache, we only need to: + # a) check for conflicting `generate` arguments + # b) convert to the new cache format (if the user passes a legacy cache and model supports it) user_defined_cache = model_kwargs.get(cache_name) if user_defined_cache is not None: if is_torchdynamo_compiling(): @@ -1520,19 +1524,35 @@ def _prepare_cache_for_generation( f"Passing both `cache_implementation` (used to initialize certain caches) and `{cache_name}` (a " "Cache object) is unsupported. Please use only one of the two." ) - return use_dynamic_cache_by_default + if isinstance(user_defined_cache, tuple) and self._supports_default_dynamic_cache(): + model_kwargs[cache_name] = ( + DynamicCache.from_legacy_cache(user_defined_cache) + if not requires_cross_attention_cache + else EncoderDecoderCache.from_legacy_cache(user_defined_cache) + ) + return # Quick escape route 2: if the user specifies no cache is to be used. (conflicting arguments are handled in # `generation_config.validate()`) + if generation_config.use_cache is False: + return + + # Quick escape route 3: model that only supports legacy caches = nothing to prepare + if not self._supports_default_dynamic_cache(): + if generation_config.cache_implementation is not None: + warnings.warn( + "This model does not support `Cache` instances, it only supports the legacy cache format (tuple " + f"of tuples). `cache_implementation` (set to {generation_config.cache_implementation}) will be " + "ignored.", + UserWarning, + ) + return + + # Otherwise we NEED to prepare a cache, based on `generation_config.cache_implementation` - # Otherwise we may need to prepare a cache, based on `generation_config.cache_implementation` # TODO(joao): support static caches in assisted generation. assisted generation needs to roll back caches, # which is only supported in dynamic caches atm - if ( - assistant_model is not None - and generation_config.cache_implementation is not None - and self._supports_default_dynamic_cache() - ): + if assistant_model is not None and generation_config.cache_implementation is not None: logger.warning_once( "An assistant model is provided, using a dynamic cache instead of a cache of type=" f"'{generation_config.cache_implementation}'." @@ -1557,7 +1577,7 @@ def _prepare_cache_for_generation( if not self._supports_quantized_cache: raise ValueError( "This model does not support the quantized cache. If you want your model to support quantized " - "cache, please open an issue." + "cache, please open an issue and tag @zucchini-nlp." ) cache_config = ( @@ -1581,30 +1601,15 @@ def _prepare_cache_for_generation( model_kwargs[cache_name] = cache_class(cache_config) elif generation_config.cache_implementation == "offloaded": model_kwargs[cache_name] = OffloadedCache() + # Use DynamicCache() instance by default. This will avoid back and forth from legacy format that # keeps copying the cache thus using much more memory - elif generation_config.cache_implementation is None and self._supports_default_dynamic_cache(): - past = model_kwargs.get(cache_name, None) - requires_cross_attention_cache = ( - self.config.is_encoder_decoder or model_kwargs.get("encoder_outputs") is not None + else: + model_kwargs[cache_name] = ( + DynamicCache() + if not requires_cross_attention_cache + else EncoderDecoderCache(DynamicCache(), DynamicCache()) ) - if past is None: - model_kwargs[cache_name] = ( - DynamicCache() - if not requires_cross_attention_cache - else EncoderDecoderCache(DynamicCache(), DynamicCache()) - ) - use_dynamic_cache_by_default = True - elif isinstance(past, tuple): - model_kwargs[cache_name] = ( - DynamicCache.from_legacy_cache(past) - if not requires_cross_attention_cache - else EncoderDecoderCache.from_legacy_cache(past) - ) - use_dynamic_cache_by_default = True - - return use_dynamic_cache_by_default - def _prepare_special_tokens( self, @@ -1884,8 +1889,9 @@ def generate( self._validate_generated_length(generation_config, input_ids_length, has_default_max_length) # 7. Prepare the cache. - # - `model_kwargs` may be updated in place with the appropriate cache - # - `max_length`, prepared above, is used to determine the cache length + # - `model_kwargs` may be updated in place with a cache as defined by the parameters in `generation_config`. + # - different models have a different cache name expected by the model (default = "past_key_values") + # - `max_length`, prepared above, is used to determine the maximum cache length self._prepare_cache_for_generation(generation_config, model_kwargs, assistant_model, batch_size, device) # 8. determine generation mode @@ -2153,10 +2159,16 @@ def typeerror(): # Convert to legacy cache format if requested if ( - generation_config.return_legacy_cache + generation_config.return_legacy_cache is not False and hasattr(result, "past_key_values") and isinstance(result.past_key_values, (DynamicCache, EncoderDecoderCache)) ): + if generation_config.return_legacy_cache is None: + logger.warning_once( + "From v4.47 onwards, when a model cache is to be returned, `generate` will return a `Cache` " + "instance instead by default (as opposed to the legacy tuple of tuples format). If you want to " + "keep returning the legacy format, please set `return_legacy_cache=True`." + ) result.past_key_values = result.past_key_values.to_legacy_cache() return result From 69bf5f4b66897273b32d5752b5bee5644deb02f3 Mon Sep 17 00:00:00 2001 From: Joao Gante Date: Sat, 17 Aug 2024 12:30:11 +0000 Subject: [PATCH 3/9] fix conflict --- src/transformers/generation/utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/transformers/generation/utils.py b/src/transformers/generation/utils.py index 7732aa93e23d..8d4ea9288bc6 100644 --- a/src/transformers/generation/utils.py +++ b/src/transformers/generation/utils.py @@ -1568,7 +1568,7 @@ def _prepare_cache_for_generation( ) model_kwargs[cache_name] = self._get_cache( cache_implementation=generation_config.cache_implementation, - max_batch_size=generation_config.num_beams * generation_config.num_return_sequences * batch_size, + batch_size=generation_config.num_beams * generation_config.num_return_sequences * batch_size, max_cache_len=generation_config.max_length, device=device, model_kwargs=model_kwargs, From d3c3e5ae89fad4e3521c61f727224d1886e2b626 Mon Sep 17 00:00:00 2001 From: Joao Gante Date: Sat, 17 Aug 2024 12:42:02 +0000 Subject: [PATCH 4/9] update tests --- tests/generation/test_utils.py | 64 ++++++++++++++++------------------ 1 file changed, 30 insertions(+), 34 deletions(-) diff --git a/tests/generation/test_utils.py b/tests/generation/test_utils.py index 72da44115f5c..3392778a9093 100644 --- a/tests/generation/test_utils.py +++ b/tests/generation/test_utils.py @@ -419,7 +419,6 @@ def test_greedy_generate_dict_outputs(self): for model_class in self.all_generative_model_classes: config, input_ids, attention_mask = self._get_input_ids_and_config() - config.use_cache = False model = model_class(config).to(torch_device).eval() output_generate = self._greedy_generate( model=model, @@ -430,6 +429,7 @@ def test_greedy_generate_dict_outputs(self): output_hidden_states=True, output_attentions=self.has_attentions, return_dict_in_generate=True, + use_cache=False, ) if model.config.is_encoder_decoder: @@ -454,7 +454,6 @@ def test_greedy_generate_dict_outputs_use_cache(self): if any(model_name in model_class.__name__.lower() for model_name in ["rwkv"]): self.skipTest(reason="Won't fix: model with non-standard dictionary output shapes") - config.use_cache = True config.is_decoder = True model = model_class(config).to(torch_device).eval() output_generate = self._greedy_generate( @@ -466,6 +465,7 @@ def test_greedy_generate_dict_outputs_use_cache(self): output_hidden_states=True, output_attentions=self.has_attentions, return_dict_in_generate=True, + use_cache=True, ) if model.config.is_encoder_decoder: @@ -495,7 +495,6 @@ def test_sample_generate_dict_output(self): for model_class in self.all_generative_model_classes: config, input_ids, attention_mask = self._get_input_ids_and_config() - config.use_cache = False model = model_class(config).to(torch_device).eval() output_generate = self._sample_generate( model=model, @@ -507,6 +506,7 @@ def test_sample_generate_dict_output(self): output_hidden_states=True, output_attentions=self.has_attentions, return_dict_in_generate=True, + use_cache=False, ) if model.config.is_encoder_decoder: @@ -545,9 +545,6 @@ def test_beam_search_generate_dict_output(self): for model_class in self.all_generative_model_classes: config, input_ids, attention_mask = self._get_input_ids_and_config() - # disable cache - config.use_cache = False - model = model_class(config).to(torch_device).eval() beam_kwargs = self._get_beam_kwargs() output_generate = self._beam_search_generate( @@ -560,6 +557,7 @@ def test_beam_search_generate_dict_output(self): output_hidden_states=True, output_attentions=self.has_attentions, return_dict_in_generate=True, + use_cache=False, ) if model.config.is_encoder_decoder: self.assertTrue(output_generate.sequences.shape[-1] == self.max_new_tokens + 1) @@ -589,7 +587,6 @@ def test_beam_search_generate_dict_outputs_use_cache(self): model = model_class(config).to(torch_device).eval() beam_kwargs = self._get_beam_kwargs() - config.use_cache = True config.is_decoder = True model = model_class(config).to(torch_device).eval() output_generate = self._beam_search_generate( @@ -602,6 +599,7 @@ def test_beam_search_generate_dict_outputs_use_cache(self): output_hidden_states=True, output_attentions=self.has_attentions, return_dict_in_generate=True, + use_cache=True, ) if model.config.is_encoder_decoder: @@ -676,9 +674,6 @@ def test_beam_sample_generate_dict_output(self): for model_class in self.all_generative_model_classes: config, input_ids, attention_mask = self._get_input_ids_and_config() - # disable cache - config.use_cache = False - model = model_class(config).to(torch_device).eval() beam_kwargs = self._get_beam_kwargs() @@ -692,6 +687,7 @@ def test_beam_sample_generate_dict_output(self): output_hidden_states=True, output_attentions=self.has_attentions, return_dict_in_generate=True, + use_cache=False, ) if model.config.is_encoder_decoder: @@ -764,7 +760,6 @@ def test_group_beam_search_generate(self): def test_group_beam_search_generate_dict_output(self): for model_class in self.all_generative_model_classes: config, input_ids, attention_mask = self._get_input_ids_and_config() - config.use_cache = False model = model_class(config).to(torch_device).eval() beam_kwargs = self._get_diverse_beam_kwargs() @@ -778,6 +773,7 @@ def test_group_beam_search_generate_dict_output(self): output_hidden_states=True, output_attentions=self.has_attentions, return_dict_in_generate=True, + use_cache=False, ) if model.config.is_encoder_decoder: self.assertTrue(output_generate.sequences.shape[-1] == self.max_new_tokens + 1) @@ -857,9 +853,6 @@ def test_constrained_beam_search_generate_dict_output(self): for model_class in self.all_generative_model_classes: config, input_ids, attention_mask = self._get_input_ids_and_config() - # disable cache - config.use_cache = False - model = model_class(config).to(torch_device).eval() # Sample constraints @@ -882,6 +875,7 @@ def test_constrained_beam_search_generate_dict_output(self): output_hidden_states=True, output_attentions=self.has_attentions, return_dict_in_generate=True, + use_cache=False, ) if model.config.is_encoder_decoder: @@ -913,13 +907,12 @@ def test_contrastive_generate(self): # NOTE: contrastive search only works with cache on at the moment. if not hasattr(config, "use_cache"): self.skipTest(reason="This model doesn't support caching") - config.use_cache = True config.is_decoder = True # test old generation output for backwards compatibility model = model_class(config).to(torch_device).eval() output_generate = self._contrastive_generate( - model=model, input_ids=input_ids, attention_mask=attention_mask + model=model, input_ids=input_ids, attention_mask=attention_mask, use_cache=True ) if model.config.is_encoder_decoder: self.assertTrue(output_generate.shape[-1] == self.max_new_tokens + 1) @@ -940,7 +933,6 @@ def test_contrastive_generate_dict_outputs_use_cache(self): # NOTE: contrastive search only works with cache on at the moment. if not hasattr(config, "use_cache"): self.skipTest(reason="This model doesn't support caching") - config.use_cache = True config.is_decoder = True model = model_class(config).to(torch_device).eval() @@ -953,6 +945,7 @@ def test_contrastive_generate_dict_outputs_use_cache(self): output_hidden_states=True, output_attentions=self.has_attentions, return_dict_in_generate=True, + use_cache=True, ) if model.config.is_encoder_decoder: @@ -978,7 +971,6 @@ def test_contrastive_generate_low_memory(self): if not hasattr(config, "use_cache"): self.skipTest(reason="This model doesn't support caching") - config.use_cache = True config.is_decoder = True # test output equality of low versus high memory @@ -991,6 +983,7 @@ def test_contrastive_generate_low_memory(self): low_memory=True, max_new_tokens=self.max_new_tokens, attention_mask=attention_mask, + use_cache=True, ) high_output = model.generate( @@ -1000,6 +993,7 @@ def test_contrastive_generate_low_memory(self): low_memory=False, max_new_tokens=self.max_new_tokens, attention_mask=attention_mask, + use_cache=True, ) self.assertListEqual(low_output.tolist(), high_output.tolist()) @@ -1031,10 +1025,17 @@ def test_beam_search_low_memory(self): # test output equality of low versus high memory model = model_class(config).to(torch_device).eval() - low_output = model.generate(input_ids, max_new_tokens=8, num_beams=5, early_stopping=True, low_memory=True) + low_output = model.generate( + input_ids, max_new_tokens=8, num_beams=5, early_stopping=True, low_memory=True, use_cache=True + ) high_output = model.generate( - input_ids, max_new_tokens=8, num_beams=5, early_stopping=True, low_memory=False + input_ids, + max_new_tokens=8, + num_beams=5, + early_stopping=True, + low_memory=False, + use_cache=True, ) self.assertListEqual(low_output.tolist(), high_output.tolist()) @@ -1079,7 +1080,6 @@ def test_assisted_decoding_matches_greedy_search(self, assistant_type): if not hasattr(config, "use_cache"): self.skipTest(reason="This model doesn't support caching") - config.use_cache = True config.is_decoder = True model = model_class(config).to(torch_device).eval() # Sets assisted generation arguments such that: @@ -1098,6 +1098,7 @@ def test_assisted_decoding_matches_greedy_search(self, assistant_type): "output_hidden_states": True, "output_attentions": self.has_attentions, "return_dict_in_generate": True, + "use_cache": True, } output_greedy = model.generate(input_ids, attention_mask=attention_mask, **generation_kwargs) @@ -1150,7 +1151,6 @@ def test_prompt_lookup_decoding_matches_greedy_search(self): if not hasattr(config, "use_cache"): self.skipTest(reason="This model doesn't support caching") - config.use_cache = True config.is_decoder = True model = model_class(config).to(torch_device).eval() # Sets assisted generation arguments such that: @@ -1169,6 +1169,7 @@ def test_prompt_lookup_decoding_matches_greedy_search(self): "output_hidden_states": True, "output_attentions": self.has_attentions, "return_dict_in_generate": True, + "use_cache": True, } output_greedy = model.generate(input_ids, attention_mask=attention_mask, **generation_kwargs) @@ -1196,12 +1197,6 @@ def test_dola_decoding_sample(self): # enable cache if the model is not openai-gpt, xlnet, cpm, or xlm config, input_ids, attention_mask = self._get_input_ids_and_config() - # Some models don't support the cache and returning past_key_values - if not hasattr(config, "use_cache"): - config.use_cache = False - else: - config.use_cache = True - # Encoder-decoder models are not supported if config.is_encoder_decoder: self.skipTest("DoLa is not supported for encoder-decoder models") @@ -1224,11 +1219,12 @@ def test_dola_decoding_sample(self): "output_hidden_states": True, "output_attentions": self.has_attentions, "return_dict_in_generate": True, + "use_cache": hasattr(config, "use_cache"), # Some models don't support the cache } generation_kwargs.update({"dola_layers": "low"}) model_kwargs = {"attention_mask": attention_mask} if attention_mask is not None else {} output_dola = model.generate(input_ids, **model_kwargs, **generation_kwargs) - self._check_outputs(output_dola, input_ids, model.config, use_cache=config.use_cache) + self._check_outputs(output_dola, input_ids, model.config, use_cache=hasattr(config, "use_cache")) def test_assisted_decoding_sample(self): # In this test we don't check assisted vs non-assisted output -- seeded assisted decoding with sample will not @@ -1261,7 +1257,6 @@ def test_assisted_decoding_sample(self): if not hasattr(config, "use_cache"): self.skipTest(reason="This model doesn't support caching") - config.use_cache = True config.is_decoder = True model = model_class(config).to(torch_device).eval() # Sets assisted generation arguments such that: @@ -1284,6 +1279,7 @@ def test_assisted_decoding_sample(self): "output_hidden_states": True, "output_attentions": self.has_attentions, "return_dict_in_generate": True, + "use_cache": True, } output_assisted = model.generate(input_ids, attention_mask=attention_mask, **generation_kwargs) @@ -1566,7 +1562,6 @@ def test_generate_continue_from_past_key_values(self): # 3. ignore `token_type_ids` for simplicity # 4. ignore `forced_eos_token_id`, which requires further manipulation of the continuation inputs and is # active by default on some models - config.use_cache = True if "token_type_ids" in inputs: del inputs["token_type_ids"] @@ -1574,6 +1569,7 @@ def test_generate_continue_from_past_key_values(self): model.eval() model.generation_config.pad_token_id = model.generation_config.eos_token_id = -1 model.generation_config.forced_eos_token_id = None + model.generation_config.use_cache = True # If "past_key_values" is not returned, skip the test (e.g. RWKV uses a different cache name and format) outputs = model(**inputs) @@ -1631,7 +1627,6 @@ def test_new_cache_format(self, num_beams, do_sample): self.skipTest(reason="This model does not support the new cache format") config, input_ids, attention_mask = self._get_input_ids_and_config() - config.use_cache = True model = model_class(config).to(torch_device).eval() generation_kwargs = { @@ -1640,6 +1635,7 @@ def test_new_cache_format(self, num_beams, do_sample): "num_beams": num_beams, "num_return_sequences": num_beams, "return_dict_in_generate": True, # Required to return `past_key_values` + "use_cache": True, } # Sets seed before calling `generate` for the case with do_sample=True @@ -1701,7 +1697,6 @@ def test_generate_with_static_cache(self): if config.is_encoder_decoder: self.skipTest(reason="This model is encoder-decoder and has Encoder-Decoder Cache") - config.use_cache = True config.is_decoder = True batch_size, seq_length = input_ids.shape max_new_tokens = 20 @@ -1712,6 +1707,7 @@ def test_generate_with_static_cache(self): "max_new_tokens": max_new_tokens, "cache_implementation": "static", "return_dict_in_generate": True, # Required to return `past_key_values` + "use_cache": True, } max_cache_len = seq_length + max_new_tokens @@ -1740,7 +1736,6 @@ def test_generate_with_quant_cache(self): self.skipTest(reason="This model does not support the quantized cache format") config, input_ids, attention_mask = self._get_input_ids_and_config() - config.use_cache = True config.is_decoder = True model = model_class(config).to(torch_device).eval() @@ -1750,6 +1745,7 @@ def test_generate_with_quant_cache(self): # careful with group size, should be divisor of model's hidden size "cache_config": {"backend": "quanto", "nbits": 2, "q_group_size": 8, "residual_length": 128}, "return_dict_in_generate": True, # Required to return `past_key_values` + "use_cache": True, } results = model.generate(input_ids, attention_mask=attention_mask, **generation_kwargs) From bc5b50a72d786b27e01fef98de256af9cbd3d9cc Mon Sep 17 00:00:00 2001 From: Joao Gante Date: Sat, 17 Aug 2024 14:15:36 +0000 Subject: [PATCH 5/9] handle corner cases --- src/transformers/generation/utils.py | 33 ++++++++++++++++++---------- tests/generation/test_utils.py | 16 ++++++++++++++ 2 files changed, 37 insertions(+), 12 deletions(-) diff --git a/src/transformers/generation/utils.py b/src/transformers/generation/utils.py index 8d4ea9288bc6..380deae46c72 100644 --- a/src/transformers/generation/utils.py +++ b/src/transformers/generation/utils.py @@ -1499,11 +1499,7 @@ def _prepare_cache_for_generation( instantiated, writes it to `model_kwargs`, under the name expected by the model. """ - if "mamba" in self.__class__.__name__.lower(): - cache_name = "cache_params" - else: - cache_name = "past_key_values" - + cache_name = "past_key_values" if "mamba" not in self.__class__.__name__.lower() else "cache_params" requires_cross_attention_cache = ( self.config.is_encoder_decoder or model_kwargs.get("encoder_outputs") is not None ) @@ -1515,9 +1511,9 @@ def _prepare_cache_for_generation( if user_defined_cache is not None: if is_torchdynamo_compiling(): raise ValueError( - "Passing `past_key_values` is not supported when compiling `model.generate` with torch.compile -- you " - "may get incorrect outputs. Please compile `model.forward` only or use the `cache_implementation` " - "input argument." + "Passing `past_key_values` is not supported when compiling `model.generate` with torch.compile -- " + "you may get incorrect outputs. Please compile `model.forward` only or use the " + "`cache_implementation` input argument." ) if generation_config.cache_implementation is not None: raise ValueError( @@ -1892,6 +1888,9 @@ def generate( # - `model_kwargs` may be updated in place with a cache as defined by the parameters in `generation_config`. # - different models have a different cache name expected by the model (default = "past_key_values") # - `max_length`, prepared above, is used to determine the maximum cache length + # TODO (joao): remove `user_defined_cache` after v4.47 (remove default conversion to legacy format) + cache_name = "past_key_values" if "mamba" not in self.__class__.__name__.lower() else "cache_params" + user_defined_cache = model_kwargs.get(cache_name) self._prepare_cache_for_generation(generation_config, model_kwargs, assistant_model, batch_size, device) # 8. determine generation mode @@ -2159,17 +2158,27 @@ def typeerror(): # Convert to legacy cache format if requested if ( - generation_config.return_legacy_cache is not False + generation_config.return_legacy_cache is not False # Should check for `True` after v4.47 and hasattr(result, "past_key_values") - and isinstance(result.past_key_values, (DynamicCache, EncoderDecoderCache)) + and hasattr(result.past_key_values, "to_legacy_cache") + and result.past_key_values.to_legacy_cache is not None ): - if generation_config.return_legacy_cache is None: + # handle BC (convert by default if he user hasn't passed a cache AND the cache is of the default type) + should_convert_cache = generation_config.return_legacy_cache + is_user_defined_cache = user_defined_cache is not None + is_default_cache_type = ( + type(result.past_key_values) == DynamicCache # noqa E721 + or type(result.past_key_values) == EncoderDecoderCache # noqa E721 + ) + if not is_user_defined_cache and is_default_cache_type: logger.warning_once( "From v4.47 onwards, when a model cache is to be returned, `generate` will return a `Cache` " "instance instead by default (as opposed to the legacy tuple of tuples format). If you want to " "keep returning the legacy format, please set `return_legacy_cache=True`." ) - result.past_key_values = result.past_key_values.to_legacy_cache() + should_convert_cache = True + if should_convert_cache: + result.past_key_values = result.past_key_values.to_legacy_cache() return result def _has_unfinished_sequences( diff --git a/tests/generation/test_utils.py b/tests/generation/test_utils.py index 3392778a9093..f91172d5c588 100644 --- a/tests/generation/test_utils.py +++ b/tests/generation/test_utils.py @@ -194,6 +194,7 @@ def _greedy_generate( output_attentions=False, output_hidden_states=False, return_dict_in_generate=False, + use_cache=True, ): logits_processor_kwargs = self._get_logits_processor_kwargs(do_sample=False) model_kwargs = {"attention_mask": attention_mask} if attention_mask is not None else {} @@ -207,6 +208,7 @@ def _greedy_generate( output_scores=output_scores, output_logits=output_logits, return_dict_in_generate=return_dict_in_generate, + use_cache=use_cache, **logits_processor_kwargs, **model_kwargs, ) @@ -224,6 +226,7 @@ def _sample_generate( output_attentions=False, output_hidden_states=False, return_dict_in_generate=False, + use_cache=True, ): torch.manual_seed(0) logits_processor_kwargs = self._get_logits_processor_kwargs(do_sample=True) @@ -239,6 +242,7 @@ def _sample_generate( output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict_in_generate=return_dict_in_generate, + use_cache=use_cache, **logits_processor_kwargs, **model_kwargs, ) @@ -256,6 +260,7 @@ def _beam_search_generate( output_attentions=False, output_hidden_states=False, return_dict_in_generate=False, + use_cache=True, ): logits_processor_kwargs = self._get_logits_processor_kwargs(do_sample=False) model_kwargs = {"attention_mask": attention_mask} if attention_mask is not None else {} @@ -268,6 +273,7 @@ def _beam_search_generate( output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict_in_generate=return_dict_in_generate, + use_cache=use_cache, **beam_kwargs, **logits_processor_kwargs, **model_kwargs, @@ -286,6 +292,7 @@ def _beam_sample_generate( output_attentions=False, output_hidden_states=False, return_dict_in_generate=False, + use_cache=True, ): torch.manual_seed(0) logits_processor_kwargs = self._get_logits_processor_kwargs(do_sample=True) @@ -299,6 +306,7 @@ def _beam_sample_generate( output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict_in_generate=return_dict_in_generate, + use_cache=use_cache, **beam_kwargs, **logits_processor_kwargs, **model_kwargs, @@ -317,6 +325,7 @@ def _group_beam_search_generate( output_attentions=False, output_hidden_states=False, return_dict_in_generate=False, + use_cache=True, ): logits_processor_kwargs = self._get_logits_processor_kwargs(do_sample=False) model_kwargs = {"attention_mask": attention_mask} if attention_mask is not None else {} @@ -329,6 +338,7 @@ def _group_beam_search_generate( output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict_in_generate=return_dict_in_generate, + use_cache=use_cache, **beam_kwargs, **logits_processor_kwargs, **model_kwargs, @@ -348,6 +358,7 @@ def _constrained_beam_search_generate( output_attentions=False, output_hidden_states=False, return_dict_in_generate=False, + use_cache=True, ): logits_processor_kwargs = self._get_logits_processor_kwargs(do_sample=False) model_kwargs = {"attention_mask": attention_mask} if attention_mask is not None else {} @@ -361,6 +372,7 @@ def _constrained_beam_search_generate( output_hidden_states=output_hidden_states, return_dict_in_generate=return_dict_in_generate, constraints=constraints, + use_cache=use_cache, **beam_kwargs, **logits_processor_kwargs, **model_kwargs, @@ -378,6 +390,7 @@ def _contrastive_generate( output_attentions=False, output_hidden_states=False, return_dict_in_generate=False, + use_cache=True, ): contrastive_search_kwargs = { "penalty_alpha": 0.6, @@ -396,6 +409,7 @@ def _contrastive_generate( output_scores=output_scores, output_logits=output_logits, return_dict_in_generate=return_dict_in_generate, + use_cache=use_cache, **logits_processor_kwargs, **model_kwargs, **contrastive_search_kwargs, @@ -1902,6 +1916,8 @@ def _check_outputs(self, output, input_ids, config, use_cache=False, num_return_ seq_length=past_sequence_length, config=config, ) + elif use_cache is False: + self.assertTrue(output.past_key_values is None) def _check_scores(self, batch_size, scores, length, config): expected_shape = (batch_size, config.vocab_size) From 8958d49a0bbafa47408dbbfb27da4c550789730d Mon Sep 17 00:00:00 2001 From: Joao Gante Date: Sat, 17 Aug 2024 14:50:07 +0000 Subject: [PATCH 6/9] special models --- src/transformers/models/rwkv/modeling_rwkv.py | 3 +- tests/generation/test_utils.py | 30 +++++++++---------- 2 files changed, 17 insertions(+), 16 deletions(-) diff --git a/src/transformers/models/rwkv/modeling_rwkv.py b/src/transformers/models/rwkv/modeling_rwkv.py index f6b8cd412be5..7dec1f26e1a3 100644 --- a/src/transformers/models/rwkv/modeling_rwkv.py +++ b/src/transformers/models/rwkv/modeling_rwkv.py @@ -768,7 +768,7 @@ def get_output_embeddings(self): def set_output_embeddings(self, new_embeddings): self.head = new_embeddings - def prepare_inputs_for_generation(self, input_ids, state=None, inputs_embeds=None, **kwargs): + def prepare_inputs_for_generation(self, input_ids, state=None, inputs_embeds=None, use_cache=None, **kwargs): # only last token for inputs_ids if the state is passed along. if state is not None: input_ids = input_ids[:, -1].unsqueeze(-1) @@ -780,6 +780,7 @@ def prepare_inputs_for_generation(self, input_ids, state=None, inputs_embeds=Non model_inputs = {"input_ids": input_ids} model_inputs["state"] = state + model_inputs["use_cache"] = use_cache return model_inputs @add_start_docstrings_to_model_forward(RWKV_INPUTS_DOCSTRING) diff --git a/tests/generation/test_utils.py b/tests/generation/test_utils.py index f91172d5c588..ae52f6c67404 100644 --- a/tests/generation/test_utils.py +++ b/tests/generation/test_utils.py @@ -1900,24 +1900,24 @@ def _check_outputs(self, output, input_ids, config, use_cache=False, num_return_ # Past Key Value States -- a few notes here: # 1. Its inner sequence length is with respect to the inputs of the latest forward pass, hence the "-1" - # 2. Some old models still return `output.past_key_values` even without `use_cache=True` - # 3. TODO (joao): A few models have different formats/types, skipping those until the cache refactor is - # complete - models_without_standard_cache = ("ctrl", "fsmt", "gptbigcode", "mega", "reformer", "jamba", "mamba") + # 2. We ignore models that have unique cache structures (e.g. mamba) or are in need of refatoring to match the + # standard cache format (e.g.gptbigcode ) + models_without_standard_cache = ("ctrl", "fsmt", "gptbigcode", "mega", "reformer", "jamba", "mamba", "xlnet") has_standard_cache = not any( model_name in config.__class__.__name__.lower() for model_name in models_without_standard_cache ) - if use_cache and has_standard_cache: - past_key_values = output.past_key_values - past_sequence_length = output.sequences.shape[-1] - 1 - self._check_past_key_values_for_generate( - num_sequences_in_output, - past_key_values, - seq_length=past_sequence_length, - config=config, - ) - elif use_cache is False: - self.assertTrue(output.past_key_values is None) + if has_standard_cache: + if use_cache: + past_key_values = output.past_key_values + past_sequence_length = output.sequences.shape[-1] - 1 + self._check_past_key_values_for_generate( + num_sequences_in_output, + past_key_values, + seq_length=past_sequence_length, + config=config, + ) + elif use_cache is False: + self.assertTrue(output.past_key_values is None) def _check_scores(self, batch_size, scores, length, config): expected_shape = (batch_size, config.vocab_size) From 914a7eabfbf4261df1303eec03e0e54078585e6b Mon Sep 17 00:00:00 2001 From: Joao Gante Date: Sat, 17 Aug 2024 15:01:43 +0000 Subject: [PATCH 7/9] whisper is special --- src/transformers/generation/utils.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/src/transformers/generation/utils.py b/src/transformers/generation/utils.py index 380deae46c72..188fa3832909 100644 --- a/src/transformers/generation/utils.py +++ b/src/transformers/generation/utils.py @@ -2168,7 +2168,11 @@ def typeerror(): is_user_defined_cache = user_defined_cache is not None is_default_cache_type = ( type(result.past_key_values) == DynamicCache # noqa E721 - or type(result.past_key_values) == EncoderDecoderCache # noqa E721 + or ( + isinstance(result.past_key_values, EncoderDecoderCache) + and type(result.past_key_values.self_attention_cache) == DynamicCache + and type(result.past_key_values.cross_attention_cache) == DynamicCache + ) # noqa E721 ) if not is_user_defined_cache and is_default_cache_type: logger.warning_once( From e8492bf62baa7ec991c469151e6145cbd332fab7 Mon Sep 17 00:00:00 2001 From: Joao Gante Date: Sat, 17 Aug 2024 15:05:27 +0000 Subject: [PATCH 8/9] make fixup :D --- src/transformers/generation/utils.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/transformers/generation/utils.py b/src/transformers/generation/utils.py index 188fa3832909..f0fbc8f77f35 100644 --- a/src/transformers/generation/utils.py +++ b/src/transformers/generation/utils.py @@ -2170,9 +2170,9 @@ def typeerror(): type(result.past_key_values) == DynamicCache # noqa E721 or ( isinstance(result.past_key_values, EncoderDecoderCache) - and type(result.past_key_values.self_attention_cache) == DynamicCache - and type(result.past_key_values.cross_attention_cache) == DynamicCache - ) # noqa E721 + and type(result.past_key_values.self_attention_cache) == DynamicCache # noqa E721 + and type(result.past_key_values.cross_attention_cache) == DynamicCache # noqa E721 + ) ) if not is_user_defined_cache and is_default_cache_type: logger.warning_once( From 550f7a6d872651d0d7f59ec1338c11209bc9b251 Mon Sep 17 00:00:00 2001 From: Joao Gante Date: Thu, 22 Aug 2024 17:30:03 +0000 Subject: [PATCH 9/9] PR comments --- src/transformers/generation/configuration_utils.py | 2 +- src/transformers/generation/utils.py | 7 +------ 2 files changed, 2 insertions(+), 7 deletions(-) diff --git a/src/transformers/generation/configuration_utils.py b/src/transformers/generation/configuration_utils.py index bc464e3d6f2e..160a8a7eae2d 100644 --- a/src/transformers/generation/configuration_utils.py +++ b/src/transformers/generation/configuration_utils.py @@ -716,7 +716,7 @@ def validate(self, is_init=False): ) for arg_name in ("cache_implementation", "cache_config", "return_legacy_cache"): if getattr(self, arg_name) is not None: - warnings.warn( + logger.warning_once( no_cache_warning.format(cache_arg=arg_name, cache_arg_value=getattr(self, arg_name)), UserWarning, ) diff --git a/src/transformers/generation/utils.py b/src/transformers/generation/utils.py index f0fbc8f77f35..a9ebdcdd4775 100644 --- a/src/transformers/generation/utils.py +++ b/src/transformers/generation/utils.py @@ -1509,12 +1509,6 @@ def _prepare_cache_for_generation( # b) convert to the new cache format (if the user passes a legacy cache and model supports it) user_defined_cache = model_kwargs.get(cache_name) if user_defined_cache is not None: - if is_torchdynamo_compiling(): - raise ValueError( - "Passing `past_key_values` is not supported when compiling `model.generate` with torch.compile -- " - "you may get incorrect outputs. Please compile `model.forward` only or use the " - "`cache_implementation` input argument." - ) if generation_config.cache_implementation is not None: raise ValueError( f"Passing both `cache_implementation` (used to initialize certain caches) and `{cache_name}` (a " @@ -2159,6 +2153,7 @@ def typeerror(): # Convert to legacy cache format if requested if ( generation_config.return_legacy_cache is not False # Should check for `True` after v4.47 + and not is_torchdynamo_compiling() and hasattr(result, "past_key_values") and hasattr(result.past_key_values, "to_legacy_cache") and result.past_key_values.to_legacy_cache is not None