Add separated decoder_head_mask for T5 Models#9634
Add separated decoder_head_mask for T5 Models#9634patrickvonplaten merged 8 commits intohuggingface:masterfrom
Conversation
* Add decoder_head_mask args into T5Model and T5ForConditionalGeneration * Slightly change the order of input args to be in accordance with the convention from BART-based models introduced within the PR huggingface#9569.
* Separate head_mask and decoder_head_mask args in TF T5 models * Slightly change the order of input args to follow convention of BART-based models updated in PR huggingface#9569 * Update test_forward_signature tests/test_modeling_tf_common.py w.r.t. the changed order of input args
|
Great, that looks nice! Let's first merge #9569 and then rebase this PR so that it passes all tests :-) |
LysandreJik
left a comment
There was a problem hiding this comment.
Cool, welcome feature!
This is a slight breaking change, isn't it? Users that once used the head_mask had it control both their encoder and their decoder. From now on, specifying head_mask only controls the encoder, leaving the decoder to be controlled by decoder_head_mask.
Can we do a deprecation cycle, where if no decoder_head_mask is given, we set it to the value of head_mask? Having a FutureWarning there would be nice, too.
|
Thanks for fixing this! I have one note/question: This seems to only apply to self-attention heads, not heads in the cross attention module, right? Is this intentional? |
* Add FutureWarnings for T5 and TFT5 models warning a user that input argument `head_mask` was split into two arguments - `head_mask` and `decoder_head_mask` * Add default behaviour - `decoder_head_mask` is set to copy `head_mask`
|
@talkhaldi Thank you very much for pointing this out. It seems you're right and this is not intentional by myself. It'll be fixed in another commit. |
* Make proper usage of head_mask and decoder_head_mask in cross_attention * Fix conditions for raising FutureWarning
|
Hey @patrickvonplaten and @LysandreJik. I've added some |
LysandreJik
left a comment
There was a problem hiding this comment.
LGTM! Thanks for making the change @stancld.
| if head_mask is not None and decoder_head_mask is None: | ||
| if self.config.num_layers == self.config.num_decoder_layers: | ||
| warning_msg = """ | ||
| The input argument `head_mask` was split into two arguments `head_mask` and `decoder_head_mask`. | ||
| Currently, `decoder_head_mask` is set to copy `head_mask`, but this feature is deprecated and will be | ||
| removed in future versions. If you do not want to use any `decoder_head_mask` now, please set | ||
| `decoder_head_mask = torch.ones(num_layers, num_heads)`. | ||
| """ | ||
| warnings.warn(warning_msg, FutureWarning) | ||
| decoder_head_mask = head_mask |
sgugger
left a comment
There was a problem hiding this comment.
Thanks for your PR, it's very clean!
| # FutureWarning: head_mask was separated into two input args - head_mask, decoder_head_mask | ||
| if head_mask is not None and decoder_head_mask is None: | ||
| if self.config.num_layers == self.config.num_decoder_layers: | ||
| warning_msg = """ |
There was a problem hiding this comment.
I see this message is used twice, maybe it could be refactored in a private constant?
There was a problem hiding this comment.
That's a good point, thanks!
| """ | ||
| # FutureWarning: head_mask was separated into two input args - head_mask, decoder_head_mask | ||
| if head_mask is not None and decoder_head_mask is None: | ||
| warning_msg = """ |
Fix issue #9632
This PR separates
head_maskanddecoder_head_maskfor T5 models, and thus enables to specify different head masks for an encoder and decoder.Description:
head_maskwith a separated couplehead_maskanddecoder_head_maskfor the T5 models:T5Model, T5ForConditionalGeneration, TFT5Model, TFT5ForConditionalGeneration"input_ids", "attention_mask", "decoder_input_ids", "decoder_attention_mask", "head_mask", "decoder_head_mask", "encoder_outputs"test_forward_signatureintests/test_modeling_common.py. This problem will be diminished once PR Add head_mask/decoder_head_mask for BART #9569 to be merged.Reviewer: @patrickvonplaten (the code is ready for review)