Skip to content

T5 migration to new masking interface#41804

Merged
vasqu merged 13 commits intohuggingface:mainfrom
Aravind-11:t5-migration-to-new-masking-interface
Nov 11, 2025
Merged

T5 migration to new masking interface#41804
vasqu merged 13 commits intohuggingface:mainfrom
Aravind-11:t5-migration-to-new-masking-interface

Conversation

@Aravind-11
Copy link
Copy Markdown
Contributor

@Aravind-11 Aravind-11 commented Oct 23, 2025

What does this PR do?

This PR migrates the T5 model to use the new masking utilities (masking_utils.py) for attention mask creation.

Fixes # (40743)

Before submitting

Who can review?

@vasqu @Rocketknight1

Passed all existing test cases except for test_small_integration_test which needs GPU. I need guidance on additional test cases to be added for this. Thank you.

Copy link
Copy Markdown
Contributor

@vasqu vasqu left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

create_bidirectional_mask is not correctly used atm + let's remove unnecessary comments (including the old ones that already exist, e.g. on the cross attn mask)

I expect that some other models depend on T5, can you check with make fix-copies?

Comment thread src/transformers/models/t5/modeling_t5.py Outdated
Comment thread src/transformers/models/t5/modeling_t5.py
Comment thread src/transformers/models/t5/modeling_t5.py Outdated
@Aravind-11
Copy link
Copy Markdown
Contributor Author

create_bidirectional_mask is not correctly used atm + let's remove unnecessary comments (including the old ones that already exist, e.g. on the cross attn mask)

I expect that some other models depend on T5, can you check with make fix-copies?

Thank you for your review @vasqu . I Changed the parameters of create_bidirectional_mask. I ran make fix-copies and there was one dependent model mT5 that updated automatically.

The corresponding functions in modeling_mt5.py now reflect the new create_causal_mask / create_bidirectional_mask usage.

@vasqu
Copy link
Copy Markdown
Contributor

vasqu commented Oct 24, 2025

run-slow: mt5, t5

@github-actions
Copy link
Copy Markdown
Contributor

This comment contains run-slow, running the specified jobs:

models: ['models/mt5', 'models/t5']
quantizations: [] ...

@vasqu
Copy link
Copy Markdown
Contributor

vasqu commented Oct 24, 2025

run-slow: mt5, t5

@github-actions
Copy link
Copy Markdown
Contributor

This comment contains run-slow, running the specified jobs:

models: ['models/mt5', 'models/t5']
quantizations: [] ...

@vasqu
Copy link
Copy Markdown
Contributor

vasqu commented Oct 24, 2025

@Aravind-11 this is a bit more complicated than anticipated, you can see the failing tests here https://github.com/huggingface/transformers/actions/runs/18776158789/job/53571258876

This means that torch export is failing. Also we would need to remove traces of the _update_causal... function

@vasqu
Copy link
Copy Markdown
Contributor

vasqu commented Oct 24, 2025

run-slow: mt5, t5

@vasqu
Copy link
Copy Markdown
Contributor

vasqu commented Oct 24, 2025

Not sure why executorch is failing with the new API tbh, reverted these parts. If you find anything that fixes this, then go ahead! I would wait a bit for you, but if the current state is ok with you, I would merge this.

cc @Cyrilvallez if you have time to check why exchanging the causal mask fails here for executorch

@github-actions
Copy link
Copy Markdown
Contributor

This comment contains run-slow, running the specified jobs:

models: ['models/mt5', 'models/t5']
quantizations: [] ...

@Aravind-11
Copy link
Copy Markdown
Contributor Author

Not sure why executorch is failing with the new API tbh, reverted these parts. If you find anything that fixes this, then go ahead! I would wait a bit for you, but if the current state is ok with you, I would merge this.

cc @Cyrilvallez if you have time to check why exchanging the causal mask fails here for executorch

Thanks @vasqu! I agree, it makes sense to merge this as-is for now. The main T5 migration seems to work well and I think Executorch can be debugged later. Appreciate the review and all the follow-up commits 🙏.

@vasqu
Copy link
Copy Markdown
Contributor

vasqu commented Oct 27, 2025

Checking some stuff with masking in #41852. Maybe I can fix the issue here as well then

@Aravind-11
Copy link
Copy Markdown
Contributor Author

Checking some stuff with masking in #41852. Maybe I can fix the issue here as well then

The changes made in that pr should be reflected here too?

@vasqu
Copy link
Copy Markdown
Contributor

vasqu commented Oct 27, 2025

Not sure why executorch is failing with the new API tbh, reverted these parts.

I meant these, working on a masking version that requires less special treatment for executorch

@Aravind-11
Copy link
Copy Markdown
Contributor Author

Not sure why executorch is failing with the new API tbh, reverted these parts.

I meant these, working on a masking version that requires less special treatment for executorch

Got it.

@Aravind-11
Copy link
Copy Markdown
Contributor Author

Not sure why executorch is failing with the new API tbh, reverted these parts.

I meant these, working on a masking version that requires less special treatment for executorch

Got it.

Do you want me to migrate those changes here as well?

@vasqu
Copy link
Copy Markdown
Contributor

vasqu commented Oct 27, 2025

No not yet at least

@vasqu
Copy link
Copy Markdown
Contributor

vasqu commented Oct 29, 2025

I checked that #41852 works with t5 and all new masks API. Let's wait for that PR and then change it here to use the new mask API for causal masks as well

@Aravind-11
Copy link
Copy Markdown
Contributor Author

I checked that #41852 works with t5 and all new masks API. Let's wait for that PR and then change it here to use the new mask API for causal masks as well

Got it. Thank you.

@Aravind-11
Copy link
Copy Markdown
Contributor Author

I checked that #41852 works with t5 and all new masks API. Let's wait for that PR and then change it here to use the new mask API for causal masks as well

I think that #28818 issue also can be solved after the #41852 is approved ?

causal_mask = (1.0 - causal_mask) * torch.finfo(inputs_embeds.dtype).min
else:
causal_mask = None
causal_mask = create_bidirectional_mask(
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Big mismatch between variable name and what it contains here 😉 either we're causal, or bidirectionnal

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fair enough, should be more general (this issue was inherited from the previous code tho)

@vasqu
Copy link
Copy Markdown
Contributor

vasqu commented Nov 10, 2025

run-slow: mt5, t5

@github-actions
Copy link
Copy Markdown
Contributor

This comment contains run-slow, running the specified jobs:

models: ["models/mt5", "models/t5"]
quantizations: []

@github-actions
Copy link
Copy Markdown
Contributor

CI Results

Workflow Run ⚙️

✅ No failing test specific to this PR 🎉 !

@github-actions
Copy link
Copy Markdown
Contributor

[For maintainers] Suggested jobs to run (before merge)

run-slow: mt5, t5

Copy link
Copy Markdown
Contributor

@vasqu vasqu left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I pushed the new changes for causal, thx a lot for your patience @Aravind-11

@Cyrilvallez I'd merge this tomorrow or so except you have found anything critical

@Aravind-11
Copy link
Copy Markdown
Contributor Author

I pushed the new changes for causal, thx a lot for your patience @Aravind-11

@Cyrilvallez I'd merge this tomorrow or so except you have found anything critical

Awesome!!

@vasqu vasqu enabled auto-merge (squash) November 11, 2025 18:01
@vasqu vasqu merged commit 2b8068c into huggingface:main Nov 11, 2025
23 checks passed
@HuggingFaceDocBuilderDev
Copy link
Copy Markdown

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

@vasqu
Copy link
Copy Markdown
Contributor

vasqu commented Nov 11, 2025

Thx for sticking through this PR! Super neat to have this

@Aravind-11
Copy link
Copy Markdown
Contributor Author

Thx for sticking through this PR! Super neat to have this

Thank you for the code review!!

SangbumChoi pushed a commit to SangbumChoi/transformers that referenced this pull request Jan 23, 2026
* Refactor: migrate T5 attention masking to masking_utils interface

* Refactor: migrate T5 attention masking to masking_utils interface

* create_bidirectional_mask function with appropriate paramaters

* create_bidirectional_mask function with appropriate paramaters

* fixup executorch + import

* revert causal masks

* rm executorch stuff

* add causal mask with non vmap

* copies

* remove unnecessary import

---------

Co-authored-by: Vasqu <antonprogamer@gmail.com>
Co-authored-by: Anton Vlasjuk <73884904+vasqu@users.noreply.github.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants