Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
180 changes: 90 additions & 90 deletions colossalai/checkpoint_io/hybrid_parallel_checkpoint_io.py
Original file line number Diff line number Diff line change
Expand Up @@ -534,96 +534,96 @@ def save_sharded_optimizer(
f"index located at {final_index_file_path}."
)

# def load_sharded_optimizer(self, optimizer: OptimizerWrapper, checkpoint_index_file: str, prefix: str = ""):
# """
# Load sharded optimizer with the given path to index file of checkpoint folder.

# Args:
# optimizer (OptimizerWrapper): The optimizer to be loaded.
# checkpoint_index_file (str): Path to the index file of checkpointing folder.
# prefix (str): Not used.
# """
# assert isinstance(optimizer, OptimizerWrapper), "Please boost the optimizer before loading!"

# def _get_param_id_from_optimizer_param(
# param: torch.Tensor, master_to_working_map: Optional[Dict[int, torch.Tensor]] = None
# ):
# if master_to_working_map is not None:
# working_param = master_to_working_map[id(param)]
# else:
# working_param = param
# return optimizer.param_info["param2id"][id(working_param)]

# # id_map is a mapping from param ids kept by current pipeline, to their corresponding parameter objects.
# # When Zero is used, the mapped parameter objects should be fp32 master parameters.
# # IDs should be obtained through saved param2id mapping earlier saved in optimizer.param_info.
# id_map = {}
# master_to_working_map = optimizer.get_master_to_working_map()
# for pg in optimizer.optim.param_groups:
# for param in pg["params"]:
# param_id = _get_param_id_from_optimizer_param(param, master_to_working_map)
# id_map[param_id] = param

# # Read checkpoint index file.
# ckpt_index_file = CheckpointIndexFile.from_file(checkpoint_index_file)
# ckpt_root_path = ckpt_index_file.root_path
# weight_map = ckpt_index_file.weight_map
# weight_map = {int(k): v for k, v in weight_map.items()} # convert saved id from str to int

# # Load param_groups
# param_group_path = ckpt_index_file.get_param_group_filename()
# if param_group_path is None:
# raise RuntimeError(
# f"Invalid index file path {checkpoint_index_file} for an optimizer. \
# Lacking param group file under current directory."
# )
# saved_groups = torch.load(param_group_path)

# updated_groups = []
# for old_pg, saved_pg in zip(optimizer.optim.param_groups, saved_groups):
# # obtain updated param group
# new_pg = copy.deepcopy(saved_pg)
# new_pg["params"] = old_pg["params"] # The parameters in the same group shouldn't change.
# updated_groups.append(new_pg)
# optimizer.optim.__dict__.update({"param_groups": updated_groups})

# # Load saved states to optimizer.
# # Keep a record of loaded files so that file will not be repeatedly loaded.
# loaded_file = set()
# for pg in optimizer.optim.param_groups:
# for param in pg["params"]:
# if param is None:
# continue
# param_id = _get_param_id_from_optimizer_param(param, master_to_working_map)
# if param_id not in weight_map:
# continue
# filename = weight_map[param_id]

# # If this param's states has been loaded before, directly return.
# if filename in loaded_file:
# continue

# file_path = os.path.join(ckpt_root_path, filename)
# state_dict = load_shard_state_dict(Path(file_path), use_safetensors=False)
# load_states_into_optimizer(optimizer.optim, state_dict, id_map, strict=True)
# loaded_file.add(filename)

# # Then shard the loaded optimizer states if using tp/zero.
# for param, state in optimizer.optim.state.items():
# device = param.device
# if master_to_working_map is not None:
# working_param = master_to_working_map[id(param)]
# else:
# working_param = param
# original_shape = optimizer.param_info["param2shape"][id(working_param)]
# sharded_state = self.shard_from_complete_optimizer_state(
# state, current_shape=working_param.shape, original_shape=original_shape, device=device, inplace=True
# )
# optimizer.optim.state[param] = sharded_state

# sharded_optimizer_loading_epilogue(optimizer.optim)
# if self.verbose and self.coordinator.is_master():
# logging.info(f"The optimizer has been successfully loaded from sharded checkpoint: {ckpt_root_path}.")
def load_sharded_optimizer(self, optimizer: OptimizerWrapper, checkpoint_index_file: str, prefix: str = ""):
"""
Load sharded optimizer with the given path to index file of checkpoint folder.

Args:
optimizer (OptimizerWrapper): The optimizer to be loaded.
checkpoint_index_file (str): Path to the index file of checkpointing folder.
prefix (str): Not used.
"""
assert isinstance(optimizer, OptimizerWrapper), "Please boost the optimizer before loading!"

def _get_param_id_from_optimizer_param(
param: torch.Tensor, master_to_working_map: Optional[Dict[int, torch.Tensor]] = None
):
if master_to_working_map is not None:
working_param = master_to_working_map[id(param)]
else:
working_param = param
return optimizer.param_info["param2id"][id(working_param)]

# id_map is a mapping from param ids kept by current pipeline, to their corresponding parameter objects.
# When Zero is used, the mapped parameter objects should be fp32 master parameters.
# IDs should be obtained through saved param2id mapping earlier saved in optimizer.param_info.
id_map = {}
master_to_working_map = optimizer.get_master_to_working_map()
for pg in optimizer.optim.param_groups:
for param in pg["params"]:
param_id = _get_param_id_from_optimizer_param(param, master_to_working_map)
id_map[param_id] = param

# Read checkpoint index file.
ckpt_index_file = CheckpointIndexFile.from_file(checkpoint_index_file)
ckpt_root_path = ckpt_index_file.root_path
weight_map = ckpt_index_file.weight_map
weight_map = {int(k): v for k, v in weight_map.items()} # convert saved id from str to int

# Load param_groups
param_group_path = ckpt_index_file.get_param_group_filename()
if param_group_path is None:
raise RuntimeError(
f"Invalid index file path {checkpoint_index_file} for an optimizer. \
Lacking param group file under current directory."
)
saved_groups = torch.load(param_group_path)

updated_groups = []
for old_pg, saved_pg in zip(optimizer.optim.param_groups, saved_groups):
# obtain updated param group
new_pg = copy.deepcopy(saved_pg)
new_pg["params"] = old_pg["params"] # The parameters in the same group shouldn't change.
updated_groups.append(new_pg)
optimizer.optim.__dict__.update({"param_groups": updated_groups})

# Load saved states to optimizer.
# Keep a record of loaded files so that file will not be repeatedly loaded.
loaded_file = set()
for pg in optimizer.optim.param_groups:
for param in pg["params"]:
if param is None:
continue
param_id = _get_param_id_from_optimizer_param(param, master_to_working_map)
if param_id not in weight_map:
continue
filename = weight_map[param_id]

# If this param's states has been loaded before, directly return.
if filename in loaded_file:
continue

file_path = os.path.join(ckpt_root_path, filename)
state_dict = load_shard_state_dict(Path(file_path), use_safetensors=False)
load_states_into_optimizer(optimizer.optim, state_dict, id_map, strict=True)
loaded_file.add(filename)

# Then shard the loaded optimizer states if using tp/zero.
for param, state in optimizer.optim.state.items():
device = param.device
if master_to_working_map is not None:
working_param = master_to_working_map[id(param)]
else:
working_param = param
original_shape = optimizer.param_info["param2shape"][id(working_param)]
sharded_state = self.shard_from_complete_optimizer_state(
state, current_shape=working_param.shape, original_shape=original_shape, device=device, inplace=True
)
optimizer.optim.state[param] = sharded_state

sharded_optimizer_loading_epilogue(optimizer.optim)
if self.verbose and self.coordinator.is_master():
logging.info(f"The optimizer has been successfully loaded from sharded checkpoint: {ckpt_root_path}.")

def save_unsharded_model(self, model: ModelWrapper, checkpoint: str, gather_dtensor: bool, use_safetensors: bool):
"""
Expand Down
4 changes: 2 additions & 2 deletions colossalai/shardformer/modeling/mixtral.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
MoeCausalLMOutputWithPast,
load_balancing_loss_func,
)
from transformers.utils import logging
from transformers.utils import is_flash_attn_2_available, logging

from colossalai.lazy import LazyInitContext
from colossalai.moe._operation import MoeInGradScaler, MoeOutGradScaler, all_to_all_uneven
Expand Down Expand Up @@ -218,7 +218,7 @@ def mixtral_model_forward(

# embed positions, for the first stage, hidden_states is the input embeddings,
# for the other stages, hidden_states is the output of the previous stage
if self._use_flash_attention_2:
if is_flash_attn_2_available():
# 2d mask is passed through the layers
attention_mask = attention_mask if (attention_mask is not None and 0 in attention_mask) else None
else:
Expand Down
4 changes: 3 additions & 1 deletion tests/test_shardformer/test_model/test_shard_llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,9 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,
working_p = sharded_optimizer.master_to_working_param[id(p2)]
grads = sharded_optimizer.get_partitioned_gradients_by_param_id(0, id(working_p))
grad_index = (
0 if sharded_optimizer._partition_grads else sharded_optimizer.pid_to_bucket_store[id(p2)].local_rank
0
if sharded_optimizer._partition_grads
else sharded_optimizer.pid_to_bucket_store[id(working_p)].local_rank
)
grad = grads[grad_index]
sharded_grad = p1.grad.view(-1).chunk(dist.get_world_size())[dist.get_rank()]
Expand Down