Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
159 changes: 112 additions & 47 deletions colossalai/booster/plugin/low_level_zero_plugin.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,8 @@
import logging
import os
import warnings
from functools import partial
from pathlib import Path
from typing import Callable, Iterator, List, Optional, Tuple, Union

import torch
Expand All @@ -10,10 +13,16 @@
from torch.utils._pytree import tree_map
from torch.utils.data import DataLoader

from colossalai.checkpoint_io import CheckpointIO, GeneralCheckpointIO
from colossalai.checkpoint_io import CheckpointIndexFile, CheckpointIO
from colossalai.checkpoint_io.utils import (
get_optimizer_base_filenames,
get_shard_filename,
save_param_groups,
save_state_dict,
)
from colossalai.interface import ModelWrapper, OptimizerWrapper
from colossalai.utils import get_current_device
from colossalai.zero import zero_model_wrapper, zero_optim_wrapper
from colossalai.zero import LowLevelZeroOptimizer, zero_model_wrapper, zero_optim_wrapper

from .dp_plugin_base import DPPluginBase
from .torch_ddp_plugin import TorchDDPCheckpointIO
Expand All @@ -32,21 +41,104 @@ def _convert_floating_point(x, dtype: torch.dtype = torch.float16):

class LowLevelZeroCheckpointIO(TorchDDPCheckpointIO):

def save_unsharded_optimizer(self, optimizer: Optimizer, checkpoint: str, gather_dtensor: bool):
"""
Save optimizer to checkpoint but only on master process.
def save_unsharded_optimizer(self, optimizer: OptimizerWrapper, checkpoint: str, gather_dtensor: bool = False):
"""Save optimizer to checkpoint but only on master process.

Args:
optimizer (OptimizerWrapper): Optimizer to save state_dict
checkpoint (str): Path to save checkpoint
gather_dtensor (bool): Whether to gather_dtensor, not used
"""
# TODO(ver217): optimizer state dict is sharded, and cannot get full state dict now
warnings.warn(
'LowLevelZeroPlugin does not support save full optimizer checkpoint now. Save it on every process.')
checkpoint = f'{checkpoint}.rank{self.coordinator.rank}'
GeneralCheckpointIO.save_unsharded_optimizer(self, optimizer, checkpoint, gather_dtensor)

def load_optimizer(self, optimizer: Optimizer, checkpoint: str):
warnings.warn(
'LowLevelZeroPlugin can only load optimizer checkpoint saved by itself with the same number of processes.')
checkpoint = f'{checkpoint}.rank{self.coordinator.rank}'
super().load_optimizer(optimizer, checkpoint)
# the `state_dict` in LowLevelZeroOptimizer has communication
# if only the master rank collect state_dict and save,
# the communication on each rank would not match
state_dict = optimizer.state_dict()
if self.coordinator.is_master():
save_state_dict(state_dict, checkpoint, use_safetensors=False)

def save_sharded_optimizer(self,
optimizer: OptimizerWrapper,
checkpoint: str,
gather_dtensor: bool = False,
prefix: str = None,
size_per_shard: int = 1024):
"""
Save sharded Zero-optimizer checkpoint under the given checkpointing path.
The following files will be created under the path:
- An index file (pytorch_optim.bin.index.json) containing a map between optimizer states and file names
- A group file (pytorch_optim_group.bin) recording information of param_groups
- Multiple files (pytorch_optim-000XX.bin) that store state tensors of optimizer in a sharding way

Args:
optimizer (OptimizerWrapper): Optimizer to save sharded state_dict
checkpoint (str): Path to save optimizer state_dict
gather_dtensor (bool): Whether to gather_dtensor, not used
prefix (str): Perfix of file to save
size_per_shard (int): Max file size of each file that store state tensors
"""
if os.path.isfile(checkpoint):
logging.error(f"Provided path ({checkpoint}) should be a directory, not a file")
return

Path(checkpoint).mkdir(parents=True, exist_ok=True)

# state_dict only provide only 'param_groups'
state_dict = optimizer.optim.state_dict()
# state shard would be handled by the low-level zero optimizer
sharded_state = optimizer.state_dict_shard(max_shard_size=size_per_shard)

# Preparing file paths and index file.
states_name, save_index_file, param_group_file = get_optimizer_base_filenames(prefix)
index_file = CheckpointIndexFile(checkpoint)

# Store the information of param groups to param_group_file.
index_file.append_meta_data("param_groups", param_group_file)
group_file_path = os.path.join(checkpoint, param_group_file)
save_param_groups(state_dict, group_file_path)

# Save shards of optimizer states.
total_size = 0
for idx, shard_pair in enumerate(sharded_state):
shard, current_size = shard_pair
shard_file = get_shard_filename(states_name, idx)
total_size = total_size + current_size
for param_id in shard.keys():
index_file.append_weight_map(str(param_id), shard_file)

checkpoint_file_path = os.path.join(checkpoint, shard_file)
if self.coordinator.is_master():
save_state_dict(shard, checkpoint_file_path, use_safetensors=False)

# Wrap up index file.
index_file.append_meta_data("total_size", total_size)
if self.coordinator.is_master():
index_file.write_index_file(save_index_file)
logging.info(f"The optimizer is going to be split to checkpoint shards. "
f"You can find where each parameters has been saved in the "
f"index located at {save_index_file}.")

def load_sharded_optimizer(self, optimizer: OptimizerWrapper, index_file_path: str, prefix: str):
"""Load sharded optimizer with the given path to index file.

Args:
optimizer (OptimizerWrapper): Optimizer to load state_dict
index_file_path (str): Path to the index file
prefix (str): Not used.
"""
super().load_sharded_optimizer(optimizer, index_file_path, prefix)
current_rank_state_dict = optimizer.optim.state_dict()['state']
for param_idx, state in current_rank_state_dict.items():
for k, v in state.items():
if isinstance(v, torch.Tensor) and k != 'step':
padding_size = (self.coordinator.world_size -
v.numel() % self.coordinator.world_size) % self.coordinator.world_size
with torch.no_grad():
v = v.flatten()
if padding_size > 0:
v = torch.nn.functional.pad(v, [0, padding_size])
v_list = v.split(v.numel() // self.coordinator.world_size)
current_rank_state_dict[param_idx][k] = v_list[self.coordinator.rank].detach()


class LowLevelZeroModel(ModelWrapper):
Expand Down Expand Up @@ -74,36 +166,6 @@ def forward(self, *args, **kwargs):
return super().forward(*args, **kwargs)


class LowLevelZeroOptimizer(OptimizerWrapper):

def __init__(self,
module: nn.Module,
optimizer: Optimizer,
zero_optim_config: dict,
optim_kwargs: dict,
verbose: bool = False) -> None:
optimizer = zero_optim_wrapper(module,
optimizer,
optim_config=zero_optim_config,
**optim_kwargs,
verbose=verbose)
super().__init__(optimizer)

def backward(self, loss: Tensor, *args, **kwargs):
self.optim.backward(loss)

def clip_grad_by_norm(self,
max_norm: Union[float, int],
norm_type: Union[float, int] = 2,
error_if_nonfinite: bool = False,
*args,
**kwargs) -> Tensor:
warnings.warn(f'LowLevelZero controls grad clipping by itself, so you should not use clip_grad_by_norm')

def clip_grad_by_value(self, clip_value: float, *args, **kwargs) -> None:
raise NotImplementedError('LowLevelZero does not support clip_grad_by_value')


class LowLevelZeroPlugin(DPPluginBase):
"""
Plugin for low level zero.
Expand Down Expand Up @@ -211,8 +273,11 @@ def configure(

if optimizer is not None and \
not isinstance(optimizer, OptimizerWrapper):
optimizer = LowLevelZeroOptimizer(model.unwrap(), optimizer, self.zero_optim_config, self.optim_kwargs,
self.verbose)
optimizer = zero_optim_wrapper(model.unwrap(),
optimizer,
optim_config=self.zero_optim_config,
**self.optim_kwargs,
verbose=self.verbose)

return model, optimizer, criterion, dataloader, lr_scheduler

Expand Down
77 changes: 63 additions & 14 deletions colossalai/zero/low_level/low_level_optim.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import copy
from contextlib import contextmanager
from functools import partial
from typing import Optional
from typing import Dict, Iterator, Optional, Tuple

import torch
import torch.distributed as dist
Expand Down Expand Up @@ -447,18 +447,23 @@ def _unscale_and_clip_grads(self, grad_groups_flat, total_norm):
# Gradient Synchronization #
############################

# this method is used to sync gradient manually
def sync_grad(self):
for group_id in range(self.num_param_groups):
param_group = self._working_param_groups[group_id]
for param in param_group:
if param.requires_grad and param.grad is not None:
self._add_to_bucket(param, group_id)

self._run_reduction()

def _reduce_grad(self, partition_grad):
# if not overlapping communication (no reduction hook is attached) when zero1
# we need to manually reduce these gradients
if not partition_grad and not self._overlap_communication:
for group_id in range(len(self._working_param_groups)):
param_group = self._working_param_groups[group_id]
for param in param_group:
if param.grad is not None:
self._add_to_bucket(param, group_id)

# run reduction
self._run_reduction()
self.sync_grad()
else:
self._run_reduction()

# this context comes from pytorch DDP
@contextmanager
Expand All @@ -473,7 +478,8 @@ def no_sync(self):
##############
# State Dict #
##############
def _pack_state(self, state: dict) -> dict:

def _pack_state(self, state: Dict) -> Dict:
# comes from pytorch optimizer.state_dict()
param_mappings = {}
start_index = 0
Expand All @@ -487,17 +493,17 @@ def pack_group(group):
start_index += len(packed['params'])
return packed

param_groups = [pack_group(g) for g in self.param_groups]
param_groups = [pack_group(g) for g in self.optim.param_groups]
# Remap state to use order indices as keys
packed_state = {(param_mappings[id(k)] if isinstance(k, torch.Tensor) else k): v for k, v in state.items()}

return {'state': packed_state, 'param_groups': param_groups}

def state_dict(self) -> dict:
def state_dict(self) -> Dict:
"""Return a state_dict same with DDP

Returns:
dict: the pytorch form state_dict
Dict: the pytorch form state_dict
"""
zero_state = dict()
for param, state in self.optim.state.items():
Expand All @@ -514,7 +520,7 @@ def state_dict(self) -> dict:

return states_dict

def load_state_dict(self, state_dict: dict):
def load_state_dict(self, state_dict: Dict):
"""Load state dict, requires the state_dict be the pytorch form

Args:
Expand All @@ -534,3 +540,46 @@ def load_state_dict(self, state_dict: dict):

self.optim.load_state_dict(zero_state_dict)
zero_state_dict = dict()

def state_dict_shard(self, max_shard_size: int = 1024) -> Iterator[Tuple[Dict, int]]:
"""Returns dictionaries containing a whole state of the module one by one. The max size of dictionary shard is specified by ``max_shard_size``.
Only include the 'state' in state_dict.

Args:
max_shard_size (int, optional): max size of state shard (in MB). Defaults to 1024.

Yields:
Iterator[OrderedDict]: A generator of state dict shard
"""
ret_block = dict()
ret_block_size = 0

local_states = self.optim.state_dict()['state']
for param_idx, states in local_states.items():
current_block_size = 0
current_block = copy.deepcopy(states)

# find the working param of current param_id
for group_id, pg in self._master_param_groups_of_current_rank.items():
if (group_id + 1) * len(pg) < param_idx:
continue
master_param = pg[param_idx - (group_id) * len(pg)]
working_param = self._param_store.master_to_working_param[id(master_param)]

for k, v in states.items():
if isinstance(v, torch.Tensor) and k != 'step':
state_tensor = [torch.zeros_like(v) for _ in range(self._world_size)]
dist.all_gather(state_tensor, v, group=self.dp_pg)
state_tensor = torch.stack(state_tensor).view(-1)[:working_param.numel()].reshape_as(working_param)
current_block_size += state_tensor.numel()
current_block[k] = state_tensor

if ret_block_size + current_block_size > max_shard_size and len(ret_block) > 0:
yield ret_block, ret_block_size
ret_block = dict()
ret_block_size = 0

ret_block[param_idx] = current_block
ret_block_size += current_block_size

yield ret_block, ret_block_size
54 changes: 54 additions & 0 deletions colossalai/zero/low_level/readme.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
# Low Level ZeRO
>Low Level ZeRO == ZeRO-DP stage 1 and 2, we would denote it as ZeRO.

## Design:
### Notion
`p32` denotes the param copy in the optimizer
`p` denotes the model param
`g` denotes the gradient

### INIT
In low level zero(1, 2), `p32` is split. Different from the previous implement, we split each `p32` evenly by world_size. Thus, rank0 got a param list as `[p00, p10]`, rank1 got a param list as `[p-01, p-11]`, etc.
<img width="840" alt="image" src="https://github.com/hpcaitech/ColossalAI/assets/74758262/f7758d7d-c5e5-44a4-a121-3aba8b05c904">

For the detailed implementation, we first pad `p` for it can be split by world_size if needed. Then, we would view it to the shape `[world_size, -1]`, and each rank got its own part `p32` by cloning.

### BWD
To leverage the communication, a gradient would be added to a bucket first. When the bucket is full, each `g` in it would be reshaped as `[world_size, -1]`. And the `[local_rank]` parts would be united.
The data structure looks like this:
```
{
0: [g-00, g-10],
1: [g-01, g-11],
2: [g-02, g-12]
}
```
After that, the gradients would be flattened by rank, and the data structure looks like this:
```
# g-0 means flatten([g-00, g-10])
{
0: [g-0],
1: [g-1],
2: [g-2]
}
```
For zero1, we iterate the dictionary and do `all_reduce`. For zero2, we can just do `reduce-scatter`.

### Optim
For each rank gets its own `p32` and the counterpart `g`, it is quite easy to do `optim.step()`.

However, we have to consider a situation of layer drop, for instance:
```
class MlpModel(nn.Module):
def __init__(self):
super(MlpModel, self).__init__()
self.linear1 = nn.Linear(128, 256)
self.drop_linear = nn.Linear(256, 256)
self.linear2 = nn.Linear(256, 512)

def forward(self, x):
x = self.linear1(x)
x = self.linear2(x)
return x
```
And the solution is to build a mapping of `p32`, `p`, and `g`. Before `optim.step()`, we collect `p` which `requires_grad=True` and `p.grad != None` as a real working param. And select the counterpart `p32` and `g`.
Loading