diff --git a/colossalai/booster/plugin/gemini_plugin.py b/colossalai/booster/plugin/gemini_plugin.py index 46714fe1c679..4a7efc165cbd 100644 --- a/colossalai/booster/plugin/gemini_plugin.py +++ b/colossalai/booster/plugin/gemini_plugin.py @@ -99,8 +99,11 @@ def save_sharded_model(self, save_state_dict(shard, checkpoint_file_path, use_safetensors) index_file.append_meta_data("total_size", total_size) - index_file.write_index_file(save_index_file) - logging.info(f"The model is going to be split to checkpoint shards. " + + # only save the index file on the master rank + if self.coordinator.is_master(): + index_file.write_index_file(save_index_file) + logging.info(f"The model is split into checkpoint shards. " f"You can find where each parameters has been saved in the " f"index located at {save_index_file}.") diff --git a/colossalai/checkpoint_io/index_file.py b/colossalai/checkpoint_io/index_file.py index 334ecbc04738..a41cc482e054 100644 --- a/colossalai/checkpoint_io/index_file.py +++ b/colossalai/checkpoint_io/index_file.py @@ -1,8 +1,8 @@ import json -from pathlib import Path -from typing import Any, List, Union import os -import json +from collections import OrderedDict +from pathlib import Path +from typing import Any, Dict, List, Union from .utils import is_dtensor_checkpoint @@ -22,8 +22,10 @@ class CheckpointIndexFile: def __init__(self, root_path=None) -> None: self.root_path = root_path - self.metadata: dict = dict() - self.weight_map: dict = dict() + + # use ordered dict to preserve the tensor checkpoint order + self.metadata: Dict = OrderedDict() + self.weight_map: Dict = OrderedDict() @staticmethod def from_file(index_path: Union[str, Path]): @@ -150,13 +152,13 @@ def get_checkpoint_file(self, param_name: str) -> str: """ ckpt_path = self.weight_map[param_name] return ckpt_path - + def get_all_param_names(self): """ Get all the weight keys. """ return list(self.weight_map.keys()) - + def write_index_file(self, save_index_file): """ Write index file. @@ -164,5 +166,5 @@ def write_index_file(self, save_index_file): save_index_file = os.path.join(self.root_path, save_index_file) index = {"metadata": self.metadata, "weight_map": self.weight_map} with open(save_index_file, "w", encoding="utf-8") as f: - content = json.dumps(index, indent=2, sort_keys=True) + "\n" + content = json.dumps(index, indent=2) + "\n" f.write(content) diff --git a/colossalai/zero/gemini/gemini_ddp.py b/colossalai/zero/gemini/gemini_ddp.py index 7e23fdb425f8..094320c4aff4 100644 --- a/colossalai/zero/gemini/gemini_ddp.py +++ b/colossalai/zero/gemini/gemini_ddp.py @@ -716,7 +716,10 @@ def append(self, name: str, tensor: torch.Tensor) -> Tuple[Optional[OrderedDict] tensor_size = calculate_tensor_size(tensor) ret_block = None ret_block_size = 0 - if self.current_block_size + tensor_size > self.max_shard_size: + + # before we return the current block and create a new block, + # we need to ensure that the current block is not empty + if self.current_block_size + tensor_size > self.max_shard_size and self.current_block_size > 0: ret_block = self.current_block ret_block_size = self.current_block_size self.current_block = OrderedDict()