add FlashAttentionKwargs and seq_idx to flat collator#36456
add FlashAttentionKwargs and seq_idx to flat collator#36456Cyrilvallez merged 31 commits intohuggingface:mainfrom
Conversation
|
Hi 👋, thank you for opening this pull request! The pull request is converted to draft by default. When it is ready for review, please click the |
1d732bf to
f0113af
Compare
|
I'd like to take a look as well when you think you're ready, so gladly ping then :) |
3daac1b to
a3fc94c
Compare
|
@vasqu could you please take a look when you have time? Thank you. |
vasqu
left a comment
There was a problem hiding this comment.
Smaller issues/nits overall, I'd be for a warning in case settings might cause issues (RoPE with fa kwarg only)
Otherwise, not on you but I think it would be nice to have equivalent collator tests on torch at least. Usually, all corresponding paths are tested (pt, tf, np).
| return_position_ids=True, | ||
| return_flash_attn_kwargs=False, |
There was a problem hiding this comment.
Iiuc #35941 points to a problem that fa kwargs will cause issues on the rope paths subsequently (under fa true, positions false).
I'd be for a warning in case of the bad combo of fa kwarg true and position ids false on init (maybe someone has a different use case which shouldn't directly cause errors)
There was a problem hiding this comment.
maybe someone has a different use case which shouldn't directly cause errors
Yeah, I thought about warnings in cases like that, but I was hesitant because of different requirements for different models.
Like, if a transformer model uses FA but not RoPE, then FA True, pos_ids False make sense. And for a mamba-only model (like mamba2) FA False, pos_ids False, seq_idx True is what you'd use.
There was a problem hiding this comment.
So IMO it should be up to the model to ensure it's getting the right inputs it needs and to raise a ValueError or similar if an improper combination is passed.
There was a problem hiding this comment.
Fair point, then FA utils should be written in a way that ensures this (which should be a follow up PR to this).
There was a problem hiding this comment.
I don't think this can be address at the level of the FA utils, since different models can logically use different valid combinations here. The FA utils just need to be able to handle the different combos.
It does look like non-trivial FlashAttentionKwargs and position_ids=attention_mask=None is currently not supported, though. IIUC you'd end up in this block with all of your cu_seq_lens_{q,k} etc ignored.
There was a problem hiding this comment.
Couldn't we just detect if fa kwargs were passed (before the if else into the different paths) and handle it then if position_ids is None? It might be an error (unintentional path) or warning (no padding path); unsure here. (Or even implementing that path ourselves which seems unwanted)
Imo it's a bit confusing when the original flash attn can handle those args while we can't. As a user it would silently fall through when I'm familiar with fa but not transformers.
There was a problem hiding this comment.
Couldn't we just detect if fa kwargs were passed (before the if else into the different paths) and handle it then if position_ids is None?
Yeah, not sure why this isn't done in the current FA utils. Also found this confusing.
There was a problem hiding this comment.
Yeah, either way I think some sort of handling is warranted here which should help ease up the bamba checks (hopefully).
There was a problem hiding this comment.
ah the only issue I see here is that the order of arg is breaking!
There was a problem hiding this comment.
Ah, good catch @ArthurZucker , I'll move the separator_id back to the second position.
I found it a little surprising that it's all |
|
Would you be willing to add the other paths or at least pt? Seems like an oversight on the initial PR 👀 |
Yep, and I just discovered that the So, this is all going to take some reworking. I'll ping again later. |
a3fc94c to
601e649
Compare
|
Alright, made a few changes:
Do you have any advice on the test @vasqu ? I wanted a non-trivial test that the collator outputs are in the right format (which is easy to get wrong), but the above seems like overkill. EDIT: I only verified that the |
vasqu
left a comment
There was a problem hiding this comment.
I think it's overall solid. Some implementations could be simplified imo (e.g. the wrapper at the end of the collator) and left some other smaller comments.
No batch dimension is expected on any of the FlashAttentionKwargs, and all these variables and seq_idx must be int32 rather than the default int64. I removed the reliance on the default data collators to achieve this.
Personally, I don't think that the comment about the batch dim provides any real value - would drop it, kinda confused me at first. int32 adjustments are good, only worried about max_seq_len_q/k (should be a simple py in32, but see comments).
I added a ModelTesterMixin::test_flash_attention_2_padding_matches_padding_free_with_position_ids_from_flat_collator which both has an awfully long name and is probably an unnecessarily expensive addition to the test suite.
Imo, could be added to the previous padding-free test. If you want to keep it as is, I'd suggest a renaming (from flattening collator is rather non-telling).
Added tf and pt tests for the flat collator.
🥳
vasqu
left a comment
There was a problem hiding this comment.
LGTM! cc @ArthurZucker for core maintainer review
|
Looks like you'll need to run |
5590060 to
b582c82
Compare
|
@ArthurZucker please let me know if I can answer any further questions |
|
@vasqu any advice here? |
|
@garrett361 sorry don't have anything to add. Arthur is quite busy so notifications can easily get lost at times, I'm gonna cc again. |
|
cc core maintainer review @ArthurZucker @Cyrilvallez |
ArthurZucker
left a comment
There was a problem hiding this comment.
Thanks for your PR, sorry that it got lost 😓 merging 🤗
Cyrilvallez
left a comment
There was a problem hiding this comment.
Just a few nits on top of Arthur's review as I was looking at it at the same time!
No worries, I know you're super busy. Will make these changes and ping. |
b582c82 to
0d5128c
Compare
|
The failing test is due to infra issues: Believe I made all the requested changes. CC @ArthurZucker @Cyrilvallez |
b9e9cc4 to
3f42490
Compare
|
Alright @vasqu |
|
Hi @Cyrilvallez and @ArthurZucker , a reminder about this one, please. |
Cyrilvallez
left a comment
There was a problem hiding this comment.
Merging! Thanks a lot for the great PR! 🤗 Super sorry we took this long!
|
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update. |
|
Thanks @Cyrilvallez ! |
) * add flash attn kwargs to flattening collator * add return_seq_idx option * doc string edits * cleaner max len updates * various fixes * temp testing code * return int32 seq_idx and FlashAttnKwargs * DataCollatorIntegrationTest impl * fix batch dims and dtypes * fill out remaining collator tests * test name change and fmt * rm unused var * fmt * minor change * fmt * add missing pos_ids check * consistent {np,pt,tf} tests * split pt tests into 3, like np/tf tests * mv comment, rename fa test * remove batch dim comment * simply wrapping * compute cu_seq_len/max_length once * fmt * remove tf code * rm warning * move separator_id back to 2nd pos * use cleaner lists in tests * ret -> batch * fmt * attr ordering * use py ints for max_length_{k,q}
) * add flash attn kwargs to flattening collator * add return_seq_idx option * doc string edits * cleaner max len updates * various fixes * temp testing code * return int32 seq_idx and FlashAttnKwargs * DataCollatorIntegrationTest impl * fix batch dims and dtypes * fill out remaining collator tests * test name change and fmt * rm unused var * fmt * minor change * fmt * add missing pos_ids check * consistent {np,pt,tf} tests * split pt tests into 3, like np/tf tests * mv comment, rename fa test * remove batch dim comment * simply wrapping * compute cu_seq_len/max_length once * fmt * remove tf code * rm warning * move separator_id back to 2nd pos * use cleaner lists in tests * ret -> batch * fmt * attr ordering * use py ints for max_length_{k,q}
What does this PR do?
Adds additional, optional return values in
DataCollatorWithFlatteningas needed for padding-free training with particular models.Relates to #35861 and #35941.
Fixes # (issue)
Before submitting
Pull Request section?
to it if that's the case.
documentation guidelines, and
here are tips on formatting docstrings.
Who can review?
Anyone in the community is free to review the PR once the tests have passed. Feel free to tag
members/contributors who may be interested in your PR.