Skip to content

Fix GA loss bugs and add unit test#35121

Merged
ArthurZucker merged 9 commits intohuggingface:mainfrom
techkang:main
Dec 9, 2024
Merged

Fix GA loss bugs and add unit test#35121
ArthurZucker merged 9 commits intohuggingface:mainfrom
techkang:main

Conversation

@techkang
Copy link
Copy Markdown
Contributor

@techkang techkang commented Dec 6, 2024

What does this PR do?

There are two ways to fix GA loss bugs:

  1. Use num_items_in_batch in loss function defined by model. In this case, model_accepts_loss_kwargs is True.
  2. The model doesn't have loss function or user has self-defined loss function, which is compute_loss_func.

However, previes unit test only test for the second condition. So I introduced a new unit test to cover the first condition and fix the bugs by the way.

Before submitting

  • This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case).
  • Did you read the contributor guideline,
    Pull Request section?
  • Was this discussed/approved via a Github issue or the forum? Please add a link
    to it if that's the case.
  • Did you make sure to update the documentation with your changes? Here are the
    documentation guidelines, and
    here are tips on formatting docstrings.
  • Did you write any new necessary tests?

Who can review?

@muellerzr @ArthurZucker

Copy link
Copy Markdown
Contributor

@muellerzr muellerzr left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks! These solutions make sense and ran the tests myself. Left some nits for less-leeway on the test closeness. cc @ArthurZucker

Comment thread tests/trainer/test_trainer.py Outdated
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's be a bit more aggressive and do 0.15 (this passes). I still feel that's quite big but I can't figure out why (my tests showed 0.001 should be doable).

Suggested change
self.assertLess(max(diff_truth), 0.3, f"Difference {max(diff_truth)} is not within 0.3")
self.assertLess(max(diff_truth), 0.15, f"Difference {max(diff_truth)} is not within 0.15")

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It is strange that I tested on both Mac and Windows that max(diff_truth) is 0.144. So maybe 0.15 may failed on some other machine.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done! I use TinyStories to narrow down gap to the same as you. The code is submitted.

Comment thread tests/trainer/test_trainer.py Outdated
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
self.assertLess(max(diff_truth), 0.3, f"Difference {max(diff_truth)} is not within 0.3")
self.assertLess(max(diff_truth), 0.2, f"Difference {max(diff_truth)} is not within 0.2")

Similarly we can be aggressive here too

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I managed to reduce the gap to 1e-4 by padding all input labels to the same length. However, this method did not work for the GPT-2 model. I will continue to explore other solutions.

Copy link
Copy Markdown
Collaborator

@ArthurZucker ArthurZucker left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry @muellerzr but this does not solve:

FAILED examples/pytorch/test_pytorch_examples.py::ExamplesTests::test_run_speech_recognition_seq2seq - TypeError: Wav2Vec2Model.forward() got an unexpected keyword argument 'num_items_in_batch'

so I am not sure I understand. Related to #35113 and #35128.
We can't merge with the broken test

@techkang
Copy link
Copy Markdown
Contributor Author

techkang commented Dec 7, 2024

@ArthurZucker The Wav2Vec2Model bug is because SpeechEncoderDecoderModel takes variable argument as forward paramters:


But it dispatches the paramater to it's encoder and decode which doesn't accept variable argument:
kwargs_encoder = {argument: value for argument, value in kwargs.items() if not argument.startswith("decoder_")}
kwargs_decoder = {
argument[len("decoder_") :]: value for argument, value in kwargs.items() if argument.startswith("decoder_")
}

I think the better solution is to modify it's decode to accept variable argument. I proposed a new commit and the test succeed.

@techkang
Copy link
Copy Markdown
Contributor Author

techkang commented Dec 8, 2024

Finally all unit test passed. Please check again. @muellerzr @ArthurZucker

Copy link
Copy Markdown
Collaborator

@ArthurZucker ArthurZucker left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks a lot @techkang !
I did not dive enough on the test, my bad 🤗
Merging ASAP and doing the patch

import datasets

model_name = "distilgpt2"
model_name = "nickypro/tinyllama-110M"
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Okay, would be nice if we had a safetensors model here but alright.

@ArthurZucker ArthurZucker merged commit 1ccca8f into huggingface:main Dec 9, 2024
ArthurZucker pushed a commit that referenced this pull request Dec 10, 2024
* fix GA bugs and add unit test

* narrow down model loss unit test diff gap

* format code to make ruff happy

* send num_items_in_batch argument to decoder

* fix GA loss bug in BertLMHeadModel

* use TinyStories-33M to narrow down diff gap

* fotmat code

* missing .config

* avoid add extra args

---------

Co-authored-by: kangsheng <kangsheng@meituan.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants