Skip to content

Add head_mask and decoder_head_mask to PyTorch LED#9856

Merged
stas00 merged 6 commits intohuggingface:masterfrom
stancld:LED_encoder_decoder_head_masks
Feb 2, 2021
Merged

Add head_mask and decoder_head_mask to PyTorch LED#9856
stas00 merged 6 commits intohuggingface:masterfrom
stancld:LED_encoder_decoder_head_masks

Conversation

@stancld
Copy link
Copy Markdown
Contributor

@stancld stancld commented Jan 27, 2021

This PR implements head_mask and decoder_head_mask for PyTorch LED (and Longformer as there's a copy dependency) and it is the follow-up to the open issue #9814.

Motivation: This PR is a part of an endeavour to enable the usage of head_mask and decoder_head_mask for all encoder-decoder transformers following the recent work on BART-like models (#9569).


Fixes: #9814

Reviewers: @patrickvonplaten @LysandreJik @stas00

@stancld stancld changed the title Add head_mask and decoder_head_mask to LED (+ head_mask to Longformer due to copy dependency) Add head_mask and decoder_head_mask to PyTorch LED Jan 27, 2021
Copy link
Copy Markdown
Contributor

@patrickvonplaten patrickvonplaten left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Great! Very clean implementation, thanks for taking care of this :-)

Copy link
Copy Markdown
Member

@LysandreJik LysandreJik left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fantastic! Thanks for working on that @stancld!

Copy link
Copy Markdown
Collaborator

@sgugger sgugger left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Very clean, thanks for your PR! Just one styling nit, but feel free to ignore.

Comment on lines +838 to +840
assert layer_head_mask.size() == (
self.num_heads,
), f"Head mask for a single layer should be of size {(self.num_heads,)}, but is {layer_head_mask.size()}"
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Complete nit, but those asserts are not super well formatted. Can we replace them by if and raise a proper error?

Suggested change
assert layer_head_mask.size() == (
self.num_heads,
), f"Head mask for a single layer should be of size {(self.num_heads,)}, but is {layer_head_mask.size()}"
if layer_head_mask.size() == self.num_heads:
raise ValueError(f"Head mask for a single layer should be of size {(self.num_heads,)}, but is {layer_head_mask.size()}")

(If you do one do all of them.)

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

much better - I tried to find a way to make the autoformatter be nice, but couldn't figure out. Yours is great and it reads even better.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@sgugger Thank you very much for your suggestion, I definitely agree with this. Do you think I should create a new PR to replace this part in all the models where desired?

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since it was a nit and not a requirement, I'd say whatever works the best for you, @stancld - thank you!

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@stas00 Thanks for a quick reply. As such, I would leave it now as it is and I will replace these assertions in a new PR once there will be head_mask and decoder_head_mask implemented for all encoder-decoder models. Just in order not to mix this change in this PR, even though it's only a minor one; if it's okay :)

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That works well, @stancld!

Thank you for your great contribution!

@stas00 stas00 merged commit 71bdc07 into huggingface:master Feb 2, 2021
@stancld stancld deleted the LED_encoder_decoder_head_masks branch February 2, 2021 20:38
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

Missing head_mask and decoder_head_mask arguments in encoder-decoder models

5 participants