Skip to content

Add separated decoder_head_mask for T5 Models#9634

Merged
patrickvonplaten merged 8 commits intohuggingface:masterfrom
stancld:decoder_mask_for_T5
Jan 19, 2021
Merged

Add separated decoder_head_mask for T5 Models#9634
patrickvonplaten merged 8 commits intohuggingface:masterfrom
stancld:decoder_mask_for_T5

Conversation

@stancld
Copy link
Copy Markdown
Contributor

@stancld stancld commented Jan 16, 2021

Fix issue #9632


This PR separates head_mask and decoder_head_mask for T5 models, and thus enables to specify different head masks for an encoder and decoder.

Description:

  • Replace a single input argument head_mask with a separated couple head_mask and decoder_head_mask for the T5 models: T5Model, T5ForConditionalGeneration, TFT5Model, TFT5ForConditionalGeneration
  • Slightly change the order of input arguments to follow the convention of first 7 arguments introduced in PR Add head_mask/decoder_head_mask for BART #9569 for BART-based models, i.e. "input_ids", "attention_mask", "decoder_input_ids", "decoder_attention_mask", "head_mask", "decoder_head_mask", "encoder_outputs"
  • Currently, the updated PyTorch T5 model does not pass test_forward_signature in tests/test_modeling_common.py. This problem will be diminished once PR Add head_mask/decoder_head_mask for BART #9569 to be merged.

Reviewer: @patrickvonplaten (the code is ready for review)

* Add decoder_head_mask args into T5Model and T5ForConditionalGeneration

* Slightly change the order of input args to be in accordance
with the convention from BART-based models introduced within the PR huggingface#9569.
* Separate head_mask and decoder_head_mask args in TF T5 models

* Slightly change the order of input args to follow convention
of BART-based models updated in PR huggingface#9569

* Update test_forward_signature tests/test_modeling_tf_common.py
w.r.t. the changed order of input args
@patrickvonplaten
Copy link
Copy Markdown
Contributor

Great, that looks nice! Let's first merge #9569 and then rebase this PR so that it passes all tests :-)

Copy link
Copy Markdown
Member

@LysandreJik LysandreJik left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Cool, welcome feature!

This is a slight breaking change, isn't it? Users that once used the head_mask had it control both their encoder and their decoder. From now on, specifying head_mask only controls the encoder, leaving the decoder to be controlled by decoder_head_mask.

Can we do a deprecation cycle, where if no decoder_head_mask is given, we set it to the value of head_mask? Having a FutureWarning there would be nice, too.

@talkhaldi
Copy link
Copy Markdown
Contributor

Thanks for fixing this!

I have one note/question: This seems to only apply to self-attention heads, not heads in the cross attention module, right? Is this intentional?

* Add FutureWarnings for T5 and TFT5 models warning a user that
input argument `head_mask` was split into two arguments -
`head_mask` and `decoder_head_mask`

* Add default behaviour - `decoder_head_mask` is set to copy
`head_mask`
@stancld
Copy link
Copy Markdown
Contributor Author

stancld commented Jan 19, 2021

@talkhaldi Thank you very much for pointing this out. It seems you're right and this is not intentional by myself. It'll be fixed in another commit.

* Make proper usage of head_mask and decoder_head_mask
in cross_attention

* Fix conditions for raising FutureWarning
@stancld
Copy link
Copy Markdown
Contributor Author

stancld commented Jan 19, 2021

Hey @patrickvonplaten and @LysandreJik. I've added some FutureWarning into the code to handle cases when only head_mask is passed by a user. Also, I fixed a cross-attention issue noted by @talkhaldi.
I believe, the PR is now ready for review as all the checks have passed after the rebasing.

Copy link
Copy Markdown
Member

@LysandreJik LysandreJik left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM! Thanks for making the change @stancld.

Comment on lines +1280 to +1289
if head_mask is not None and decoder_head_mask is None:
if self.config.num_layers == self.config.num_decoder_layers:
warning_msg = """
The input argument `head_mask` was split into two arguments `head_mask` and `decoder_head_mask`.
Currently, `decoder_head_mask` is set to copy `head_mask`, but this feature is deprecated and will be
removed in future versions. If you do not want to use any `decoder_head_mask` now, please set
`decoder_head_mask = torch.ones(num_layers, num_heads)`.
"""
warnings.warn(warning_msg, FutureWarning)
decoder_head_mask = head_mask
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Great message!

Copy link
Copy Markdown
Collaborator

@sgugger sgugger left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for your PR, it's very clean!

# FutureWarning: head_mask was separated into two input args - head_mask, decoder_head_mask
if head_mask is not None and decoder_head_mask is None:
if self.config.num_layers == self.config.num_decoder_layers:
warning_msg = """
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see this message is used twice, maybe it could be refactored in a private constant?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That's a good point, thanks!

"""
# FutureWarning: head_mask was separated into two input args - head_mask, decoder_head_mask
if head_mask is not None and decoder_head_mask is None:
warning_msg = """
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same here.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants