Add head_mask/decoder_head_mask for BART#9569
Add head_mask/decoder_head_mask for BART#9569patrickvonplaten merged 7 commits intohuggingface:masterfrom
Conversation
This branch implement head_mask and decoder_head_mask for BART-based models. Full list below: - BART - MBart - Blenderbot - BlenderbotSmall - Marian - Pegasus Everything is accompanied with updated testing.
|
Thanks for opening a new PR. Let me know if you need a review (It's also ok if I go into the PR and fix some things if your stuck :-) ) |
|
@patrickvonplaten I hope this PR is again ready for review. The only thing remaining to resolve is that issue in |
* Fix text_headmasking for BART-like models
which has only 2 layers in each modules.
The condition
```
self.assertNotEqual(attentions[1][..., 0, :, :].flatten().sum().item(), 0.0)
```
is, therefore, invalid for encoder-decoder models considering
the `head_mask`
```
head_mask = torch.ones(
self.model_tester.num_hidden_layers,
self.model_tester.num_attention_heads,
device=torch_device,
)
head_mask[0, 0] = 0
head_mask[-1, :-1] = 0
```
specified in the `test_headmasking` test/function.
|
Hey @patrickvonplaten. I would like to inform you I fixed pointed to the last layer of encoder/decoder (encoder-decoder models have only 2 layers in each module while BERT has 5 layers during testing). At the end of the day, this condition was invalid for BART-based models considering the I hope this PR is then ready for review. |
| is_encoder_decoder = True | ||
| test_pruning = False | ||
| test_head_masking = False | ||
| test_head_masking = True |
| if model.config.is_encoder_decoder: | ||
| signature = inspect.signature(model.forward) | ||
| arg_names = [*signature.parameters.keys()] | ||
| if "decoder_head_mask" in arg_names: # necessary diferentiation because of T5 model |
There was a problem hiding this comment.
good for me for now - could you maybe open an issue saying that T5 should separate "head_mask" and "decoder_head_mask" and ping me on it? Then we can clean this up for T5 at a later stage :-)
patrickvonplaten
left a comment
There was a problem hiding this comment.
This is awesome! Amazing work @stancld - very clean and nice comments. If you want it would be awesome if you could open an issue regarding T5 having only a "head_mask", but no "decoder_head_mask" and ping me on it so that we can fix this in a follow-up PR :-)
Otherwise, LGTM!
sgugger
left a comment
There was a problem hiding this comment.
Looks great! Thanks for adding this!
* Add decoder_head_mask args into T5Model and T5ForConditionalGeneration * Slightly change the order of input args to be in accordance with the convention from BART-based models introduced within the PR huggingface#9569.
* Separate head_mask and decoder_head_mask args in TF T5 models * Slightly change the order of input args to follow convention of BART-based models updated in PR huggingface#9569 * Update test_forward_signature tests/test_modeling_tf_common.py w.r.t. the changed order of input args
* Add head_mask and decoder_head_mask input arguments for TF BART-based models as a TF counterpart to the PR huggingface#9569 * Add test_headmasking functionality to tests/test_modeling_tf_common.py * TODO: Add a test to verify that we can get a gradient back for importance score computation
LysandreJik
left a comment
There was a problem hiding this comment.
This is great, very clean implementation! Thanks for implementing the tests, too.
LGTM!
Co-authored-by: Lysandre Debut <lysandre@huggingface.co>
Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>
* Add decoder_head_mask for PyTorch T5 model * Add decoder_head_mask args into T5Model and T5ForConditionalGeneration * Slightly change the order of input args to be in accordance with the convention from BART-based models introduced within the PR #9569. * Make style for modeling_t5.py * Add decoder_head_mask for TF T5 models * Separate head_mask and decoder_head_mask args in TF T5 models * Slightly change the order of input args to follow convention of BART-based models updated in PR #9569 * Update test_forward_signature tests/test_modeling_tf_common.py w.r.t. the changed order of input args * Add FutureWarnings for T5 and TFT5 models * Add FutureWarnings for T5 and TFT5 models warning a user that input argument `head_mask` was split into two arguments - `head_mask` and `decoder_head_mask` * Add default behaviour - `decoder_head_mask` is set to copy `head_mask` * Fix T5 modeling and FutureWarning * Make proper usage of head_mask and decoder_head_mask in cross_attention * Fix conditions for raising FutureWarning * Reformat FutureWarning in T5 modeling * Refactor the warning message
* Add head_mask/decoder_head_mask for TF BART models * Add head_mask and decoder_head_mask input arguments for TF BART-based models as a TF counterpart to the PR #9569 * Add test_headmasking functionality to tests/test_modeling_tf_common.py * TODO: Add a test to verify that we can get a gradient back for importance score computation * Remove redundant #TODO note Remove redundant #TODO note from tests/test_modeling_tf_common.py * Fix assertions * Make style * Fix ...Model input args and adjust one new test * Add back head_mask and decoder_head_mask to BART-based ...Model after the last commit * Remove head_mask ande decoder_head_mask from input_dict in TF test_train_pipeline_custom_model as these two have different shape than other input args (Necessary for passing this test) * Revert adding global_rng in test_modeling_tf_common.py
This PR implement
head_maskanddecoder_head_maskfor PyTorch BART-based models. The full list, please, see below:This PR is a follow up on the closed PR #9404.
Motivation:
According to HuggingFace's websites "There is a growing field of study concerned with investigating the inner working of large-scale transformers like BERT (that some call “BERTology”)." This PR enables to mask attention heads in encoder and decoder models exactly like for BERT. This PR thus creates an opportunity to study the importance of attention heads in encoder-decoder BERT-like model.
Description
New arguments
head_maskanddecoder_head_maskare passed to all the BART-based models...Model,...ForConditionalGenerationand...ForQuestionAnsweringafter four argumentsinput_ids, attention_mask, decoder_input_ids, decoder_attention_maskso that a testing and whole pipeline remains smooth.This PR also contains updated
test_headmasking, which currently works fine with one problem - BART-based models do not satisfy a condition:Fixing this problem is currently underway.
Reviewer: @patrickvonplaten