-
Notifications
You must be signed in to change notification settings - Fork 33.1k
Generate: Deprecate returning legacy cache by default; Handle use_cache=False
#32863
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
13c07d0
a6611e7
69bf5f4
d3c3e5a
bc5b50a
8958d49
914a7ea
e8492bf
550f7a6
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -130,9 +130,29 @@ class GenerationConfig(PushToHubMixin): | |
| [this paper](https://arxiv.org/pdf/1610.02424.pdf) for more details. | ||
| penalty_alpha (`float`, *optional*): | ||
| The values balance the model confidence and the degeneration penalty in contrastive search decoding. | ||
| dola_layers (`str` or `List[int]`, *optional*): | ||
| The layers to use for DoLa decoding. If `None`, DoLa decoding is not used. If a string, it must | ||
| be one of "low" or "high", which means using the lower part or higher part of the model layers, respectively. | ||
| "low" means the first half of the layers up to the first 20 layers, and "high" means the last half of the | ||
| layers up to the last 20 layers. | ||
| If a list of integers, it must contain the indices of the layers to use for candidate premature layers in DoLa. | ||
| The 0-th layer is the word embedding layer of the model. Set to `'low'` to improve long-answer reasoning tasks, | ||
| `'high'` to improve short-answer tasks. Check the [documentation](https://github.com/huggingface/transformers/blob/main/docs/source/en/generation_strategies.md) | ||
| or [the paper](https://arxiv.org/abs/2309.03883) for more details. | ||
|
|
||
| > Parameters that control the cache | ||
|
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. new cache-related docs section in |
||
|
|
||
| use_cache (`bool`, *optional*, defaults to `True`): | ||
| Whether or not the model should use the past last key/values attentions (if applicable to the model) to | ||
| speed up decoding. | ||
| cache_implementation (`str`, *optional*, default to `None`): | ||
| Cache class that should be used when generating. | ||
| cache_config (`CacheConfig` or `dict`, *optional*, default to `None`): | ||
| Arguments used in the key-value cache class can be passed in `cache_config`. Can be passed as a `Dict` and | ||
| it will be converted to its repsective `CacheConfig` internally. | ||
| Otherwise can be passed as a `CacheConfig` class matching the indicated `cache_implementation`. | ||
| return_legacy_cache (`bool`, *optional*, default to `True`): | ||
| Whether to return the legacy or new format of the cache when `DynamicCache` is used by default. | ||
|
|
||
| > Parameters for manipulation of the model output logits | ||
|
|
||
|
|
@@ -307,29 +327,6 @@ class GenerationConfig(PushToHubMixin): | |
| max_matching_ngram_size (`int`, *optional*, default to `None`): | ||
| The maximum ngram size to be considered for matching in the prompt. Default to 2 if not provided. | ||
|
|
||
| > Generation parameters exclusive to [DoLa decoding](https://arxiv.org/abs/2309.03883) | ||
|
|
||
| dola_layers (`str` or `List[int]`, *optional*): | ||
| The layers to use for DoLa decoding. If `None`, DoLa decoding is not used. If a string, it must | ||
| be one of "low" or "high", which means using the lower part or higher part of the model layers, respectively. | ||
| "low" means the first half of the layers up to the first 20 layers, and "high" means the last half of the | ||
| layers up to the last 20 layers. | ||
| If a list of integers, it must contain the indices of the layers to use for candidate premature layers in DoLa. | ||
| The 0-th layer is the word embedding layer of the model. Set to `'low'` to improve long-answer reasoning tasks, | ||
| `'high'` to improve short-answer tasks. Check the [documentation](https://github.com/huggingface/transformers/blob/main/docs/source/en/generation_strategies.md) | ||
| or [the paper](https://arxiv.org/abs/2309.03883) for more details. | ||
|
|
||
| > Parameters specific to the caching mechanism: | ||
|
|
||
| cache_implementation (`str`, *optional*, default to `None`): | ||
| Cache class that should be used when generating. | ||
| cache_config (`CacheConfig` or `dict`, *optional*, default to `None`): | ||
| Arguments used in the key-value cache class can be passed in `cache_config`. Can be passed as a `Dict` and | ||
| it will be converted to its repsective `CacheConfig` internally. | ||
| Otherwise can be passed as a `CacheConfig` class matching the indicated `cache_implementation`. | ||
| return_legacy_cache (`bool`, *optional*, default to `True`): | ||
| Whether to return the legacy or new format of the cache when `DynamicCache` is used by default. | ||
|
|
||
| > Wild card | ||
|
|
||
| generation_kwargs: | ||
|
|
@@ -352,7 +349,19 @@ def __init__(self, **kwargs): | |
| self.num_beams = kwargs.pop("num_beams", 1) | ||
| self.num_beam_groups = kwargs.pop("num_beam_groups", 1) | ||
| self.penalty_alpha = kwargs.pop("penalty_alpha", None) | ||
| self.dola_layers = kwargs.pop("dola_layers", None) | ||
|
|
||
| # Parameters that control the cache | ||
| self.use_cache = kwargs.pop("use_cache", True) | ||
| self.cache_implementation = kwargs.pop("cache_implementation", None) | ||
| self.cache_config = kwargs.pop("cache_config", None) | ||
| if self.cache_implementation is not None and self.cache_implementation in NEEDS_CACHE_CONFIG: | ||
| cache_config_class = NEEDS_CACHE_CONFIG[self.cache_implementation] | ||
| if self.cache_config is None: | ||
| self.cache_config = cache_config_class() | ||
| elif isinstance(self.cache_config, dict): | ||
| self.cache_config = cache_config_class.from_dict(self.cache_config) | ||
| self.return_legacy_cache = kwargs.pop("return_legacy_cache", None) | ||
|
|
||
| # Parameters for manipulation of the model output logits | ||
| self.temperature = kwargs.pop("temperature", 1.0) | ||
|
|
@@ -411,20 +420,6 @@ def __init__(self, **kwargs): | |
| self.num_assistant_tokens = kwargs.pop("num_assistant_tokens", 5) | ||
| self.num_assistant_tokens_schedule = kwargs.pop("num_assistant_tokens_schedule", "heuristic") | ||
|
|
||
| # DoLa generation | ||
| self.dola_layers = kwargs.pop("dola_layers", None) | ||
|
|
||
| # Cache implementation | ||
| self.cache_implementation = kwargs.pop("cache_implementation", None) | ||
| self.cache_config = kwargs.pop("cache_config", None) | ||
| if self.cache_implementation is not None and self.cache_implementation in NEEDS_CACHE_CONFIG: | ||
| cache_config_class = NEEDS_CACHE_CONFIG[self.cache_implementation] | ||
| if self.cache_config is None: | ||
| self.cache_config = cache_config_class() | ||
| elif isinstance(self.cache_config, dict): | ||
| self.cache_config = cache_config_class.from_dict(self.cache_config) | ||
| self.return_legacy_cache = kwargs.pop("return_legacy_cache", True) | ||
|
|
||
| # Prompt lookup decoding | ||
| self.prompt_lookup_num_tokens = kwargs.pop("prompt_lookup_num_tokens", None) | ||
| self.max_matching_ngram_size = kwargs.pop("max_matching_ngram_size", None) | ||
|
|
@@ -544,8 +539,9 @@ def validate(self, is_init=False): | |
| raise ValueError(f"`max_new_tokens` must be greater than 0, but is {self.max_new_tokens}.") | ||
| if self.pad_token_id is not None and self.pad_token_id < 0: | ||
| warnings.warn( | ||
| f"`pad_token_id` should be positive but got {self.pad_token_id}. This will cause errors when batch generating, if there is padding. " | ||
| "Please set `pad_token_id` explicitly by `model.generation_config.pad_token_id=PAD_TOKEN_ID` to avoid errors in generation, and ensure your `input_ids` input does not have negative values." | ||
| f"`pad_token_id` should be positive but got {self.pad_token_id}. This will cause errors when batch " | ||
|
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. (>120 chars/line) |
||
| "generating, if there is padding. Please set `pad_token_id` explicitly as " | ||
| "`model.generation_config.pad_token_id=PAD_TOKEN_ID` to avoid errors in generation" | ||
| ) | ||
|
|
||
| # Validation of attribute relations: | ||
|
|
@@ -675,6 +671,14 @@ def validate(self, is_init=False): | |
| group_error_prefix | ||
| + "`diversity_penalty` should be greater than `0.0`, otherwise your groups will be identical." | ||
| ) | ||
| # DoLa generation | ||
| if self.dola_layers is not None and (self.repetition_penalty is None or self.repetition_penalty < 1.2): | ||
|
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. (moved) |
||
| warnings.warn( | ||
| "`dola_layers` is set to trigger DoLa decoding, but `repetition_penalty` is set to a value of " | ||
| f"{self.repetition_penalty}, which could induce unwanted repetition. The recommended value for " | ||
| "DoLa decoding is `repetition_penalty>=1.2`.", | ||
| UserWarning, | ||
| ) | ||
|
|
||
| # 4. check `num_return_sequences` | ||
| if self.num_return_sequences != 1: | ||
|
|
@@ -690,7 +694,7 @@ def validate(self, is_init=False): | |
| f"({self.num_beams})." | ||
| ) | ||
|
|
||
| # 5. check `cache_config` | ||
| # 5. check cache-related arguments | ||
| if self.cache_config is not None: | ||
| cache_class = NEEDS_CACHE_CONFIG.get(self.cache_implementation) | ||
| if cache_class is None: | ||
|
|
@@ -702,6 +706,20 @@ def validate(self, is_init=False): | |
| if not isinstance(self.cache_config, cache_class): | ||
| self.cache_config = cache_class.from_dict(self.cache_config) | ||
| self.cache_config.validate() | ||
| if self.use_cache is False: | ||
| # In this case, all cache-related arguments should be unset. However, since `use_cache=False` is often used | ||
| # passed to `generate` directly to hot-fix cache issues, let's raise a warning instead of an error | ||
| # (otherwise a user might need to overwrite several parameters). | ||
| no_cache_warning = ( | ||
| "You have set `use_cache` to `False`, but {cache_arg} is set to {cache_arg_value}. {cache_arg} will " | ||
| "have no effect." | ||
| ) | ||
| for arg_name in ("cache_implementation", "cache_config", "return_legacy_cache"): | ||
| if getattr(self, arg_name) is not None: | ||
| logger.warning_once( | ||
| no_cache_warning.format(cache_arg=arg_name, cache_arg_value=getattr(self, arg_name)), | ||
| UserWarning, | ||
| ) | ||
|
|
||
| # 6. check watermarking arguments | ||
| if self.watermarking_config is not None: | ||
|
|
@@ -727,17 +745,6 @@ def validate(self, is_init=False): | |
| "`generate()` (or a pipeline) directly." | ||
| ) | ||
|
|
||
| # 6. if dola_layers is set, check if repetition_penalty is set to >= 1.2 | ||
| if self.dola_layers is not None and (self.repetition_penalty is None or self.repetition_penalty < 1.2): | ||
| dola_decoding_wrong_parameter_msg = ( | ||
| "`dola_layers` is set to trigger DoLa decoding, but `repetition_penalty` is set to a value of {repetition_penalty}, " | ||
| "which could induce unwanted repetition. The recommended value for DoLa decoding is `repetition_penalty>=1.2`." | ||
| ) | ||
| warnings.warn( | ||
| dola_decoding_wrong_parameter_msg.format(repetition_penalty=self.repetition_penalty), | ||
| UserWarning, | ||
| ) | ||
|
|
||
| def save_pretrained( | ||
| self, | ||
| save_directory: Union[str, os.PathLike], | ||
|
|
||
Uh oh!
There was an error while loading. Please reload this page.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
moved up to this documentation section (
Parameters that control the generation strategy used), which makes more sense