Skip to content

Fix: bamba error handling kwargs with forward pass #35378

Closed
Ssukriti wants to merge 2 commits intohuggingface:mainfrom
Ssukriti:fix_bamba_error_forward_Trainer
Closed

Fix: bamba error handling kwargs with forward pass #35378
Ssukriti wants to merge 2 commits intohuggingface:mainfrom
Ssukriti:fix_bamba_error_forward_Trainer

Conversation

@Ssukriti
Copy link
Copy Markdown
Contributor

This PR adds a small change to handle additional kwargs passed to the forward function of BambaModel architecture.

Fixes a behavior when tuning Bamba models - when HF Trainer would pass additional LossArgs to BambaForCausalLM.forward() , which are passed to BambaModel.forward , and need to be ignored.
This would fix the error - num_items_in_batch is an unexpected arg , while tuning the model.

Additional Context:
LlamaForCausalLM class passes all kwargs to self.model.forward , hence they need to be handled.
In the future, we would add support for flashAttentionKwargs in BambaModel.forward() .

cc: @fabianlim

Who can review?

Anyone in the community is free to review the PR once the tests have passed. Feel free to tag
members/contributors who may be interested in your PR.

Signed-off-by: Sukriti-Sharma4 <sukriti.sharma4@ibm.com>
Signed-off-by: Sukriti-Sharma4 <sukriti.sharma4@ibm.com>
@Ssukriti Ssukriti force-pushed the fix_bamba_error_forward_Trainer branch from abe82db to ffe384e Compare December 20, 2024 23:34
@fabianlim
Copy link
Copy Markdown
Contributor

@ArthurZucker @molbap this is just a temporary minor fix to make the model work with trainer.Trainer. We have not properly handled the FlashAttentionKwargs yet because we plan to get to that when we support padding free. But that work is involved, and for now we want to make sure the model is able to tune properly

cc: @ani300 @raghukiran1224

Copy link
Copy Markdown
Collaborator

@ArthurZucker ArthurZucker left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We can close this as #35875 added this! 🤗 sorry for the delay!

@Ssukriti Ssukriti deleted the fix_bamba_error_forward_Trainer branch April 10, 2025 23:44
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants