Generate: TF uses GenerationConfig as the basis for .generate() parametrization#20994
Generate: TF uses GenerationConfig as the basis for .generate() parametrization#20994gante merged 8 commits intohuggingface:mainfrom
GenerationConfig as the basis for .generate() parametrization#20994Conversation
GenerationConfig as the basis for .generate() parametrization
|
The documentation is not available anymore as the PR was closed or merged. |
| if return_dict_in_generate is not None | ||
| else self.generation_config.return_dict_in_generate | ||
| ) | ||
| use_cache = model_kwargs.pop("use_cache", self.generation_config.use_cache) |
There was a problem hiding this comment.
This pattern, pulling use_cache out of the model_kwargs, is the only “new” modification when compared to the PR made on the PT side.
It was also an existing bug that went undetected: when use_cache is set in model_kwargs, the XLA-compiled loop converts it to a tf.bool variable (because it is part of an input/output variable in the tf.while_loop, and therefore assumed as dynamic at runtime). Because it is converted to a tf.bool, XLA model compilation fails due to lines like present_key_values = () if use_cache else None, which require use_cache to be a static variable.
Workaround: pull use_cache out of the model_kwargs and use the corresponding static variable in the tf.while_loop
| else: | ||
| input_ids = tf.expand_dims(generated[:, cur_len - 1], -1) | ||
| model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs) | ||
| model_inputs = self.prepare_inputs_for_generation(input_ids, use_cache=use_cache, **model_kwargs) |
sgugger
left a comment
There was a problem hiding this comment.
Thanks a lot for expanding generation configs to TF!
What does this PR do?
Changes the TF side of
.generate()such that it relies on theGenerationConfig. This is the TF equivalent of #20388