[OSS] Balance the trainable params only#262
Conversation
…ve little consequences
|
Thank you very much for fixing that, @blefaudeux |
| _params_t = Any | ||
|
|
||
|
|
||
| class BucketFlush(Enum): |
There was a problem hiding this comment.
dead code removal
| if work_handle.callback is not None: | ||
| work_handle.callback() | ||
|
|
||
| def _handle_trailing_buckets(self, flush_type: BucketFlush) -> None: |
There was a problem hiding this comment.
same, sorry about that
| param_lists[rank].append(param) | ||
| sizes[rank] += param.numel() | ||
|
|
||
| # We're partitioning the optimizer state, |
There was a problem hiding this comment.
this is the real change, the partitioning was not taking into account the fact that the params are trainable or not, although this is what counts for the optimizer state. The test case for Huggingface was kind of pathological for that, because there was one big non trainable parameter (goes to rank 0) and then a lot of cumulatively smaller trainable parameters, which all went to rank 1. This meant that the model was effectively optimized on rank 1, hence defeating the whole sharding purpose
|
ping reviews, if you don't mind, I would love master to work for HuggingFace |
| params.append(torch.rand(size, 1)) | ||
|
|
||
| # Make sure that the params are trainable, enforces size-based partitioning | ||
| for p in params: |
There was a problem hiding this comment.
need to add a test case where some params are not trainable too?
|
|
||
| o = optim.OSS(params, lr=0.1) | ||
| assert len(o.param_groups) == 1 | ||
| o.add_param_group({"params": [torch.rand(3, 1)]}) |
There was a problem hiding this comment.
@min-xu-ai this parameter is non trainable actually, so there's a mix of both
|
The CI issue seems unrelated, pipe benchmark and host with an old cuda, I'm missing some context but could have a look late. cc @msbaines just in case |
min-xu-ai
left a comment
There was a problem hiding this comment.
Really nice. Thanks for adding comments and tests.
|
@blefaudeux, when you feel all the important changes have been merged into master - please ping me so that I could re-test with transformers, and then it'd be great to make a new fairscale release on pypi and then we can announce it as working with transformers. I have a brief doc ready to merge when the above has come to satisfaction huggingface/transformers#9208 (in case you'd like to add anything there please don't hesitate to suggest) and then we can make an announcement that transformers has Sharded ZeRO features from fairscale integrated. Yay! |
Hi @stas00, should be good to go ! The current issue with CircleCI is unrelated, problem with the default torch install via pip being incompatible with the provided CUDA on these machines, else the two PRs required for HuggingFace (which were related to genuine issues) are landed. Please keep me posted if you encounter any issues, and thanks for the great work around that ! Looking forward to the announce, I'll have a look :) (spotty availability right now, but doing my best) |
|
That's great! I rebuilt and retested and everything looks good. Would it be possible to make a new release on pypi first and then we are good to announce. Thank you very much! |
Before submitting
What does this PR do?
Fixes #261, also huggingface/transformers#9156. I suspect that this was also partially to blame with the iGPT trainings
PR review
Anyone in the community is free to review the PR once the tests have passed.
If we didn't discuss your PR in Github issues there's a high chance it will not be merged.
Did you have fun?
Yes 🙃 HuggingFace / ShardedDDP / AMP works, at least the dummy example shared