From 1151e73ea9396e824208fd97fcd0a198f4fcadcd Mon Sep 17 00:00:00 2001 From: ver217 Date: Tue, 18 Apr 2023 16:07:50 +0800 Subject: [PATCH 1/4] [gemini] save state dict support fp16 --- colossalai/zero/gemini/gemini_ddp.py | 25 +++++++++++++++++-------- 1 file changed, 17 insertions(+), 8 deletions(-) diff --git a/colossalai/zero/gemini/gemini_ddp.py b/colossalai/zero/gemini/gemini_ddp.py index 9a193310bab1..0d3ef46fd019 100644 --- a/colossalai/zero/gemini/gemini_ddp.py +++ b/colossalai/zero/gemini/gemini_ddp.py @@ -202,7 +202,12 @@ def set_chunk_grad_device(self, chunk: Chunk, device: torch.device) -> None: for tensor in chunk.get_tensors(): self.grads_device[tensor] = device - def state_dict(self, destination=None, prefix='', keep_vars=False, only_rank_0: bool = True): + def state_dict(self, + destination=None, + prefix='', + keep_vars=False, + only_rank_0: bool = True, + dtype: torch.dtype = torch.float16): """Returns a dictionary containing a whole state of the module. Both parameters and persistent buffers (e.g. running averages) are included. @@ -221,7 +226,7 @@ def state_dict(self, destination=None, prefix='', keep_vars=False, only_rank_0: destination = OrderedDict() destination._metadata = OrderedDict() destination._metadata[prefix[:-1]] = local_metadata = dict(version=self._version) - self._save_to_state_dict(destination, prefix, keep_vars, only_rank_0) + self._save_to_state_dict(destination, prefix, keep_vars, only_rank_0, dtype) for hook in self._state_dict_hooks.values(): hook_result = hook(self, destination, prefix, local_metadata) @@ -273,7 +278,7 @@ def _get_param_to_save_data(self, param_list: List[torch.nn.Parameter], only_ran param_to_save_data.update(self._get_chunk_to_save_data(chunk, only_rank_0)) return param_to_save_data - def _save_to_state_dict(self, destination, prefix, keep_vars, only_rank_0=True): + def _save_to_state_dict(self, destination, prefix, keep_vars, only_rank_0=True, dtype=torch.float16): r"""Saves module state to `destination` dictionary, containing a state of the module, but not its descendants. This is called on every submodule in :meth:`~torch.nn.Module.state_dict`. @@ -288,15 +293,19 @@ def _save_to_state_dict(self, destination, prefix, keep_vars, only_rank_0=True): """ assert keep_vars is False, "`state_dict` with parameter, `keep_vars=True`, is not supported now." + params_to_save = self.fp16_params if dtype == torch.float16 else self.fp32_params # get copies of fp32 parameters in CPU - param_to_save_data = self._get_param_to_save_data(self.fp32_params, only_rank_0) + param_to_save_data = self._get_param_to_save_data(params_to_save, only_rank_0) # get the mapping between copies and fp16 parameters p_mapping = dict() for p, fp32_p in zip(self.fp16_params, self.fp32_params): - name = self.param2name[p] - assert fp32_p in param_to_save_data, "Parameter '{}' is neglected in the chunk list".format(name) - record_parameter = param_to_save_data[fp32_p] - p_mapping[p] = record_parameter + if dtype == torch.float16: + p_mapping[p] = param_to_save_data[p] + else: + name = self.param2name[p] + assert fp32_p in param_to_save_data, "Parameter '{}' is neglected in the chunk list".format(name) + record_parameter = param_to_save_data[fp32_p] + p_mapping[p] = record_parameter for name, param in self.name2param.items(): if param is not None: if is_ddp_ignored(param): From 96b086a120ff12917d6590d774191a352924e43a Mon Sep 17 00:00:00 2001 From: ver217 Date: Tue, 18 Apr 2023 16:10:39 +0800 Subject: [PATCH 2/4] [gemini] save state dict shard support fp16 --- colossalai/zero/gemini/gemini_ddp.py | 11 ++++++----- 1 file changed, 6 insertions(+), 5 deletions(-) diff --git a/colossalai/zero/gemini/gemini_ddp.py b/colossalai/zero/gemini/gemini_ddp.py index 0d3ef46fd019..645dcb66e173 100644 --- a/colossalai/zero/gemini/gemini_ddp.py +++ b/colossalai/zero/gemini/gemini_ddp.py @@ -583,7 +583,8 @@ def state_dict_shard(self, prefix: str = '', keep_vars: bool = False, max_shard_size: int = 1024, - only_rank_0: bool = True) -> Iterator[OrderedDict]: + only_rank_0: bool = True, + dtype: torch.dtype = torch.float16) -> Iterator[OrderedDict]: """Returns dictionaries containing a whole state of the module one by one. The max size of dictionary shard is specified by ``max_shard_size``. Both parameters and persistent buffers (e.g. running averages) are included. @@ -616,11 +617,11 @@ def state_dict_shard(self, # deal with ddp ignored parameters gathered_param = param if keep_vars else param.detach() else: - fp32_param = fp16_to_fp32[param] - if fp32_param not in gathered_param_buffer: - chunk = self.chunk_manager.get_chunk(fp32_param) + param_to_save = param if dtype == torch.float16 else fp16_to_fp32[param] + if param_to_save not in gathered_param_buffer: + chunk = self.chunk_manager.get_chunk(param_to_save) gathered_param_buffer.update(self._get_chunk_to_save_data(chunk, only_rank_0)) - gathered_param = gathered_param_buffer.pop(fp32_param) + gathered_param = gathered_param_buffer.pop(param_to_save) block = sharder.append(prefix + name, gathered_param) if block is not None: From 224b7240ccb65c8436d4c70ac2811a60fbf3dfcc Mon Sep 17 00:00:00 2001 From: ver217 Date: Tue, 18 Apr 2023 16:29:31 +0800 Subject: [PATCH 3/4] [gemini] fix state dict --- colossalai/zero/gemini/gemini_ddp.py | 35 +++++++++---------- .../test_gemini/test_zeroddp_state_dict.py | 2 +- 2 files changed, 18 insertions(+), 19 deletions(-) diff --git a/colossalai/zero/gemini/gemini_ddp.py b/colossalai/zero/gemini/gemini_ddp.py index 645dcb66e173..41382afd1dcd 100644 --- a/colossalai/zero/gemini/gemini_ddp.py +++ b/colossalai/zero/gemini/gemini_ddp.py @@ -234,7 +234,7 @@ def state_dict(self, destination = hook_result return destination - def _get_chunk_to_save_data(self, chunk: Chunk, only_rank_0: bool) -> Dict: + def _get_chunk_to_save_data(self, chunk: Chunk, only_rank_0: bool, dtype: torch.dtype = torch.float16) -> Dict: """ get gathered chunk content. @@ -247,7 +247,7 @@ def _get_chunk_to_save_data(self, chunk: Chunk, only_rank_0: bool) -> Dict: """ # save parameters chunk_to_save_data = dict() - temp_chunk = get_temp_total_chunk_on_cuda(chunk) + temp_chunk = get_temp_total_chunk_on_cuda(chunk).to(dtype) for tensor, tensor_info in chunk.tensors_info.items(): record_tensor = torch.empty([0]) record_flag = (not only_rank_0) | (dist.get_rank(chunk.torch_pg) == 0) @@ -260,7 +260,8 @@ def _get_chunk_to_save_data(self, chunk: Chunk, only_rank_0: bool) -> Dict: del temp_chunk return chunk_to_save_data - def _get_param_to_save_data(self, param_list: List[torch.nn.Parameter], only_rank_0: bool) -> Dict: + def _get_param_to_save_data(self, param_list: List[torch.nn.Parameter], only_rank_0: bool, + dtype: torch.dtype) -> Dict: """ get param content from chunks. @@ -275,7 +276,7 @@ def _get_param_to_save_data(self, param_list: List[torch.nn.Parameter], only_ran param_to_save_data = dict() chunk_list = self.chunk_manager.get_chunks(param_list) for chunk in chunk_list: - param_to_save_data.update(self._get_chunk_to_save_data(chunk, only_rank_0)) + param_to_save_data.update(self._get_chunk_to_save_data(chunk, only_rank_0, dtype)) return param_to_save_data def _save_to_state_dict(self, destination, prefix, keep_vars, only_rank_0=True, dtype=torch.float16): @@ -293,19 +294,16 @@ def _save_to_state_dict(self, destination, prefix, keep_vars, only_rank_0=True, """ assert keep_vars is False, "`state_dict` with parameter, `keep_vars=True`, is not supported now." - params_to_save = self.fp16_params if dtype == torch.float16 else self.fp32_params # get copies of fp32 parameters in CPU - param_to_save_data = self._get_param_to_save_data(params_to_save, only_rank_0) + # as memory of fp16_params may be reused by grad, it's not reliable, we should use fp32_params and convert to fp16 + param_to_save_data = self._get_param_to_save_data(self.fp32_params, only_rank_0, dtype) # get the mapping between copies and fp16 parameters p_mapping = dict() for p, fp32_p in zip(self.fp16_params, self.fp32_params): - if dtype == torch.float16: - p_mapping[p] = param_to_save_data[p] - else: - name = self.param2name[p] - assert fp32_p in param_to_save_data, "Parameter '{}' is neglected in the chunk list".format(name) - record_parameter = param_to_save_data[fp32_p] - p_mapping[p] = record_parameter + name = self.param2name[p] + assert fp32_p in param_to_save_data, "Parameter '{}' is neglected in the chunk list".format(name) + record_parameter = param_to_save_data[fp32_p] + p_mapping[p] = record_parameter for name, param in self.name2param.items(): if param is not None: if is_ddp_ignored(param): @@ -617,11 +615,12 @@ def state_dict_shard(self, # deal with ddp ignored parameters gathered_param = param if keep_vars else param.detach() else: - param_to_save = param if dtype == torch.float16 else fp16_to_fp32[param] - if param_to_save not in gathered_param_buffer: - chunk = self.chunk_manager.get_chunk(param_to_save) - gathered_param_buffer.update(self._get_chunk_to_save_data(chunk, only_rank_0)) - gathered_param = gathered_param_buffer.pop(param_to_save) + # as memory of fp16 param may be reused, we should use fp32 param and then convert to fp16 + fp32_param = fp16_to_fp32[param] + if fp32_param not in gathered_param_buffer: + chunk = self.chunk_manager.get_chunk(fp32_param) + gathered_param_buffer.update(self._get_chunk_to_save_data(chunk, only_rank_0, dtype)) + gathered_param = gathered_param_buffer.pop(fp32_param) block = sharder.append(prefix + name, gathered_param) if block is not None: diff --git a/tests/test_zero/test_gemini/test_zeroddp_state_dict.py b/tests/test_zero/test_gemini/test_zeroddp_state_dict.py index 66e05f3ed1ec..cb801f0553f8 100644 --- a/tests/test_zero/test_gemini/test_zeroddp_state_dict.py +++ b/tests/test_zero/test_gemini/test_zeroddp_state_dict.py @@ -81,7 +81,7 @@ def exam_load_state_dict(placement_policy, keep_gathered, model_name: str): torch_dict = torch_model.state_dict() model.load_state_dict(torch_dict, strict=False) - zero_dict = model.state_dict(only_rank_0=False) + zero_dict = model.state_dict(only_rank_0=False, dtype=torch.float32) for key, value in torch_dict.items(): assert key in zero_dict, "{} not in ZeRO dictionary.".format(key) From 9053980c4cbf47364c04c68cb65860b1de68be4b Mon Sep 17 00:00:00 2001 From: ver217 Date: Tue, 18 Apr 2023 16:33:50 +0800 Subject: [PATCH 4/4] [gemini] fix state dict --- colossalai/zero/gemini/gemini_ddp.py | 4 +++- tests/test_zero/test_gemini/test_zeroddp_state_dict.py | 2 +- 2 files changed, 4 insertions(+), 2 deletions(-) diff --git a/colossalai/zero/gemini/gemini_ddp.py b/colossalai/zero/gemini/gemini_ddp.py index 41382afd1dcd..e151f1aefb2d 100644 --- a/colossalai/zero/gemini/gemini_ddp.py +++ b/colossalai/zero/gemini/gemini_ddp.py @@ -247,7 +247,9 @@ def _get_chunk_to_save_data(self, chunk: Chunk, only_rank_0: bool, dtype: torch. """ # save parameters chunk_to_save_data = dict() - temp_chunk = get_temp_total_chunk_on_cuda(chunk).to(dtype) + temp_chunk = get_temp_total_chunk_on_cuda(chunk) + if torch.is_floating_point(temp_chunk): + temp_chunk = temp_chunk.to(dtype) for tensor, tensor_info in chunk.tensors_info.items(): record_tensor = torch.empty([0]) record_flag = (not only_rank_0) | (dist.get_rank(chunk.torch_pg) == 0) diff --git a/tests/test_zero/test_gemini/test_zeroddp_state_dict.py b/tests/test_zero/test_gemini/test_zeroddp_state_dict.py index cb801f0553f8..66e05f3ed1ec 100644 --- a/tests/test_zero/test_gemini/test_zeroddp_state_dict.py +++ b/tests/test_zero/test_gemini/test_zeroddp_state_dict.py @@ -81,7 +81,7 @@ def exam_load_state_dict(placement_policy, keep_gathered, model_name: str): torch_dict = torch_model.state_dict() model.load_state_dict(torch_dict, strict=False) - zero_dict = model.state_dict(only_rank_0=False, dtype=torch.float32) + zero_dict = model.state_dict(only_rank_0=False) for key, value in torch_dict.items(): assert key in zero_dict, "{} not in ZeRO dictionary.".format(key)