Skip to content

Add head_mask/decoder_head_mask for BART#9569

Merged
patrickvonplaten merged 7 commits intohuggingface:masterfrom
stancld:head_mask_for_bart_new
Jan 18, 2021
Merged

Add head_mask/decoder_head_mask for BART#9569
patrickvonplaten merged 7 commits intohuggingface:masterfrom
stancld:head_mask_for_bart_new

Conversation

@stancld
Copy link
Copy Markdown
Contributor

@stancld stancld commented Jan 13, 2021

This PR implement head_mask and decoder_head_mask for PyTorch BART-based models. The full list, please, see below:

  • BART
  • MBart
  • Blenderbot
  • BlenderbotSmall
  • Marian
  • Pegasus

This PR is a follow up on the closed PR #9404.

Motivation:
According to HuggingFace's websites "There is a growing field of study concerned with investigating the inner working of large-scale transformers like BERT (that some call “BERTology”)." This PR enables to mask attention heads in encoder and decoder models exactly like for BERT. This PR thus creates an opportunity to study the importance of attention heads in encoder-decoder BERT-like model.

Description

  • New arguments head_mask anddecoder_head_mask are passed to all the BART-based models ...Model, ...ForConditionalGeneration and ...ForQuestionAnswering after four arguments input_ids, attention_mask, decoder_input_ids, decoder_attention_mask so that a testing and whole pipeline remains smooth.

  • This PR also contains updated test_headmasking, which currently works fine with one problem - BART-based models do not satisfy a condition:

self.assertNotEqual(attentions[1][..., 0, :, :].flatten().sum().item(), 0.0).

Fixing this problem is currently underway.

Reviewer: @patrickvonplaten

This branch implement head_mask and decoder_head_mask
for BART-based models. Full list below:
- BART
- MBart
- Blenderbot
- BlenderbotSmall
- Marian
- Pegasus

Everything is accompanied with updated testing.
@patrickvonplaten
Copy link
Copy Markdown
Contributor

Thanks for opening a new PR. Let me know if you need a review (It's also ok if I go into the PR and fix some things if your stuck :-) )

@stancld
Copy link
Copy Markdown
Contributor Author

stancld commented Jan 13, 2021

@patrickvonplaten I hope this PR is again ready for review. The only thing remaining to resolve is that issue in test_headmasking described above. Currently, I've been trying to fix this one, but I'll be grateful for sure if you can have a look at that too :)

* Fix text_headmasking for BART-like models
which has only 2 layers in each modules.
The condition
```
self.assertNotEqual(attentions[1][..., 0, :, :].flatten().sum().item(), 0.0)
```
is, therefore, invalid for encoder-decoder models considering
the `head_mask`
```
head_mask = torch.ones(
    self.model_tester.num_hidden_layers,
    self.model_tester.num_attention_heads,
    device=torch_device,
)
head_mask[0, 0] = 0
head_mask[-1, :-1] = 0
```
specified in the `test_headmasking` test/function.
@stancld
Copy link
Copy Markdown
Contributor Author

stancld commented Jan 14, 2021

Hey @patrickvonplaten. I would like to inform you I fixed test_headmasking for BART-based. The problem was that code inside

self.assertNotEqual(attentions[1][..., 0, :, :].flatten().sum().item(), 0.0)

pointed to the last layer of encoder/decoder (encoder-decoder models have only 2 layers in each module while BERT has 5 layers during testing). At the end of the day, this condition was invalid for BART-based models considering the head_mask to be

head_mask = torch.ones(
    self.model_tester.num_hidden_layers,
    self.model_tester.num_attention_heads,
    device=torch_device,
)
head_mask[0, 0] = 0
head_mask[-1, :-1] = 0

I hope this PR is then ready for review.

Comment thread tests/test_modeling_common.py
is_encoder_decoder = True
test_pruning = False
test_head_masking = False
test_head_masking = True
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nice!

if model.config.is_encoder_decoder:
signature = inspect.signature(model.forward)
arg_names = [*signature.parameters.keys()]
if "decoder_head_mask" in arg_names: # necessary diferentiation because of T5 model
Copy link
Copy Markdown
Contributor

@patrickvonplaten patrickvonplaten Jan 14, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

good for me for now - could you maybe open an issue saying that T5 should separate "head_mask" and "decoder_head_mask" and ping me on it? Then we can clean this up for T5 at a later stage :-)

Copy link
Copy Markdown
Contributor

@patrickvonplaten patrickvonplaten left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is awesome! Amazing work @stancld - very clean and nice comments. If you want it would be awesome if you could open an issue regarding T5 having only a "head_mask", but no "decoder_head_mask" and ping me on it so that we can fix this in a follow-up PR :-)

Otherwise, LGTM!

Copy link
Copy Markdown
Collaborator

@sgugger sgugger left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks great! Thanks for adding this!

Comment thread src/transformers/models/bart/modeling_bart.py Outdated
Comment thread src/transformers/models/blenderbot/modeling_blenderbot.py Outdated
Comment thread src/transformers/models/marian/modeling_marian.py Outdated
Comment thread src/transformers/models/mbart/modeling_mbart.py Outdated
Comment thread src/transformers/models/pegasus/modeling_pegasus.py Outdated
stancld added a commit to stancld/transformers that referenced this pull request Jan 16, 2021
* Add decoder_head_mask args into T5Model and T5ForConditionalGeneration

* Slightly change the order of input args to be in accordance
with the convention from BART-based models introduced within the PR huggingface#9569.
stancld added a commit to stancld/transformers that referenced this pull request Jan 16, 2021
* Separate head_mask and decoder_head_mask args in TF T5 models

* Slightly change the order of input args to follow convention
of BART-based models updated in PR huggingface#9569

* Update test_forward_signature tests/test_modeling_tf_common.py
w.r.t. the changed order of input args
stancld added a commit to stancld/transformers that referenced this pull request Jan 16, 2021
* Add head_mask and decoder_head_mask input arguments for TF BART-based
models as a TF counterpart to the PR huggingface#9569

* Add test_headmasking functionality to tests/test_modeling_tf_common.py

* TODO: Add a test to verify that we can get a gradient back for
importance score computation
Copy link
Copy Markdown
Member

@LysandreJik LysandreJik left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is great, very clean implementation! Thanks for implementing the tests, too.

LGTM!

Comment thread tests/test_modeling_common.py Outdated
patrickvonplaten and others added 4 commits January 18, 2021 12:54
Co-authored-by: Lysandre Debut <lysandre@huggingface.co>
Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>
@patrickvonplaten patrickvonplaten merged commit 357fb1c into huggingface:master Jan 18, 2021
patrickvonplaten pushed a commit that referenced this pull request Jan 19, 2021
* Add decoder_head_mask for PyTorch T5 model

* Add decoder_head_mask args into T5Model and T5ForConditionalGeneration

* Slightly change the order of input args to be in accordance
with the convention from BART-based models introduced within the PR #9569.

* Make style for modeling_t5.py

* Add decoder_head_mask for TF T5 models

* Separate head_mask and decoder_head_mask args in TF T5 models

* Slightly change the order of input args to follow convention
of BART-based models updated in PR #9569

* Update test_forward_signature tests/test_modeling_tf_common.py
w.r.t. the changed order of input args

* Add FutureWarnings for T5 and TFT5 models

* Add FutureWarnings for T5 and TFT5 models warning a user that
input argument `head_mask` was split into two arguments -
`head_mask` and `decoder_head_mask`

* Add default behaviour - `decoder_head_mask` is set to copy
`head_mask`

* Fix T5 modeling and FutureWarning

* Make proper usage of head_mask and decoder_head_mask
in cross_attention

* Fix conditions for raising FutureWarning

* Reformat FutureWarning in T5 modeling

* Refactor the warning message
LysandreJik pushed a commit that referenced this pull request Jan 26, 2021
* Add head_mask/decoder_head_mask for TF BART models

* Add head_mask and decoder_head_mask input arguments for TF BART-based
models as a TF counterpart to the PR #9569

* Add test_headmasking functionality to tests/test_modeling_tf_common.py

* TODO: Add a test to verify that we can get a gradient back for
importance score computation

* Remove redundant #TODO note

Remove redundant #TODO note from tests/test_modeling_tf_common.py

* Fix assertions

* Make style

* Fix ...Model input args and adjust one new test

* Add back head_mask and decoder_head_mask to BART-based ...Model
after the last commit

* Remove head_mask ande decoder_head_mask from input_dict
in TF test_train_pipeline_custom_model as these two have different
shape than other input args (Necessary for passing this test)

* Revert adding global_rng in test_modeling_tf_common.py
@stancld stancld deleted the head_mask_for_bart_new branch January 26, 2021 16:35
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants