Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
37 changes: 31 additions & 6 deletions colossalai/nn/parallel/data_parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
from colossalai.zero.utils.gemini_hook import GeminiZeROHook

from .reducer import Reducer
from .utils import get_static_torch_model

try:
from torch.nn.modules.module import _EXTRA_STATE_KEY_SUFFIX, _IncompatibleKeys
Expand Down Expand Up @@ -251,6 +252,7 @@ def __init__(self,
pin_memory=pin_memory)
self.fp32_params.append(fp32_p)
self.grads_device[p] = self.gemini_manager.default_device

self.chunk_manager.close_all_groups()
self._cast_buffers()

Expand Down Expand Up @@ -331,12 +333,11 @@ def set_chunk_grad_device(self, chunk: Chunk, device: torch.device) -> None:
for tensor in chunk.get_tensors():
self.grads_device[tensor] = device

def state_dict(self, destination=None, prefix='', keep_vars=False, only_rank_0: bool = True):
r"""Returns a dictionary containing a whole state of the module.

Both parameters and persistent buffers (e.g. running averages) are
included. Keys are corresponding parameter and buffer names.
Parameters and buffers set to ``None`` are not included.
def state_dict(self, destination=None, prefix='', keep_vars=False, only_rank_0: bool = True, strict: bool = True):
r"""
Args:
strict (bool): whether to reture the whole model state
as the original pytorch state_dict()

Returns:
dict:
Expand All @@ -346,7 +347,31 @@ def state_dict(self, destination=None, prefix='', keep_vars=False, only_rank_0:

>>> module.state_dict().keys()
['bias', 'weight']
"""
if strict:
return get_static_torch_model(zero_ddp_model=self, device=get_current_device(),
only_rank_0=only_rank_0).state_dict(destination=destination,
prefix=prefix,
keep_vars=keep_vars)
return self._non_strict_state_dict(destination=destination,
prefix=prefix,
keep_vars=keep_vars,
only_rank_0=only_rank_0)

def _non_strict_state_dict(self, destination=None, prefix='', keep_vars=False, only_rank_0: bool = True):
r"""Returns a dictionary containing a whole state of the module.

Both parameters and persistent buffers (e.g. running averages) are
included. Keys are corresponding parameter and buffer names.
Parameters and buffers set to ``None`` are not included.

Warning: The non strict state dict would ignore the parameters if the
tensors of the parameters are shared with other parameters which
have been included in the dictionary.

Returns:
dict:
a dictionary containing a whole state of the module
"""
if destination is None:
destination = OrderedDict()
Expand Down
16 changes: 8 additions & 8 deletions colossalai/nn/parallel/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,29 +60,29 @@ def _get_shallow_copy_model(model: nn.Module):
return name_to_module['']


def get_static_torch_model(gemini_ddp_model,
def get_static_torch_model(zero_ddp_model,
device=torch.device("cpu"),
dtype=torch.float32,
only_rank_0=True) -> torch.nn.Module:
"""Get a static torch.nn.Module model from the given GeminiDDP module.
You should notice that the original GeminiDDP model is not modified.
"""Get a static torch.nn.Module model from the given ZeroDDP module.
You should notice that the original ZeroDDP model is not modified.
Thus, you can use the original model in further training.
But you should not use the returned torch model to train, this can cause unexpected errors.

Args:
gemini_ddp_model (GeminiDDP): a gemini ddp model
zero_ddp_model (ZeroDDP): a zero ddp model
device (torch.device): the device of the final torch model
dtype (torch.dtype): the dtype of the final torch model
only_rank_0 (bool): if True, only rank0 has the coverted torch model

Returns:
torch.nn.Module: a static torch model used for saving checkpoints or numeric checks
"""
from colossalai.nn.parallel import GeminiDDP
assert isinstance(gemini_ddp_model, GeminiDDP)
from colossalai.nn.parallel import ZeroDDP
assert isinstance(zero_ddp_model, ZeroDDP)

state_dict = gemini_ddp_model.state_dict(only_rank_0=only_rank_0)
colo_model = gemini_ddp_model.module
state_dict = zero_ddp_model.state_dict(only_rank_0=only_rank_0, strict=False)
colo_model = zero_ddp_model.module
torch_model = _get_shallow_copy_model(colo_model)

if not only_rank_0 or dist.get_rank() == 0:
Expand Down