diff --git a/src/transformers/generation/configuration_utils.py b/src/transformers/generation/configuration_utils.py index aa5e77ac6817..160a8a7eae2d 100644 --- a/src/transformers/generation/configuration_utils.py +++ b/src/transformers/generation/configuration_utils.py @@ -130,9 +130,29 @@ class GenerationConfig(PushToHubMixin): [this paper](https://arxiv.org/pdf/1610.02424.pdf) for more details. penalty_alpha (`float`, *optional*): The values balance the model confidence and the degeneration penalty in contrastive search decoding. + dola_layers (`str` or `List[int]`, *optional*): + The layers to use for DoLa decoding. If `None`, DoLa decoding is not used. If a string, it must + be one of "low" or "high", which means using the lower part or higher part of the model layers, respectively. + "low" means the first half of the layers up to the first 20 layers, and "high" means the last half of the + layers up to the last 20 layers. + If a list of integers, it must contain the indices of the layers to use for candidate premature layers in DoLa. + The 0-th layer is the word embedding layer of the model. Set to `'low'` to improve long-answer reasoning tasks, + `'high'` to improve short-answer tasks. Check the [documentation](https://github.com/huggingface/transformers/blob/main/docs/source/en/generation_strategies.md) + or [the paper](https://arxiv.org/abs/2309.03883) for more details. + + > Parameters that control the cache + use_cache (`bool`, *optional*, defaults to `True`): Whether or not the model should use the past last key/values attentions (if applicable to the model) to speed up decoding. + cache_implementation (`str`, *optional*, default to `None`): + Cache class that should be used when generating. + cache_config (`CacheConfig` or `dict`, *optional*, default to `None`): + Arguments used in the key-value cache class can be passed in `cache_config`. Can be passed as a `Dict` and + it will be converted to its repsective `CacheConfig` internally. + Otherwise can be passed as a `CacheConfig` class matching the indicated `cache_implementation`. + return_legacy_cache (`bool`, *optional*, default to `True`): + Whether to return the legacy or new format of the cache when `DynamicCache` is used by default. > Parameters for manipulation of the model output logits @@ -307,29 +327,6 @@ class GenerationConfig(PushToHubMixin): max_matching_ngram_size (`int`, *optional*, default to `None`): The maximum ngram size to be considered for matching in the prompt. Default to 2 if not provided. - > Generation parameters exclusive to [DoLa decoding](https://arxiv.org/abs/2309.03883) - - dola_layers (`str` or `List[int]`, *optional*): - The layers to use for DoLa decoding. If `None`, DoLa decoding is not used. If a string, it must - be one of "low" or "high", which means using the lower part or higher part of the model layers, respectively. - "low" means the first half of the layers up to the first 20 layers, and "high" means the last half of the - layers up to the last 20 layers. - If a list of integers, it must contain the indices of the layers to use for candidate premature layers in DoLa. - The 0-th layer is the word embedding layer of the model. Set to `'low'` to improve long-answer reasoning tasks, - `'high'` to improve short-answer tasks. Check the [documentation](https://github.com/huggingface/transformers/blob/main/docs/source/en/generation_strategies.md) - or [the paper](https://arxiv.org/abs/2309.03883) for more details. - - > Parameters specific to the caching mechanism: - - cache_implementation (`str`, *optional*, default to `None`): - Cache class that should be used when generating. - cache_config (`CacheConfig` or `dict`, *optional*, default to `None`): - Arguments used in the key-value cache class can be passed in `cache_config`. Can be passed as a `Dict` and - it will be converted to its repsective `CacheConfig` internally. - Otherwise can be passed as a `CacheConfig` class matching the indicated `cache_implementation`. - return_legacy_cache (`bool`, *optional*, default to `True`): - Whether to return the legacy or new format of the cache when `DynamicCache` is used by default. - > Wild card generation_kwargs: @@ -352,7 +349,19 @@ def __init__(self, **kwargs): self.num_beams = kwargs.pop("num_beams", 1) self.num_beam_groups = kwargs.pop("num_beam_groups", 1) self.penalty_alpha = kwargs.pop("penalty_alpha", None) + self.dola_layers = kwargs.pop("dola_layers", None) + + # Parameters that control the cache self.use_cache = kwargs.pop("use_cache", True) + self.cache_implementation = kwargs.pop("cache_implementation", None) + self.cache_config = kwargs.pop("cache_config", None) + if self.cache_implementation is not None and self.cache_implementation in NEEDS_CACHE_CONFIG: + cache_config_class = NEEDS_CACHE_CONFIG[self.cache_implementation] + if self.cache_config is None: + self.cache_config = cache_config_class() + elif isinstance(self.cache_config, dict): + self.cache_config = cache_config_class.from_dict(self.cache_config) + self.return_legacy_cache = kwargs.pop("return_legacy_cache", None) # Parameters for manipulation of the model output logits self.temperature = kwargs.pop("temperature", 1.0) @@ -411,20 +420,6 @@ def __init__(self, **kwargs): self.num_assistant_tokens = kwargs.pop("num_assistant_tokens", 5) self.num_assistant_tokens_schedule = kwargs.pop("num_assistant_tokens_schedule", "heuristic") - # DoLa generation - self.dola_layers = kwargs.pop("dola_layers", None) - - # Cache implementation - self.cache_implementation = kwargs.pop("cache_implementation", None) - self.cache_config = kwargs.pop("cache_config", None) - if self.cache_implementation is not None and self.cache_implementation in NEEDS_CACHE_CONFIG: - cache_config_class = NEEDS_CACHE_CONFIG[self.cache_implementation] - if self.cache_config is None: - self.cache_config = cache_config_class() - elif isinstance(self.cache_config, dict): - self.cache_config = cache_config_class.from_dict(self.cache_config) - self.return_legacy_cache = kwargs.pop("return_legacy_cache", True) - # Prompt lookup decoding self.prompt_lookup_num_tokens = kwargs.pop("prompt_lookup_num_tokens", None) self.max_matching_ngram_size = kwargs.pop("max_matching_ngram_size", None) @@ -544,8 +539,9 @@ def validate(self, is_init=False): raise ValueError(f"`max_new_tokens` must be greater than 0, but is {self.max_new_tokens}.") if self.pad_token_id is not None and self.pad_token_id < 0: warnings.warn( - f"`pad_token_id` should be positive but got {self.pad_token_id}. This will cause errors when batch generating, if there is padding. " - "Please set `pad_token_id` explicitly by `model.generation_config.pad_token_id=PAD_TOKEN_ID` to avoid errors in generation, and ensure your `input_ids` input does not have negative values." + f"`pad_token_id` should be positive but got {self.pad_token_id}. This will cause errors when batch " + "generating, if there is padding. Please set `pad_token_id` explicitly as " + "`model.generation_config.pad_token_id=PAD_TOKEN_ID` to avoid errors in generation" ) # Validation of attribute relations: @@ -675,6 +671,14 @@ def validate(self, is_init=False): group_error_prefix + "`diversity_penalty` should be greater than `0.0`, otherwise your groups will be identical." ) + # DoLa generation + if self.dola_layers is not None and (self.repetition_penalty is None or self.repetition_penalty < 1.2): + warnings.warn( + "`dola_layers` is set to trigger DoLa decoding, but `repetition_penalty` is set to a value of " + f"{self.repetition_penalty}, which could induce unwanted repetition. The recommended value for " + "DoLa decoding is `repetition_penalty>=1.2`.", + UserWarning, + ) # 4. check `num_return_sequences` if self.num_return_sequences != 1: @@ -690,7 +694,7 @@ def validate(self, is_init=False): f"({self.num_beams})." ) - # 5. check `cache_config` + # 5. check cache-related arguments if self.cache_config is not None: cache_class = NEEDS_CACHE_CONFIG.get(self.cache_implementation) if cache_class is None: @@ -702,6 +706,20 @@ def validate(self, is_init=False): if not isinstance(self.cache_config, cache_class): self.cache_config = cache_class.from_dict(self.cache_config) self.cache_config.validate() + if self.use_cache is False: + # In this case, all cache-related arguments should be unset. However, since `use_cache=False` is often used + # passed to `generate` directly to hot-fix cache issues, let's raise a warning instead of an error + # (otherwise a user might need to overwrite several parameters). + no_cache_warning = ( + "You have set `use_cache` to `False`, but {cache_arg} is set to {cache_arg_value}. {cache_arg} will " + "have no effect." + ) + for arg_name in ("cache_implementation", "cache_config", "return_legacy_cache"): + if getattr(self, arg_name) is not None: + logger.warning_once( + no_cache_warning.format(cache_arg=arg_name, cache_arg_value=getattr(self, arg_name)), + UserWarning, + ) # 6. check watermarking arguments if self.watermarking_config is not None: @@ -727,17 +745,6 @@ def validate(self, is_init=False): "`generate()` (or a pipeline) directly." ) - # 6. if dola_layers is set, check if repetition_penalty is set to >= 1.2 - if self.dola_layers is not None and (self.repetition_penalty is None or self.repetition_penalty < 1.2): - dola_decoding_wrong_parameter_msg = ( - "`dola_layers` is set to trigger DoLa decoding, but `repetition_penalty` is set to a value of {repetition_penalty}, " - "which could induce unwanted repetition. The recommended value for DoLa decoding is `repetition_penalty>=1.2`." - ) - warnings.warn( - dola_decoding_wrong_parameter_msg.format(repetition_penalty=self.repetition_penalty), - UserWarning, - ) - def save_pretrained( self, save_directory: Union[str, os.PathLike], diff --git a/src/transformers/generation/utils.py b/src/transformers/generation/utils.py index 6ab88ad26bf3..a9ebdcdd4775 100644 --- a/src/transformers/generation/utils.py +++ b/src/transformers/generation/utils.py @@ -136,27 +136,23 @@ class GenerateDecoderOnlyOutput(ModelOutput): sequences (`torch.LongTensor` of shape `(batch_size, sequence_length)`): The generated sequences. The second dimension (sequence_length) is either equal to `max_length` or shorter if all batches finished early due to the `eos_token_id`. - scores (`tuple(torch.FloatTensor)` *optional*, returned when `output_scores=True` is passed or when `config.output_scores=True`): + scores (`tuple(torch.FloatTensor)` *optional*, returned when `output_scores=True`): Processed prediction scores of the language modeling head (scores for each vocabulary token before SoftMax) at each generation step. Tuple of `torch.FloatTensor` with up to `max_new_tokens` elements (one element for each generated token), with each tensor of shape `(batch_size, config.vocab_size)`. - logits (`tuple(torch.FloatTensor)` *optional*, returned when `output_logits=True` is passed or when `config.output_logits=True`): + logits (`tuple(torch.FloatTensor)` *optional*, returned when `output_logits=True`): Unprocessed prediction scores of the language modeling head (scores for each vocabulary token before SoftMax) at each generation step. Tuple of `torch.FloatTensor` with up to `max_new_tokens` elements (one element for each generated token), with each tensor of shape `(batch_size, config.vocab_size)`. - attentions (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `output_attentions=True` is passed or `config.output_attentions=True`): + attentions (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `output_attentions=True`): Tuple (one element for each generated token) of tuples (one element for each layer of the decoder) of `torch.FloatTensor` of shape `(batch_size, num_heads, generated_length, sequence_length)`. - hidden_states (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): + hidden_states (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `output_hidden_states=True`): Tuple (one element for each generated token) of tuples (one element for each layer of the decoder) of `torch.FloatTensor` of shape `(batch_size, generated_length, hidden_size)`. - past_key_values (`tuple(tuple(torch.FloatTensor)))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): - NOTE: some models have a different `past_key_values` format, confirm with the model's documentation. - Usually a Tuple (one element for each layer of the decoder) of tuples (two elements, key tensor and value - tensor). The first Tuple is of length `config.n_layers`, with each tuple having 2 tensors of shape - `(batch_size, num_heads, sequence_length, embed_size_per_head)`) and optionally if - `config.is_encoder_decoder=True` 2 additional tensors of shape `(batch_size, num_heads, - encoder_sequence_length, embed_size_per_head)`. + past_key_values (`tuple(tuple(torch.FloatTensor)))`, *optional*, returned when `use_cache=True`): + Returns the model cache, used to speed up decoding. Different models have a different cache format, check + the model's documentation. Usually, a [`~cache_utils.Cache`] instance. """ sequences: torch.LongTensor = None @@ -176,36 +172,32 @@ class GenerateEncoderDecoderOutput(ModelOutput): sequences (`torch.LongTensor` of shape `(batch_size*num_return_sequences, sequence_length)`): The generated sequences. The second dimension (sequence_length) is either equal to `max_length` or shorter if all batches finished early due to the `eos_token_id`. - scores (`tuple(torch.FloatTensor)` *optional*, returned when `output_scores=True` is passed or when `config.output_scores=True`): + scores (`tuple(torch.FloatTensor)` *optional*, returned when `output_scores=True`): Processed prediction scores of the language modeling head (scores for each vocabulary token before SoftMax) at each generation step. Tuple of `torch.FloatTensor` with up to `max_new_tokens` elements (one element for each generated token), with each tensor of shape `(batch_size, config.vocab_size)`. - logits (`tuple(torch.FloatTensor)` *optional*, returned when `output_logits=True` is passed or when `config.output_logits=True`): + logits (`tuple(torch.FloatTensor)` *optional*, returned when `output_logits=True`): Unprocessed prediction scores of the language modeling head (scores for each vocabulary token before SoftMax) at each generation step. Tuple of `torch.FloatTensor` with up to `max_new_tokens` elements (one element for each generated token), with each tensor of shape `(batch_size, config.vocab_size)`. - encoder_attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or `config.output_attentions=True`): + encoder_attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True`): Tuple of `torch.FloatTensor` (one for each layer of the decoder) of shape `(batch_size, num_heads, sequence_length, sequence_length)`. - encoder_hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): + encoder_hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True`): Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`. - decoder_attentions (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `output_attentions=True` is passed or `config.output_attentions=True`): + decoder_attentions (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `output_attentions=True`): Tuple (one element for each generated token) of tuples (one element for each layer of the decoder) of `torch.FloatTensor` of shape `(batch_size, num_heads, generated_length, sequence_length)`. - cross_attentions (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `output_attentions=True` is passed or `config.output_attentions=True`): + cross_attentions (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `output_attentions=True`): Tuple (one element for each generated token) of tuples (one element for each layer of the decoder) of `torch.FloatTensor` of shape `(batch_size, num_heads, generated_length, sequence_length)`. - decoder_hidden_states (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): + decoder_hidden_states (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `output_hidden_states=True`): Tuple (one element for each generated token) of tuples (one element for each layer of the decoder) of `torch.FloatTensor` of shape `(batch_size, generated_length, hidden_size)`. past_key_values (`tuple(tuple(torch.FloatTensor)))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): - NOTE: some models have a different `past_key_values` format, confirm with the model's documentation. - Usually a Tuple (one element for each layer of the decoder) of tuples (two elements, key tensor and value - tensor). The first Tuple is of length `config.n_layers`, with each tuple having 2 tensors of shape - `(batch_size, num_heads, sequence_length, embed_size_per_head)`) and optionally if - `config.is_encoder_decoder=True` 2 additional tensors of shape `(batch_size, num_heads, - encoder_sequence_length, embed_size_per_head)`. + Returns the model cache, used to speed up decoding. Different models have a different cache format, check + the model's documentation. Usually, a [`~cache_utils.Cache`] instance. """ sequences: torch.LongTensor = None @@ -228,33 +220,29 @@ class GenerateBeamDecoderOnlyOutput(ModelOutput): sequences (`torch.LongTensor` of shape `(batch_size*num_return_sequences, sequence_length)`): The generated sequences. The second dimension (sequence_length) is either equal to `max_length` or shorter if all batches finished early due to the `eos_token_id`. - sequences_scores (`torch.FloatTensor` of shape `(batch_size*num_return_sequences)`, *optional*, returned when `output_scores=True` is passed or when `config.output_scores=True`): + sequences_scores (`torch.FloatTensor` of shape `(batch_size*num_return_sequences)`, *optional*, returned when `output_scores=True`): Final beam scores of the generated `sequences`. - scores (`tuple(torch.FloatTensor)` *optional*, returned when `output_scores=True` is passed or when `config.output_scores=True`): + scores (`tuple(torch.FloatTensor)` *optional*, returned when `output_scores=True`): Beam transition scores for each vocabulary token at each generation step. Beam transition scores consisting of log probabilities of tokens conditioned on log softmax of previously generated tokens in this beam. Tuple of `torch.FloatTensor` with up to `max_new_tokens` elements (one element for each generated token), with each tensor of shape `(batch_size*num_beams, config.vocab_size)`. - logits (`tuple(torch.FloatTensor)` *optional*, returned when `output_logits=True` is passed or when `config.output_logits=True`): + logits (`tuple(torch.FloatTensor)` *optional*, returned when `output_logits=True`): Unprocessed prediction scores of the language modeling head (scores for each vocabulary token before SoftMax) at each generation step. Tuple of `torch.FloatTensor` with up to `max_new_tokens` elements (one element for each generated token), with each tensor of shape `(batch_size, config.vocab_size)`. - beam_indices (`torch.LongTensor`, *optional*, returned when `output_scores=True` is passed or when `config.output_scores=True`): + beam_indices (`torch.LongTensor`, *optional*, returned when `output_scores=True`): Beam indices of generated token id at each generation step. `torch.LongTensor` of shape `(batch_size*num_return_sequences, sequence_length)`. - attentions (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `output_attentions=True` is passed or `config.output_attentions=True`): + attentions (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `output_attentions=True`): Tuple (one element for each generated token) of tuples (one element for each layer of the decoder) of `torch.FloatTensor` of shape `(batch_size*num_beams, num_heads, generated_length, sequence_length)`. - hidden_states (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): + hidden_states (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `output_hidden_states=True`): Tuple (one element for each generated token) of tuples (one element for each layer of the decoder) of `torch.FloatTensor` of shape `(batch_size*num_beams*num_return_sequences, generated_length, hidden_size)`. - past_key_values (`tuple(tuple(torch.FloatTensor)))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): - NOTE: some models have a different `past_key_values` format, confirm with the model's documentation. - Usually a Tuple (one element for each layer of the decoder) of tuples (two elements, key tensor and value - tensor). The first Tuple is of length `config.n_layers`, with each tuple having 2 tensors of shape - `(batch_size, num_heads, sequence_length, embed_size_per_head)`) and optionally if - `config.is_encoder_decoder=True` 2 additional tensors of shape `(batch_size, num_heads, - encoder_sequence_length, embed_size_per_head)`. + past_key_values (`tuple(tuple(torch.FloatTensor)))`, *optional*, returned when `use_cache=True`): + Returns the model cache, used to speed up decoding. Different models have a different cache format, check + the model's documentation. Usually, a [`~cache_utils.Cache`] instance. """ sequences: torch.LongTensor = None @@ -276,43 +264,39 @@ class GenerateBeamEncoderDecoderOutput(ModelOutput): sequences (`torch.LongTensor` of shape `(batch_size*num_return_sequences, sequence_length)`): The generated sequences. The second dimension (sequence_length) is either equal to `max_length` or shorter if all batches finished early due to the `eos_token_id`. - sequences_scores (`torch.FloatTensor` of shape `(batch_size*num_return_sequences)`, *optional*, returned when `output_scores=True` is passed or when `config.output_scores=True`): + sequences_scores (`torch.FloatTensor` of shape `(batch_size*num_return_sequences)`, *optional*, returned when `output_scores=True`): Final beam scores of the generated `sequences`. - scores (`tuple(torch.FloatTensor)` *optional*, returned when `output_scores=True` is passed or when `config.output_scores=True`): + scores (`tuple(torch.FloatTensor)` *optional*, returned when `output_scores=True`): Beam transition scores for each vocabulary token at each generation step. Beam transition scores consisting of log probabilities of tokens conditioned on log softmax of previously generated tokens in this beam. Tuple of `torch.FloatTensor` with up to `max_new_tokens` elements (one element for each generated token), with each tensor of shape `(batch_size*num_beams, config.vocab_size)`. - logits (`tuple(torch.FloatTensor)` *optional*, returned when `output_logits=True` is passed or when `config.output_logits=True`): + logits (`tuple(torch.FloatTensor)` *optional*, returned when `output_logits=True`): Unprocessed prediction scores of the language modeling head (scores for each vocabulary token before SoftMax) at each generation step. Tuple of `torch.FloatTensor` with up to `max_new_tokens` elements (one element for each generated token), with each tensor of shape `(batch_size, config.vocab_size)`. - beam_indices (`torch.LongTensor`, *optional*, returned when `output_scores=True` is passed or when `config.output_scores=True`): + beam_indices (`torch.LongTensor`, *optional*, returned when `output_scores=True`): Beam indices of generated token id at each generation step. `torch.LongTensor` of shape `(batch_size*num_return_sequences, sequence_length)`. - encoder_attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or `config.output_attentions=True`): + encoder_attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True`): Tuple of `torch.FloatTensor` (one for each layer of the decoder) of shape `(batch_size, num_heads, sequence_length, sequence_length)`. - encoder_hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): + encoder_hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True`): Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer) of shape `(batch_size*num_beams*num_return_sequences, sequence_length, hidden_size)`. - decoder_attentions (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `output_attentions=True` is passed or `config.output_attentions=True`): + decoder_attentions (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `output_attentions=True`): Tuple (one element for each generated token) of tuples (one element for each layer of the decoder) of `torch.FloatTensor` of shape `(batch_size*num_beams*num_return_sequences, num_heads, generated_length, sequence_length)`. - cross_attentions (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `output_attentions=True` is passed or `config.output_attentions=True`): + cross_attentions (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `output_attentions=True`): Tuple (one element for each generated token) of tuples (one element for each layer of the decoder) of `torch.FloatTensor` of shape `(batch_size, num_heads, generated_length, sequence_length)`. - decoder_hidden_states (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): + decoder_hidden_states (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `output_hidden_states=True`): Tuple (one element for each generated token) of tuples (one element for each layer of the decoder) of `torch.FloatTensor` of shape `(batch_size*num_beams*num_return_sequences, generated_length, hidden_size)`. - past_key_values (`tuple(tuple(torch.FloatTensor)))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): - NOTE: some models have a different `past_key_values` format, confirm with the model's documentation. - Usually a Tuple (one element for each layer of the decoder) of tuples (two elements, key tensor and value - tensor). The first Tuple is of length `config.n_layers`, with each tuple having 2 tensors of shape - `(batch_size, num_heads, sequence_length, embed_size_per_head)`) and optionally if - `config.is_encoder_decoder=True` 2 additional tensors of shape `(batch_size, num_heads, - encoder_sequence_length, embed_size_per_head)`. + past_key_values (`tuple(tuple(torch.FloatTensor)))`, *optional*, returned when `use_cache=True`): + Returns the model cache, used to speed up decoding. Different models have a different cache format, check + the model's documentation. Usually, a [`~cache_utils.Cache`] instance. """ sequences: torch.LongTensor = None @@ -328,6 +312,7 @@ class GenerateBeamEncoderDecoderOutput(ModelOutput): past_key_values: Optional[Tuple[Tuple[Tuple[torch.FloatTensor]]]] = None +# TODO (joao): remove the equivalent classes and typing shortcuts below in v5 # Equivalent classes (kept for retrocompatibility purposes) GreedySearchDecoderOnlyOutput = GenerateDecoderOnlyOutput ContrastiveSearchDecoderOnlyOutput = GenerateDecoderOnlyOutput @@ -1501,6 +1486,121 @@ def _supports_default_dynamic_cache(self) -> bool: """ return self._supports_cache_class and "jamba" not in self.__class__.__name__.lower() + def _prepare_cache_for_generation( + self, + generation_config: GenerationConfig, + model_kwargs: Dict, + assistant_model: "PreTrainedModel", + batch_size: int, + device: torch.device, + ) -> bool: + """ + Prepares the cache for generation (if applicable), given `generate`'s paramaterization. If a cache is + instantiated, writes it to `model_kwargs`, under the name expected by the model. + """ + + cache_name = "past_key_values" if "mamba" not in self.__class__.__name__.lower() else "cache_params" + requires_cross_attention_cache = ( + self.config.is_encoder_decoder or model_kwargs.get("encoder_outputs") is not None + ) + + # Quick escape route 1: if the user specifies a cache, we only need to: + # a) check for conflicting `generate` arguments + # b) convert to the new cache format (if the user passes a legacy cache and model supports it) + user_defined_cache = model_kwargs.get(cache_name) + if user_defined_cache is not None: + if generation_config.cache_implementation is not None: + raise ValueError( + f"Passing both `cache_implementation` (used to initialize certain caches) and `{cache_name}` (a " + "Cache object) is unsupported. Please use only one of the two." + ) + if isinstance(user_defined_cache, tuple) and self._supports_default_dynamic_cache(): + model_kwargs[cache_name] = ( + DynamicCache.from_legacy_cache(user_defined_cache) + if not requires_cross_attention_cache + else EncoderDecoderCache.from_legacy_cache(user_defined_cache) + ) + return + + # Quick escape route 2: if the user specifies no cache is to be used. (conflicting arguments are handled in + # `generation_config.validate()`) + if generation_config.use_cache is False: + return + + # Quick escape route 3: model that only supports legacy caches = nothing to prepare + if not self._supports_default_dynamic_cache(): + if generation_config.cache_implementation is not None: + warnings.warn( + "This model does not support `Cache` instances, it only supports the legacy cache format (tuple " + f"of tuples). `cache_implementation` (set to {generation_config.cache_implementation}) will be " + "ignored.", + UserWarning, + ) + return + + # Otherwise we NEED to prepare a cache, based on `generation_config.cache_implementation` + + # TODO(joao): support static caches in assisted generation. assisted generation needs to roll back caches, + # which is only supported in dynamic caches atm + if assistant_model is not None and generation_config.cache_implementation is not None: + logger.warning_once( + "An assistant model is provided, using a dynamic cache instead of a cache of type=" + f"'{generation_config.cache_implementation}'." + ) + generation_config.cache_implementation = None + + if generation_config.cache_implementation is not None: + if generation_config.cache_implementation in NEED_SETUP_CACHE_CLASSES_MAPPING: + if generation_config.cache_implementation == "static" and not self._supports_static_cache: + raise ValueError( + "This model does not support `cache_implementation='static'`. Please check the following " + "issue: https://github.com/huggingface/transformers/issues/28981" + ) + model_kwargs[cache_name] = self._get_cache( + cache_implementation=generation_config.cache_implementation, + batch_size=generation_config.num_beams * generation_config.num_return_sequences * batch_size, + max_cache_len=generation_config.max_length, + device=device, + model_kwargs=model_kwargs, + ) + elif generation_config.cache_implementation == "quantized": + if not self._supports_quantized_cache: + raise ValueError( + "This model does not support the quantized cache. If you want your model to support quantized " + "cache, please open an issue and tag @zucchini-nlp." + ) + + cache_config = ( + generation_config.cache_config + if generation_config.cache_config is not None + else QuantizedCacheConfig() + ) + cache_class = QUANT_BACKEND_CLASSES_MAPPING[cache_config.backend] + + if cache_config.backend == "quanto" and not is_quanto_available(): + raise ImportError( + "You need to install `quanto` in order to use KV cache quantization with quanto backend. " + "Please install it via with `pip install quanto`" + ) + elif cache_config.backend == "HQQ" and not is_hqq_available(): + raise ImportError( + "You need to install `HQQ` in order to use KV cache quantization with HQQ backend. " + "Please install it via with `pip install hqq`" + ) + + model_kwargs[cache_name] = cache_class(cache_config) + elif generation_config.cache_implementation == "offloaded": + model_kwargs[cache_name] = OffloadedCache() + + # Use DynamicCache() instance by default. This will avoid back and forth from legacy format that + # keeps copying the cache thus using much more memory + else: + model_kwargs[cache_name] = ( + DynamicCache() + if not requires_cross_attention_cache + else EncoderDecoderCache(DynamicCache(), DynamicCache()) + ) + def _prepare_special_tokens( self, generation_config: GenerationConfig, @@ -1776,104 +1876,18 @@ def generate( inputs_tensor=inputs_tensor, input_ids_length=input_ids_length, ) - - use_dynamic_cache_by_default = False - if "mamba" in self.__class__.__name__.lower(): - cache_name = "cache_params" - else: - cache_name = "past_key_values" - - # TODO(joao): support static caches in assisted generation. assisted generation needs to roll back caches, - # which is only supported in dynamic caches atm - if ( - assistant_model is not None - and generation_config.cache_implementation is not None - and self._supports_default_dynamic_cache() - ): - logger.warning_once( - "An assistant model is provided, using a dynamic cache instead of a cache of type=" - f"'{generation_config.cache_implementation}'." - ) - generation_config.cache_implementation = None - - if (model_kwargs.get(cache_name) is not None) and is_torchdynamo_compiling(): - raise ValueError( - "Passing `past_key_values` is not supported when compiling `model.generate` with torch.compile -- you " - "may get incorrect outputs. Please compile `model.forward` only or use the `cache_implementation` " - "input argument." - ) - if generation_config.cache_implementation is not None and (model_kwargs.get(cache_name) is not None): - raise ValueError( - f"Passing both `cache_implementation` (used to initialize certain caches) and `{cache_name}` (a " - "Cache object) is unsupported. Please use only one of the two." - ) - elif generation_config.cache_implementation is not None: - if generation_config.cache_implementation in NEED_SETUP_CACHE_CLASSES_MAPPING: - if generation_config.cache_implementation == "static" and not self._supports_static_cache: - raise ValueError( - "This model does not support `cache_implementation='static'`. Please check the following " - "issue: https://github.com/huggingface/transformers/issues/28981" - ) - model_kwargs[cache_name] = self._get_cache( - cache_implementation=generation_config.cache_implementation, - batch_size=generation_config.num_beams * generation_config.num_return_sequences * batch_size, - max_cache_len=generation_config.max_length, - device=device, - model_kwargs=model_kwargs, - ) - elif generation_config.cache_implementation == "quantized": - if not self._supports_quantized_cache: - raise ValueError( - "This model does not support the quantized cache. If you want your model to support quantized " - "cache, please open an issue." - ) - - cache_config = ( - generation_config.cache_config - if generation_config.cache_config is not None - else QuantizedCacheConfig() - ) - cache_class = QUANT_BACKEND_CLASSES_MAPPING[cache_config.backend] - - if cache_config.backend == "quanto" and not is_quanto_available(): - raise ImportError( - "You need to install `quanto` in order to use KV cache quantization with quanto backend. " - "Please install it via with `pip install quanto`" - ) - elif cache_config.backend == "HQQ" and not is_hqq_available(): - raise ImportError( - "You need to install `HQQ` in order to use KV cache quantization with HQQ backend. " - "Please install it via with `pip install hqq`" - ) - - model_kwargs[cache_name] = cache_class(cache_config) - elif generation_config.cache_implementation == "offloaded": - model_kwargs[cache_name] = OffloadedCache() - # Use DynamicCache() instance by default. This will avoid back and forth from legacy format that - # keeps copying the cache thus using much more memory - elif generation_config.cache_implementation is None and self._supports_default_dynamic_cache(): - past = model_kwargs.get(cache_name, None) - requires_cross_attention_cache = ( - self.config.is_encoder_decoder or model_kwargs.get("encoder_outputs") is not None - ) - if past is None: - model_kwargs[cache_name] = ( - DynamicCache() - if not requires_cross_attention_cache - else EncoderDecoderCache(DynamicCache(), DynamicCache()) - ) - use_dynamic_cache_by_default = True - elif isinstance(past, tuple): - model_kwargs[cache_name] = ( - DynamicCache.from_legacy_cache(past) - if not requires_cross_attention_cache - else EncoderDecoderCache.from_legacy_cache(past) - ) - use_dynamic_cache_by_default = True - self._validate_generated_length(generation_config, input_ids_length, has_default_max_length) - # 7. determine generation mode + # 7. Prepare the cache. + # - `model_kwargs` may be updated in place with a cache as defined by the parameters in `generation_config`. + # - different models have a different cache name expected by the model (default = "past_key_values") + # - `max_length`, prepared above, is used to determine the maximum cache length + # TODO (joao): remove `user_defined_cache` after v4.47 (remove default conversion to legacy format) + cache_name = "past_key_values" if "mamba" not in self.__class__.__name__.lower() else "cache_params" + user_defined_cache = model_kwargs.get(cache_name) + self._prepare_cache_for_generation(generation_config, model_kwargs, assistant_model, batch_size, device) + + # 8. determine generation mode generation_mode = generation_config.get_generation_mode(assistant_model) if streamer is not None and (generation_config.num_beams > 1): @@ -1892,7 +1906,7 @@ def generate( UserWarning, ) - # 8. prepare distribution pre_processing samplers + # 9. prepare logits processors and stopping criteria prepared_logits_processor = self._get_logits_processor( generation_config=generation_config, input_ids_seq_length=input_ids_length, @@ -1904,8 +1918,6 @@ def generate( negative_prompt_ids=negative_prompt_ids, negative_prompt_attention_mask=negative_prompt_attention_mask, ) - - # 9. prepare stopping criteria prepared_stopping_criteria = self._get_stopping_criteria( generation_config=generation_config, stopping_criteria=stopping_criteria, tokenizer=tokenizer, **kwargs ) @@ -2138,11 +2150,34 @@ def typeerror(): **model_kwargs, ) - # Convert to legacy cache if needed - if use_dynamic_cache_by_default and generation_config.return_legacy_cache: - if isinstance(result, ModelOutput) and hasattr(result, "past_key_values"): - if isinstance(result.past_key_values, (DynamicCache, EncoderDecoderCache)): - result.past_key_values = result.past_key_values.to_legacy_cache() + # Convert to legacy cache format if requested + if ( + generation_config.return_legacy_cache is not False # Should check for `True` after v4.47 + and not is_torchdynamo_compiling() + and hasattr(result, "past_key_values") + and hasattr(result.past_key_values, "to_legacy_cache") + and result.past_key_values.to_legacy_cache is not None + ): + # handle BC (convert by default if he user hasn't passed a cache AND the cache is of the default type) + should_convert_cache = generation_config.return_legacy_cache + is_user_defined_cache = user_defined_cache is not None + is_default_cache_type = ( + type(result.past_key_values) == DynamicCache # noqa E721 + or ( + isinstance(result.past_key_values, EncoderDecoderCache) + and type(result.past_key_values.self_attention_cache) == DynamicCache # noqa E721 + and type(result.past_key_values.cross_attention_cache) == DynamicCache # noqa E721 + ) + ) + if not is_user_defined_cache and is_default_cache_type: + logger.warning_once( + "From v4.47 onwards, when a model cache is to be returned, `generate` will return a `Cache` " + "instance instead by default (as opposed to the legacy tuple of tuples format). If you want to " + "keep returning the legacy format, please set `return_legacy_cache=True`." + ) + should_convert_cache = True + if should_convert_cache: + result.past_key_values = result.past_key_values.to_legacy_cache() return result def _has_unfinished_sequences( diff --git a/src/transformers/models/rwkv/modeling_rwkv.py b/src/transformers/models/rwkv/modeling_rwkv.py index f6b8cd412be5..7dec1f26e1a3 100644 --- a/src/transformers/models/rwkv/modeling_rwkv.py +++ b/src/transformers/models/rwkv/modeling_rwkv.py @@ -768,7 +768,7 @@ def get_output_embeddings(self): def set_output_embeddings(self, new_embeddings): self.head = new_embeddings - def prepare_inputs_for_generation(self, input_ids, state=None, inputs_embeds=None, **kwargs): + def prepare_inputs_for_generation(self, input_ids, state=None, inputs_embeds=None, use_cache=None, **kwargs): # only last token for inputs_ids if the state is passed along. if state is not None: input_ids = input_ids[:, -1].unsqueeze(-1) @@ -780,6 +780,7 @@ def prepare_inputs_for_generation(self, input_ids, state=None, inputs_embeds=Non model_inputs = {"input_ids": input_ids} model_inputs["state"] = state + model_inputs["use_cache"] = use_cache return model_inputs @add_start_docstrings_to_model_forward(RWKV_INPUTS_DOCSTRING) diff --git a/tests/generation/test_utils.py b/tests/generation/test_utils.py index 72da44115f5c..ae52f6c67404 100644 --- a/tests/generation/test_utils.py +++ b/tests/generation/test_utils.py @@ -194,6 +194,7 @@ def _greedy_generate( output_attentions=False, output_hidden_states=False, return_dict_in_generate=False, + use_cache=True, ): logits_processor_kwargs = self._get_logits_processor_kwargs(do_sample=False) model_kwargs = {"attention_mask": attention_mask} if attention_mask is not None else {} @@ -207,6 +208,7 @@ def _greedy_generate( output_scores=output_scores, output_logits=output_logits, return_dict_in_generate=return_dict_in_generate, + use_cache=use_cache, **logits_processor_kwargs, **model_kwargs, ) @@ -224,6 +226,7 @@ def _sample_generate( output_attentions=False, output_hidden_states=False, return_dict_in_generate=False, + use_cache=True, ): torch.manual_seed(0) logits_processor_kwargs = self._get_logits_processor_kwargs(do_sample=True) @@ -239,6 +242,7 @@ def _sample_generate( output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict_in_generate=return_dict_in_generate, + use_cache=use_cache, **logits_processor_kwargs, **model_kwargs, ) @@ -256,6 +260,7 @@ def _beam_search_generate( output_attentions=False, output_hidden_states=False, return_dict_in_generate=False, + use_cache=True, ): logits_processor_kwargs = self._get_logits_processor_kwargs(do_sample=False) model_kwargs = {"attention_mask": attention_mask} if attention_mask is not None else {} @@ -268,6 +273,7 @@ def _beam_search_generate( output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict_in_generate=return_dict_in_generate, + use_cache=use_cache, **beam_kwargs, **logits_processor_kwargs, **model_kwargs, @@ -286,6 +292,7 @@ def _beam_sample_generate( output_attentions=False, output_hidden_states=False, return_dict_in_generate=False, + use_cache=True, ): torch.manual_seed(0) logits_processor_kwargs = self._get_logits_processor_kwargs(do_sample=True) @@ -299,6 +306,7 @@ def _beam_sample_generate( output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict_in_generate=return_dict_in_generate, + use_cache=use_cache, **beam_kwargs, **logits_processor_kwargs, **model_kwargs, @@ -317,6 +325,7 @@ def _group_beam_search_generate( output_attentions=False, output_hidden_states=False, return_dict_in_generate=False, + use_cache=True, ): logits_processor_kwargs = self._get_logits_processor_kwargs(do_sample=False) model_kwargs = {"attention_mask": attention_mask} if attention_mask is not None else {} @@ -329,6 +338,7 @@ def _group_beam_search_generate( output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict_in_generate=return_dict_in_generate, + use_cache=use_cache, **beam_kwargs, **logits_processor_kwargs, **model_kwargs, @@ -348,6 +358,7 @@ def _constrained_beam_search_generate( output_attentions=False, output_hidden_states=False, return_dict_in_generate=False, + use_cache=True, ): logits_processor_kwargs = self._get_logits_processor_kwargs(do_sample=False) model_kwargs = {"attention_mask": attention_mask} if attention_mask is not None else {} @@ -361,6 +372,7 @@ def _constrained_beam_search_generate( output_hidden_states=output_hidden_states, return_dict_in_generate=return_dict_in_generate, constraints=constraints, + use_cache=use_cache, **beam_kwargs, **logits_processor_kwargs, **model_kwargs, @@ -378,6 +390,7 @@ def _contrastive_generate( output_attentions=False, output_hidden_states=False, return_dict_in_generate=False, + use_cache=True, ): contrastive_search_kwargs = { "penalty_alpha": 0.6, @@ -396,6 +409,7 @@ def _contrastive_generate( output_scores=output_scores, output_logits=output_logits, return_dict_in_generate=return_dict_in_generate, + use_cache=use_cache, **logits_processor_kwargs, **model_kwargs, **contrastive_search_kwargs, @@ -419,7 +433,6 @@ def test_greedy_generate_dict_outputs(self): for model_class in self.all_generative_model_classes: config, input_ids, attention_mask = self._get_input_ids_and_config() - config.use_cache = False model = model_class(config).to(torch_device).eval() output_generate = self._greedy_generate( model=model, @@ -430,6 +443,7 @@ def test_greedy_generate_dict_outputs(self): output_hidden_states=True, output_attentions=self.has_attentions, return_dict_in_generate=True, + use_cache=False, ) if model.config.is_encoder_decoder: @@ -454,7 +468,6 @@ def test_greedy_generate_dict_outputs_use_cache(self): if any(model_name in model_class.__name__.lower() for model_name in ["rwkv"]): self.skipTest(reason="Won't fix: model with non-standard dictionary output shapes") - config.use_cache = True config.is_decoder = True model = model_class(config).to(torch_device).eval() output_generate = self._greedy_generate( @@ -466,6 +479,7 @@ def test_greedy_generate_dict_outputs_use_cache(self): output_hidden_states=True, output_attentions=self.has_attentions, return_dict_in_generate=True, + use_cache=True, ) if model.config.is_encoder_decoder: @@ -495,7 +509,6 @@ def test_sample_generate_dict_output(self): for model_class in self.all_generative_model_classes: config, input_ids, attention_mask = self._get_input_ids_and_config() - config.use_cache = False model = model_class(config).to(torch_device).eval() output_generate = self._sample_generate( model=model, @@ -507,6 +520,7 @@ def test_sample_generate_dict_output(self): output_hidden_states=True, output_attentions=self.has_attentions, return_dict_in_generate=True, + use_cache=False, ) if model.config.is_encoder_decoder: @@ -545,9 +559,6 @@ def test_beam_search_generate_dict_output(self): for model_class in self.all_generative_model_classes: config, input_ids, attention_mask = self._get_input_ids_and_config() - # disable cache - config.use_cache = False - model = model_class(config).to(torch_device).eval() beam_kwargs = self._get_beam_kwargs() output_generate = self._beam_search_generate( @@ -560,6 +571,7 @@ def test_beam_search_generate_dict_output(self): output_hidden_states=True, output_attentions=self.has_attentions, return_dict_in_generate=True, + use_cache=False, ) if model.config.is_encoder_decoder: self.assertTrue(output_generate.sequences.shape[-1] == self.max_new_tokens + 1) @@ -589,7 +601,6 @@ def test_beam_search_generate_dict_outputs_use_cache(self): model = model_class(config).to(torch_device).eval() beam_kwargs = self._get_beam_kwargs() - config.use_cache = True config.is_decoder = True model = model_class(config).to(torch_device).eval() output_generate = self._beam_search_generate( @@ -602,6 +613,7 @@ def test_beam_search_generate_dict_outputs_use_cache(self): output_hidden_states=True, output_attentions=self.has_attentions, return_dict_in_generate=True, + use_cache=True, ) if model.config.is_encoder_decoder: @@ -676,9 +688,6 @@ def test_beam_sample_generate_dict_output(self): for model_class in self.all_generative_model_classes: config, input_ids, attention_mask = self._get_input_ids_and_config() - # disable cache - config.use_cache = False - model = model_class(config).to(torch_device).eval() beam_kwargs = self._get_beam_kwargs() @@ -692,6 +701,7 @@ def test_beam_sample_generate_dict_output(self): output_hidden_states=True, output_attentions=self.has_attentions, return_dict_in_generate=True, + use_cache=False, ) if model.config.is_encoder_decoder: @@ -764,7 +774,6 @@ def test_group_beam_search_generate(self): def test_group_beam_search_generate_dict_output(self): for model_class in self.all_generative_model_classes: config, input_ids, attention_mask = self._get_input_ids_and_config() - config.use_cache = False model = model_class(config).to(torch_device).eval() beam_kwargs = self._get_diverse_beam_kwargs() @@ -778,6 +787,7 @@ def test_group_beam_search_generate_dict_output(self): output_hidden_states=True, output_attentions=self.has_attentions, return_dict_in_generate=True, + use_cache=False, ) if model.config.is_encoder_decoder: self.assertTrue(output_generate.sequences.shape[-1] == self.max_new_tokens + 1) @@ -857,9 +867,6 @@ def test_constrained_beam_search_generate_dict_output(self): for model_class in self.all_generative_model_classes: config, input_ids, attention_mask = self._get_input_ids_and_config() - # disable cache - config.use_cache = False - model = model_class(config).to(torch_device).eval() # Sample constraints @@ -882,6 +889,7 @@ def test_constrained_beam_search_generate_dict_output(self): output_hidden_states=True, output_attentions=self.has_attentions, return_dict_in_generate=True, + use_cache=False, ) if model.config.is_encoder_decoder: @@ -913,13 +921,12 @@ def test_contrastive_generate(self): # NOTE: contrastive search only works with cache on at the moment. if not hasattr(config, "use_cache"): self.skipTest(reason="This model doesn't support caching") - config.use_cache = True config.is_decoder = True # test old generation output for backwards compatibility model = model_class(config).to(torch_device).eval() output_generate = self._contrastive_generate( - model=model, input_ids=input_ids, attention_mask=attention_mask + model=model, input_ids=input_ids, attention_mask=attention_mask, use_cache=True ) if model.config.is_encoder_decoder: self.assertTrue(output_generate.shape[-1] == self.max_new_tokens + 1) @@ -940,7 +947,6 @@ def test_contrastive_generate_dict_outputs_use_cache(self): # NOTE: contrastive search only works with cache on at the moment. if not hasattr(config, "use_cache"): self.skipTest(reason="This model doesn't support caching") - config.use_cache = True config.is_decoder = True model = model_class(config).to(torch_device).eval() @@ -953,6 +959,7 @@ def test_contrastive_generate_dict_outputs_use_cache(self): output_hidden_states=True, output_attentions=self.has_attentions, return_dict_in_generate=True, + use_cache=True, ) if model.config.is_encoder_decoder: @@ -978,7 +985,6 @@ def test_contrastive_generate_low_memory(self): if not hasattr(config, "use_cache"): self.skipTest(reason="This model doesn't support caching") - config.use_cache = True config.is_decoder = True # test output equality of low versus high memory @@ -991,6 +997,7 @@ def test_contrastive_generate_low_memory(self): low_memory=True, max_new_tokens=self.max_new_tokens, attention_mask=attention_mask, + use_cache=True, ) high_output = model.generate( @@ -1000,6 +1007,7 @@ def test_contrastive_generate_low_memory(self): low_memory=False, max_new_tokens=self.max_new_tokens, attention_mask=attention_mask, + use_cache=True, ) self.assertListEqual(low_output.tolist(), high_output.tolist()) @@ -1031,10 +1039,17 @@ def test_beam_search_low_memory(self): # test output equality of low versus high memory model = model_class(config).to(torch_device).eval() - low_output = model.generate(input_ids, max_new_tokens=8, num_beams=5, early_stopping=True, low_memory=True) + low_output = model.generate( + input_ids, max_new_tokens=8, num_beams=5, early_stopping=True, low_memory=True, use_cache=True + ) high_output = model.generate( - input_ids, max_new_tokens=8, num_beams=5, early_stopping=True, low_memory=False + input_ids, + max_new_tokens=8, + num_beams=5, + early_stopping=True, + low_memory=False, + use_cache=True, ) self.assertListEqual(low_output.tolist(), high_output.tolist()) @@ -1079,7 +1094,6 @@ def test_assisted_decoding_matches_greedy_search(self, assistant_type): if not hasattr(config, "use_cache"): self.skipTest(reason="This model doesn't support caching") - config.use_cache = True config.is_decoder = True model = model_class(config).to(torch_device).eval() # Sets assisted generation arguments such that: @@ -1098,6 +1112,7 @@ def test_assisted_decoding_matches_greedy_search(self, assistant_type): "output_hidden_states": True, "output_attentions": self.has_attentions, "return_dict_in_generate": True, + "use_cache": True, } output_greedy = model.generate(input_ids, attention_mask=attention_mask, **generation_kwargs) @@ -1150,7 +1165,6 @@ def test_prompt_lookup_decoding_matches_greedy_search(self): if not hasattr(config, "use_cache"): self.skipTest(reason="This model doesn't support caching") - config.use_cache = True config.is_decoder = True model = model_class(config).to(torch_device).eval() # Sets assisted generation arguments such that: @@ -1169,6 +1183,7 @@ def test_prompt_lookup_decoding_matches_greedy_search(self): "output_hidden_states": True, "output_attentions": self.has_attentions, "return_dict_in_generate": True, + "use_cache": True, } output_greedy = model.generate(input_ids, attention_mask=attention_mask, **generation_kwargs) @@ -1196,12 +1211,6 @@ def test_dola_decoding_sample(self): # enable cache if the model is not openai-gpt, xlnet, cpm, or xlm config, input_ids, attention_mask = self._get_input_ids_and_config() - # Some models don't support the cache and returning past_key_values - if not hasattr(config, "use_cache"): - config.use_cache = False - else: - config.use_cache = True - # Encoder-decoder models are not supported if config.is_encoder_decoder: self.skipTest("DoLa is not supported for encoder-decoder models") @@ -1224,11 +1233,12 @@ def test_dola_decoding_sample(self): "output_hidden_states": True, "output_attentions": self.has_attentions, "return_dict_in_generate": True, + "use_cache": hasattr(config, "use_cache"), # Some models don't support the cache } generation_kwargs.update({"dola_layers": "low"}) model_kwargs = {"attention_mask": attention_mask} if attention_mask is not None else {} output_dola = model.generate(input_ids, **model_kwargs, **generation_kwargs) - self._check_outputs(output_dola, input_ids, model.config, use_cache=config.use_cache) + self._check_outputs(output_dola, input_ids, model.config, use_cache=hasattr(config, "use_cache")) def test_assisted_decoding_sample(self): # In this test we don't check assisted vs non-assisted output -- seeded assisted decoding with sample will not @@ -1261,7 +1271,6 @@ def test_assisted_decoding_sample(self): if not hasattr(config, "use_cache"): self.skipTest(reason="This model doesn't support caching") - config.use_cache = True config.is_decoder = True model = model_class(config).to(torch_device).eval() # Sets assisted generation arguments such that: @@ -1284,6 +1293,7 @@ def test_assisted_decoding_sample(self): "output_hidden_states": True, "output_attentions": self.has_attentions, "return_dict_in_generate": True, + "use_cache": True, } output_assisted = model.generate(input_ids, attention_mask=attention_mask, **generation_kwargs) @@ -1566,7 +1576,6 @@ def test_generate_continue_from_past_key_values(self): # 3. ignore `token_type_ids` for simplicity # 4. ignore `forced_eos_token_id`, which requires further manipulation of the continuation inputs and is # active by default on some models - config.use_cache = True if "token_type_ids" in inputs: del inputs["token_type_ids"] @@ -1574,6 +1583,7 @@ def test_generate_continue_from_past_key_values(self): model.eval() model.generation_config.pad_token_id = model.generation_config.eos_token_id = -1 model.generation_config.forced_eos_token_id = None + model.generation_config.use_cache = True # If "past_key_values" is not returned, skip the test (e.g. RWKV uses a different cache name and format) outputs = model(**inputs) @@ -1631,7 +1641,6 @@ def test_new_cache_format(self, num_beams, do_sample): self.skipTest(reason="This model does not support the new cache format") config, input_ids, attention_mask = self._get_input_ids_and_config() - config.use_cache = True model = model_class(config).to(torch_device).eval() generation_kwargs = { @@ -1640,6 +1649,7 @@ def test_new_cache_format(self, num_beams, do_sample): "num_beams": num_beams, "num_return_sequences": num_beams, "return_dict_in_generate": True, # Required to return `past_key_values` + "use_cache": True, } # Sets seed before calling `generate` for the case with do_sample=True @@ -1701,7 +1711,6 @@ def test_generate_with_static_cache(self): if config.is_encoder_decoder: self.skipTest(reason="This model is encoder-decoder and has Encoder-Decoder Cache") - config.use_cache = True config.is_decoder = True batch_size, seq_length = input_ids.shape max_new_tokens = 20 @@ -1712,6 +1721,7 @@ def test_generate_with_static_cache(self): "max_new_tokens": max_new_tokens, "cache_implementation": "static", "return_dict_in_generate": True, # Required to return `past_key_values` + "use_cache": True, } max_cache_len = seq_length + max_new_tokens @@ -1740,7 +1750,6 @@ def test_generate_with_quant_cache(self): self.skipTest(reason="This model does not support the quantized cache format") config, input_ids, attention_mask = self._get_input_ids_and_config() - config.use_cache = True config.is_decoder = True model = model_class(config).to(torch_device).eval() @@ -1750,6 +1759,7 @@ def test_generate_with_quant_cache(self): # careful with group size, should be divisor of model's hidden size "cache_config": {"backend": "quanto", "nbits": 2, "q_group_size": 8, "residual_length": 128}, "return_dict_in_generate": True, # Required to return `past_key_values` + "use_cache": True, } results = model.generate(input_ids, attention_mask=attention_mask, **generation_kwargs) @@ -1890,22 +1900,24 @@ def _check_outputs(self, output, input_ids, config, use_cache=False, num_return_ # Past Key Value States -- a few notes here: # 1. Its inner sequence length is with respect to the inputs of the latest forward pass, hence the "-1" - # 2. Some old models still return `output.past_key_values` even without `use_cache=True` - # 3. TODO (joao): A few models have different formats/types, skipping those until the cache refactor is - # complete - models_without_standard_cache = ("ctrl", "fsmt", "gptbigcode", "mega", "reformer", "jamba", "mamba") + # 2. We ignore models that have unique cache structures (e.g. mamba) or are in need of refatoring to match the + # standard cache format (e.g.gptbigcode ) + models_without_standard_cache = ("ctrl", "fsmt", "gptbigcode", "mega", "reformer", "jamba", "mamba", "xlnet") has_standard_cache = not any( model_name in config.__class__.__name__.lower() for model_name in models_without_standard_cache ) - if use_cache and has_standard_cache: - past_key_values = output.past_key_values - past_sequence_length = output.sequences.shape[-1] - 1 - self._check_past_key_values_for_generate( - num_sequences_in_output, - past_key_values, - seq_length=past_sequence_length, - config=config, - ) + if has_standard_cache: + if use_cache: + past_key_values = output.past_key_values + past_sequence_length = output.sequences.shape[-1] - 1 + self._check_past_key_values_for_generate( + num_sequences_in_output, + past_key_values, + seq_length=past_sequence_length, + config=config, + ) + elif use_cache is False: + self.assertTrue(output.past_key_values is None) def _check_scores(self, batch_size, scores, length, config): expected_shape = (batch_size, config.vocab_size)