Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
109 changes: 58 additions & 51 deletions src/transformers/generation/configuration_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,9 +130,29 @@ class GenerationConfig(PushToHubMixin):
[this paper](https://arxiv.org/pdf/1610.02424.pdf) for more details.
penalty_alpha (`float`, *optional*):
The values balance the model confidence and the degeneration penalty in contrastive search decoding.
dola_layers (`str` or `List[int]`, *optional*):
Copy link
Copy Markdown
Contributor Author

@gante gante Aug 17, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

moved up to this documentation section (Parameters that control the generation strategy used), which makes more sense

The layers to use for DoLa decoding. If `None`, DoLa decoding is not used. If a string, it must
be one of "low" or "high", which means using the lower part or higher part of the model layers, respectively.
"low" means the first half of the layers up to the first 20 layers, and "high" means the last half of the
layers up to the last 20 layers.
If a list of integers, it must contain the indices of the layers to use for candidate premature layers in DoLa.
The 0-th layer is the word embedding layer of the model. Set to `'low'` to improve long-answer reasoning tasks,
`'high'` to improve short-answer tasks. Check the [documentation](https://github.com/huggingface/transformers/blob/main/docs/source/en/generation_strategies.md)
or [the paper](https://arxiv.org/abs/2309.03883) for more details.

> Parameters that control the cache
Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

new cache-related docs section in GenerationConfig, moved all cache-related flags here


use_cache (`bool`, *optional*, defaults to `True`):
Whether or not the model should use the past last key/values attentions (if applicable to the model) to
speed up decoding.
cache_implementation (`str`, *optional*, default to `None`):
Cache class that should be used when generating.
cache_config (`CacheConfig` or `dict`, *optional*, default to `None`):
Arguments used in the key-value cache class can be passed in `cache_config`. Can be passed as a `Dict` and
it will be converted to its repsective `CacheConfig` internally.
Otherwise can be passed as a `CacheConfig` class matching the indicated `cache_implementation`.
return_legacy_cache (`bool`, *optional*, default to `True`):
Whether to return the legacy or new format of the cache when `DynamicCache` is used by default.

> Parameters for manipulation of the model output logits

Expand Down Expand Up @@ -307,29 +327,6 @@ class GenerationConfig(PushToHubMixin):
max_matching_ngram_size (`int`, *optional*, default to `None`):
The maximum ngram size to be considered for matching in the prompt. Default to 2 if not provided.

> Generation parameters exclusive to [DoLa decoding](https://arxiv.org/abs/2309.03883)

dola_layers (`str` or `List[int]`, *optional*):
The layers to use for DoLa decoding. If `None`, DoLa decoding is not used. If a string, it must
be one of "low" or "high", which means using the lower part or higher part of the model layers, respectively.
"low" means the first half of the layers up to the first 20 layers, and "high" means the last half of the
layers up to the last 20 layers.
If a list of integers, it must contain the indices of the layers to use for candidate premature layers in DoLa.
The 0-th layer is the word embedding layer of the model. Set to `'low'` to improve long-answer reasoning tasks,
`'high'` to improve short-answer tasks. Check the [documentation](https://github.com/huggingface/transformers/blob/main/docs/source/en/generation_strategies.md)
or [the paper](https://arxiv.org/abs/2309.03883) for more details.

> Parameters specific to the caching mechanism:

cache_implementation (`str`, *optional*, default to `None`):
Cache class that should be used when generating.
cache_config (`CacheConfig` or `dict`, *optional*, default to `None`):
Arguments used in the key-value cache class can be passed in `cache_config`. Can be passed as a `Dict` and
it will be converted to its repsective `CacheConfig` internally.
Otherwise can be passed as a `CacheConfig` class matching the indicated `cache_implementation`.
return_legacy_cache (`bool`, *optional*, default to `True`):
Whether to return the legacy or new format of the cache when `DynamicCache` is used by default.

> Wild card

generation_kwargs:
Expand All @@ -352,7 +349,19 @@ def __init__(self, **kwargs):
self.num_beams = kwargs.pop("num_beams", 1)
self.num_beam_groups = kwargs.pop("num_beam_groups", 1)
self.penalty_alpha = kwargs.pop("penalty_alpha", None)
self.dola_layers = kwargs.pop("dola_layers", None)

# Parameters that control the cache
self.use_cache = kwargs.pop("use_cache", True)
self.cache_implementation = kwargs.pop("cache_implementation", None)
self.cache_config = kwargs.pop("cache_config", None)
if self.cache_implementation is not None and self.cache_implementation in NEEDS_CACHE_CONFIG:
cache_config_class = NEEDS_CACHE_CONFIG[self.cache_implementation]
if self.cache_config is None:
self.cache_config = cache_config_class()
elif isinstance(self.cache_config, dict):
self.cache_config = cache_config_class.from_dict(self.cache_config)
self.return_legacy_cache = kwargs.pop("return_legacy_cache", None)

# Parameters for manipulation of the model output logits
self.temperature = kwargs.pop("temperature", 1.0)
Expand Down Expand Up @@ -411,20 +420,6 @@ def __init__(self, **kwargs):
self.num_assistant_tokens = kwargs.pop("num_assistant_tokens", 5)
self.num_assistant_tokens_schedule = kwargs.pop("num_assistant_tokens_schedule", "heuristic")

# DoLa generation
self.dola_layers = kwargs.pop("dola_layers", None)

# Cache implementation
self.cache_implementation = kwargs.pop("cache_implementation", None)
self.cache_config = kwargs.pop("cache_config", None)
if self.cache_implementation is not None and self.cache_implementation in NEEDS_CACHE_CONFIG:
cache_config_class = NEEDS_CACHE_CONFIG[self.cache_implementation]
if self.cache_config is None:
self.cache_config = cache_config_class()
elif isinstance(self.cache_config, dict):
self.cache_config = cache_config_class.from_dict(self.cache_config)
self.return_legacy_cache = kwargs.pop("return_legacy_cache", True)

# Prompt lookup decoding
self.prompt_lookup_num_tokens = kwargs.pop("prompt_lookup_num_tokens", None)
self.max_matching_ngram_size = kwargs.pop("max_matching_ngram_size", None)
Expand Down Expand Up @@ -544,8 +539,9 @@ def validate(self, is_init=False):
raise ValueError(f"`max_new_tokens` must be greater than 0, but is {self.max_new_tokens}.")
if self.pad_token_id is not None and self.pad_token_id < 0:
warnings.warn(
f"`pad_token_id` should be positive but got {self.pad_token_id}. This will cause errors when batch generating, if there is padding. "
"Please set `pad_token_id` explicitly by `model.generation_config.pad_token_id=PAD_TOKEN_ID` to avoid errors in generation, and ensure your `input_ids` input does not have negative values."
f"`pad_token_id` should be positive but got {self.pad_token_id}. This will cause errors when batch "
Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

(>120 chars/line)

"generating, if there is padding. Please set `pad_token_id` explicitly as "
"`model.generation_config.pad_token_id=PAD_TOKEN_ID` to avoid errors in generation"
)

# Validation of attribute relations:
Expand Down Expand Up @@ -675,6 +671,14 @@ def validate(self, is_init=False):
group_error_prefix
+ "`diversity_penalty` should be greater than `0.0`, otherwise your groups will be identical."
)
# DoLa generation
if self.dola_layers is not None and (self.repetition_penalty is None or self.repetition_penalty < 1.2):
Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

(moved)

warnings.warn(
"`dola_layers` is set to trigger DoLa decoding, but `repetition_penalty` is set to a value of "
f"{self.repetition_penalty}, which could induce unwanted repetition. The recommended value for "
"DoLa decoding is `repetition_penalty>=1.2`.",
UserWarning,
)

# 4. check `num_return_sequences`
if self.num_return_sequences != 1:
Expand All @@ -690,7 +694,7 @@ def validate(self, is_init=False):
f"({self.num_beams})."
)

# 5. check `cache_config`
# 5. check cache-related arguments
if self.cache_config is not None:
cache_class = NEEDS_CACHE_CONFIG.get(self.cache_implementation)
if cache_class is None:
Expand All @@ -702,6 +706,20 @@ def validate(self, is_init=False):
if not isinstance(self.cache_config, cache_class):
self.cache_config = cache_class.from_dict(self.cache_config)
self.cache_config.validate()
if self.use_cache is False:
# In this case, all cache-related arguments should be unset. However, since `use_cache=False` is often used
# passed to `generate` directly to hot-fix cache issues, let's raise a warning instead of an error
# (otherwise a user might need to overwrite several parameters).
no_cache_warning = (
"You have set `use_cache` to `False`, but {cache_arg} is set to {cache_arg_value}. {cache_arg} will "
"have no effect."
)
for arg_name in ("cache_implementation", "cache_config", "return_legacy_cache"):
if getattr(self, arg_name) is not None:
logger.warning_once(
no_cache_warning.format(cache_arg=arg_name, cache_arg_value=getattr(self, arg_name)),
UserWarning,
)

# 6. check watermarking arguments
if self.watermarking_config is not None:
Expand All @@ -727,17 +745,6 @@ def validate(self, is_init=False):
"`generate()` (or a pipeline) directly."
)

# 6. if dola_layers is set, check if repetition_penalty is set to >= 1.2
if self.dola_layers is not None and (self.repetition_penalty is None or self.repetition_penalty < 1.2):
dola_decoding_wrong_parameter_msg = (
"`dola_layers` is set to trigger DoLa decoding, but `repetition_penalty` is set to a value of {repetition_penalty}, "
"which could induce unwanted repetition. The recommended value for DoLa decoding is `repetition_penalty>=1.2`."
)
warnings.warn(
dola_decoding_wrong_parameter_msg.format(repetition_penalty=self.repetition_penalty),
UserWarning,
)

def save_pretrained(
self,
save_directory: Union[str, os.PathLike],
Expand Down
Loading