From 72687ce2b6154d52d8009c7d9a502d40c365c134 Mon Sep 17 00:00:00 2001 From: ver217 Date: Mon, 17 Apr 2023 14:48:45 +0800 Subject: [PATCH 1/5] [gemini] support state dict shard --- colossalai/zero/gemini/gemini_ddp.py | 114 +++++++++++++++++++++++---- 1 file changed, 97 insertions(+), 17 deletions(-) diff --git a/colossalai/zero/gemini/gemini_ddp.py b/colossalai/zero/gemini/gemini_ddp.py index 2e35be0661e9..4a495d4cf8f3 100644 --- a/colossalai/zero/gemini/gemini_ddp.py +++ b/colossalai/zero/gemini/gemini_ddp.py @@ -1,12 +1,13 @@ import itertools from collections import OrderedDict from functools import partial -from typing import Dict, List, Optional, Union +from typing import Dict, Iterator, List, Optional, Union import torch import torch.distributed as dist import torch.nn as nn +from colossalai.checkpoint_io.utils import calculate_tensor_size from colossalai.logging import get_dist_logger from colossalai.nn.parallel.data_parallel import ColoDDP, _cast_float, free_storage from colossalai.tensor import ProcessGroup as ColoProcessGroup @@ -228,6 +229,32 @@ def state_dict(self, destination=None, prefix='', keep_vars=False, only_rank_0: destination = hook_result return destination + def _get_chunk_to_save_data(self, chunk: Chunk, only_rank_0: bool) -> Dict: + """ + get gathered chunk content. + + Args: + chunk (Chunk): a chunk + only_rank_0 (bool): whether to only save data on rank 0 + + Returns: + Dict: a dict whose key is param name and value is param with correct payload + """ + # save parameters + chunk_to_save_data = dict() + temp_chunk = get_temp_total_chunk_on_cuda(chunk) + for tensor, tensor_info in chunk.tensors_info.items(): + record_tensor = torch.empty([0]) + record_flag = (not only_rank_0) | (dist.get_rank(chunk.torch_pg) == 0) + if record_flag: + record_tensor = temp_chunk[tensor_info.offset:tensor_info.end].view(tensor.shape).cpu() + + assert tensor not in chunk_to_save_data + chunk_to_save_data[tensor] = record_tensor + + del temp_chunk + return chunk_to_save_data + def _get_param_to_save_data(self, param_list: List[torch.nn.Parameter], only_rank_0: bool) -> Dict: """ get param content from chunks. @@ -243,18 +270,7 @@ def _get_param_to_save_data(self, param_list: List[torch.nn.Parameter], only_ran param_to_save_data = dict() chunk_list = self.chunk_manager.get_chunks(param_list) for chunk in chunk_list: - temp_chunk = get_temp_total_chunk_on_cuda(chunk) - - for tensor, tensor_info in chunk.tensors_info.items(): - record_tensor = torch.empty([0]) - record_flag = (not only_rank_0) | (dist.get_rank(chunk.torch_pg) == 0) - if record_flag: - record_tensor = temp_chunk[tensor_info.offset:tensor_info.end].view(tensor.shape).cpu() - - assert tensor not in param_to_save_data - param_to_save_data[tensor] = record_tensor - - del temp_chunk + param_to_save_data.update(self._get_chunk_to_save_data(chunk, only_rank_0)) return param_to_save_data def _save_to_state_dict(self, destination, prefix, keep_vars, only_rank_0=True): @@ -554,6 +570,72 @@ def _preprocess_param(self, p: Union[nn.Parameter, ColoParameter, 'LazyTensor']) p.__class__ = ColoParameter p.__init__(p, requires_grad=requires_grad) + def state_dict_shard(self, + prefix: str = '', + keep_vars: bool = False, + max_shard_size: int = 1024, + only_rank_0: bool = True) -> Iterator[OrderedDict]: + # get the mapping between copies and fp16 parameters + fp16_to_fp32 = dict() + for p, fp32_p in zip(self.fp16_params, self.fp32_params): + fp16_to_fp32[p] = fp32_p + + # key is fp32 param, and value is gathered param on CPU + gathered_param_buffer = dict() + + current_block = OrderedDict() + current_block_size = 0 + + def _save_to_block(name, tensor): + nonlocal current_block, current_block_size + tensor_size = calculate_tensor_size(tensor) + returned_block = None + if current_block_size + tensor_size > max_shard_size: + returned_block = current_block + current_block = OrderedDict() + current_block_size = 0 + current_block[name] = tensor + current_block_size += tensor_size + return returned_block + + for name, param in self.name2param.items(): + if param is not None: + if is_ddp_ignored(param): + # deal with ddp ignored parameters + gathered_param = param if keep_vars else param.detach() + else: + fp32_param = fp16_to_fp32[param] + if fp32_param not in gathered_param_buffer: + chunk = self.chunk_manager.get_chunk(fp32_param) + gathered_param_buffer.update(self._get_chunk_to_save_data(chunk, only_rank_0)) + gathered_param = gathered_param_buffer.pop(fp32_param) + + tmp_block = _save_to_block(prefix + name, gathered_param) + if tmp_block is not None: + yield tmp_block + + assert len(gathered_param_buffer) == 0, "gathered_param_buffer should be empty after state_dict_shard" + del fp16_to_fp32 + del gathered_param_buffer + + # save all buffers + for name, buf in self.named_buffers(): + if buf is not None and name not in self._non_persistent_buffers_set: + buffer = buf if keep_vars else buf.detach() + tmp_block = _save_to_block(prefix + name, buffer) + if tmp_block is not None: + yield tmp_block + # save extra states + extra_state_key = prefix + _EXTRA_STATE_KEY_SUFFIX + if getattr(self.__class__, "get_extra_state", + torch.nn.Module.get_extra_state) is not torch.nn.Module.get_extra_state: + extra_state = self.get_extra_state() + tmp_block = _save_to_block(extra_state_key, extra_state) + if tmp_block is not None: + yield tmp_block + + yield current_block + class GeminiDDP(ZeroDDP): @@ -567,8 +649,7 @@ def __init__(self, search_range_mb: int = 32, hidden_dim: Optional[int] = None, min_chunk_size_mb: float = 32, - memstats: Optional[MemStats] = None, - verbose: bool = False) -> None: + memstats: Optional[MemStats] = None) -> None: """ A torch.Module warpper using ZeRO-DP and Genimi. ZeRO is for parallel. Gemini is for memory management. @@ -605,7 +686,6 @@ def __init__(self, hidden_dim=hidden_dim, search_range_mb=search_range_mb, min_chunk_size_mb=min_chunk_size_mb, - strict_ddp_flag=strict_ddp_mode, - verbose=verbose) + strict_ddp_flag=strict_ddp_mode) gemini_manager = GeminiManager(placement_policy, chunk_manager, memstats) super().__init__(module, gemini_manager, pin_memory, force_outputs_fp32, strict_ddp_mode) From c18757cfd7034e19af92c072c5529317c905eb70 Mon Sep 17 00:00:00 2001 From: ver217 Date: Mon, 17 Apr 2023 15:12:38 +0800 Subject: [PATCH 2/5] [gemini] add test state dict shard --- colossalai/zero/gemini/gemini_ddp.py | 1 - .../test_zeroddp_state_dict_shard.py | 56 +++++++++++++++++++ 2 files changed, 56 insertions(+), 1 deletion(-) create mode 100644 tests/test_zero/test_gemini/test_zeroddp_state_dict_shard.py diff --git a/colossalai/zero/gemini/gemini_ddp.py b/colossalai/zero/gemini/gemini_ddp.py index 4a495d4cf8f3..467beff0c888 100644 --- a/colossalai/zero/gemini/gemini_ddp.py +++ b/colossalai/zero/gemini/gemini_ddp.py @@ -614,7 +614,6 @@ def _save_to_block(name, tensor): if tmp_block is not None: yield tmp_block - assert len(gathered_param_buffer) == 0, "gathered_param_buffer should be empty after state_dict_shard" del fp16_to_fp32 del gathered_param_buffer diff --git a/tests/test_zero/test_gemini/test_zeroddp_state_dict_shard.py b/tests/test_zero/test_gemini/test_zeroddp_state_dict_shard.py new file mode 100644 index 000000000000..96c26a1de4df --- /dev/null +++ b/tests/test_zero/test_gemini/test_zeroddp_state_dict_shard.py @@ -0,0 +1,56 @@ +import pytest +import torch +from torch.testing import assert_close + +import colossalai +from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn +from colossalai.utils.cuda import get_current_device +from colossalai.zero import ColoInitContext, ZeroDDP +from colossalai.zero.gemini.chunk import ChunkManager, search_chunk_configuration +from colossalai.zero.gemini.gemini_mgr import GeminiManager +from tests.components_to_test.registry import non_distributed_component_funcs + + +@parameterize('placement_policy', ['cuda', 'cpu']) +@parameterize('model_name', ['gpt2', 'bert']) +def exam_state_dict(placement_policy, model_name: str): + get_components_func = non_distributed_component_funcs.get_callable(model_name) + model_builder, train_dataloader, test_dataloader, optimizer_class, criterion = get_components_func() + + with ColoInitContext(device=get_current_device()): + model = model_builder() + + model_size = sum(p.numel() * p.element_size() for p in model.parameters()) / 1024**2 + + config_dict, *_ = search_chunk_configuration(model, search_range_mb=1, search_interval_byte=100) + chunk_manager = ChunkManager(config_dict) + gemini_manager = GeminiManager(placement_policy, chunk_manager) + model = ZeroDDP(model, gemini_manager) + model.train() + + zero_dict = model.state_dict(only_rank_0=False) + accumulated_keys = set() + # ensure number of shards > 1 + for shard in model.state_dict_shard(max_shard_size=(model_size / 3), only_rank_0=False): + for key, value in shard.items(): + assert key not in accumulated_keys, f"key `{key}` is duplicated." + accumulated_keys.add(key) + assert key in zero_dict, f"{key} not in ZeRO dictionary." + assert torch.equal(value, zero_dict[key]), f"{key} not equal." + + +def run_dist(rank, world_size, port): + config = {} + colossalai.launch(config=config, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') + exam_state_dict() + + +@pytest.mark.dist +@pytest.mark.parametrize('world_size', [1, 4]) +@rerun_if_address_is_in_use() +def test_zero_ddp_state_dict_shard(world_size): + spawn(run_dist, world_size) + + +if __name__ == '__main__': + test_zero_ddp_state_dict_shard(1) From 6cfb96b6d59ce6bb3aee9e9b593a12e5b8b679fd Mon Sep 17 00:00:00 2001 From: ver217 Date: Mon, 17 Apr 2023 15:17:26 +0800 Subject: [PATCH 3/5] [gemini] polish docstr --- colossalai/zero/gemini/gemini_ddp.py | 17 +++++++++++++++++ 1 file changed, 17 insertions(+) diff --git a/colossalai/zero/gemini/gemini_ddp.py b/colossalai/zero/gemini/gemini_ddp.py index 467beff0c888..02ac3ce440ea 100644 --- a/colossalai/zero/gemini/gemini_ddp.py +++ b/colossalai/zero/gemini/gemini_ddp.py @@ -575,6 +575,23 @@ def state_dict_shard(self, keep_vars: bool = False, max_shard_size: int = 1024, only_rank_0: bool = True) -> Iterator[OrderedDict]: + """Returns dictionaries containing a whole state of the module one by one. The max size of dictionary shard is specified by ``max_shard_size``. + + Both parameters and persistent buffers (e.g. running averages) are included. + Keys are corresponding parameter and buffer names. + Parameters and buffers set to ``None`` are not included. + + Args: + prefix (str, optional): the prefix for parameters and buffers used in this + module. Defaults to ''. + keep_vars (bool, optional): whether to keep variables. Defaults to False. + max_shard_size (int, optional): max size of state dict shard (in MB). Defaults to 1024. + only_rank_0 (bool, optional): only get data on rank0. Defaults to True. + + + Yields: + Iterator[OrderedDict]: A generator of state dict shard + """ # get the mapping between copies and fp16 parameters fp16_to_fp32 = dict() for p, fp32_p in zip(self.fp16_params, self.fp32_params): From c42481f2794e1e40cc9df6d66efeff257681c64c Mon Sep 17 00:00:00 2001 From: ver217 Date: Mon, 17 Apr 2023 15:19:29 +0800 Subject: [PATCH 4/5] [gemini] fix merge --- colossalai/zero/gemini/gemini_ddp.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/colossalai/zero/gemini/gemini_ddp.py b/colossalai/zero/gemini/gemini_ddp.py index 02ac3ce440ea..8bc2c23aa9e1 100644 --- a/colossalai/zero/gemini/gemini_ddp.py +++ b/colossalai/zero/gemini/gemini_ddp.py @@ -665,7 +665,8 @@ def __init__(self, search_range_mb: int = 32, hidden_dim: Optional[int] = None, min_chunk_size_mb: float = 32, - memstats: Optional[MemStats] = None) -> None: + memstats: Optional[MemStats] = None, + verbose: bool = False) -> None: """ A torch.Module warpper using ZeRO-DP and Genimi. ZeRO is for parallel. Gemini is for memory management. @@ -702,6 +703,7 @@ def __init__(self, hidden_dim=hidden_dim, search_range_mb=search_range_mb, min_chunk_size_mb=min_chunk_size_mb, - strict_ddp_flag=strict_ddp_mode) + strict_ddp_flag=strict_ddp_mode, + verbose=verbose) gemini_manager = GeminiManager(placement_policy, chunk_manager, memstats) super().__init__(module, gemini_manager, pin_memory, force_outputs_fp32, strict_ddp_mode) From 67b6a2561a24eaf8252c1910f89d57fa6583b40c Mon Sep 17 00:00:00 2001 From: ver217 Date: Mon, 17 Apr 2023 15:49:29 +0800 Subject: [PATCH 5/5] [gemini] polish code --- colossalai/zero/gemini/gemini_ddp.py | 59 +++++++++++++++------------- 1 file changed, 32 insertions(+), 27 deletions(-) diff --git a/colossalai/zero/gemini/gemini_ddp.py b/colossalai/zero/gemini/gemini_ddp.py index 8bc2c23aa9e1..9a193310bab1 100644 --- a/colossalai/zero/gemini/gemini_ddp.py +++ b/colossalai/zero/gemini/gemini_ddp.py @@ -592,6 +592,8 @@ def state_dict_shard(self, Yields: Iterator[OrderedDict]: A generator of state dict shard """ + sharder = _StateDictSharder(max_shard_size) + # get the mapping between copies and fp16 parameters fp16_to_fp32 = dict() for p, fp32_p in zip(self.fp16_params, self.fp32_params): @@ -599,22 +601,6 @@ def state_dict_shard(self, # key is fp32 param, and value is gathered param on CPU gathered_param_buffer = dict() - - current_block = OrderedDict() - current_block_size = 0 - - def _save_to_block(name, tensor): - nonlocal current_block, current_block_size - tensor_size = calculate_tensor_size(tensor) - returned_block = None - if current_block_size + tensor_size > max_shard_size: - returned_block = current_block - current_block = OrderedDict() - current_block_size = 0 - current_block[name] = tensor - current_block_size += tensor_size - return returned_block - for name, param in self.name2param.items(): if param is not None: if is_ddp_ignored(param): @@ -627,9 +613,9 @@ def _save_to_block(name, tensor): gathered_param_buffer.update(self._get_chunk_to_save_data(chunk, only_rank_0)) gathered_param = gathered_param_buffer.pop(fp32_param) - tmp_block = _save_to_block(prefix + name, gathered_param) - if tmp_block is not None: - yield tmp_block + block = sharder.append(prefix + name, gathered_param) + if block is not None: + yield block del fp16_to_fp32 del gathered_param_buffer @@ -638,19 +624,38 @@ def _save_to_block(name, tensor): for name, buf in self.named_buffers(): if buf is not None and name not in self._non_persistent_buffers_set: buffer = buf if keep_vars else buf.detach() - tmp_block = _save_to_block(prefix + name, buffer) - if tmp_block is not None: - yield tmp_block + block = sharder.append(prefix + name, buffer) + if block is not None: + yield block # save extra states extra_state_key = prefix + _EXTRA_STATE_KEY_SUFFIX if getattr(self.__class__, "get_extra_state", torch.nn.Module.get_extra_state) is not torch.nn.Module.get_extra_state: extra_state = self.get_extra_state() - tmp_block = _save_to_block(extra_state_key, extra_state) - if tmp_block is not None: - yield tmp_block - - yield current_block + block = sharder.append(extra_state_key, extra_state) + if block is not None: + yield block + + yield sharder.current_block + + +class _StateDictSharder: + + def __init__(self, max_shard_size: int) -> None: + self.max_shard_size = max_shard_size + self.current_block = OrderedDict() + self.current_block_size = 0 + + def append(self, name: str, tensor: torch.Tensor) -> Optional[OrderedDict]: + tensor_size = calculate_tensor_size(tensor) + ret_block = None + if self.current_block_size + tensor_size > self.max_shard_size: + ret_block = self.current_block + self.current_block = OrderedDict() + self.current_block_size = 0 + self.current_block[name] = tensor + self.current_block_size += tensor_size + return ret_block class GeminiDDP(ZeroDDP):