diff --git a/fairscale/optim/oss.py b/fairscale/optim/oss.py index 18dce8c7a..26dfe7df2 100644 --- a/fairscale/optim/oss.py +++ b/fairscale/optim/oss.py @@ -5,7 +5,6 @@ from collections import OrderedDict, deque import copy -from enum import Enum, auto import itertools from itertools import chain import logging @@ -27,11 +26,6 @@ _params_t = Any -class BucketFlush(Enum): - Reduce = auto() - Broadcast = auto() - - class OSS(Optimizer): """Wraps an arbitrary :class:`optim.Optimizer ` optimizer and shards its state as described by ZeRO_. @@ -139,7 +133,16 @@ def partition_parameters(self) -> List[List[dict]]: # Add this param to rank with smallest size. rank = sizes.index(min(sizes)) param_lists[rank].append(param) - sizes[rank] += param.numel() + + # We're partitioning the optimizer state, + # so trainable parameters are the ones which really count + if param.requires_grad: + sizes[rank] += param.numel() + else: + # Spread frozen params on a per-tensor basis + # Mostly useful for balance partitions for fine tuning for instance + # Not required strictly speaking + sizes[rank] += 1 for rank, params in enumerate(param_lists): param_group_rank = copy.copy(param_group) @@ -585,30 +588,6 @@ def _try_consume_work_handle(self) -> None: if work_handle.callback is not None: work_handle.callback() - def _handle_trailing_buckets(self, flush_type: BucketFlush) -> None: - """ - Go through the buckets, flush them if not already empty - .. warning: Could be that a bucket flush was already requested, needs to be handled carefully - """ - - for bucket_list in self.buckets.values(): - for bucket in bucket_list: - if bucket.current_offset > 0: - self.work_handles.append( - Workhandle( - handle=dist.broadcast( - tensor=bucket.buffer, src=bucket.global_ref_rank, group=self.group, async_op=True, - ) - if flush_type == BucketFlush.Broadcast - else dist.reduce( - tensor=bucket.buffer, dst=bucket.global_ref_rank, group=self.group, async_op=True, - ), - callback=bucket.unroll, - ) - ) - - self._consume_work_handles() - def _setup_bucket_strategy(self) -> None: """ Tag parameters to either bucket them or broadcast/reduce them directly. The parameters are ordered (smallest first), the bucket will hold the smallest elements, the remaining ones will be directly sent diff --git a/tests/optim/test_oss.py b/tests/optim/test_oss.py index e30a68427..13650fa59 100644 --- a/tests/optim/test_oss.py +++ b/tests/optim/test_oss.py @@ -178,16 +178,47 @@ def test_implicit_local_state_dict(self): def run_test_add_param_group(rank, world_size, tempfile_name): dist_init(rank, world_size, tempfile_name) - params = [] - for size in [4, 5, 2, 6, 4]: - params.append(torch.rand(size, 1)) - o = optim.OSS(params, lr=0.1) - assert len(o.param_groups) == 1 - o.add_param_group({"params": [torch.rand(3, 1)]}) - assert len(o.param_groups) == 2 - # Verify that added group is added to the correct partition making all have 8 elements. - assert sum([x.numel() for g in o.optim.param_groups for x in g["params"]]) == 8 - assert len(o.optim.param_groups) == 2 + + # Test with all parameters trainable to begin with + def all_trainable(): + params = [] + for size in [4, 5, 2, 6, 4]: + params.append(torch.rand(size, 1)) + + # Make sure that the params are trainable, enforces size-based partitioning + for p in params: + p.requires_grad = True + + o = optim.OSS(params, lr=0.1) + + assert len(o.param_groups) == 1 + o.add_param_group({"params": [torch.rand(3, 1)]}) + + assert len(o.param_groups) == 2 + # Verify that added group is added to the correct partition making all have 8 elements. + assert sum([x.numel() for g in o.optim.param_groups for x in g["params"]]) == 8 + assert len(o.optim.param_groups) == 2 + + # Test a pathological config with a first big non-trainable param + def some_trainable(): + params = [] + for size in [100, 3, 5, 2, 6, 4]: + params.append(torch.rand(size, 1)) + + # Make sure that the params are trainable, enforces size-based partitioning + for p in params[1:]: + p.requires_grad = True + + o = optim.OSS(params, lr=0.1) + + assert len(o.param_groups) == 1 + o.add_param_group({"params": [torch.rand(3, 1)]}) + + assert len(o.param_groups) == 2 + assert len(o.optim.param_groups) == 2 + + all_trainable() + some_trainable() dist.destroy_process_group() @@ -303,6 +334,11 @@ def run_test_sharding(rank, world_size, tempfile_name): params = [] for size in [5, 4, 2, 6, 4, 3]: params.append(torch.rand(size, 1)) + + # Make sure that the params are trainable, enforces size-based partitioning + for p in params: + p.requires_grad = True + o = optim.OSS(params, lr=0.1) assert sum([x.numel() for x in o.optim.param_groups[0]["params"]]) == 8