Skip to content

[trainer] fix the GA model_accepts_loss_kwargs#34915

Merged
ArthurZucker merged 4 commits intomainfrom
fix-ga-fix
Dec 5, 2024
Merged

[trainer] fix the GA model_accepts_loss_kwargs#34915
ArthurZucker merged 4 commits intomainfrom
fix-ga-fix

Conversation

@ArthurZucker
Copy link
Copy Markdown
Collaborator

@ArthurZucker ArthurZucker commented Nov 25, 2024

What does this PR do?

Fixes #34577
model_accepts_loss_kwargs was wrongly looking at kwarg names, while you should only need kwargs (since the name can vary for FlashAttentionKwargs, LossKwargs etc)

Copy link
Copy Markdown
Contributor

@muellerzr muellerzr left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

TIL! However, as you can see by the failing test, this doesn't always work 😅 (If we can get it to that's great, I think that's originally why I went with explicit rather than implicit)

@HuggingFaceDocBuilderDev
Copy link
Copy Markdown

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

@ArthurZucker ArthurZucker merged commit a928d9c into main Dec 5, 2024
@ArthurZucker ArthurZucker deleted the fix-ga-fix branch December 5, 2024 15:37
@techkang
Copy link
Copy Markdown
Contributor

techkang commented Dec 6, 2024

I think this pr introduced a new bug that if user use user defined loss funcion and model_accepts_loss_kwargs is False, compute_loss function cannot get the num_items_in_batch argument. Finally, user defined compute_loss_func will not receive this argument either.

@ArthurZucker
Copy link
Copy Markdown
Collaborator Author

model_accepts_loss_kwargs is supposed to be determined at init time and only depends on the forward pass of the model

@techkang
Copy link
Copy Markdown
Contributor

techkang commented Dec 6, 2024

I understand now, I misinterpreted the if condition earlier. However, in this PR, when model_accepts_loss_kwargs is True, it won't pass the num_items_in_batch parameter, which would make the GA loss modification functionality ineffective. Why is that?

@techkang
Copy link
Copy Markdown
Contributor

techkang commented Dec 6, 2024

In newest code, run

export RUN_SLOW=True
pytest tests/trainer/test_trainer.py::TrainerIntegrationPrerunTest::test_gradient_accumulation_loss_alignment

will cause error:

======================================================= short test summary info =======================================================
FAILED tests/trainer/test_trainer.py::TrainerIntegrationPrerunTest::test_gradient_accumulation_loss_alignment - AssertionError: 0.9038000000000004 not less than 0.1 : Difference -0.9038000000000004 is not within 0.1
============================================== 1 failed, 2 warnings in 102.43s (0:01:42) ==============================================

@ArthurZucker
Copy link
Copy Markdown
Collaborator Author

Ah shit the if condition is reversed

@ArthurZucker
Copy link
Copy Markdown
Collaborator Author

Opened a PR for a fix, thanks!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

Mismatched keyword argument names of llama make GA fix invalid

4 participants