Skip to content

Generate: TF uses GenerationConfig as the basis for .generate() parametrization#20994

Merged
gante merged 8 commits intohuggingface:mainfrom
gante:generation_config_tf
Jan 4, 2023
Merged

Generate: TF uses GenerationConfig as the basis for .generate() parametrization#20994
gante merged 8 commits intohuggingface:mainfrom
gante:generation_config_tf

Conversation

@gante
Copy link
Copy Markdown
Contributor

@gante gante commented Jan 4, 2023

What does this PR do?

Changes the TF side of .generate() such that it relies on the GenerationConfig. This is the TF equivalent of #20388

@gante gante changed the title Generate: TF uses generation config Generate: TF uses GenerationConfig as the basis for .generate() parametrization Jan 4, 2023
@HuggingFaceDocBuilderDev
Copy link
Copy Markdown

HuggingFaceDocBuilderDev commented Jan 4, 2023

The documentation is not available anymore as the PR was closed or merged.

@gante gante marked this pull request as ready for review January 4, 2023 17:31
@gante gante requested a review from sgugger January 4, 2023 17:31
if return_dict_in_generate is not None
else self.generation_config.return_dict_in_generate
)
use_cache = model_kwargs.pop("use_cache", self.generation_config.use_cache)
Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This pattern, pulling use_cache out of the model_kwargs, is the only “new” modification when compared to the PR made on the PT side.

It was also an existing bug that went undetected: when use_cache is set in model_kwargs, the XLA-compiled loop converts it to a tf.bool variable (because it is part of an input/output variable in the tf.while_loop, and therefore assumed as dynamic at runtime). Because it is converted to a tf.bool, XLA model compilation fails due to lines like present_key_values = () if use_cache else None, which require use_cache to be a static variable.

Workaround: pull use_cache out of the model_kwargs and use the corresponding static variable in the tf.while_loop

else:
input_ids = tf.expand_dims(generated[:, cur_len - 1], -1)
model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs)
model_inputs = self.prepare_inputs_for_generation(input_ids, use_cache=use_cache, **model_kwargs)
Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

(see comment above)

Copy link
Copy Markdown
Collaborator

@sgugger sgugger left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks a lot for expanding generation configs to TF!

@gante gante merged commit a6c850e into huggingface:main Jan 4, 2023
@gante gante deleted the generation_config_tf branch January 4, 2023 18:23
silverriver pushed a commit to silverriver/transformers that referenced this pull request Jan 6, 2023
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants