Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 20 additions & 9 deletions colossalai/zero/gemini/gemini_ddp.py
Original file line number Diff line number Diff line change
Expand Up @@ -202,7 +202,12 @@ def set_chunk_grad_device(self, chunk: Chunk, device: torch.device) -> None:
for tensor in chunk.get_tensors():
self.grads_device[tensor] = device

def state_dict(self, destination=None, prefix='', keep_vars=False, only_rank_0: bool = True):
def state_dict(self,
destination=None,
prefix='',
keep_vars=False,
only_rank_0: bool = True,
dtype: torch.dtype = torch.float16):
"""Returns a dictionary containing a whole state of the module.

Both parameters and persistent buffers (e.g. running averages) are included.
Expand All @@ -221,15 +226,15 @@ def state_dict(self, destination=None, prefix='', keep_vars=False, only_rank_0:
destination = OrderedDict()
destination._metadata = OrderedDict()
destination._metadata[prefix[:-1]] = local_metadata = dict(version=self._version)
self._save_to_state_dict(destination, prefix, keep_vars, only_rank_0)
self._save_to_state_dict(destination, prefix, keep_vars, only_rank_0, dtype)

for hook in self._state_dict_hooks.values():
hook_result = hook(self, destination, prefix, local_metadata)
if hook_result is not None:
destination = hook_result
return destination

def _get_chunk_to_save_data(self, chunk: Chunk, only_rank_0: bool) -> Dict:
def _get_chunk_to_save_data(self, chunk: Chunk, only_rank_0: bool, dtype: torch.dtype = torch.float16) -> Dict:
"""
get gathered chunk content.

Expand All @@ -243,6 +248,8 @@ def _get_chunk_to_save_data(self, chunk: Chunk, only_rank_0: bool) -> Dict:
# save parameters
chunk_to_save_data = dict()
temp_chunk = get_temp_total_chunk_on_cuda(chunk)
if torch.is_floating_point(temp_chunk):
temp_chunk = temp_chunk.to(dtype)
for tensor, tensor_info in chunk.tensors_info.items():
record_tensor = torch.empty([0])
record_flag = (not only_rank_0) | (dist.get_rank(chunk.torch_pg) == 0)
Expand All @@ -255,7 +262,8 @@ def _get_chunk_to_save_data(self, chunk: Chunk, only_rank_0: bool) -> Dict:
del temp_chunk
return chunk_to_save_data

def _get_param_to_save_data(self, param_list: List[torch.nn.Parameter], only_rank_0: bool) -> Dict:
def _get_param_to_save_data(self, param_list: List[torch.nn.Parameter], only_rank_0: bool,
dtype: torch.dtype) -> Dict:
"""
get param content from chunks.

Expand All @@ -270,10 +278,10 @@ def _get_param_to_save_data(self, param_list: List[torch.nn.Parameter], only_ran
param_to_save_data = dict()
chunk_list = self.chunk_manager.get_chunks(param_list)
for chunk in chunk_list:
param_to_save_data.update(self._get_chunk_to_save_data(chunk, only_rank_0))
param_to_save_data.update(self._get_chunk_to_save_data(chunk, only_rank_0, dtype))
return param_to_save_data

def _save_to_state_dict(self, destination, prefix, keep_vars, only_rank_0=True):
def _save_to_state_dict(self, destination, prefix, keep_vars, only_rank_0=True, dtype=torch.float16):
r"""Saves module state to `destination` dictionary, containing a state
of the module, but not its descendants. This is called on every
submodule in :meth:`~torch.nn.Module.state_dict`.
Expand All @@ -289,7 +297,8 @@ def _save_to_state_dict(self, destination, prefix, keep_vars, only_rank_0=True):
assert keep_vars is False, "`state_dict` with parameter, `keep_vars=True`, is not supported now."

# get copies of fp32 parameters in CPU
param_to_save_data = self._get_param_to_save_data(self.fp32_params, only_rank_0)
# as memory of fp16_params may be reused by grad, it's not reliable, we should use fp32_params and convert to fp16
param_to_save_data = self._get_param_to_save_data(self.fp32_params, only_rank_0, dtype)
# get the mapping between copies and fp16 parameters
p_mapping = dict()
for p, fp32_p in zip(self.fp16_params, self.fp32_params):
Expand Down Expand Up @@ -574,7 +583,8 @@ def state_dict_shard(self,
prefix: str = '',
keep_vars: bool = False,
max_shard_size: int = 1024,
only_rank_0: bool = True) -> Iterator[OrderedDict]:
only_rank_0: bool = True,
dtype: torch.dtype = torch.float16) -> Iterator[OrderedDict]:
"""Returns dictionaries containing a whole state of the module one by one. The max size of dictionary shard is specified by ``max_shard_size``.

Both parameters and persistent buffers (e.g. running averages) are included.
Expand Down Expand Up @@ -607,10 +617,11 @@ def state_dict_shard(self,
# deal with ddp ignored parameters
gathered_param = param if keep_vars else param.detach()
else:
# as memory of fp16 param may be reused, we should use fp32 param and then convert to fp16
fp32_param = fp16_to_fp32[param]
if fp32_param not in gathered_param_buffer:
chunk = self.chunk_manager.get_chunk(fp32_param)
gathered_param_buffer.update(self._get_chunk_to_save_data(chunk, only_rank_0))
gathered_param_buffer.update(self._get_chunk_to_save_data(chunk, only_rank_0, dtype))
gathered_param = gathered_param_buffer.pop(fp32_param)

block = sharder.append(prefix + name, gathered_param)
Expand Down