From cd30f4bfeda3e8ec01e73043e24f581734385348 Mon Sep 17 00:00:00 2001 From: wangbluo <2538539015@qq.com> Date: Thu, 14 Nov 2024 08:04:20 +0000 Subject: [PATCH 01/12] fix --- .github/workflows/build_on_pr.yml | 2 +- colossalai/booster/booster.py | 3 +- colossalai/booster/plugin/gemini_plugin.py | 57 +++++++++----- .../booster/plugin/torch_fsdp_plugin.py | 49 +++++++----- .../checkpoint_io/checkpoint_io_base.py | 4 +- .../checkpoint_io/general_checkpoint_io.py | 1 - .../hybrid_parallel_checkpoint_io.py | 77 +++++++++++++------ colossalai/checkpoint_io/moe_checkpoint.py | 52 ++++++++----- colossalai/checkpoint_io/utils.py | 11 ++- 9 files changed, 170 insertions(+), 86 deletions(-) diff --git a/.github/workflows/build_on_pr.yml b/.github/workflows/build_on_pr.yml index ceb33c9ac7a8..8d96ca1b90bc 100644 --- a/.github/workflows/build_on_pr.yml +++ b/.github/workflows/build_on_pr.yml @@ -117,7 +117,7 @@ jobs: cd TensorNVMe conda install cmake pip install -r requirements.txt - DISABLE_URING=1 pip install -v . + DISABLE_URING=1 pip install -v --no-cache-dir . - name: Store TensorNVMe Cache run: | diff --git a/colossalai/booster/booster.py b/colossalai/booster/booster.py index 2518b25111a4..bfd5300536eb 100644 --- a/colossalai/booster/booster.py +++ b/colossalai/booster/booster.py @@ -310,7 +310,7 @@ def save_model( prefix: Optional[str] = None, size_per_shard: int = 1024, use_safetensors: bool = False, - use_async: bool = False, + use_async: Optional[bool] = False, ) -> None: """Save model to checkpoint. @@ -325,6 +325,7 @@ def save_model( names to compose the keys in state_dict. Defaults to None. size_per_shard (int, optional): Maximum size of checkpoint shard file in MB. This is useful only when ``shard=True``. Defaults to 1024. use_safetensors (bool, optional): whether to use safe tensors. Default: False. If set to True, the checkpoint will be saved. + use_async (bool, optional): whether to save the state_dict of model asynchronously. Default: False. """ self.checkpoint_io.save_model( model, diff --git a/colossalai/booster/plugin/gemini_plugin.py b/colossalai/booster/plugin/gemini_plugin.py index 4c8258113018..6e6f30c30b5a 100644 --- a/colossalai/booster/plugin/gemini_plugin.py +++ b/colossalai/booster/plugin/gemini_plugin.py @@ -65,7 +65,14 @@ def __init__(self) -> None: self.coordinator = DistCoordinator() self.logger = get_dist_logger() - def save_unsharded_model(self, model: GeminiDDP, checkpoint: str, gather_dtensor: bool, use_safetensors: bool): + def save_unsharded_model( + self, + model: GeminiDDP, + checkpoint: str, + gather_dtensor: bool, + use_safetensors: bool, + use_async: Optional[bool] = False, + ): """ Save sharded model to checkpoint but only on master process. The model should be unwrapped in self.load_model via ModelWrapper.unwrap. @@ -74,7 +81,10 @@ def save_unsharded_model(self, model: GeminiDDP, checkpoint: str, gather_dtensor assert isinstance(model, GeminiDDP), "Please boost the model before saving!" state_dict = model.state_dict(only_rank_0=True) if self.coordinator.is_master(): - save_state_dict(state_dict, checkpoint, use_safetensors) + if use_async: + super().save_unsharded_model(model, checkpoint, gather_dtensor, use_safetensors, use_async) + else: + save_state_dict(state_dict, checkpoint, use_safetensors, use_async) def load_unsharded_model(self, model: GeminiDDP, checkpoint: str, strict: bool = True): """ @@ -112,6 +122,7 @@ def save_sharded_model( prefix: Optional[str] = None, max_shard_size: int = 1024, use_safetensors: bool = False, + use_async: Optional[bool] = False, ): """ Save sharded model. @@ -130,27 +141,33 @@ def save_sharded_model( # Save shards of optimizer states. is_master = self.coordinator.is_master() - total_size = save_state_dict_shards( - sharded_state_dict=state_dict_shard, - checkpoint=checkpoint_path, - index_file=index_file, - base_filename=weights_name, - is_master=is_master, - use_safetensors=use_safetensors, - ) + if use_async: + super().save_sharded_model( + model, checkpoint_path, gather_dtensor, prefix, max_shard_size, use_safetensors, use_async + ) - # only save the index file on the master rank - if self.coordinator.is_master(): - index_file.append_meta_data("total_size", total_size) - index_file.write_index_file(save_index_file) - save_config_file(model.unwrap(), checkpoint_path) - self.logger.info( - f"The model is split into checkpoint shards. " - f"You can find where each parameters has been saved in the " - f"index located at {save_index_file}.", - ranks=[0], + else: + total_size = save_state_dict_shards( + sharded_state_dict=state_dict_shard, + checkpoint=checkpoint_path, + index_file=index_file, + base_filename=weights_name, + is_master=is_master, + use_safetensors=use_safetensors, ) + # only save the index file on the master rank + if self.coordinator.is_master(): + index_file.append_meta_data("total_size", total_size) + index_file.write_index_file(save_index_file) + save_config_file(model.unwrap(), checkpoint_path) + self.logger.info( + f"The model is split into checkpoint shards. " + f"You can find where each parameters has been saved in the " + f"index located at {save_index_file}.", + ranks=[0], + ) + def load_sharded_model( self, model: GeminiDDP, checkpoint_index_file: Path, strict: bool = False, use_safetensors: bool = False ): diff --git a/colossalai/booster/plugin/torch_fsdp_plugin.py b/colossalai/booster/plugin/torch_fsdp_plugin.py index 23a35bbcbd3b..04fa21b01b69 100644 --- a/colossalai/booster/plugin/torch_fsdp_plugin.py +++ b/colossalai/booster/plugin/torch_fsdp_plugin.py @@ -82,6 +82,7 @@ def save_sharded_model( prefix: Optional[str] = None, size_per_shard: int = 1024, use_safetensors: bool = False, + use_async: Optional[bool] = False, ): """ Save model to checkpoint but only on master process. @@ -102,26 +103,36 @@ def save_sharded_model( weights_name, save_index_file = utils.get_model_base_filenames(prefix, use_safetensors) index_file = CheckpointIndexFile(checkpoint_path) - # In general cases, is_master is set to True to get the right behavior. - total_size = utils.save_state_dict_shards( - sharded_state_dict=state_dict_shard, - checkpoint=checkpoint_path, - index_file=index_file, - base_filename=weights_name, - is_master=self.coordinator.is_master(), - use_safetensors=use_safetensors, - ) - - # only save the index file on the master rank - if self.coordinator.is_master(): - index_file.append_meta_data("total_size", total_size) - index_file.write_index_file(save_index_file) - utils.save_config_file(model.unwrap(), checkpoint_path) - self.logger.info( - f"The model is split into checkpoint shards. " - f"You can find where each parameters has been saved in the " - f"index located at {save_index_file}." + if use_async: + super().save_sharded_model( + model=model, + checkpoint_path=checkpoint_path, + gather_dtensor=gather_dtensor, + prefix=prefix, + use_safetensors=use_safetensors, + use_async=use_async, ) + else: + # In general cases, is_master is set to True to get the right behavior. + total_size = utils.save_state_dict_shards( + sharded_state_dict=state_dict_shard, + checkpoint=checkpoint_path, + index_file=index_file, + base_filename=weights_name, + is_master=self.coordinator.is_master(), + use_safetensors=use_safetensors, + ) + + # only save the index file on the master rank + if self.coordinator.is_master(): + index_file.append_meta_data("total_size", total_size) + index_file.write_index_file(save_index_file) + utils.save_config_file(model.unwrap(), checkpoint_path) + self.logger.info( + f"The model is split into checkpoint shards. " + f"You can find where each parameters has been saved in the " + f"index located at {save_index_file}." + ) def load_sharded_model( self, diff --git a/colossalai/checkpoint_io/checkpoint_io_base.py b/colossalai/checkpoint_io/checkpoint_io_base.py index 4d752f3e6e9c..6e4681f0ec2e 100644 --- a/colossalai/checkpoint_io/checkpoint_io_base.py +++ b/colossalai/checkpoint_io/checkpoint_io_base.py @@ -176,10 +176,10 @@ def save_model( if shard: self.save_sharded_model( - model, checkpoint, gather_dtensor, prefix, size_per_shard, use_safetensors, use_async=use_async + model, checkpoint, gather_dtensor, prefix, size_per_shard, use_safetensors, use_async ) else: - self.save_unsharded_model(model, checkpoint, gather_dtensor, use_safetensors, use_async=use_async) + self.save_unsharded_model(model, checkpoint, gather_dtensor, use_safetensors, use_async) def load_optimizer(self, optimizer: Optimizer, checkpoint: str, prefix: str = None, size_per_shard: int = 1024): """ diff --git a/colossalai/checkpoint_io/general_checkpoint_io.py b/colossalai/checkpoint_io/general_checkpoint_io.py index a4866e64c9e8..ef58e119da90 100644 --- a/colossalai/checkpoint_io/general_checkpoint_io.py +++ b/colossalai/checkpoint_io/general_checkpoint_io.py @@ -62,7 +62,6 @@ def save_unsharded_model( self.async_writers.append(writer) move_and_save(writer, state_dict, self.pinned_state_dicts[id(model)]) else: - # save the checkpoint save_state_dict(state_dict, checkpoint, use_safetensors) diff --git a/colossalai/checkpoint_io/hybrid_parallel_checkpoint_io.py b/colossalai/checkpoint_io/hybrid_parallel_checkpoint_io.py index 79bb33dca3a4..49d4f35f9cc0 100644 --- a/colossalai/checkpoint_io/hybrid_parallel_checkpoint_io.py +++ b/colossalai/checkpoint_io/hybrid_parallel_checkpoint_io.py @@ -27,6 +27,8 @@ from .index_file import CheckpointIndexFile from .utils import ( StateDictSharder, + async_save_state_dict_shards, + create_pinned_state_dict, gather_distributed_param, get_model_base_filenames, get_optimizer_base_filenames, @@ -177,6 +179,7 @@ def save_sharded_model( prefix: Optional[str] = None, size_per_shard: int = 1024, use_safetensors: bool = False, + use_async: bool = False, ) -> None: """ Save sharded model checkpoint under the given checkpointing path. @@ -194,6 +197,7 @@ def save_sharded_model( prefix (str, optional): Perfix of file to save. Defaults to None. size_per_shard (int, optional): Size per shard in MB. Defaults to 1024. use_safetensors (bool, optional): Whether to use safe tensors. Defaults to False. + use_async (bool, optional): Whether to save the state_dicts of model asynchronously. Defaults to False. """ assert isinstance(model, ModelWrapper), "Please boost the model before saving!" @@ -219,24 +223,27 @@ def save_sharded_model( if self.pp_size == 1: # When pipeline is not used, save the model shards as in general checkpointIO - total_size = save_state_dict_shards( - sharded_state_dict=state_dict_shard, - checkpoint=checkpoint, - index_file=index_file, - base_filename=weights_name, - is_master=control_saving, - use_safetensors=use_safetensors, - ) - if control_saving: - index_file.append_meta_data("total_size", total_size) - index_file.write_index_file(save_index_file) - save_config_file(model, checkpoint) - if self.verbose and self.coordinator.is_master(): - logging.info( - f"The model is split into checkpoint shards. " - f"You can find where each parameters has been saved in the " - f"index located at {save_index_file}." - ) + if use_async: + super().save_unsharded_model(model, checkpoint, gather_dtensor, use_safetensors, use_async=use_async) + else: + total_size = save_state_dict_shards( + sharded_state_dict=state_dict_shard, + checkpoint=checkpoint, + index_file=index_file, + base_filename=weights_name, + is_master=control_saving, + use_safetensors=use_safetensors, + ) + if control_saving: + index_file.append_meta_data("total_size", total_size) + index_file.write_index_file(save_index_file) + save_config_file(model, checkpoint) + if self.verbose and self.coordinator.is_master(): + logging.info( + f"The model is split into checkpoint shards. " + f"You can find where each parameters has been saved in the " + f"index located at {save_index_file}." + ) else: # When pipeline is used, each stage produces its own shard files and index files. @@ -251,7 +258,16 @@ def save_sharded_model( weights_name = weights_name.replace(".safetensors", f"-stage-{self.pp_rank+1:05d}-shard.safetensors") save_index_file = save_index_file.replace(".json", f"-stage-{self.pp_rank+1:05d}.json") save_index_file = os.path.join("tmp_index_files", save_index_file) - + if use_async: + total_size, returned_state_dict, writers = async_save_state_dict_shards( + sharded_state_dict=state_dict_shard, + checkpoint=checkpoint, + index_file=index_file, + base_filename=weights_name, + is_master=control_saving, + use_pp_format=True, + n_write_entries=191, + ) total_size = save_state_dict_shards( sharded_state_dict=state_dict_shard, checkpoint=checkpoint, @@ -626,7 +642,9 @@ def _get_param_id_from_optimizer_param( if self.verbose and self.coordinator.is_master(): logging.info(f"The optimizer has been successfully loaded from sharded checkpoint: {ckpt_root_path}.") - def save_unsharded_model(self, model: ModelWrapper, checkpoint: str, gather_dtensor: bool, use_safetensors: bool): + def save_unsharded_model( + self, model: ModelWrapper, checkpoint: str, gather_dtensor: bool, use_safetensors: bool, use_async: bool = False + ): """ Save model state dict to a single file with given checkpointing path. @@ -635,6 +653,7 @@ def save_unsharded_model(self, model: ModelWrapper, checkpoint: str, gather_dten checkpoint (str): Checkpointing path which should be a file path. Can be absolute or relative path. gather_dtensor (bool, optional): Whether to gather dtensor, currently not used. Defaults to True. use_safetensors (bool, optional): Whether to use safe tensors. Defaults to False. + use_async (bool, optional): Whether to save the state_dicts of model asynchronously. Defaults to False. """ if self.coordinator.is_master(): logging.warning("Please avoid using unsharded checkpointing methods when dealing with large models!") @@ -651,7 +670,10 @@ def save_unsharded_model(self, model: ModelWrapper, checkpoint: str, gather_dten if self.pp_size == 1: # When pipeline is not used, let master rank directly save the collected state_dict. if self.tp_rank == 0: - save_state_dict(state_dict, checkpoint, use_safetensors) + if use_async: + super().save_unsharded_model(model, checkpoint, gather_dtensor, use_safetensors, use_async) + else: + save_state_dict(state_dict, checkpoint, use_safetensors) else: # When pipeline is used, first collect state_dict from every pipeline stage, then save the complete state_dict. state_dict_list = [None for _ in range(self.pp_size)] @@ -662,7 +684,18 @@ def save_unsharded_model(self, model: ModelWrapper, checkpoint: str, gather_dten complete_state_dict = dict() for _state_dict in state_dict_list: complete_state_dict.update(_state_dict) - save_state_dict(complete_state_dict, checkpoint, use_safetensors) + if use_async: + from tensornvme.async_file_io import AsyncFileWriter + + from colossalai.utils.safetensors import move_and_save + + writer = AsyncFileWriter(open(checkpoint, "wb"), self.N_WRITE_ENTRIES, backend="pthread") + if id(model) not in self.pinned_state_dicts: + self.pinned_state_dicts[id(model)] = create_pinned_state_dict(state_dict) + self.async_writers.append(writer) + move_and_save(writer, state_dict, self.pinned_state_dicts[id(model)]) + else: + save_state_dict(complete_state_dict, checkpoint, use_safetensors) def load_unsharded_model(self, model: ModelWrapper, checkpoint: str, strict: bool = False): """ diff --git a/colossalai/checkpoint_io/moe_checkpoint.py b/colossalai/checkpoint_io/moe_checkpoint.py index 9181956b7f60..4cb0f300f65e 100644 --- a/colossalai/checkpoint_io/moe_checkpoint.py +++ b/colossalai/checkpoint_io/moe_checkpoint.py @@ -117,6 +117,7 @@ def save_sharded_model( prefix: Optional[str] = None, size_per_shard: int = 1024, use_safetensors: bool = False, + use_async: bool = False, ) -> None: """ Save sharded model checkpoint under the given checkpointing path. @@ -161,24 +162,27 @@ def save_sharded_model( if self.pp_size == 1 and self.ep_size == 1: # When pipeline is not used, save the model shards as in general checkpointIO - total_size = save_state_dict_shards( - sharded_state_dict=state_dict_shard, - checkpoint=checkpoint, - index_file=index_file, - base_filename=weights_name, - is_master=control_saving, - use_safetensors=use_safetensors, - ) - if control_saving: - index_file.append_meta_data("total_size", total_size) - index_file.write_index_file(save_index_file) - save_config_file(model, checkpoint) - if self.verbose and self.coordinator.is_master(): - logging.info( - f"The model is split into checkpoint shards. " - f"You can find where each parameters has been saved in the " - f"index located at {save_index_file}." - ) + if use_async: + super().save_unsharded_model(model, checkpoint, gather_dtensor, use_safetensors, use_async=use_async) + else: + total_size = save_state_dict_shards( + sharded_state_dict=state_dict_shard, + checkpoint=checkpoint, + index_file=index_file, + base_filename=weights_name, + is_master=control_saving, + use_safetensors=use_safetensors, + ) + if control_saving: + index_file.append_meta_data("total_size", total_size) + index_file.write_index_file(save_index_file) + save_config_file(model, checkpoint) + if self.verbose and self.coordinator.is_master(): + logging.info( + f"The model is split into checkpoint shards. " + f"You can find where each parameters has been saved in the " + f"index located at {save_index_file}." + ) dist.barrier() else: @@ -708,10 +712,20 @@ def save_unsharded_model( checkpoint: str, gather_dtensor: bool, use_safetensors: bool, + use_async: bool = False, ): state_dict = self.pre_save_model(model) if dist.get_rank() == 0: - torch.save(state_dict, checkpoint) + if use_async: + super().save_unsharded_model( + model=model, + checkpoint=checkpoint, + gather_dtensor=gather_dtensor, + use_safetensors=use_safetensors, + use_async=use_async, + ) + else: + torch.save(state_dict, checkpoint) dist.barrier() # Copied from colossalai.moe diff --git a/colossalai/checkpoint_io/utils.py b/colossalai/checkpoint_io/utils.py index 6d539cce60c9..0b65d70dd8cd 100644 --- a/colossalai/checkpoint_io/utils.py +++ b/colossalai/checkpoint_io/utils.py @@ -21,6 +21,11 @@ ) from colossalai.utils.safetensors import move_and_save +try: + pass +except ModuleNotFoundError: + raise ModuleNotFoundError("Please install tensornvme to use Async Checkpoint save") + SAFE_WEIGHTS_NAME = "model.safetensors" WEIGHTS_NAME = "pytorch_model.bin" STATES_NAME = "pytorch_optim.bin" @@ -371,7 +376,11 @@ def shard_optimizer_checkpoint(state_dict: dict, max_shard_size: int = 1024) -> # ====================================== -def save_state_dict(state_dict: dict, checkpoint_file_path: str, use_safetensors: bool) -> None: +def save_state_dict( + state_dict: dict, + checkpoint_file_path: str, + use_safetensors: bool, +) -> None: """ Save state dict to checkpoint. From 8552c869adbde3979084545d2b91d21909ad586d Mon Sep 17 00:00:00 2001 From: wangbluo <2538539015@qq.com> Date: Thu, 14 Nov 2024 08:16:42 +0000 Subject: [PATCH 02/12] fix --- colossalai/booster/booster.py | 2 +- colossalai/booster/plugin/gemini_plugin.py | 4 ++-- .../booster/plugin/torch_fsdp_plugin.py | 22 ++++++++++++++----- 3 files changed, 19 insertions(+), 9 deletions(-) diff --git a/colossalai/booster/booster.py b/colossalai/booster/booster.py index bfd5300536eb..ad4047ee2fc5 100644 --- a/colossalai/booster/booster.py +++ b/colossalai/booster/booster.py @@ -310,7 +310,7 @@ def save_model( prefix: Optional[str] = None, size_per_shard: int = 1024, use_safetensors: bool = False, - use_async: Optional[bool] = False, + use_async: bool = False, ) -> None: """Save model to checkpoint. diff --git a/colossalai/booster/plugin/gemini_plugin.py b/colossalai/booster/plugin/gemini_plugin.py index 6e6f30c30b5a..5ac2e857e809 100644 --- a/colossalai/booster/plugin/gemini_plugin.py +++ b/colossalai/booster/plugin/gemini_plugin.py @@ -71,7 +71,7 @@ def save_unsharded_model( checkpoint: str, gather_dtensor: bool, use_safetensors: bool, - use_async: Optional[bool] = False, + use_async: bool = False, ): """ Save sharded model to checkpoint but only on master process. @@ -122,7 +122,7 @@ def save_sharded_model( prefix: Optional[str] = None, max_shard_size: int = 1024, use_safetensors: bool = False, - use_async: Optional[bool] = False, + use_async: bool = False, ): """ Save sharded model. diff --git a/colossalai/booster/plugin/torch_fsdp_plugin.py b/colossalai/booster/plugin/torch_fsdp_plugin.py index 04fa21b01b69..720a674beae7 100644 --- a/colossalai/booster/plugin/torch_fsdp_plugin.py +++ b/colossalai/booster/plugin/torch_fsdp_plugin.py @@ -54,16 +54,26 @@ def load_unsharded_optimizer(self, optimizer: OptimizerWrapper, checkpoint: Path sharded_osd = FSDP.scatter_full_optim_state_dict(checkpoint, fsdp_model) optimizer.load_state_dict(sharded_osd) - def save_unsharded_model(self, model: ModelWrapper, checkpoint: str, gather_dtensor: bool, use_safetensors: bool): + def save_unsharded_model( + self, + model: ModelWrapper, + checkpoint: str, + gather_dtensor: bool, + use_safetensors: bool, + use_async: bool = False, + ): """ Save model to checkpoint but only on master process. """ assert isinstance(model, TorchFSDPModel), "Please boost the model before saving!" model = model.unwrap() - cfg = FullStateDictConfig(offload_to_cpu=True, rank0_only=True) - with FSDP.state_dict_type(model, StateDictType.FULL_STATE_DICT, cfg): - full_model_state = model.state_dict() - utils.save_state_dict(full_model_state, checkpoint_file_path=checkpoint, use_safetensors=use_safetensors) + if use_async: + super().save_unsharded_model(model, checkpoint, gather_dtensor, use_safetensors, use_async) + else: + cfg = FullStateDictConfig(offload_to_cpu=True, rank0_only=True) + with FSDP.state_dict_type(model, StateDictType.FULL_STATE_DICT, cfg): + full_model_state = model.state_dict() + utils.save_state_dict(full_model_state, checkpoint_file_path=checkpoint, use_safetensors=use_safetensors) def save_unsharded_optimizer(self, optimizer: OptimizerWrapper, checkpoint: str, gather_dtensor: bool): """ @@ -82,7 +92,7 @@ def save_sharded_model( prefix: Optional[str] = None, size_per_shard: int = 1024, use_safetensors: bool = False, - use_async: Optional[bool] = False, + use_async: bool = False, ): """ Save model to checkpoint but only on master process. From 3701c48ef8e180460c8ccef4ed40d8457fbcfe54 Mon Sep 17 00:00:00 2001 From: wangbluo <2538539015@qq.com> Date: Fri, 15 Nov 2024 01:56:47 +0000 Subject: [PATCH 03/12] fix --- colossalai/booster/plugin/gemini_plugin.py | 2 +- .../checkpoint_io/general_checkpoint_io.py | 10 ++-- colossalai/checkpoint_io/utils.py | 14 ++---- colossalai/utils/safetensors.py | 4 +- .../test_gemini_checkpoint_io.py | 24 ++++++++-- .../test_general_checkpoint_io.py | 47 +++++++++++++++++++ .../test_low_level_zero_checkpoint_io.py | 18 +++++-- 7 files changed, 92 insertions(+), 27 deletions(-) diff --git a/colossalai/booster/plugin/gemini_plugin.py b/colossalai/booster/plugin/gemini_plugin.py index 5ac2e857e809..35c51da0105a 100644 --- a/colossalai/booster/plugin/gemini_plugin.py +++ b/colossalai/booster/plugin/gemini_plugin.py @@ -84,7 +84,7 @@ def save_unsharded_model( if use_async: super().save_unsharded_model(model, checkpoint, gather_dtensor, use_safetensors, use_async) else: - save_state_dict(state_dict, checkpoint, use_safetensors, use_async) + save_state_dict(state_dict, checkpoint, use_safetensors) def load_unsharded_model(self, model: GeminiDDP, checkpoint: str, strict: bool = True): """ diff --git a/colossalai/checkpoint_io/general_checkpoint_io.py b/colossalai/checkpoint_io/general_checkpoint_io.py index ef58e119da90..ac4893380a7b 100644 --- a/colossalai/checkpoint_io/general_checkpoint_io.py +++ b/colossalai/checkpoint_io/general_checkpoint_io.py @@ -14,7 +14,6 @@ from .index_file import CheckpointIndexFile from .utils import ( async_save_state_dict_shards, - create_pinned_state_dict, get_model_base_filenames, get_optimizer_base_filenames, is_safetensors_available, @@ -57,10 +56,13 @@ def save_unsharded_model( from tensornvme.async_file_io import AsyncFileWriter writer = AsyncFileWriter(open(checkpoint, "wb"), self.N_WRITE_ENTRIES, backend="pthread") - if id(model) not in self.pinned_state_dicts: - self.pinned_state_dicts[id(model)] = create_pinned_state_dict(state_dict) + # if id(model) not in self.pinned_state_dicts: + # self.pinned_state_dicts[id(model)] = create_pinned_state_dict(state_dict) + writer.sync_before_step() self.async_writers.append(writer) - move_and_save(writer, state_dict, self.pinned_state_dicts[id(model)]) + move_and_save(writer, state_dict, None) + writer.synchronize() + else: # save the checkpoint save_state_dict(state_dict, checkpoint, use_safetensors) diff --git a/colossalai/checkpoint_io/utils.py b/colossalai/checkpoint_io/utils.py index 0b65d70dd8cd..9bc205705079 100644 --- a/colossalai/checkpoint_io/utils.py +++ b/colossalai/checkpoint_io/utils.py @@ -21,11 +21,6 @@ ) from colossalai.utils.safetensors import move_and_save -try: - pass -except ModuleNotFoundError: - raise ModuleNotFoundError("Please install tensornvme to use Async Checkpoint save") - SAFE_WEIGHTS_NAME = "model.safetensors" WEIGHTS_NAME = "pytorch_model.bin" STATES_NAME = "pytorch_optim.bin" @@ -593,11 +588,8 @@ def load_shard_state_dict(checkpoint_file: Path, use_safetensors: bool = False): from safetensors.torch import safe_open with safe_open(checkpoint_file, framework="pt") as f: - metadata = f.metadata() - if metadata["format"] != "pt": - raise NotImplementedError( - f"Conversion from a {metadata['format']} safetensors archive to PyTorch is not implemented yet." - ) + f.metadata() + return safe_load_file(checkpoint_file) else: return torch.load(checkpoint_file, map_location=torch.device("cpu")) @@ -819,6 +811,7 @@ def load_state_dict(checkpoint_file_path: Path): from safetensors import safe_open state_dict = {} + print("checkpoint_file_path:", checkpoint_file_path) with safe_open(checkpoint_file_path, framework="pt", device="cpu") as f: for k in f.keys(): state_dict[k] = f.get_tensor(k) @@ -826,6 +819,7 @@ def load_state_dict(checkpoint_file_path: Path): else: # load with torch + print("checkpoint_file_path:", checkpoint_file_path) return torch.load(checkpoint_file_path, map_location=torch.device("cpu")) diff --git a/colossalai/utils/safetensors.py b/colossalai/utils/safetensors.py index bf8decd0faff..0359541147f0 100644 --- a/colossalai/utils/safetensors.py +++ b/colossalai/utils/safetensors.py @@ -28,14 +28,12 @@ class PreparedData: def prepare(data: Dict[str, torch.Tensor]) -> Tuple[PreparedData, List[torch.Tensor], List[str]]: - sorted_data = sorted(data.items(), key=lambda x: (x[1].dtype, x[0])) tensors = [] tensor_keys = [] metadata = {} offset = 0 - - for name, tensor in sorted_data: + for name, tensor in data.items(): n = tensor.numel() * tensor.element_size() tensor_info = TensorInfo( dtype=_TYPES_INV[tensor.dtype], shape=list(tensor.shape), data_offsets=(offset, offset + n) diff --git a/tests/test_checkpoint_io/test_gemini_checkpoint_io.py b/tests/test_checkpoint_io/test_gemini_checkpoint_io.py index b133be948c1e..c810244a9387 100644 --- a/tests/test_checkpoint_io/test_gemini_checkpoint_io.py +++ b/tests/test_checkpoint_io/test_gemini_checkpoint_io.py @@ -33,9 +33,12 @@ @parameterize("placement_config", MODEL_PLACEMENT_CONFIGS) @parameterize("model_name", ["transformers_bert_for_sequence_classification"]) @parameterize("use_safetensors", [False, True]) +@parameterize("use_async", [False, True]) @parameterize("tp_size", [1, 2]) @parameterize("zero_size", [2]) -def exam_state_dict_with_origin(placement_config, model_name, use_safetensors: bool, tp_size: int, zero_size: int): +def exam_state_dict_with_origin( + placement_config, model_name, use_safetensors: bool, use_async: bool, tp_size: int, zero_size: int +): from transformers import BertForSequenceClassification (model_fn, data_gen_fn, output_transform_fn, _, _) = next(iter(model_zoo.get_sub_registry(model_name).values())) @@ -63,11 +66,19 @@ def exam_state_dict_with_origin(placement_config, model_name, use_safetensors: b model_size = sum(p.numel() * p.element_size() for p in bert_model.parameters()) / 1024**2 booster.save_model( - bert_model, pretrained_path, True, True, "", (model_size / 3), use_safetensors=use_safetensors + bert_model, + pretrained_path, + True, + True, + "", + (model_size / 3), + use_safetensors=use_safetensors, + use_async=use_async, ) dist.barrier() - + print("pretrained_path: ", pretrained_path) new_bert_model = BertForSequenceClassification.from_pretrained(pretrained_path) + print("new_bert_model: ", new_bert_model) check_state_dict_equal(bert_model.state_dict(only_rank_0=False), new_bert_model.state_dict()) @@ -119,7 +130,12 @@ def exam_state_dict(placement_config, shard: bool, model_name: str, size_per_sha with shared_tempdir() as tempdir: model_ckpt_path = f"{tempdir}/model" optimizer_ckpt_path = f"{tempdir}/optimizer" - booster.save_model(model, model_ckpt_path, shard=shard, size_per_shard=size_per_shard) + booster.save_model( + model, + model_ckpt_path, + shard=shard, + size_per_shard=size_per_shard, + ) booster.save_optimizer(optimizer, optimizer_ckpt_path, shard=shard, size_per_shard=size_per_shard) dist.barrier() diff --git a/tests/test_checkpoint_io/test_general_checkpoint_io.py b/tests/test_checkpoint_io/test_general_checkpoint_io.py index 8431036df6b7..3d224e144b3d 100644 --- a/tests/test_checkpoint_io/test_general_checkpoint_io.py +++ b/tests/test_checkpoint_io/test_general_checkpoint_io.py @@ -106,6 +106,53 @@ def test_sharded_model_checkpoint(use_safetensors: bool): check_state_dict_equal(optimizer.state_dict(), new_optimizer.state_dict()) +@parameterize("use_async", [True, False]) +def test_unsharded_checkpoint(use_async: bool): + # create a model and optimizer + model = resnet18() + optimizer = Adam(model.parameters(), lr=0.001) + lr_scheduler = CosineAnnealingWarmupLR(optimizer, total_steps=10) + + # create test data sample + x = torch.randn(1, 3, 224, 224) + + # run fwd and bwd + y = model(x) + loss = y.sum() + loss.backward() + optimizer.step() + lr_scheduler.step() + + # create a temp file for checkpoint + if use_async: + suffix = ".safetensors" + else: + suffix = ".bin" + model_ckpt_tempfile = tempfile.NamedTemporaryFile(suffix=suffix) + optimizer_ckpt_tempfile = tempfile.NamedTemporaryFile() + lr_scheduler_ckpt_tempfile = tempfile.NamedTemporaryFile() + + # save the model, optimizer, lr_scheduler + ckpt_io = GeneralCheckpointIO() + ckpt_io.save_model(model, model_ckpt_tempfile.name, use_async=use_async) + ckpt_io.save_optimizer(optimizer, optimizer_ckpt_tempfile.name) + ckpt_io.save_lr_scheduler(lr_scheduler, lr_scheduler_ckpt_tempfile.name) + + # create new model + new_model = resnet18() + new_optimizer = Adam(new_model.parameters(), lr=0.001) + new_lr_scheduler = CosineAnnealingWarmupLR(optimizer, total_steps=10) + + # load the model, optimizer, lr_scheduler + ckpt_io.load_model(new_model, model_ckpt_tempfile.name) + ckpt_io.load_optimizer(new_optimizer, optimizer_ckpt_tempfile.name) + ckpt_io.load_lr_scheduler(new_lr_scheduler, lr_scheduler_ckpt_tempfile.name) + + # check for model and optimizer state dict recursively + check_state_dict_equal(model.state_dict(), new_model.state_dict()) + check_state_dict_equal(optimizer.state_dict(), new_optimizer.state_dict()) + + def test_sharded_optimizer_checkpoint(): # create a model and optimizer model = resnet18() diff --git a/tests/test_checkpoint_io/test_low_level_zero_checkpoint_io.py b/tests/test_checkpoint_io/test_low_level_zero_checkpoint_io.py index a8e05a25ad28..c8d0e4251ebd 100644 --- a/tests/test_checkpoint_io/test_low_level_zero_checkpoint_io.py +++ b/tests/test_checkpoint_io/test_low_level_zero_checkpoint_io.py @@ -28,7 +28,8 @@ @parameterize("stage", [2]) @parameterize("shard", [True, False]) @parameterize("offload", [False, True]) -def check_low_level_zero_checkpointIO(stage: int, shard: bool, offload: bool): +@parameterize("use_async", [False, True]) +def check_low_level_zero_checkpointIO(stage: int, shard: bool, offload: bool, use_async: bool): plugin = LowLevelZeroPlugin(stage=stage, max_norm=1.0, initial_scale=32, cpu_offload=offload) booster = Booster(plugin=plugin) model = resnet18() @@ -41,11 +42,18 @@ def check_low_level_zero_checkpointIO(stage: int, shard: bool, offload: bool): loss = criterion(output) booster.backward(loss, optimizer) optimizer.step() - with shared_tempdir() as tempdir: - model_ckpt_path = f"{tempdir}/model" - optimizer_ckpt_path = f"{tempdir}/optimizer" + output_dir = "./checkpoints" + import os + + os.makedirs(output_dir, exist_ok=True) + + with open(output_dir, "rb") as f: + model_ckpt_path = f"{f}/model" + optimizer_ckpt_path = f"{f}/optimizer" + if not shard: + model_ckpt_path = f"{model_ckpt_path}.safetensors" # lr scheduler is tested in test_torch_ddp_checkpoint_io.py and low level zero does not change it, we can skip it here - booster.save_model(model, model_ckpt_path, shard=shard) + booster.save_model(model, model_ckpt_path, shard=shard, use_async=use_async) booster.save_optimizer(optimizer, optimizer_ckpt_path, shard=shard) dist.barrier() From 9c534430fadaf7c55e534d6e7c3077be78769ad6 Mon Sep 17 00:00:00 2001 From: wangbluo <2538539015@qq.com> Date: Fri, 15 Nov 2024 02:41:24 +0000 Subject: [PATCH 04/12] fix --- .../test_low_level_zero_checkpoint_io.py | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/tests/test_checkpoint_io/test_low_level_zero_checkpoint_io.py b/tests/test_checkpoint_io/test_low_level_zero_checkpoint_io.py index c8d0e4251ebd..8388b315c162 100644 --- a/tests/test_checkpoint_io/test_low_level_zero_checkpoint_io.py +++ b/tests/test_checkpoint_io/test_low_level_zero_checkpoint_io.py @@ -46,16 +46,16 @@ def check_low_level_zero_checkpointIO(stage: int, shard: bool, offload: bool, us import os os.makedirs(output_dir, exist_ok=True) - - with open(output_dir, "rb") as f: - model_ckpt_path = f"{f}/model" - optimizer_ckpt_path = f"{f}/optimizer" - if not shard: - model_ckpt_path = f"{model_ckpt_path}.safetensors" + model_ckpt_path = f"{output_dir}/model" + optimizer_ckpt_path = f"{output_dir}/optimizer" + if not shard: + model_ckpt_path = f"{model_ckpt_path}.safetensors" + print("model_ckpt_path: ", model_ckpt_path) # lr scheduler is tested in test_torch_ddp_checkpoint_io.py and low level zero does not change it, we can skip it here booster.save_model(model, model_ckpt_path, shard=shard, use_async=use_async) booster.save_optimizer(optimizer, optimizer_ckpt_path, shard=shard) - + booster.checkpoint_io._sync_d2h() + booster.checkpoint_io._sync_io() dist.barrier() new_model = resnet18() From 290df416a3dc06a3bac740bcdb4b902ec72d46b3 Mon Sep 17 00:00:00 2001 From: wangbluo <2538539015@qq.com> Date: Fri, 15 Nov 2024 04:40:54 +0000 Subject: [PATCH 05/12] fix --- colossalai/checkpoint_io/utils.py | 2 -- .../test_low_level_zero_checkpoint_io.py | 20 +++++++++---------- 2 files changed, 10 insertions(+), 12 deletions(-) diff --git a/colossalai/checkpoint_io/utils.py b/colossalai/checkpoint_io/utils.py index 9bc205705079..610d05fe4910 100644 --- a/colossalai/checkpoint_io/utils.py +++ b/colossalai/checkpoint_io/utils.py @@ -811,7 +811,6 @@ def load_state_dict(checkpoint_file_path: Path): from safetensors import safe_open state_dict = {} - print("checkpoint_file_path:", checkpoint_file_path) with safe_open(checkpoint_file_path, framework="pt", device="cpu") as f: for k in f.keys(): state_dict[k] = f.get_tensor(k) @@ -819,7 +818,6 @@ def load_state_dict(checkpoint_file_path: Path): else: # load with torch - print("checkpoint_file_path:", checkpoint_file_path) return torch.load(checkpoint_file_path, map_location=torch.device("cpu")) diff --git a/tests/test_checkpoint_io/test_low_level_zero_checkpoint_io.py b/tests/test_checkpoint_io/test_low_level_zero_checkpoint_io.py index 8388b315c162..6ef6601a00cf 100644 --- a/tests/test_checkpoint_io/test_low_level_zero_checkpoint_io.py +++ b/tests/test_checkpoint_io/test_low_level_zero_checkpoint_io.py @@ -42,17 +42,16 @@ def check_low_level_zero_checkpointIO(stage: int, shard: bool, offload: bool, us loss = criterion(output) booster.backward(loss, optimizer) optimizer.step() - output_dir = "./checkpoints" - import os - - os.makedirs(output_dir, exist_ok=True) - model_ckpt_path = f"{output_dir}/model" - optimizer_ckpt_path = f"{output_dir}/optimizer" - if not shard: - model_ckpt_path = f"{model_ckpt_path}.safetensors" - print("model_ckpt_path: ", model_ckpt_path) - # lr scheduler is tested in test_torch_ddp_checkpoint_io.py and low level zero does not change it, we can skip it here + with shared_tempdir() as tempdir: + model_ckpt_path = f"{tempdir}/model" + optimizer_ckpt_path = f"{tempdir}/optimizer" + if use_async: + model_ckpt_path = f"{model_ckpt_path}.safetensors" + if not use_async: + model_ckpt_path = f"{model_ckpt_path}.pt" booster.save_model(model, model_ckpt_path, shard=shard, use_async=use_async) + + # lr scheduler is tested in test_torch_ddp_checkpoint_io.py and low level zero does not change it, we can skip it here booster.save_optimizer(optimizer, optimizer_ckpt_path, shard=shard) booster.checkpoint_io._sync_d2h() booster.checkpoint_io._sync_io() @@ -79,6 +78,7 @@ def check_low_level_zero_checkpointIO(stage: int, shard: bool, offload: bool, us booster.load_optimizer(new_optimizer, optimizer_ckpt_path) check_state_dict_equal(optimizer.optim.state_dict(), new_optimizer.optim.state_dict()) + torch.cuda.empty_cache() From 9475570ae2ab85842595645d3d3587d1c9b0f7f0 Mon Sep 17 00:00:00 2001 From: wangbluo <2538539015@qq.com> Date: Fri, 15 Nov 2024 04:45:39 +0000 Subject: [PATCH 06/12] fix --- .../test_gemini_checkpoint_io.py | 8 +--- .../test_general_checkpoint_io.py | 47 ------------------- 2 files changed, 1 insertion(+), 54 deletions(-) diff --git a/tests/test_checkpoint_io/test_gemini_checkpoint_io.py b/tests/test_checkpoint_io/test_gemini_checkpoint_io.py index c810244a9387..8bee8fe97290 100644 --- a/tests/test_checkpoint_io/test_gemini_checkpoint_io.py +++ b/tests/test_checkpoint_io/test_gemini_checkpoint_io.py @@ -33,12 +33,9 @@ @parameterize("placement_config", MODEL_PLACEMENT_CONFIGS) @parameterize("model_name", ["transformers_bert_for_sequence_classification"]) @parameterize("use_safetensors", [False, True]) -@parameterize("use_async", [False, True]) @parameterize("tp_size", [1, 2]) @parameterize("zero_size", [2]) -def exam_state_dict_with_origin( - placement_config, model_name, use_safetensors: bool, use_async: bool, tp_size: int, zero_size: int -): +def exam_state_dict_with_origin(placement_config, model_name, use_safetensors: bool, tp_size: int, zero_size: int): from transformers import BertForSequenceClassification (model_fn, data_gen_fn, output_transform_fn, _, _) = next(iter(model_zoo.get_sub_registry(model_name).values())) @@ -73,12 +70,9 @@ def exam_state_dict_with_origin( "", (model_size / 3), use_safetensors=use_safetensors, - use_async=use_async, ) dist.barrier() - print("pretrained_path: ", pretrained_path) new_bert_model = BertForSequenceClassification.from_pretrained(pretrained_path) - print("new_bert_model: ", new_bert_model) check_state_dict_equal(bert_model.state_dict(only_rank_0=False), new_bert_model.state_dict()) diff --git a/tests/test_checkpoint_io/test_general_checkpoint_io.py b/tests/test_checkpoint_io/test_general_checkpoint_io.py index 3d224e144b3d..8431036df6b7 100644 --- a/tests/test_checkpoint_io/test_general_checkpoint_io.py +++ b/tests/test_checkpoint_io/test_general_checkpoint_io.py @@ -106,53 +106,6 @@ def test_sharded_model_checkpoint(use_safetensors: bool): check_state_dict_equal(optimizer.state_dict(), new_optimizer.state_dict()) -@parameterize("use_async", [True, False]) -def test_unsharded_checkpoint(use_async: bool): - # create a model and optimizer - model = resnet18() - optimizer = Adam(model.parameters(), lr=0.001) - lr_scheduler = CosineAnnealingWarmupLR(optimizer, total_steps=10) - - # create test data sample - x = torch.randn(1, 3, 224, 224) - - # run fwd and bwd - y = model(x) - loss = y.sum() - loss.backward() - optimizer.step() - lr_scheduler.step() - - # create a temp file for checkpoint - if use_async: - suffix = ".safetensors" - else: - suffix = ".bin" - model_ckpt_tempfile = tempfile.NamedTemporaryFile(suffix=suffix) - optimizer_ckpt_tempfile = tempfile.NamedTemporaryFile() - lr_scheduler_ckpt_tempfile = tempfile.NamedTemporaryFile() - - # save the model, optimizer, lr_scheduler - ckpt_io = GeneralCheckpointIO() - ckpt_io.save_model(model, model_ckpt_tempfile.name, use_async=use_async) - ckpt_io.save_optimizer(optimizer, optimizer_ckpt_tempfile.name) - ckpt_io.save_lr_scheduler(lr_scheduler, lr_scheduler_ckpt_tempfile.name) - - # create new model - new_model = resnet18() - new_optimizer = Adam(new_model.parameters(), lr=0.001) - new_lr_scheduler = CosineAnnealingWarmupLR(optimizer, total_steps=10) - - # load the model, optimizer, lr_scheduler - ckpt_io.load_model(new_model, model_ckpt_tempfile.name) - ckpt_io.load_optimizer(new_optimizer, optimizer_ckpt_tempfile.name) - ckpt_io.load_lr_scheduler(new_lr_scheduler, lr_scheduler_ckpt_tempfile.name) - - # check for model and optimizer state dict recursively - check_state_dict_equal(model.state_dict(), new_model.state_dict()) - check_state_dict_equal(optimizer.state_dict(), new_optimizer.state_dict()) - - def test_sharded_optimizer_checkpoint(): # create a model and optimizer model = resnet18() From 70284f2dd8fdb62b82e94310e064ac12674a195c Mon Sep 17 00:00:00 2001 From: wangbluo <2538539015@qq.com> Date: Fri, 15 Nov 2024 04:58:37 +0000 Subject: [PATCH 07/12] fix --- colossalai/checkpoint_io/general_checkpoint_io.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/colossalai/checkpoint_io/general_checkpoint_io.py b/colossalai/checkpoint_io/general_checkpoint_io.py index ac4893380a7b..8c4b2264fe25 100644 --- a/colossalai/checkpoint_io/general_checkpoint_io.py +++ b/colossalai/checkpoint_io/general_checkpoint_io.py @@ -14,6 +14,7 @@ from .index_file import CheckpointIndexFile from .utils import ( async_save_state_dict_shards, + create_pinned_state_dict, get_model_base_filenames, get_optimizer_base_filenames, is_safetensors_available, @@ -56,11 +57,11 @@ def save_unsharded_model( from tensornvme.async_file_io import AsyncFileWriter writer = AsyncFileWriter(open(checkpoint, "wb"), self.N_WRITE_ENTRIES, backend="pthread") - # if id(model) not in self.pinned_state_dicts: - # self.pinned_state_dicts[id(model)] = create_pinned_state_dict(state_dict) + if id(model) not in self.pinned_state_dicts: + self.pinned_state_dicts[id(model)] = create_pinned_state_dict(state_dict) writer.sync_before_step() self.async_writers.append(writer) - move_and_save(writer, state_dict, None) + move_and_save(writer, state_dict, self.pinned_state_dicts[id(model)]) writer.synchronize() else: From 996b2346f7ab03468da95c0ba66dad1bdb38be25 Mon Sep 17 00:00:00 2001 From: wangbluo <2538539015@qq.com> Date: Fri, 15 Nov 2024 05:00:42 +0000 Subject: [PATCH 08/12] fix --- colossalai/checkpoint_io/general_checkpoint_io.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/colossalai/checkpoint_io/general_checkpoint_io.py b/colossalai/checkpoint_io/general_checkpoint_io.py index 8c4b2264fe25..580be91ca0d8 100644 --- a/colossalai/checkpoint_io/general_checkpoint_io.py +++ b/colossalai/checkpoint_io/general_checkpoint_io.py @@ -59,10 +59,8 @@ def save_unsharded_model( writer = AsyncFileWriter(open(checkpoint, "wb"), self.N_WRITE_ENTRIES, backend="pthread") if id(model) not in self.pinned_state_dicts: self.pinned_state_dicts[id(model)] = create_pinned_state_dict(state_dict) - writer.sync_before_step() self.async_writers.append(writer) move_and_save(writer, state_dict, self.pinned_state_dicts[id(model)]) - writer.synchronize() else: # save the checkpoint From b8721a19909fc93f8a68496bf378cb4f96a9b90e Mon Sep 17 00:00:00 2001 From: wangbluo <2538539015@qq.com> Date: Fri, 15 Nov 2024 07:08:55 +0000 Subject: [PATCH 09/12] fix --- .../booster/plugin/torch_fsdp_plugin.py | 61 +++++++------------ colossalai/checkpoint_io/utils.py | 4 -- .../test_low_level_zero_checkpoint_io.py | 16 +++-- 3 files changed, 35 insertions(+), 46 deletions(-) diff --git a/colossalai/booster/plugin/torch_fsdp_plugin.py b/colossalai/booster/plugin/torch_fsdp_plugin.py index 720a674beae7..964e96e41567 100644 --- a/colossalai/booster/plugin/torch_fsdp_plugin.py +++ b/colossalai/booster/plugin/torch_fsdp_plugin.py @@ -60,20 +60,16 @@ def save_unsharded_model( checkpoint: str, gather_dtensor: bool, use_safetensors: bool, - use_async: bool = False, ): """ Save model to checkpoint but only on master process. """ assert isinstance(model, TorchFSDPModel), "Please boost the model before saving!" model = model.unwrap() - if use_async: - super().save_unsharded_model(model, checkpoint, gather_dtensor, use_safetensors, use_async) - else: - cfg = FullStateDictConfig(offload_to_cpu=True, rank0_only=True) - with FSDP.state_dict_type(model, StateDictType.FULL_STATE_DICT, cfg): - full_model_state = model.state_dict() - utils.save_state_dict(full_model_state, checkpoint_file_path=checkpoint, use_safetensors=use_safetensors) + cfg = FullStateDictConfig(offload_to_cpu=True, rank0_only=True) + with FSDP.state_dict_type(model, StateDictType.FULL_STATE_DICT, cfg): + full_model_state = model.state_dict() + utils.save_state_dict(full_model_state, checkpoint_file_path=checkpoint, use_safetensors=use_safetensors) def save_unsharded_optimizer(self, optimizer: OptimizerWrapper, checkpoint: str, gather_dtensor: bool): """ @@ -92,7 +88,6 @@ def save_sharded_model( prefix: Optional[str] = None, size_per_shard: int = 1024, use_safetensors: bool = False, - use_async: bool = False, ): """ Save model to checkpoint but only on master process. @@ -113,37 +108,27 @@ def save_sharded_model( weights_name, save_index_file = utils.get_model_base_filenames(prefix, use_safetensors) index_file = CheckpointIndexFile(checkpoint_path) - if use_async: - super().save_sharded_model( - model=model, - checkpoint_path=checkpoint_path, - gather_dtensor=gather_dtensor, - prefix=prefix, - use_safetensors=use_safetensors, - use_async=use_async, - ) - else: - # In general cases, is_master is set to True to get the right behavior. - total_size = utils.save_state_dict_shards( - sharded_state_dict=state_dict_shard, - checkpoint=checkpoint_path, - index_file=index_file, - base_filename=weights_name, - is_master=self.coordinator.is_master(), - use_safetensors=use_safetensors, + # In general cases, is_master is set to True to get the right behavior. + total_size = utils.save_state_dict_shards( + sharded_state_dict=state_dict_shard, + checkpoint=checkpoint_path, + index_file=index_file, + base_filename=weights_name, + is_master=self.coordinator.is_master(), + use_safetensors=use_safetensors, + ) + + # only save the index file on the master rank + if self.coordinator.is_master(): + index_file.append_meta_data("total_size", total_size) + index_file.write_index_file(save_index_file) + utils.save_config_file(model.unwrap(), checkpoint_path) + self.logger.info( + f"The model is split into checkpoint shards. " + f"You can find where each parameters has been saved in the " + f"index located at {save_index_file}." ) - # only save the index file on the master rank - if self.coordinator.is_master(): - index_file.append_meta_data("total_size", total_size) - index_file.write_index_file(save_index_file) - utils.save_config_file(model.unwrap(), checkpoint_path) - self.logger.info( - f"The model is split into checkpoint shards. " - f"You can find where each parameters has been saved in the " - f"index located at {save_index_file}." - ) - def load_sharded_model( self, model: nn.Module, diff --git a/colossalai/checkpoint_io/utils.py b/colossalai/checkpoint_io/utils.py index 610d05fe4910..8487064f5fee 100644 --- a/colossalai/checkpoint_io/utils.py +++ b/colossalai/checkpoint_io/utils.py @@ -585,10 +585,6 @@ def load_shard_state_dict(checkpoint_file: Path, use_safetensors: bool = False): raise Exception("load the model using `safetensors`, but no file endwith .safetensors") if use_safetensors: from safetensors.torch import load_file as safe_load_file - from safetensors.torch import safe_open - - with safe_open(checkpoint_file, framework="pt") as f: - f.metadata() return safe_load_file(checkpoint_file) else: diff --git a/tests/test_checkpoint_io/test_low_level_zero_checkpoint_io.py b/tests/test_checkpoint_io/test_low_level_zero_checkpoint_io.py index 6ef6601a00cf..fe521453d16c 100644 --- a/tests/test_checkpoint_io/test_low_level_zero_checkpoint_io.py +++ b/tests/test_checkpoint_io/test_low_level_zero_checkpoint_io.py @@ -26,7 +26,7 @@ # only test 2 is fine @clear_cache_before_run() @parameterize("stage", [2]) -@parameterize("shard", [True, False]) +@parameterize("shard", [True]) @parameterize("offload", [False, True]) @parameterize("use_async", [False, True]) def check_low_level_zero_checkpointIO(stage: int, shard: bool, offload: bool, use_async: bool): @@ -42,14 +42,22 @@ def check_low_level_zero_checkpointIO(stage: int, shard: bool, offload: bool, us loss = criterion(output) booster.backward(loss, optimizer) optimizer.step() + with shared_tempdir() as tempdir: + model_ckpt_path = f"{tempdir}/model" optimizer_ckpt_path = f"{tempdir}/optimizer" - if use_async: + if not shard and not use_async: model_ckpt_path = f"{model_ckpt_path}.safetensors" - if not use_async: + if not shard and use_async: model_ckpt_path = f"{model_ckpt_path}.pt" - booster.save_model(model, model_ckpt_path, shard=shard, use_async=use_async) + + booster.save_model( + model, + model_ckpt_path, + shard=shard, + use_async=use_async, + ) # lr scheduler is tested in test_torch_ddp_checkpoint_io.py and low level zero does not change it, we can skip it here booster.save_optimizer(optimizer, optimizer_ckpt_path, shard=shard) From 680e6017a29ab79493c3cff1c71c98a67528dc80 Mon Sep 17 00:00:00 2001 From: wangbluo <2538539015@qq.com> Date: Fri, 15 Nov 2024 08:17:16 +0000 Subject: [PATCH 10/12] fix --- colossalai/booster/plugin/torch_fsdp_plugin.py | 6 +----- 1 file changed, 1 insertion(+), 5 deletions(-) diff --git a/colossalai/booster/plugin/torch_fsdp_plugin.py b/colossalai/booster/plugin/torch_fsdp_plugin.py index 964e96e41567..3de3d879c83f 100644 --- a/colossalai/booster/plugin/torch_fsdp_plugin.py +++ b/colossalai/booster/plugin/torch_fsdp_plugin.py @@ -55,11 +55,7 @@ def load_unsharded_optimizer(self, optimizer: OptimizerWrapper, checkpoint: Path optimizer.load_state_dict(sharded_osd) def save_unsharded_model( - self, - model: ModelWrapper, - checkpoint: str, - gather_dtensor: bool, - use_safetensors: bool, + self, model: ModelWrapper, checkpoint: str, gather_dtensor: bool, use_safetensors: bool, use_async: bool ): """ Save model to checkpoint but only on master process. From 889282b10b8acb7a5cdbd9276c0636143f9f0e90 Mon Sep 17 00:00:00 2001 From: wangbluo <2538539015@qq.com> Date: Fri, 15 Nov 2024 08:39:47 +0000 Subject: [PATCH 11/12] fix --- .../test_low_level_zero_checkpoint_io.py | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/tests/test_checkpoint_io/test_low_level_zero_checkpoint_io.py b/tests/test_checkpoint_io/test_low_level_zero_checkpoint_io.py index fe521453d16c..5e3cc2bdc6b3 100644 --- a/tests/test_checkpoint_io/test_low_level_zero_checkpoint_io.py +++ b/tests/test_checkpoint_io/test_low_level_zero_checkpoint_io.py @@ -26,7 +26,7 @@ # only test 2 is fine @clear_cache_before_run() @parameterize("stage", [2]) -@parameterize("shard", [True]) +@parameterize("shard", [False, True]) @parameterize("offload", [False, True]) @parameterize("use_async", [False, True]) def check_low_level_zero_checkpointIO(stage: int, shard: bool, offload: bool, use_async: bool): @@ -48,10 +48,9 @@ def check_low_level_zero_checkpointIO(stage: int, shard: bool, offload: bool, us model_ckpt_path = f"{tempdir}/model" optimizer_ckpt_path = f"{tempdir}/optimizer" if not shard and not use_async: - model_ckpt_path = f"{model_ckpt_path}.safetensors" - if not shard and use_async: model_ckpt_path = f"{model_ckpt_path}.pt" - + if not shard and use_async: + model_ckpt_path = f"{model_ckpt_path}.safetensors" booster.save_model( model, model_ckpt_path, From 97ec19831a50420a30f5aa45d2df622259506065 Mon Sep 17 00:00:00 2001 From: wangbluo <2538539015@qq.com> Date: Fri, 15 Nov 2024 09:09:31 +0000 Subject: [PATCH 12/12] fix --- colossalai/booster/plugin/torch_fsdp_plugin.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/colossalai/booster/plugin/torch_fsdp_plugin.py b/colossalai/booster/plugin/torch_fsdp_plugin.py index 3de3d879c83f..d309370dd620 100644 --- a/colossalai/booster/plugin/torch_fsdp_plugin.py +++ b/colossalai/booster/plugin/torch_fsdp_plugin.py @@ -55,7 +55,7 @@ def load_unsharded_optimizer(self, optimizer: OptimizerWrapper, checkpoint: Path optimizer.load_state_dict(sharded_osd) def save_unsharded_model( - self, model: ModelWrapper, checkpoint: str, gather_dtensor: bool, use_safetensors: bool, use_async: bool + self, model: ModelWrapper, checkpoint: str, gather_dtensor: bool, use_safetensors: bool, use_async: bool = False ): """ Save model to checkpoint but only on master process. @@ -84,6 +84,7 @@ def save_sharded_model( prefix: Optional[str] = None, size_per_shard: int = 1024, use_safetensors: bool = False, + use_async: bool = False, ): """ Save model to checkpoint but only on master process.