Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
62 changes: 30 additions & 32 deletions colossalai/checkpoint_io/hybrid_parallel_checkpoint_io.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import copy
import logging
import os
from collections import defaultdict
from functools import reduce
from pathlib import Path
from shutil import rmtree
Expand All @@ -10,6 +11,7 @@
import torch.distributed as dist
import torch.nn as nn
from torch.distributed import ProcessGroup
from torch.optim import Optimizer
from torch.optim.lr_scheduler import _LRScheduler as LRScheduler
from torch.utils._pytree import tree_map

Expand Down Expand Up @@ -37,7 +39,6 @@
load_shard_state_dict,
load_state_dict,
load_state_dict_into_model,
load_states_into_optimizer,
save_config_file,
save_param_groups,
save_state_dict,
Expand Down Expand Up @@ -724,26 +725,37 @@ def _get_param_id_from_optimizer_param(
state_dict = load_shard_state_dict(Path(file_path), use_safetensors=False)
if not low_cpu_mem_mode:
state_dict = create_pinned_state_dict(state_dict, empty=False, num_threads=num_threads)
load_states_into_optimizer(optimizer.optim, state_dict, id_map, strict=True)
self.load_states_into_optimizer(optimizer, state_dict, id_map)
loaded_file.add(filename)

# Then shard the loaded optimizer states if using tp/zero.
for param, state in optimizer.optim.state.items():
device = param.device
if master_to_working_map is not None:
working_param = master_to_working_map[id(param)]
else:
working_param = param
original_shape = optimizer.param_info["param2shape"][id(working_param)]
sharded_state = self.shard_from_complete_optimizer_state(
state, current_shape=working_param.shape, original_shape=original_shape, device=device, inplace=True
)
optimizer.optim.state[param] = sharded_state

sharded_optimizer_loading_epilogue(optimizer.optim)
if self.verbose and self.coordinator.is_master():
logging.info(f"The optimizer has been successfully loaded from sharded checkpoint: {ckpt_root_path}.")

def load_states_into_optimizer(self, optimizer: Optimizer, state_dict: dict, id_map: dict):
state_dict = {int(k): v for k, v in state_dict.items()}
new_states = defaultdict(dict)
master_to_working_map = optimizer.get_master_to_working_map()
for k, state in state_dict.items():
if k in id_map:
param = id_map[k]
device = param.device
dtype = param.dtype
if master_to_working_map is not None:
working_param = master_to_working_map[id(param)]
else:
working_param = param
original_shape = optimizer.param_info["param2shape"][id(working_param)]
new_states[param] = self.shard_from_complete_optimizer_state(
state,
current_shape=working_param.shape,
original_shape=original_shape,
device=device,
dtype=dtype,
inplace=True,
)
optimizer.optim.state.update(new_states)

def save_unsharded_model(
self, model: ModelWrapper, checkpoint: str, gather_dtensor: bool, use_safetensors: bool, use_async: bool = False
):
Expand Down Expand Up @@ -988,22 +1000,7 @@ def _get_param_id_from_optimizer_param(
for param in pg["params"]:
param_id = _get_param_id_from_optimizer_param(param, master_to_working_map)
id_map[param_id] = param
load_states_into_optimizer(optimizer.optim, state_dict["state"], id_map, strict=True)

# Then shard the loaded optimizer states if using tp/zero.
for param, state in optimizer.optim.state.items():
if param is None:
continue
device = param.device
if master_to_working_map is not None:
working_param = master_to_working_map[id(param)]
else:
working_param = param
original_shape = optimizer.param_info["param2shape"][id(working_param)]
sharded_state = self.shard_from_complete_optimizer_state(
state, current_shape=working_param.shape, original_shape=original_shape, device=device, inplace=True
)
optimizer.optim.state[param] = sharded_state
self.load_states_into_optimizer(optimizer, state_dict["state"], id_map)

sharded_optimizer_loading_epilogue(optimizer.optim)

Expand Down Expand Up @@ -1086,6 +1083,7 @@ def shard_from_complete_optimizer_state(
current_shape: torch.Size,
original_shape: torch.Size,
device: torch.device,
dtype: torch.dtype,
inplace: bool,
) -> OrderedDict:
"""
Expand Down Expand Up @@ -1135,7 +1133,7 @@ def shard_from_complete_optimizer_state(
slice_size = v.numel() // self.global_dp_size
v = v.split(slice_size, dim=0)[self.dp_rank]

state_[k] = v.detach().clone().to(device)
state_[k] = v.detach().clone().to(device=device, dtype=dtype)

return state_

Expand Down