Conversation
|
Hi 👋, thank you for opening this pull request! The pull request is converted to draft by default. The CI will be paused while the PR is in draft mode. When it is ready for review, please click the |
|
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update. |
| None, | ||
| dropout=0.0 if not self.training else self.attention_dropout, | ||
| scaling=None, | ||
| scaling=self.scaling, |
There was a problem hiding this comment.
super weird that this still work, can you run the test / check where this is from? (blame) might be expected 😐
There was a problem hiding this comment.
I think it works with sdpa and other kernels because they default the scaling to sqrt(head_dim) when None is passed, checking
There was a problem hiding this comment.
changing tests, they were broken too 😀
There was a problem hiding this comment.
I ran the tests, it seems it works. I think it was broken during the interim after the recent VLMs refactor
There was a problem hiding this comment.
...but, because of TP, setting the scaling manually here is probably something that could break the vision attention module. I've set a vision_eager_attention_forward to at least fix the eager forward that is currently broken, while preserving the behaviour we had and err on the side of caution. @ArthurZucker alright with you?
There was a problem hiding this comment.
Oh man, I see that we use the config's head dim and it's only set and done so in the text attention 👀
So in essence, the local head_dim was previously used in the vision attention. But depending on the TP degree, this is different. Did I understand it correctly? Imo, this is still a mistake, and not worth being BC: We essentially have different models based on different TP degrees - seems unreasonable to me and not the point of TP.
There was a problem hiding this comment.
Hmm, looking into the code
transformers/src/transformers/models/llama4/modeling_llama4.py
Lines 1065 to 1074 in 0f77ca7
It seems to me, the global head_dim was previously used (by module.head_dim) before #37576 - not sure what the local head_dim would be.
There was a problem hiding this comment.
I think None was passed (self.scaling was unset before the PR you linked). Definitely agree lol, the whole point of tp_plan is to have one model for all possible TP degrees. Passing scaling=None should be safe in all cases here, what was broken was mainly eager_attention_forward that did not compute scaling but instead used only the passed one, as you say... so now in main if you use eager you get a float * NoneType multiplication error here.
So sdpa/flex/etc are fine IMO, just the eager path is currently broken, one simple way that should be DTensor-compatible when it comes to scaling could be e.g. scaling = 1.0 / torch.sqrt(torch.tensor(query.size(-1), dtype=query.dtype, device=query.device)) WDYT?
There was a problem hiding this comment.
Shouldn't we be able to just pass self.scaling here @
But yea, if that doesn't work, then I'm definitely pro your solution :D
There was a problem hiding this comment.
yeah that was my initial fix here 9ffde0b! but I wasn't sure of why things were in this state, it's good to put in writing 🤗
ArthurZucker
left a comment
There was a problem hiding this comment.
LGTM thanks for digging!
@vasqu happy if you upstream the changes to our interface to better support scaling / default etc to make it more user friendly and avoid us these kind of issues!
|
Sounds good :D |
* fix wrong scaling value/default Cache init * style * fix various issues on integration tests * change expected outputs * fixup * fix config access * protect default scaling
* fix wrong scaling value/default Cache init * style * fix various issues on integration tests * change expected outputs * fixup * fix config access * protect default scaling
What does this PR do?
Fixes a couple nits/wrong defaults in Llama4 code for DynamicCache init - not sure we should even default to it, I think.
The scaling however has to be passed else `eager_attention_forward" will fail.