Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
41 changes: 10 additions & 31 deletions fairscale/optim/oss.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@

from collections import OrderedDict, deque
import copy
from enum import Enum, auto
import itertools
from itertools import chain
import logging
Expand All @@ -27,11 +26,6 @@
_params_t = Any


class BucketFlush(Enum):
Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

dead code removal

Reduce = auto()
Broadcast = auto()


class OSS(Optimizer):
"""Wraps an arbitrary :class:`optim.Optimizer <torch.optim.Optimizer>`
optimizer and shards its state as described by ZeRO_.
Expand Down Expand Up @@ -139,7 +133,16 @@ def partition_parameters(self) -> List[List[dict]]:
# Add this param to rank with smallest size.
rank = sizes.index(min(sizes))
param_lists[rank].append(param)
sizes[rank] += param.numel()

# We're partitioning the optimizer state,
Copy link
Copy Markdown
Contributor Author

@blefaudeux blefaudeux Dec 18, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this is the real change, the partitioning was not taking into account the fact that the params are trainable or not, although this is what counts for the optimizer state. The test case for Huggingface was kind of pathological for that, because there was one big non trainable parameter (goes to rank 0) and then a lot of cumulatively smaller trainable parameters, which all went to rank 1. This meant that the model was effectively optimized on rank 1, hence defeating the whole sharding purpose

# so trainable parameters are the ones which really count
if param.requires_grad:
sizes[rank] += param.numel()
else:
# Spread frozen params on a per-tensor basis
# Mostly useful for balance partitions for fine tuning for instance
# Not required strictly speaking
sizes[rank] += 1

for rank, params in enumerate(param_lists):
param_group_rank = copy.copy(param_group)
Expand Down Expand Up @@ -585,30 +588,6 @@ def _try_consume_work_handle(self) -> None:
if work_handle.callback is not None:
work_handle.callback()

def _handle_trailing_buckets(self, flush_type: BucketFlush) -> None:
Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

same, sorry about that

"""
Go through the buckets, flush them if not already empty
.. warning: Could be that a bucket flush was already requested, needs to be handled carefully
"""

for bucket_list in self.buckets.values():
for bucket in bucket_list:
if bucket.current_offset > 0:
self.work_handles.append(
Workhandle(
handle=dist.broadcast(
tensor=bucket.buffer, src=bucket.global_ref_rank, group=self.group, async_op=True,
)
if flush_type == BucketFlush.Broadcast
else dist.reduce(
tensor=bucket.buffer, dst=bucket.global_ref_rank, group=self.group, async_op=True,
),
callback=bucket.unroll,
)
)

self._consume_work_handles()

def _setup_bucket_strategy(self) -> None:
""" Tag parameters to either bucket them or broadcast/reduce them directly. The parameters are ordered
(smallest first), the bucket will hold the smallest elements, the remaining ones will be directly sent
Expand Down
56 changes: 46 additions & 10 deletions tests/optim/test_oss.py
Original file line number Diff line number Diff line change
Expand Up @@ -178,16 +178,47 @@ def test_implicit_local_state_dict(self):

def run_test_add_param_group(rank, world_size, tempfile_name):
dist_init(rank, world_size, tempfile_name)
params = []
for size in [4, 5, 2, 6, 4]:
params.append(torch.rand(size, 1))
o = optim.OSS(params, lr=0.1)
assert len(o.param_groups) == 1
o.add_param_group({"params": [torch.rand(3, 1)]})
assert len(o.param_groups) == 2
# Verify that added group is added to the correct partition making all have 8 elements.
assert sum([x.numel() for g in o.optim.param_groups for x in g["params"]]) == 8
assert len(o.optim.param_groups) == 2

# Test with all parameters trainable to begin with
def all_trainable():
params = []
for size in [4, 5, 2, 6, 4]:
params.append(torch.rand(size, 1))

# Make sure that the params are trainable, enforces size-based partitioning
for p in params:
p.requires_grad = True

o = optim.OSS(params, lr=0.1)

assert len(o.param_groups) == 1
o.add_param_group({"params": [torch.rand(3, 1)]})

assert len(o.param_groups) == 2
# Verify that added group is added to the correct partition making all have 8 elements.
assert sum([x.numel() for g in o.optim.param_groups for x in g["params"]]) == 8
assert len(o.optim.param_groups) == 2

# Test a pathological config with a first big non-trainable param
def some_trainable():
params = []
for size in [100, 3, 5, 2, 6, 4]:
params.append(torch.rand(size, 1))

# Make sure that the params are trainable, enforces size-based partitioning
for p in params[1:]:
p.requires_grad = True

o = optim.OSS(params, lr=0.1)

assert len(o.param_groups) == 1
o.add_param_group({"params": [torch.rand(3, 1)]})

assert len(o.param_groups) == 2
assert len(o.optim.param_groups) == 2

all_trainable()
some_trainable()

dist.destroy_process_group()

Expand Down Expand Up @@ -303,6 +334,11 @@ def run_test_sharding(rank, world_size, tempfile_name):
params = []
for size in [5, 4, 2, 6, 4, 3]:
params.append(torch.rand(size, 1))

# Make sure that the params are trainable, enforces size-based partitioning
for p in params:
p.requires_grad = True

o = optim.OSS(params, lr=0.1)
assert sum([x.numel() for x in o.optim.param_groups[0]["params"]]) == 8

Expand Down