-
Notifications
You must be signed in to change notification settings - Fork 296
[OSS] Balance the trainable params only #262
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
232a54b
e305f10
ba09a9a
c8e2169
ca194ec
a5c83a2
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -5,7 +5,6 @@ | |
|
|
||
| from collections import OrderedDict, deque | ||
| import copy | ||
| from enum import Enum, auto | ||
| import itertools | ||
| from itertools import chain | ||
| import logging | ||
|
|
@@ -27,11 +26,6 @@ | |
| _params_t = Any | ||
|
|
||
|
|
||
| class BucketFlush(Enum): | ||
| Reduce = auto() | ||
| Broadcast = auto() | ||
|
|
||
|
|
||
| class OSS(Optimizer): | ||
| """Wraps an arbitrary :class:`optim.Optimizer <torch.optim.Optimizer>` | ||
| optimizer and shards its state as described by ZeRO_. | ||
|
|
@@ -139,7 +133,16 @@ def partition_parameters(self) -> List[List[dict]]: | |
| # Add this param to rank with smallest size. | ||
| rank = sizes.index(min(sizes)) | ||
| param_lists[rank].append(param) | ||
| sizes[rank] += param.numel() | ||
|
|
||
| # We're partitioning the optimizer state, | ||
|
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. this is the real change, the partitioning was not taking into account the fact that the params are trainable or not, although this is what counts for the optimizer state. The test case for Huggingface was kind of pathological for that, because there was one big non trainable parameter (goes to rank 0) and then a lot of cumulatively smaller trainable parameters, which all went to rank 1. This meant that the model was effectively optimized on rank 1, hence defeating the whole sharding purpose |
||
| # so trainable parameters are the ones which really count | ||
| if param.requires_grad: | ||
| sizes[rank] += param.numel() | ||
| else: | ||
| # Spread frozen params on a per-tensor basis | ||
| # Mostly useful for balance partitions for fine tuning for instance | ||
| # Not required strictly speaking | ||
| sizes[rank] += 1 | ||
|
|
||
| for rank, params in enumerate(param_lists): | ||
| param_group_rank = copy.copy(param_group) | ||
|
|
@@ -585,30 +588,6 @@ def _try_consume_work_handle(self) -> None: | |
| if work_handle.callback is not None: | ||
| work_handle.callback() | ||
|
|
||
| def _handle_trailing_buckets(self, flush_type: BucketFlush) -> None: | ||
|
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. same, sorry about that |
||
| """ | ||
| Go through the buckets, flush them if not already empty | ||
| .. warning: Could be that a bucket flush was already requested, needs to be handled carefully | ||
| """ | ||
|
|
||
| for bucket_list in self.buckets.values(): | ||
| for bucket in bucket_list: | ||
| if bucket.current_offset > 0: | ||
| self.work_handles.append( | ||
| Workhandle( | ||
| handle=dist.broadcast( | ||
| tensor=bucket.buffer, src=bucket.global_ref_rank, group=self.group, async_op=True, | ||
| ) | ||
| if flush_type == BucketFlush.Broadcast | ||
| else dist.reduce( | ||
| tensor=bucket.buffer, dst=bucket.global_ref_rank, group=self.group, async_op=True, | ||
| ), | ||
| callback=bucket.unroll, | ||
| ) | ||
| ) | ||
|
|
||
| self._consume_work_handles() | ||
|
|
||
| def _setup_bucket_strategy(self) -> None: | ||
| """ Tag parameters to either bucket them or broadcast/reduce them directly. The parameters are ordered | ||
| (smallest first), the bucket will hold the smallest elements, the remaining ones will be directly sent | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
dead code removal