Skip to content

Minor llama4 fixes#38123

Merged
molbap merged 9 commits intomainfrom
minor_fixes
May 20, 2025
Merged

Minor llama4 fixes#38123
molbap merged 9 commits intomainfrom
minor_fixes

Conversation

@molbap
Copy link
Copy Markdown
Contributor

@molbap molbap commented May 14, 2025

What does this PR do?

Fixes a couple nits/wrong defaults in Llama4 code for DynamicCache init - not sure we should even default to it, I think.
The scaling however has to be passed else `eager_attention_forward" will fail.

@github-actions github-actions Bot marked this pull request as draft May 14, 2025 08:50
@github-actions
Copy link
Copy Markdown
Contributor

Hi 👋, thank you for opening this pull request! The pull request is converted to draft by default. The CI will be paused while the PR is in draft mode. When it is ready for review, please click the Ready for review button (at the bottom of the PR page). This will assign reviewers and trigger CI.

@molbap molbap marked this pull request as ready for review May 14, 2025 08:50
@molbap molbap requested a review from ArthurZucker May 14, 2025 08:50
@HuggingFaceDocBuilderDev
Copy link
Copy Markdown

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

Comment thread src/transformers/models/llama4/modeling_llama4.py Outdated
None,
dropout=0.0 if not self.training else self.attention_dropout,
scaling=None,
scaling=self.scaling,
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

super weird that this still work, can you run the test / check where this is from? (blame) might be expected 😐

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it works with sdpa and other kernels because they default the scaling to sqrt(head_dim) when None is passed, checking

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

changing tests, they were broken too 😀

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I ran the tests, it seems it works. I think it was broken during the interim after the recent VLMs refactor

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

...but, because of TP, setting the scaling manually here is probably something that could break the vision attention module. I've set a vision_eager_attention_forward to at least fix the eager forward that is currently broken, while preserving the behaviour we had and err on the side of caution. @ArthurZucker alright with you?

Copy link
Copy Markdown
Contributor

@vasqu vasqu May 16, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh man, I see that we use the config's head dim and it's only set and done so in the text attention 👀

So in essence, the local head_dim was previously used in the vision attention. But depending on the TP degree, this is different. Did I understand it correctly? Imo, this is still a mistake, and not worth being BC: We essentially have different models based on different TP degrees - seems unreasonable to me and not the point of TP.

Copy link
Copy Markdown
Contributor

@vasqu vasqu May 16, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmm, looking into the code

class Llama4VisionAttention(nn.Module):
def __init__(self, config: Llama4VisionConfig):
super().__init__()
self.config = config
self.embed_dim = config.hidden_size
self.num_heads = config.num_attention_heads
self.head_dim = config.hidden_size // config.num_attention_heads
self.num_key_value_groups = 1
self.attention_dropout = config.attention_dropout
self.scaling = self.head_dim**-0.5

It seems to me, the global head_dim was previously used (by module.head_dim) before #37576 - not sure what the local head_dim would be.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think None was passed (self.scaling was unset before the PR you linked). Definitely agree lol, the whole point of tp_plan is to have one model for all possible TP degrees. Passing scaling=None should be safe in all cases here, what was broken was mainly eager_attention_forward that did not compute scaling but instead used only the passed one, as you say... so now in main if you use eager you get a float * NoneType multiplication error here.

So sdpa/flex/etc are fine IMO, just the eager path is currently broken, one simple way that should be DTensor-compatible when it comes to scaling could be e.g. scaling = 1.0 / torch.sqrt(torch.tensor(query.size(-1), dtype=query.dtype, device=query.device)) WDYT?

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Shouldn't we be able to just pass self.scaling here @

But yea, if that doesn't work, then I'm definitely pro your solution :D

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yeah that was my initial fix here 9ffde0b! but I wasn't sure of why things were in this state, it's good to put in writing 🤗

@molbap molbap enabled auto-merge (squash) May 15, 2025 16:39
@molbap molbap disabled auto-merge May 15, 2025 16:43
@molbap molbap requested a review from ArthurZucker May 15, 2025 19:21
Copy link
Copy Markdown
Collaborator

@ArthurZucker ArthurZucker left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM thanks for digging!
@vasqu happy if you upstream the changes to our interface to better support scaling / default etc to make it more user friendly and avoid us these kind of issues!

@vasqu
Copy link
Copy Markdown
Contributor

vasqu commented May 20, 2025

Sounds good :D

@molbap molbap enabled auto-merge (squash) May 20, 2025 13:03
@molbap molbap merged commit 9cde2f5 into main May 20, 2025
21 checks passed
@molbap molbap deleted the minor_fixes branch May 20, 2025 13:15
faaany pushed a commit to faaany/transformers that referenced this pull request May 21, 2025
* fix wrong scaling value/default Cache init

* style

* fix various issues on integration tests

* change expected outputs

* fixup

* fix config access

* protect default scaling
xvyv99 pushed a commit to xvyv99/transformers that referenced this pull request May 21, 2025
* fix wrong scaling value/default Cache init

* style

* fix various issues on integration tests

* change expected outputs

* fixup

* fix config access

* protect default scaling
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants