Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
42 commits
Select commit Hold shift + click to select a range
5f2baaa
WIP
blefaudeux Jan 13, 2021
9e98ac8
fix a loading issue, add a reproducibility unit test
blefaudeux Jan 13, 2021
b464e00
would need more checks, but getting there, saving a purely pytorch state
blefaudeux Jan 13, 2021
fc91d15
should be good to go, the OSS unit tests are still pretty horrible
blefaudeux Jan 13, 2021
70712ef
tentatively solving torch 1.5 compatibility
blefaudeux Jan 15, 2021
0b749d4
minimizing changes with this PR, removed some code reordering to make…
blefaudeux Jan 15, 2021
2ae05f2
adding a test to prove the inter operability with upstream pytorch
blefaudeux Jan 15, 2021
4d59014
updating the changelog
blefaudeux Jan 15, 2021
2754e76
eager state pruning
blefaudeux Jan 15, 2021
3468c0e
torch1.5 compat, this is infinite
blefaudeux Jan 15, 2021
fbee662
reverting the torch1.5 broken fix, the state is very strange in that …
blefaudeux Jan 16, 2021
5589c74
WIP - need to properly handle multiple param groups
blefaudeux Jan 19, 2021
25c2c11
fix + adding this to unit tests
blefaudeux Jan 19, 2021
8d8b923
Merge branch 'oss_views_fix' into oss_elastic_checkpoints
blefaudeux Jan 20, 2021
802c78e
Merge remote-tracking branch 'upstream/master' into oss_elastic_check…
blefaudeux Jan 20, 2021
ce59fd7
raise an exception if not consolidated
blefaudeux Jan 20, 2021
6a3b043
unit test fix following botched merge, now needs to make sure that gr…
blefaudeux Jan 20, 2021
aa29155
some improvement, still need some careful thinking
blefaudeux Jan 20, 2021
9135d54
should be gtg for torch1.6+. torch 1.5 has a different state indexing
blefaudeux Jan 20, 2021
89d14c8
cleanup, still no torch1.5 support but ordering the state dict the wa…
blefaudeux Jan 21, 2021
378d8f1
pytorch 1.5 compat
blefaudeux Jan 21, 2021
2b78172
comment cleanup
blefaudeux Jan 21, 2021
6975075
code review + update from torch PR, remove two loops
blefaudeux Jan 27, 2021
485242c
Merge remote-tracking branch 'upstream/master' into oss_elastic_check…
blefaudeux Jan 27, 2021
3837a18
docfix
blefaudeux Jan 27, 2021
05e84a0
review, thanks Mandeep !
blefaudeux Jan 28, 2021
0d321d9
code review, thanks Mandeep! Penelope mode
blefaudeux Jan 29, 2021
4f8f61a
updated changelog, merge
blefaudeux Feb 1, 2021
b1f9527
Merge remote-tracking branch 'origin/master' into oss_elastic_checkpo…
blefaudeux Feb 1, 2021
ab9f242
linting Pipe files..
blefaudeux Feb 1, 2021
b6d6f5a
isort pipe files, not sure how they got it wrong
blefaudeux Feb 1, 2021
25ca53d
[feature] Automate wheel builds on new releases (#342)
blefaudeux Feb 1, 2021
bf2a352
[chore] Fix lint errors that broke master (#348)
anj-s Feb 1, 2021
24713da
removing the cpu test, not too interesting anyway
blefaudeux Feb 1, 2021
c25db63
Merge remote-tracking branch 'upstream/master' into oss_elastic_check…
blefaudeux Feb 1, 2021
2dabbcc
removing gloo for now
blefaudeux Feb 1, 2021
a07f3f9
Merge branch 'shardedddp-cpu-testfix' into oss_elastic_checkpoints
blefaudeux Feb 1, 2021
830a664
linting, align with new master
blefaudeux Feb 1, 2021
8fb98ce
linting, align with new master
blefaudeux Feb 1, 2021
8b99fe7
removing a clone which is not needed I believe
blefaudeux Feb 2, 2021
d29a240
adding multiple groups to the ddp parity test
blefaudeux Feb 2, 2021
29bc4bd
Merge remote-tracking branch 'origin/master' into oss_elastic_checkpo…
blefaudeux Feb 2, 2021
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -25,3 +25,4 @@ venv/
ENV/
env.bak/
venv.bak/
.vscode/*
4 changes: 3 additions & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,10 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

## [next rel] - TBD
### Added
Comment thread
blefaudeux marked this conversation as resolved.
- Pytorch compatibility for OSS checkpoints (#310)
- Elastic checkpoints for OSS, world size can vary in between save and loads (#310)
- Tensor views for OSS bucketing, reduced CPU use (#300)
- Bucket calls in ShardedDDP, for faster inter node communications (#327)
- Tensor views for OSS bucketing, reduced CPU use

## [0.1.4] - 2021-01-07
### Fixed
Expand Down
268 changes: 134 additions & 134 deletions fairscale/optim/oss.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,10 @@

from collections import OrderedDict, deque
import copy
import itertools
from itertools import chain
import logging
from math import inf
from typing import TYPE_CHECKING, Any, Callable, Deque, Dict, List, Optional, Tuple, Type, Union
from typing import TYPE_CHECKING, Any, Callable, Deque, Dict, List, Optional, Type, Union

import torch
import torch.distributed as dist
Expand Down Expand Up @@ -80,6 +79,8 @@ def __init__(
self._per_device_params: Dict[torch.device, List[List[Parameter]]] = OrderedDict() # device, rank, params
self._param_rank: Dict[torch.Tensor, int] = {}
self._partition_parameters: List[List[dict]] = []
self._index_to_param: Dict[int, torch.Tensor] = {}
self._param_to_index: Dict[int, int] = {}

# Build the wrapped optimizer, responsible for a shard of the params
self.group = group if group is not None else dist.group.WORLD
Expand Down Expand Up @@ -142,6 +143,24 @@ def partition_parameters(self) -> List[List[dict]]:

return self._partition_parameters

@property
def index_to_param(self) -> Dict[int, torch.Tensor]:
""" Hash table in between parameter indices in the global optimizer scheme, and the actual params
"""
if len(self._index_to_param) == 0:
self._index_to_param = {i: p for i, p in enumerate(chain(*(g["params"] for g in self.param_groups)))}

return self._index_to_param

@property
def param_to_index(self) -> Dict[int, int]:
""" Hash table in between parameter indices in the global optimizer scheme, and the actual params
"""
if len(self._param_to_index) == 0:
self._param_to_index = {id(p): i for i, p in enumerate(chain(*(g["params"] for g in self.param_groups)))}

return self._param_to_index

@property
def per_device_params(self) -> Dict[torch.device, List[List[Parameter]]]:
"""Sorted list of all the params, first per device then per rank.
Expand Down Expand Up @@ -191,7 +210,7 @@ def step(self, closure: Optional[Callable[[], float]] = None, **kwargs: Any) ->
.. note: Any extra parameter is passed to the base optimizer as-is"""

# Sync oss param_groups attributes in case they've been updated by a scheduler.
self._sync_param_groups()
OSS._sync_param_groups(self.param_groups, self.optim.param_groups)

# Run the optimizer step on this shard only:
if closure is not None:
Expand All @@ -203,7 +222,7 @@ def step(self, closure: Optional[Callable[[], float]] = None, **kwargs: Any) ->
self._broadcast_params()

# Sync hypothethical new results from the wrapped optimizer to the exposed param_groups
self._sync_param_groups(local_to_global=True)
OSS._sync_param_groups(self.optim.param_groups, self.param_groups)

return loss

Expand Down Expand Up @@ -237,7 +256,7 @@ def clip_grad_norm(
norm_type = float(norm_type)

# Filter out the grad-less params, concatenate params from all devices
local_params = itertools.chain(
local_params = chain(
*[
list(filter(lambda x: x.grad is not None, device_params[self.rank]))
for device_params in self.per_device_params.values()
Expand Down Expand Up @@ -280,26 +299,13 @@ def clip_grad_norm(
return total_norm

# State dict interfaces
def local_state_dict(self) -> dict:
Comment thread
blefaudeux marked this conversation as resolved.
"""Gets this rank's state_dict.

Returns:
The state of the optimizer as a :class:`dict`.
It contains two entries:

* state - a dict holding current optimization state. Its content
differs between optimizer classes.
* param_groups - a dict containing all parameter groups
"""
return self.optim.state_dict()

def consolidate_state_dict(self, recipient_rank: int = 0) -> None:
"""Update the consolidated state_dict list, one per rank.

.. warning: This needs to be called on all replicas"""

# Sync lr and other attributes in case its been updated
self._sync_param_groups()
OSS._sync_param_groups(self.param_groups, self.optim.param_groups)

if self.rank == recipient_rank:
# Pull the sharded state from all the other replicas
Expand All @@ -310,12 +316,104 @@ def consolidate_state_dict(self, recipient_rank: int = 0) -> None:
# Acknowledge broadcasts, and send this rank's shard when needed
self._broadcast_state_dict()

def local_state_dict(self) -> dict:
""" .. deprecated:: 0.1.5

Returns this rank's state_dict as a :class:`dict` which contains two entries:

* state - a dict holding current optimization state. Its content
differs between optimizer classes.

* param_groups - a dict containing all parameter groups

.. warning: This does not represent the optimizer state dict, only a shard.
"""
return self.optim.state_dict()

def state_dict(self) -> Dict[str, Any]:
"""Return the last known global optimizer state. The returned state is compatible with Pytorch, in that the
sharded properties are not exposed. It contains two entries:

* state - a dict holding current optimization state. Its content
differs between optimizer classes.

* param_groups - a dict containing all parameter groups

.. warning:
If the state has not been consolidated, this returns a shard's worth, not the global state.

.. warning:
Returning the global state is limited to the replica which was responsible for the consolidation.
The state may also not be up to date, depending on when `consolidate_state_dict` was last called.
"""

if len(self._all_states) == 0:
raise RuntimeError(
"Optimizer state has not been consolidated on this rank. \
Please call `consolidate_state_dict()` on all ranks beforehand if you meant to save the global state"
)

# Unify the shard states and the state that pytorch would expect, given the model.
# Indexation needs several redirections, since each shard only knows a limited scope of the model
# - get the pytorch compliant parameter indexing
state_dict = super().state_dict()

# - go through the per-shard states, which are all indexed locally
for rank, s in enumerate(self._all_states):
# -- match the local indexing and the global partition, update the corresponding saved state globally
for local_pg, global_pg in zip(s["param_groups"], self.partition_parameters()[rank]):
local_index_to_param_id = {
i_param: id(global_pg["params"][i]) for i, i_param in enumerate(local_pg["params"])
}

for local_param_index in local_pg["params"]:
# Update the state, if any
if local_param_index in s["state"].keys():
global_id = self.param_to_index[local_index_to_param_id[local_param_index]]
state_dict["state"][global_id] = s["state"][local_param_index]

return state_dict

def load_state_dict(self, state_dict: Dict[str, Any]) -> None:
"""Restore the global parameter groups as well as the shard.

Arguments:
state_dict (dict): optimizer state. Should be an object returned
from a call to :meth:`state_dict`
"""

# NOTE: PyTorch 1.5 does not index linearly but with the id(params) at saving time
# we work around that here by using the fact that the params are ordered as in the param_groups

for i_param, (key, value) in enumerate(state_dict["state"].items()):
param = self.index_to_param[i_param]

# Populate the sharded optimizer state on the fly
if self.param_to_rank[param] != self.rank:
state_dict["state"][key] = None
Comment thread
blefaudeux marked this conversation as resolved.

if key in self.index_to_param:
param = self.index_to_param[i_param]

# Only add this state to the sharded optimizer if it owns this param
for pg in self.optim.param_groups:
if id(param) in [id(p) for p in pg["params"]]:
self.optim.state[param] = recursive_copy_to_device(
value, non_blocking=True, device=param.device
)

super().load_state_dict(state_dict)

# Sync with the optimizer param groups
OSS._sync_param_groups(state_dict["param_groups"], self.param_groups)
OSS._sync_param_groups(self.param_groups, self.optim.param_groups)

def _broadcast_state_dict(self) -> None:
"""Broadcast this rank's state shard, discard others"""

# Default to CPU space to gain some memory headroom
local_cpu_state = recursive_copy_to_device(
self.local_state_dict(), non_blocking=True, device=torch.device("cpu")
self.optim.state_dict(), non_blocking=True, device=torch.device("cpu")
)

# Tensor cannot be really empty, even if its size is meaningless
Expand Down Expand Up @@ -350,7 +448,7 @@ def _collect_sharded_states(self) -> List[Dict[str, Any]]:
if rank == self.rank:
logging.debug("Saving self state")
all_states.append(
recursive_copy_to_device(self.local_state_dict(), non_blocking=True, device=torch.device("cpu"))
recursive_copy_to_device(self.optim.state_dict(), non_blocking=True, device=torch.device("cpu"))
)

# Sync with other replicas
Expand Down Expand Up @@ -378,103 +476,6 @@ def _collect_sharded_states(self) -> List[Dict[str, Any]]:

return all_states

def state_dict(self) -> Dict[str, Any]:
"""Return the last known global optimizer state, which consist of a list of the shards.

.. warning:
If the state has not been consolidated, this returns a shard's worth, not the global state.

.. warning:
Returning the global state is limited to the replica which was responsible for the consolidation.
The state may also not be up to date, depending on when `consolidate_state_dict` was last called.
"""

if len(self._all_states) == 0:
logging.warning("Optimizer state has not been consolidated. Returning the local state")
logging.warning("Please call `consolidate_state_dict()` beforehand if you meant to save the global state")
state_dict = self.local_state_dict()
state_dict["local_state_dict"] = True
return state_dict

# Flatten the param_groups, save the partition which logs the rank <> shard correspondence
partition: List[Tuple[int, int]] = []
param_groups: List[Dict[Any, Any]] = []

start = 0
for i, s in enumerate(self._all_states):
param_groups.extend(s["param_groups"])
end = start + len(s["param_groups"])
partition.append((start, end))
start = end

return {
"state": [s["state"] for s in self._all_states],
"param_groups": param_groups,
"partition": partition,
"local_state_dict": False,
}

@staticmethod
def rank_local_state_dict(rank: int, state_dict: dict) -> dict:
"""Returns the local_state_dict for a given rank.

Arguments:
rank (int): rank to get local_state_dict for
state_dict (dict): global state_dict
"""
param_groups = state_dict["param_groups"][state_dict["partition"][rank][0] : state_dict["partition"][rank][1]]
return {"state": state_dict["state"][rank], "param_groups": param_groups}

def load_local_state_dict(self, state_dict: dict) -> None:
Comment thread
blefaudeux marked this conversation as resolved.
"""Loads this rank's state_dict.

.. warning: This is not meant to load the global state dict.
"""

self.optim.load_state_dict(state_dict)

# Workaround PyTorch bug that casts state (https://github.com/pytorch/pytorch/issues/43706)
# Copied from https://github.com/pytorch/fairseq/blob/v0.9.0/fairseq/optim/fp16_optimizer.py#L251-L268
groups = self.optim.param_groups
saved_groups = state_dict["param_groups"]
id_map = {
old_id: p
for old_id, p in zip(chain(*(g["params"] for g in saved_groups)), chain(*(g["params"] for g in groups)))
}
for k, v in state_dict["state"].items():
if k in id_map:
param = id_map[k]
self.optim.state[param] = recursive_copy_to_device(v, non_blocking=True, device=param.device)

# Restore the global param_groups (the params themselves are already correct)
for global_group, local_group in zip(self.param_groups, groups):
for k, v in local_group.items():
if k != "params":
global_group[k] = v

# Force a re-partitioning, in case the model changed with the new state
self._partition_parameters.clear()
self._per_device_params.clear()
self._param_rank.clear()

# Update the bucketing strategy accordingly
self._setup_bucket_strategy()

def load_state_dict(self, state_dict: Dict[str, Any]) -> None:
"""Restore the global parameter groups as well as the shard.

Arguments:
state_dict (dict): optimizer state. Should be an object returned
from a call to :meth:`state_dict`
"""

# Check whether we got a local or global dict
if "local_state_dict" in state_dict and state_dict["local_state_dict"]:
self.load_local_state_dict(state_dict)
else:
# Dispatch this rank's state dictionary to the wrapped shard optimizer
self.load_local_state_dict(OSS.rank_local_state_dict(self.rank, state_dict))

def add_param_group(self, param_group: dict) -> None:
"""Add a param group to the :class:`Optimizer` s `param_groups`.

Expand All @@ -491,17 +492,23 @@ def add_param_group(self, param_group: dict) -> None:
super().add_param_group(param_group)
if not self.in_super_constructor:
# Force a re-partitioning
self._partition_parameters.clear()
self._per_device_params.clear()
self._param_rank.clear()
self._clear_cache()

# Update the partition
param_groups = self.partition_parameters()[self.rank]
if len(param_groups) == len(self.optim.param_groups) + 1:
self.optim.add_param_group(param_groups[-1])

# Update the bucketing strategy accordingly
self._setup_bucket_strategy()

def _clear_cache(self) -> None:
self._partition_parameters.clear()
self._per_device_params.clear()
self._param_rank.clear()
self._index_to_param.clear()
self._param_to_index.clear()

@staticmethod
def get_global_rank(group: Any, rank: int) -> int:
if group is dist.group.WORLD:
Expand All @@ -510,20 +517,14 @@ def get_global_rank(group: Any, rank: int) -> int:
global_rank = dist.distributed_c10d._get_global_rank(group, rank)
return global_rank

@torch.no_grad()
def _sync_param_groups(self, local_to_global: bool = False) -> None:
"""Sync learning rate and other optimizer attributes (needed to support schedulers).
If the global param groups have been altered, and we want to make sure that the
wrapped optimizer uses the up to date version.
Conversely if the wrapped optimizer has new keys, we expose them through the global param groups"""
@staticmethod
def _sync_param_groups(source: List[Dict[Any, Any]], destination: List[Dict[Any, Any]]) -> None:
"""Sync learning rate and other optimizer attributes (needed to support schedulers)."""

for global_group, local_group in zip(self.param_groups, self.optim.param_groups):
for source_group, destination_group in zip(source, destination):
# Sync everything but the parameters
for k in filter(lambda x: x != "params", local_group.keys()):
if local_to_global:
global_group[k] = local_group[k]
elif k in global_group.keys():
local_group[k] = global_group[k]
for k in filter(lambda x: x != "params", source_group.keys()):
destination_group[k] = source_group[k]

@torch.no_grad()
def _broadcast_params(self) -> None:
Expand Down Expand Up @@ -614,7 +615,6 @@ def _setup_bucket_strategy(self) -> None:

self.buckets[device][dst_rank][offset:offset_next].copy_(param.data.flatten())
param.data = self.buckets[device][dst_rank][offset:offset_next].view_as(param.data)

offset = offset_next
else:
self.should_bucket_param.append(False)
Expand Down
Loading