Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
129 changes: 116 additions & 13 deletions colossalai/zero/gemini/gemini_ddp.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,13 @@
import itertools
from collections import OrderedDict
from functools import partial
from typing import Dict, List, Optional, Union
from typing import Dict, Iterator, List, Optional, Union

import torch
import torch.distributed as dist
import torch.nn as nn

from colossalai.checkpoint_io.utils import calculate_tensor_size
from colossalai.logging import get_dist_logger
from colossalai.nn.parallel.data_parallel import ColoDDP, _cast_float, free_storage
from colossalai.tensor import ProcessGroup as ColoProcessGroup
Expand Down Expand Up @@ -228,6 +229,32 @@ def state_dict(self, destination=None, prefix='', keep_vars=False, only_rank_0:
destination = hook_result
return destination

def _get_chunk_to_save_data(self, chunk: Chunk, only_rank_0: bool) -> Dict:
"""
get gathered chunk content.

Args:
chunk (Chunk): a chunk
only_rank_0 (bool): whether to only save data on rank 0

Returns:
Dict: a dict whose key is param name and value is param with correct payload
"""
# save parameters
chunk_to_save_data = dict()
temp_chunk = get_temp_total_chunk_on_cuda(chunk)
for tensor, tensor_info in chunk.tensors_info.items():
record_tensor = torch.empty([0])
record_flag = (not only_rank_0) | (dist.get_rank(chunk.torch_pg) == 0)
if record_flag:
record_tensor = temp_chunk[tensor_info.offset:tensor_info.end].view(tensor.shape).cpu()

assert tensor not in chunk_to_save_data
chunk_to_save_data[tensor] = record_tensor

del temp_chunk
return chunk_to_save_data

def _get_param_to_save_data(self, param_list: List[torch.nn.Parameter], only_rank_0: bool) -> Dict:
"""
get param content from chunks.
Expand All @@ -243,18 +270,7 @@ def _get_param_to_save_data(self, param_list: List[torch.nn.Parameter], only_ran
param_to_save_data = dict()
chunk_list = self.chunk_manager.get_chunks(param_list)
for chunk in chunk_list:
temp_chunk = get_temp_total_chunk_on_cuda(chunk)

for tensor, tensor_info in chunk.tensors_info.items():
record_tensor = torch.empty([0])
record_flag = (not only_rank_0) | (dist.get_rank(chunk.torch_pg) == 0)
if record_flag:
record_tensor = temp_chunk[tensor_info.offset:tensor_info.end].view(tensor.shape).cpu()

assert tensor not in param_to_save_data
param_to_save_data[tensor] = record_tensor

del temp_chunk
param_to_save_data.update(self._get_chunk_to_save_data(chunk, only_rank_0))
return param_to_save_data

def _save_to_state_dict(self, destination, prefix, keep_vars, only_rank_0=True):
Expand Down Expand Up @@ -554,6 +570,93 @@ def _preprocess_param(self, p: Union[nn.Parameter, ColoParameter, 'LazyTensor'])
p.__class__ = ColoParameter
p.__init__(p, requires_grad=requires_grad)

def state_dict_shard(self,
prefix: str = '',
keep_vars: bool = False,
max_shard_size: int = 1024,
only_rank_0: bool = True) -> Iterator[OrderedDict]:
"""Returns dictionaries containing a whole state of the module one by one. The max size of dictionary shard is specified by ``max_shard_size``.

Both parameters and persistent buffers (e.g. running averages) are included.
Keys are corresponding parameter and buffer names.
Parameters and buffers set to ``None`` are not included.

Args:
prefix (str, optional): the prefix for parameters and buffers used in this
module. Defaults to ''.
keep_vars (bool, optional): whether to keep variables. Defaults to False.
max_shard_size (int, optional): max size of state dict shard (in MB). Defaults to 1024.
only_rank_0 (bool, optional): only get data on rank0. Defaults to True.


Yields:
Iterator[OrderedDict]: A generator of state dict shard
"""
sharder = _StateDictSharder(max_shard_size)

# get the mapping between copies and fp16 parameters
fp16_to_fp32 = dict()
for p, fp32_p in zip(self.fp16_params, self.fp32_params):
fp16_to_fp32[p] = fp32_p

# key is fp32 param, and value is gathered param on CPU
gathered_param_buffer = dict()
for name, param in self.name2param.items():
if param is not None:
if is_ddp_ignored(param):
# deal with ddp ignored parameters
gathered_param = param if keep_vars else param.detach()
else:
fp32_param = fp16_to_fp32[param]
if fp32_param not in gathered_param_buffer:
chunk = self.chunk_manager.get_chunk(fp32_param)
gathered_param_buffer.update(self._get_chunk_to_save_data(chunk, only_rank_0))
gathered_param = gathered_param_buffer.pop(fp32_param)

block = sharder.append(prefix + name, gathered_param)
if block is not None:
yield block

del fp16_to_fp32
del gathered_param_buffer

# save all buffers
for name, buf in self.named_buffers():
if buf is not None and name not in self._non_persistent_buffers_set:
buffer = buf if keep_vars else buf.detach()
block = sharder.append(prefix + name, buffer)
if block is not None:
yield block
# save extra states
extra_state_key = prefix + _EXTRA_STATE_KEY_SUFFIX
if getattr(self.__class__, "get_extra_state",
torch.nn.Module.get_extra_state) is not torch.nn.Module.get_extra_state:
extra_state = self.get_extra_state()
block = sharder.append(extra_state_key, extra_state)
if block is not None:
yield block

yield sharder.current_block


class _StateDictSharder:

def __init__(self, max_shard_size: int) -> None:
self.max_shard_size = max_shard_size
self.current_block = OrderedDict()
self.current_block_size = 0

def append(self, name: str, tensor: torch.Tensor) -> Optional[OrderedDict]:
tensor_size = calculate_tensor_size(tensor)
ret_block = None
if self.current_block_size + tensor_size > self.max_shard_size:
ret_block = self.current_block
self.current_block = OrderedDict()
self.current_block_size = 0
self.current_block[name] = tensor
self.current_block_size += tensor_size
return ret_block


class GeminiDDP(ZeroDDP):

Expand Down
56 changes: 56 additions & 0 deletions tests/test_zero/test_gemini/test_zeroddp_state_dict_shard.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
import pytest
import torch
from torch.testing import assert_close

import colossalai
from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn
from colossalai.utils.cuda import get_current_device
from colossalai.zero import ColoInitContext, ZeroDDP
from colossalai.zero.gemini.chunk import ChunkManager, search_chunk_configuration
from colossalai.zero.gemini.gemini_mgr import GeminiManager
from tests.components_to_test.registry import non_distributed_component_funcs


@parameterize('placement_policy', ['cuda', 'cpu'])
@parameterize('model_name', ['gpt2', 'bert'])
def exam_state_dict(placement_policy, model_name: str):
get_components_func = non_distributed_component_funcs.get_callable(model_name)
model_builder, train_dataloader, test_dataloader, optimizer_class, criterion = get_components_func()

with ColoInitContext(device=get_current_device()):
model = model_builder()

model_size = sum(p.numel() * p.element_size() for p in model.parameters()) / 1024**2

config_dict, *_ = search_chunk_configuration(model, search_range_mb=1, search_interval_byte=100)
chunk_manager = ChunkManager(config_dict)
gemini_manager = GeminiManager(placement_policy, chunk_manager)
model = ZeroDDP(model, gemini_manager)
model.train()

zero_dict = model.state_dict(only_rank_0=False)
accumulated_keys = set()
# ensure number of shards > 1
for shard in model.state_dict_shard(max_shard_size=(model_size / 3), only_rank_0=False):
for key, value in shard.items():
assert key not in accumulated_keys, f"key `{key}` is duplicated."
accumulated_keys.add(key)
assert key in zero_dict, f"{key} not in ZeRO dictionary."
assert torch.equal(value, zero_dict[key]), f"{key} not equal."


def run_dist(rank, world_size, port):
config = {}
colossalai.launch(config=config, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
exam_state_dict()


@pytest.mark.dist
@pytest.mark.parametrize('world_size', [1, 4])
@rerun_if_address_is_in_use()
def test_zero_ddp_state_dict_shard(world_size):
spawn(run_dist, world_size)


if __name__ == '__main__':
test_zero_ddp_state_dict_shard(1)