Skip to content

Add _loss_is_scaled_for_ga to allow custom trainers to control gradient accumulation loss scaling#43651

Open
abigailtech wants to merge 2 commits intohuggingface:mainfrom
abigailtech:fix-loss-scaling-condition
Open

Add _loss_is_scaled_for_ga to allow custom trainers to control gradient accumulation loss scaling#43651
abigailtech wants to merge 2 commits intohuggingface:mainfrom
abigailtech:fix-loss-scaling-condition

Conversation

@abigailtech
Copy link
Copy Markdown

Added a _loss_is_scaled_for_ga property that custom trainers can override to explicitly control gradient accumulation loss scaling. The default implementation preserves backward compatibility. Custom trainers can now simply override this property to return False instead of manipulating model_accepts_loss_kwargs.

Fixes #43604

@Rocketknight1
Copy link
Copy Markdown
Member

cc @qgallouedec

@qgallouedec
Copy link
Copy Markdown
Member

If I understand correctly model_accepts_loss_kwargs checks if the model forward has kwargs, but is actually used to decide if the loss should be scaled or not, am I right? The underlying assumption if that if the model accepts kwargs, then it takes num_items_in_batch as argument, which means that it scales the loss by itself.
TBH, I'm wondering if we should instead aim for a more ambitious refactor here.

@abigailtech
Copy link
Copy Markdown
Author

If I understand correctly model_accepts_loss_kwargs checks if the model forward has kwargs, but is actually used to decide if the loss should be scaled or not, am I right? The underlying assumption if that if the model accepts kwargs, then it takes num_items_in_batch as argument, which means that it scales the loss by itself. TBH, I'm wondering if we should instead aim for a more ambitious refactor here.

yees, thats right. I'd be open to a more ambitious refactor, do you have a specific direction in mind?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

Revisit the condition for scaling the loss

3 participants