Update trainer for easier handling of accumulate, compile fixes, and proper reporting#34511
Update trainer for easier handling of accumulate, compile fixes, and proper reporting#34511
Conversation
|
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update. |
Rocketknight1
left a comment
There was a problem hiding this comment.
Couple comments about the test!
Rocketknight1
left a comment
There was a problem hiding this comment.
Tests look clean to me now, and I'm trusting you on the accelerate side of things! 😅
cc @LysandreJik / @ArthurZucker for core maintainer review
| context = ( | ||
| functools.partial(self.accelerator.no_sync, model=model) | ||
| if i == len(batch_samples) - 1 | ||
| else contextlib.nullcontext | ||
| ) |
There was a problem hiding this comment.
For an explanation on what we have going on here @Rocketknight1 , during DDP we use model.no_sync() to only communicate across all GPUs during the next step outside it (so we speed up training when not needed when doing gradient accumulation). accelerator.no_sync() is the lower-level accumulate() API which makes that op backed-independent (so on a single GPU it just does nullcontext)
|
@Milad335t just warning you to stop spamming or we'll have to block you 😢 |
ArthurZucker
left a comment
There was a problem hiding this comment.
Thanks, let's hope this gets stabilized!
| num_items_in_batch = sum( | ||
| [data_batch["labels"][..., 1:].ne(-100).sum().item() for data_batch in batch_samples] | ||
| ) | ||
| num_items_in_batch = sum([(batch["labels"].ne(-100)).sum() for batch in batch_samples]) |
There was a problem hiding this comment.
weird to me that we have to use -100 here, instead of a general parameter but whit was already the case
There was a problem hiding this comment.
IIRC we use -100 for padding by default in the Trainer. I can align it to self.processor if it exists else -100 if that's better?:)
There was a problem hiding this comment.
Actually our padding index is -100 everywhere.
There was a problem hiding this comment.
okay sounds good then sorry
There was a problem hiding this comment.
No worries, it's weird for me too :)
There was a problem hiding this comment.
Why do we no longer need to shift labels ["labels"][...,1:] when getting num_items_in_batch?
There was a problem hiding this comment.
sure, I also think we need to shift labels before computing the num_items_in_batch. Otherwise, the value is incorrect as the first element in labels may not be -100
ArthurZucker
left a comment
There was a problem hiding this comment.
Thanks patching today!
| num_items_in_batch = sum( | ||
| [data_batch["labels"][..., 1:].ne(-100).sum().item() for data_batch in batch_samples] | ||
| ) | ||
| num_items_in_batch = sum([(batch["labels"].ne(-100)).sum() for batch in batch_samples]) |
There was a problem hiding this comment.
okay sounds good then sorry
…proper reporting (#34511) * Update trainer for easier handling of accumulate + proper reporting * test * Fixup tests * Full fix * Fix style * rm comment * Fix tests * Minimize test + remove py 311 check * Unused import * Forward contrib credits from discussions * Fix reported metrics * Refactor, good as it's going to get * rm pad tok id check * object detection and audio are being annoying * Fin * Fin x2 --------- Co-authored-by: Gyanateet Dutta <Ryukijano@users.noreply.github.com>
…proper reporting (huggingface#34511) * Update trainer for easier handling of accumulate + proper reporting * test * Fixup tests * Full fix * Fix style * rm comment * Fix tests * Minimize test + remove py 311 check * Unused import * Forward contrib credits from discussions * Fix reported metrics * Refactor, good as it's going to get * rm pad tok id check * object detection and audio are being annoying * Fin * Fin x2 --------- Co-authored-by: Gyanateet Dutta <Ryukijano@users.noreply.github.com>
…proper reporting (huggingface#34511) * Update trainer for easier handling of accumulate + proper reporting * test * Fixup tests * Full fix * Fix style * rm comment * Fix tests * Minimize test + remove py 311 check * Unused import * Forward contrib credits from discussions * Fix reported metrics * Refactor, good as it's going to get * rm pad tok id check * object detection and audio are being annoying * Fin * Fin x2 --------- Co-authored-by: Gyanateet Dutta <Ryukijano@users.noreply.github.com>
| else: | ||
| if num_items_in_batch is not None: | ||
| if self.compute_loss_func or self.model_accepts_loss_kwargs: | ||
| loss *= self.args.gradient_accumulation_steps |
There was a problem hiding this comment.
I'm confused that the loss is no longer multiplied by the gradient accumulation steps here, because the loss has been multiplied by the data parallel size in https://github.com/huggingface/transformers/pull/34511/files#diff-ed55888e6665791fe92cc8fc0c499da54f4ace6738551cd9a2591881cda076deR3702-R3703
What does this PR do?
Alternative to #34442
TL;DR we just need to remove
lru_cacheand everything will work fine. (and adds a test)This PR also takes the full lessons from my article and adds it to the
Trainerfor a simpler solution to the grad accum calculation (we shouldn't rely onacceleratorfrom now on bc it can't handle the nuances with the grad accum fix at the highest level API, so we use a lower level version instead)Fixes #34402
Would recommend a patch after this
Before submitting
Pull Request section?
to it if that's the case.
documentation guidelines, and
here are tips on formatting docstrings.
Who can review?
Anyone in the community is free to review the PR once the tests have passed. Feel free to tag
members/contributors who may be interested in your PR.
@ArthurZucker @Rocketknight1