From 0d914c936273596f40a01160ebfba4791b869f7d Mon Sep 17 00:00:00 2001 From: luchen Date: Thu, 6 Apr 2023 18:04:55 +0800 Subject: [PATCH 1/6] [checkpoint] support huggingface style sharded checkpoint, to be compatible with hf file naming format --- colossalai/checkpoint_io/general_checkpoint_io.py | 1 + colossalai/checkpoint_io/utils.py | 4 ++-- 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/colossalai/checkpoint_io/general_checkpoint_io.py b/colossalai/checkpoint_io/general_checkpoint_io.py index 2a76f1718469..1b078d315a64 100644 --- a/colossalai/checkpoint_io/general_checkpoint_io.py +++ b/colossalai/checkpoint_io/general_checkpoint_io.py @@ -20,6 +20,7 @@ ) from .utils import SAFE_WEIGHTS_NAME, WEIGHTS_NAME, SAFE_WEIGHTS_INDEX_NAME, WEIGHTS_INDEX_NAME + __all__ = ['GeneralCheckpointIO'] diff --git a/colossalai/checkpoint_io/utils.py b/colossalai/checkpoint_io/utils.py index 81b666da5c78..a1e5414b637e 100644 --- a/colossalai/checkpoint_io/utils.py +++ b/colossalai/checkpoint_io/utils.py @@ -6,9 +6,9 @@ from colossalai.tensor.d_tensor.d_tensor import DTensor SAFE_WEIGHTS_NAME = "model.safetensors" -WEIGHTS_NAME = "model.bin" +WEIGHTS_NAME = "pytorch_model.bin" SAFE_WEIGHTS_INDEX_NAME = "model.safetensors.index.json" -WEIGHTS_INDEX_NAME = "model.bin.index.json" +WEIGHTS_INDEX_NAME = "pytorch_model.bin.index.json" # ====================================== # General helper functions From 80f4ca2b4c45abf1edbdf824622b1b3c6e77c643 Mon Sep 17 00:00:00 2001 From: luchen Date: Thu, 6 Apr 2023 18:06:17 +0800 Subject: [PATCH 2/6] [checkpoint] support huggingface style sharded checkpoint, to be compatible with hf file naming format --- colossalai/checkpoint_io/utils.py | 1 - 1 file changed, 1 deletion(-) diff --git a/colossalai/checkpoint_io/utils.py b/colossalai/checkpoint_io/utils.py index a1e5414b637e..d4adbebd34a7 100644 --- a/colossalai/checkpoint_io/utils.py +++ b/colossalai/checkpoint_io/utils.py @@ -27,7 +27,6 @@ def calculate_tensor_size(tensor: torch.Tensor) -> float: """ return tensor.numel() * tensor.element_size() / 1024 / 1024 - def is_safetensors_available() -> bool: """ Check whether safetensors is available. From 485c91a45ad9673c9ac179d5989354e54e7ddd7c Mon Sep 17 00:00:00 2001 From: luchen Date: Wed, 12 Apr 2023 13:45:02 +0800 Subject: [PATCH 3/6] [checkpoint] Shard saved checkpoint add 'variant' field to customize filename --- colossalai/checkpoint_io/checkpoint_io_base.py | 4 ++-- .../checkpoint_io/general_checkpoint_io.py | 12 ++++++++---- colossalai/checkpoint_io/utils.py | 16 ++++++++++++++-- 3 files changed, 24 insertions(+), 8 deletions(-) diff --git a/colossalai/checkpoint_io/checkpoint_io_base.py b/colossalai/checkpoint_io/checkpoint_io_base.py index b91b00831e52..9bcdb7c79ea2 100644 --- a/colossalai/checkpoint_io/checkpoint_io_base.py +++ b/colossalai/checkpoint_io/checkpoint_io_base.py @@ -104,7 +104,7 @@ def save_model(self, checkpoint: str, shard: bool = False, gather_dtensor: bool = True, - prefix: str = None, + variant: str = None, size_per_shard: int = 1024, use_safetensors: bool = False): """ @@ -138,7 +138,7 @@ def save_model(self, model = model.unwrap() if shard: - self.save_sharded_model(model, checkpoint, gather_dtensor, prefix, size_per_shard, use_safetensors) + self.save_sharded_model(model, checkpoint, gather_dtensor, variant, size_per_shard, use_safetensors) else: self.save_unsharded_model(model, checkpoint, gather_dtensor, use_safetensors) diff --git a/colossalai/checkpoint_io/general_checkpoint_io.py b/colossalai/checkpoint_io/general_checkpoint_io.py index 1b078d315a64..8265d8039baf 100644 --- a/colossalai/checkpoint_io/general_checkpoint_io.py +++ b/colossalai/checkpoint_io/general_checkpoint_io.py @@ -6,6 +6,7 @@ import os import json import gc +from typing import Optional from .checkpoint_io_base import CheckpointIO from .index_file import CheckpointIndexFile @@ -16,7 +17,8 @@ is_safetensors_available, shard_checkpoint, load_shard_state_dict, - load_state_dict_into_model + load_state_dict_into_model, + add_variant ) from .utils import SAFE_WEIGHTS_NAME, WEIGHTS_NAME, SAFE_WEIGHTS_INDEX_NAME, WEIGHTS_INDEX_NAME @@ -69,8 +71,8 @@ def save_unsharded_optimizer( save_state_dict(optimizer.state_dict(), checkpoint, use_safetensors=False) - def save_sharded_model(self, model: nn.Module, checkpoint_path: str, gather_dtensor:bool = False, - prefix: str = "", max_shard_size: int = 1024, use_safetensors: bool = False): + def save_sharded_model(self, model: nn.Module, checkpoint_path: str, gather_dtensor:bool = False, + variant: Optional[str] = None, max_shard_size: int = 1024, use_safetensors: bool = False): """ implement this method as it can be supported by Huggingface model, save shard model, save model to multiple files @@ -84,6 +86,7 @@ def save_sharded_model(self, model: nn.Module, checkpoint_path: str, gather_dten # shard checkpoint state_dict = model.state_dict() weights_name = SAFE_WEIGHTS_NAME if use_safetensors else WEIGHTS_NAME + weights_name = add_variant(weights_name, variant) shards, index = shard_checkpoint(state_dict, max_shard_size=max_shard_size, weights_name=weights_name) # Save the model @@ -93,7 +96,8 @@ def save_sharded_model(self, model: nn.Module, checkpoint_path: str, gather_dten # save index file save_index_file = SAFE_WEIGHTS_INDEX_NAME if use_safetensors else WEIGHTS_INDEX_NAME - save_index_file = os.path.join(checkpoint_path, save_index_file) + + save_index_file = os.path.join(checkpoint_path, add_variant(save_index_file, variant)) with open(save_index_file, "w", encoding="utf-8") as f: content = json.dumps(index, indent=2, sort_keys=True) + "\n" f.write(content) diff --git a/colossalai/checkpoint_io/utils.py b/colossalai/checkpoint_io/utils.py index d4adbebd34a7..37d22d08df40 100644 --- a/colossalai/checkpoint_io/utils.py +++ b/colossalai/checkpoint_io/utils.py @@ -4,6 +4,7 @@ import torch.nn as nn from typing import List, Dict, Mapping, OrderedDict, Optional, Tuple from colossalai.tensor.d_tensor.d_tensor import DTensor +import re SAFE_WEIGHTS_NAME = "model.safetensors" WEIGHTS_NAME = "pytorch_model.bin" @@ -357,13 +358,14 @@ def has_index_file(checkpoint_path: str) -> Tuple[bool, Optional[Path]]: checkpoint_path = Path(checkpoint_path) if checkpoint_path.is_file(): # check if it is .index.json - if checkpoint_path.name.endswith('.index.json'): + reg = re.compile("(.*?).index((\..*)?).json") + if reg.fullmatch(checkpoint_path.name) is not None: return True, checkpoint_path else: return False, None elif checkpoint_path.is_dir(): # check if there is only one a file ending with .index.json in this directory - index_files = list(checkpoint_path.glob('*.index.json')) + index_files = list(checkpoint_path.glob('*.index.*json')) # if we found a .index.json file, make sure there is only one if len(index_files) > 0: @@ -405,3 +407,13 @@ def load_state_dict(checkpoint_file_path: Path): else: # load with torch return torch.load(checkpoint_file_path) + + + +def add_variant(weights_name: str, variant: Optional[str] = None) -> str: + if variant is not None and len(variant) > 0: + splits = weights_name.split(".") + splits = splits[:-1] + [variant] + splits[-1:] + weights_name = ".".join(splits) + + return weights_name From 8364b4331112ed91179fa5eac1c103e9d1924d90 Mon Sep 17 00:00:00 2001 From: luchen Date: Wed, 12 Apr 2023 13:53:12 +0800 Subject: [PATCH 4/6] [checkpoint] Shard saved checkpoint add 'variant' field to customize filename --- .../checkpoint_io/checkpoint_io_base.py | 8 ++++--- .../checkpoint_io/general_checkpoint_io.py | 22 +++++++++++++++---- colossalai/checkpoint_io/index_file.py | 3 ++- 3 files changed, 25 insertions(+), 8 deletions(-) diff --git a/colossalai/checkpoint_io/checkpoint_io_base.py b/colossalai/checkpoint_io/checkpoint_io_base.py index 9bcdb7c79ea2..d1b5eba5324c 100644 --- a/colossalai/checkpoint_io/checkpoint_io_base.py +++ b/colossalai/checkpoint_io/checkpoint_io_base.py @@ -1,6 +1,7 @@ from abc import ABC, abstractmethod from pathlib import Path from typing import Union +from typing import Optional import torch import torch.nn as nn @@ -9,7 +10,8 @@ from colossalai.interface import ModelWrapper -from .utils import has_index_file +# from .utils import has_index_file +from utils import has_index_file __all__ = ['CheckpointIO'] @@ -129,7 +131,7 @@ def save_model(self, multiple files. The model shards will be specificed by a `model.index.json` file. When shard = True, please ensure that the checkpoint path is a directory path instead of a file path. gather_dtensor (bool): whether to gather the distributed tensor to the first device. Default: True. - prefix (str): prefix for the model checkpoint file name when shard=True. Default: None. + variant (str): If specified, weights are saved in the format pytorch_model..bin. Default: None. size_per_shard (int): size per shard in MB. Default: 1024. This value is only used when shard = True. use_safetensors (bool): whether to use safe tensors. Default: False. If set to True, the checkpoint will be saved """ @@ -219,7 +221,7 @@ def load_unsharded_model(self, model: nn.Module, checkpoint: str, strict: bool): pass @abstractmethod - def save_sharded_model(self, model: nn.Module, checkpoint: str, gather_dtensor: bool, prefix: str, + def save_sharded_model(self, model: nn.Module, checkpoint: str, gather_dtensor: bool, variant: Optional[str], size_per_shard: int, use_safetensors: bool): """ Save model to sharded checkpoint. diff --git a/colossalai/checkpoint_io/general_checkpoint_io.py b/colossalai/checkpoint_io/general_checkpoint_io.py index 8265d8039baf..1b44c8e223db 100644 --- a/colossalai/checkpoint_io/general_checkpoint_io.py +++ b/colossalai/checkpoint_io/general_checkpoint_io.py @@ -8,9 +8,23 @@ import gc from typing import Optional -from .checkpoint_io_base import CheckpointIO -from .index_file import CheckpointIndexFile -from .utils import ( +# from .checkpoint_io_base import CheckpointIO +# from .index_file import CheckpointIndexFile +# from .utils import ( +# has_index_file, +# load_state_dict, +# save_state_dict, +# is_safetensors_available, +# shard_checkpoint, +# load_shard_state_dict, +# load_state_dict_into_model, +# add_variant +# ) +# from .utils import SAFE_WEIGHTS_NAME, WEIGHTS_NAME, SAFE_WEIGHTS_INDEX_NAME, WEIGHTS_INDEX_NAME + +from checkpoint_io_base import CheckpointIO +from index_file import CheckpointIndexFile +from utils import ( has_index_file, load_state_dict, save_state_dict, @@ -20,7 +34,7 @@ load_state_dict_into_model, add_variant ) -from .utils import SAFE_WEIGHTS_NAME, WEIGHTS_NAME, SAFE_WEIGHTS_INDEX_NAME, WEIGHTS_INDEX_NAME +from utils import SAFE_WEIGHTS_NAME, WEIGHTS_NAME, SAFE_WEIGHTS_INDEX_NAME, WEIGHTS_INDEX_NAME __all__ = ['GeneralCheckpointIO'] diff --git a/colossalai/checkpoint_io/index_file.py b/colossalai/checkpoint_io/index_file.py index 89224787a91b..45241b4de7ec 100644 --- a/colossalai/checkpoint_io/index_file.py +++ b/colossalai/checkpoint_io/index_file.py @@ -2,7 +2,8 @@ from pathlib import Path from typing import Any, List, Union -from .utils import is_dtensor_checkpoint +# from .utils import is_dtensor_checkpoint +from utils import is_dtensor_checkpoint __all__ = ['CheckpointIndexFile'] From 243a25dae9862491919eab3be5553400c813ba8b Mon Sep 17 00:00:00 2001 From: luchen Date: Wed, 12 Apr 2023 13:55:52 +0800 Subject: [PATCH 5/6] [checkpoint] Shard saved checkpoint add 'variant' field to customize filename --- colossalai/checkpoint_io/general_checkpoint_io.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/colossalai/checkpoint_io/general_checkpoint_io.py b/colossalai/checkpoint_io/general_checkpoint_io.py index 1b44c8e223db..30b020834414 100644 --- a/colossalai/checkpoint_io/general_checkpoint_io.py +++ b/colossalai/checkpoint_io/general_checkpoint_io.py @@ -85,7 +85,7 @@ def save_unsharded_optimizer( save_state_dict(optimizer.state_dict(), checkpoint, use_safetensors=False) - def save_sharded_model(self, model: nn.Module, checkpoint_path: str, gather_dtensor:bool = False, + def save_sharded_model(self, model: nn.Module, checkpoint_path: str, gather_dtensor:bool = False, variant: Optional[str] = None, max_shard_size: int = 1024, use_safetensors: bool = False): """ implement this method as it can be supported by Huggingface model, From 1c5a08f6862e61f002b2ab25f610f7915ec6c0b7 Mon Sep 17 00:00:00 2001 From: luchen Date: Wed, 12 Apr 2023 13:57:56 +0800 Subject: [PATCH 6/6] [checkpoint] Shard saved checkpoint add 'variant' field to customize filename --- .../checkpoint_io/checkpoint_io_base.py | 3 +-- .../checkpoint_io/general_checkpoint_io.py | 22 ++++--------------- colossalai/checkpoint_io/index_file.py | 3 +-- 3 files changed, 6 insertions(+), 22 deletions(-) diff --git a/colossalai/checkpoint_io/checkpoint_io_base.py b/colossalai/checkpoint_io/checkpoint_io_base.py index d1b5eba5324c..3f8b0b0a6b47 100644 --- a/colossalai/checkpoint_io/checkpoint_io_base.py +++ b/colossalai/checkpoint_io/checkpoint_io_base.py @@ -10,8 +10,7 @@ from colossalai.interface import ModelWrapper -# from .utils import has_index_file -from utils import has_index_file +from .utils import has_index_file __all__ = ['CheckpointIO'] diff --git a/colossalai/checkpoint_io/general_checkpoint_io.py b/colossalai/checkpoint_io/general_checkpoint_io.py index 30b020834414..bf584f45d045 100644 --- a/colossalai/checkpoint_io/general_checkpoint_io.py +++ b/colossalai/checkpoint_io/general_checkpoint_io.py @@ -8,23 +8,9 @@ import gc from typing import Optional -# from .checkpoint_io_base import CheckpointIO -# from .index_file import CheckpointIndexFile -# from .utils import ( -# has_index_file, -# load_state_dict, -# save_state_dict, -# is_safetensors_available, -# shard_checkpoint, -# load_shard_state_dict, -# load_state_dict_into_model, -# add_variant -# ) -# from .utils import SAFE_WEIGHTS_NAME, WEIGHTS_NAME, SAFE_WEIGHTS_INDEX_NAME, WEIGHTS_INDEX_NAME - -from checkpoint_io_base import CheckpointIO -from index_file import CheckpointIndexFile -from utils import ( +from .checkpoint_io_base import CheckpointIO +from .index_file import CheckpointIndexFile +from .utils import ( has_index_file, load_state_dict, save_state_dict, @@ -34,7 +20,7 @@ load_state_dict_into_model, add_variant ) -from utils import SAFE_WEIGHTS_NAME, WEIGHTS_NAME, SAFE_WEIGHTS_INDEX_NAME, WEIGHTS_INDEX_NAME +from .utils import SAFE_WEIGHTS_NAME, WEIGHTS_NAME, SAFE_WEIGHTS_INDEX_NAME, WEIGHTS_INDEX_NAME __all__ = ['GeneralCheckpointIO'] diff --git a/colossalai/checkpoint_io/index_file.py b/colossalai/checkpoint_io/index_file.py index 45241b4de7ec..89224787a91b 100644 --- a/colossalai/checkpoint_io/index_file.py +++ b/colossalai/checkpoint_io/index_file.py @@ -2,8 +2,7 @@ from pathlib import Path from typing import Any, List, Union -# from .utils import is_dtensor_checkpoint -from utils import is_dtensor_checkpoint +from .utils import is_dtensor_checkpoint __all__ = ['CheckpointIndexFile']