From f0636893d8299837b147d3bcb55c5e7b820d1b9d Mon Sep 17 00:00:00 2001 From: eric8607242 Date: Fri, 6 Jan 2023 12:12:58 +0800 Subject: [PATCH 1/3] Fix state_dict output for ZeroDDP duplicated parameters --- colossalai/nn/parallel/data_parallel.py | 18 +++++++++++++++--- 1 file changed, 15 insertions(+), 3 deletions(-) diff --git a/colossalai/nn/parallel/data_parallel.py b/colossalai/nn/parallel/data_parallel.py index cbef6f532412..92eb9af998d1 100644 --- a/colossalai/nn/parallel/data_parallel.py +++ b/colossalai/nn/parallel/data_parallel.py @@ -213,6 +213,7 @@ def __init__(self, self.force_outputs_fp32 = force_outputs_fp32 self.param_op_hook = GeminiZeROHook(gemini_manager) self.fp32_params: List[ColoTensor] = [] + self.name_to_fp32_params: Dict[str, ColoTensor] = {} self.overflow_counter = 0 self.grads_device: Dict[torch.Tensor, torch.device] = {} @@ -228,6 +229,7 @@ def __init__(self, for p in module.parameters(): param_order.append(p) + params_to_fp32_params = {} for p in param_order.generate(): assert isinstance(p, ColoParameter) @@ -249,8 +251,14 @@ def __init__(self, config_key=dp_world_size, cpu_offload=cpu_offload, pin_memory=pin_memory) + params_to_fp32_params[p] = fp32_p self.fp32_params.append(fp32_p) self.grads_device[p] = self.gemini_manager.default_device + + # Set keep_vars=True to prevent the parameter to be detached + for name, p in module.state_dict(keep_vars=True).items(): + if p in params_to_fp32_params: + self.name_to_fp32_params[name] = params_to_fp32_params[p] self.chunk_manager.close_all_groups() self._cast_buffers() @@ -418,11 +426,15 @@ def _save_to_state_dict(self, destination, prefix, keep_vars, only_rank_0=True): assert keep_vars is False, "`state_dict` with parameter, `keep_vars=True`, is not supported now." param_to_save_data = self._get_param_to_save_data(self.fp32_params, only_rank_0) - for (name, p), fp32_p in zip(self.named_parameters(), self.fp32_params): - if p is not None: + for name, p in self.module.state_dict().items(): + if p is not None and name in self.name_to_fp32_params: + fp32_p = self.name_to_fp32_params[name] + assert fp32_p in param_to_save_data, "Parameter '{}' is neglected in the chunk list".format(name) record_parameter = param_to_save_data[fp32_p] - destination[prefix + name] = record_parameter + else: + record_parameter = p + destination[prefix + name] = record_parameter # save all buffers for name, buf in self.named_buffers(): From 1a8ed43b0a599f3b838565c9a34ff63dca91e3c5 Mon Sep 17 00:00:00 2001 From: eric8607242 Date: Sat, 7 Jan 2023 09:44:36 +0800 Subject: [PATCH 2/3] Rewrite state_dict based on get_static_torch_model --- colossalai/nn/parallel/data_parallel.py | 53 +++++++++++++++---------- colossalai/nn/parallel/utils.py | 2 +- 2 files changed, 34 insertions(+), 21 deletions(-) diff --git a/colossalai/nn/parallel/data_parallel.py b/colossalai/nn/parallel/data_parallel.py index ea600505734d..faf4ea1daa15 100644 --- a/colossalai/nn/parallel/data_parallel.py +++ b/colossalai/nn/parallel/data_parallel.py @@ -18,6 +18,7 @@ from colossalai.zero.utils.gemini_hook import GeminiZeROHook from .reducer import Reducer +from .utils import get_static_torch_model try: from torch.nn.modules.module import _EXTRA_STATE_KEY_SUFFIX, _IncompatibleKeys @@ -213,7 +214,6 @@ def __init__(self, self.force_outputs_fp32 = force_outputs_fp32 self.param_op_hook = GeminiZeROHook(gemini_manager) self.fp32_params: List[ColoTensor] = [] - self.name_to_fp32_params: Dict[str, ColoTensor] = {} self.overflow_counter = 0 self.grads_device: Dict[torch.Tensor, torch.device] = {} @@ -229,7 +229,6 @@ def __init__(self, for p in module.parameters(): param_order.append(p) - params_to_fp32_params = {} for p in param_order.generate(): assert isinstance(p, ColoParameter) @@ -251,14 +250,9 @@ def __init__(self, config_key=dp_world_size, cpu_offload=cpu_offload, pin_memory=pin_memory) - params_to_fp32_params[p] = fp32_p self.fp32_params.append(fp32_p) self.grads_device[p] = self.gemini_manager.default_device - # Set keep_vars=True to prevent the parameter to be detached - for name, p in module.state_dict(keep_vars=True).items(): - if p in params_to_fp32_params: - self.name_to_fp32_params[name] = params_to_fp32_params[p] self.chunk_manager.close_all_groups() self._cast_buffers() @@ -339,12 +333,11 @@ def set_chunk_grad_device(self, chunk: Chunk, device: torch.device) -> None: for tensor in chunk.get_tensors(): self.grads_device[tensor] = device - def state_dict(self, destination=None, prefix='', keep_vars=False, only_rank_0: bool = True): - r"""Returns a dictionary containing a whole state of the module. - - Both parameters and persistent buffers (e.g. running averages) are - included. Keys are corresponding parameter and buffer names. - Parameters and buffers set to ``None`` are not included. + def state_dict(self, destination=None, prefix='', keep_vars=False, only_rank_0: bool = True, strict: bool = True): + r""" + Args: + strict (bool): whether to reture the whole model state + as the original pytorch state_dict() Returns: dict: @@ -354,7 +347,31 @@ def state_dict(self, destination=None, prefix='', keep_vars=False, only_rank_0: >>> module.state_dict().keys() ['bias', 'weight'] + """ + if strict: + return get_static_torch_model(gemini_ddp_model=self, device=get_current_device(), + only_rank_0=only_rank_0).state_dict(destination=destination, + prefix=prefix, + keep_vars=keep_vars) + return self._non_strict_state_dict(destination=destination, + prefix=prefix, + keep_vars=keep_vars, + only_rank_0=only_rank_0) + + def _non_strict_state_dict(self, destination=None, prefix='', keep_vars=False, only_rank_0: bool = True): + r"""Returns a dictionary containing a whole state of the module. + Both parameters and persistent buffers (e.g. running averages) are + included. Keys are corresponding parameter and buffer names. + Parameters and buffers set to ``None`` are not included. + + Warning: The non strict state dict would ignore the parameters if the + tensors of the parameters are shared with other parameters which + have been included in the dictionary. + + Returns: + dict: + a dictionary containing a whole state of the module """ if destination is None: destination = OrderedDict() @@ -414,15 +431,11 @@ def _save_to_state_dict(self, destination, prefix, keep_vars, only_rank_0=True): param_to_save_data = self._get_param_to_save_data(self.fp32_params, only_rank_0) # TODO: (HELSON) deal with ddp ignored parameters - for name, p in self.module.state_dict().items(): - if p is not None and name in self.name_to_fp32_params: - fp32_p = self.name_to_fp32_params[name] - + for (name, p), fp32_p in zip(self.named_parameters(), self.fp32_params): + if p is not None: assert fp32_p in param_to_save_data, "Parameter '{}' is neglected in the chunk list".format(name) record_parameter = param_to_save_data[fp32_p] - else: - record_parameter = p - destination[prefix + name] = record_parameter + destination[prefix + name] = record_parameter # save all buffers for name, buf in self.named_buffers(): diff --git a/colossalai/nn/parallel/utils.py b/colossalai/nn/parallel/utils.py index 1205cbc3a658..8f83606161c3 100644 --- a/colossalai/nn/parallel/utils.py +++ b/colossalai/nn/parallel/utils.py @@ -81,7 +81,7 @@ def get_static_torch_model(gemini_ddp_model, from colossalai.nn.parallel import GeminiDDP assert isinstance(gemini_ddp_model, GeminiDDP) - state_dict = gemini_ddp_model.state_dict(only_rank_0=only_rank_0) + state_dict = gemini_ddp_model.state_dict(only_rank_0=only_rank_0, strict=False) colo_model = gemini_ddp_model.module torch_model = _get_shallow_copy_model(colo_model) From 4037c9d309afc3b9c1fc7001dc09268e61aab116 Mon Sep 17 00:00:00 2001 From: eric8607242 Date: Mon, 9 Jan 2023 12:42:36 +0800 Subject: [PATCH 3/3] Modify get_static_torch_model to be compatible with the lower version (ZeroDDP) --- colossalai/nn/parallel/data_parallel.py | 2 +- colossalai/nn/parallel/utils.py | 16 ++++++++-------- 2 files changed, 9 insertions(+), 9 deletions(-) diff --git a/colossalai/nn/parallel/data_parallel.py b/colossalai/nn/parallel/data_parallel.py index faf4ea1daa15..8fd08db957b7 100644 --- a/colossalai/nn/parallel/data_parallel.py +++ b/colossalai/nn/parallel/data_parallel.py @@ -349,7 +349,7 @@ def state_dict(self, destination=None, prefix='', keep_vars=False, only_rank_0: ['bias', 'weight'] """ if strict: - return get_static_torch_model(gemini_ddp_model=self, device=get_current_device(), + return get_static_torch_model(zero_ddp_model=self, device=get_current_device(), only_rank_0=only_rank_0).state_dict(destination=destination, prefix=prefix, keep_vars=keep_vars) diff --git a/colossalai/nn/parallel/utils.py b/colossalai/nn/parallel/utils.py index 8f83606161c3..988f978254a1 100644 --- a/colossalai/nn/parallel/utils.py +++ b/colossalai/nn/parallel/utils.py @@ -60,17 +60,17 @@ def _get_shallow_copy_model(model: nn.Module): return name_to_module[''] -def get_static_torch_model(gemini_ddp_model, +def get_static_torch_model(zero_ddp_model, device=torch.device("cpu"), dtype=torch.float32, only_rank_0=True) -> torch.nn.Module: - """Get a static torch.nn.Module model from the given GeminiDDP module. - You should notice that the original GeminiDDP model is not modified. + """Get a static torch.nn.Module model from the given ZeroDDP module. + You should notice that the original ZeroDDP model is not modified. Thus, you can use the original model in further training. But you should not use the returned torch model to train, this can cause unexpected errors. Args: - gemini_ddp_model (GeminiDDP): a gemini ddp model + zero_ddp_model (ZeroDDP): a zero ddp model device (torch.device): the device of the final torch model dtype (torch.dtype): the dtype of the final torch model only_rank_0 (bool): if True, only rank0 has the coverted torch model @@ -78,11 +78,11 @@ def get_static_torch_model(gemini_ddp_model, Returns: torch.nn.Module: a static torch model used for saving checkpoints or numeric checks """ - from colossalai.nn.parallel import GeminiDDP - assert isinstance(gemini_ddp_model, GeminiDDP) + from colossalai.nn.parallel import ZeroDDP + assert isinstance(zero_ddp_model, ZeroDDP) - state_dict = gemini_ddp_model.state_dict(only_rank_0=only_rank_0, strict=False) - colo_model = gemini_ddp_model.module + state_dict = zero_ddp_model.state_dict(only_rank_0=only_rank_0, strict=False) + colo_model = zero_ddp_model.module torch_model = _get_shallow_copy_model(colo_model) if not only_rank_0 or dist.get_rank() == 0: