Add head_mask and decoder_head_mask to PyTorch LED#9856
Add head_mask and decoder_head_mask to PyTorch LED#9856stas00 merged 6 commits intohuggingface:masterfrom
Conversation
* Add head_mask to longformer to fix dependencies of LED on Longformer. * Not working yet
patrickvonplaten
left a comment
There was a problem hiding this comment.
Great! Very clean implementation, thanks for taking care of this :-)
LysandreJik
left a comment
There was a problem hiding this comment.
Fantastic! Thanks for working on that @stancld!
sgugger
left a comment
There was a problem hiding this comment.
Very clean, thanks for your PR! Just one styling nit, but feel free to ignore.
| assert layer_head_mask.size() == ( | ||
| self.num_heads, | ||
| ), f"Head mask for a single layer should be of size {(self.num_heads,)}, but is {layer_head_mask.size()}" |
There was a problem hiding this comment.
Complete nit, but those asserts are not super well formatted. Can we replace them by if and raise a proper error?
| assert layer_head_mask.size() == ( | |
| self.num_heads, | |
| ), f"Head mask for a single layer should be of size {(self.num_heads,)}, but is {layer_head_mask.size()}" | |
| if layer_head_mask.size() == self.num_heads: | |
| raise ValueError(f"Head mask for a single layer should be of size {(self.num_heads,)}, but is {layer_head_mask.size()}") |
(If you do one do all of them.)
There was a problem hiding this comment.
much better - I tried to find a way to make the autoformatter be nice, but couldn't figure out. Yours is great and it reads even better.
There was a problem hiding this comment.
@sgugger Thank you very much for your suggestion, I definitely agree with this. Do you think I should create a new PR to replace this part in all the models where desired?
There was a problem hiding this comment.
Since it was a nit and not a requirement, I'd say whatever works the best for you, @stancld - thank you!
There was a problem hiding this comment.
@stas00 Thanks for a quick reply. As such, I would leave it now as it is and I will replace these assertions in a new PR once there will be head_mask and decoder_head_mask implemented for all encoder-decoder models. Just in order not to mix this change in this PR, even though it's only a minor one; if it's okay :)
There was a problem hiding this comment.
That works well, @stancld!
Thank you for your great contribution!
This PR implements
head_maskanddecoder_head_maskfor PyTorch LED (and Longformer as there's a copy dependency) and it is the follow-up to the open issue #9814.Motivation: This PR is a part of an endeavour to enable the usage of
head_maskanddecoder_head_maskfor all encoder-decoder transformers following the recent work on BART-like models (#9569).Fixes: #9814
Reviewers: @patrickvonplaten @LysandreJik @stas00