From 92aaa2058329b4e61285da92c08a711dd249627a Mon Sep 17 00:00:00 2001 From: luchen Date: Thu, 6 Apr 2023 12:58:20 +0800 Subject: [PATCH 1/5] [checkpoint] support huggingface style sharded checkpoint --- .../checkpoint_io/general_checkpoint_io.py | 107 +++++++--- colossalai/checkpoint_io/index_file.py | 15 ++ colossalai/checkpoint_io/utils.py | 193 +++++++++++++++++- .../test_general_checkpoint_io.py | 90 ++++++-- 4 files changed, 357 insertions(+), 48 deletions(-) diff --git a/colossalai/checkpoint_io/general_checkpoint_io.py b/colossalai/checkpoint_io/general_checkpoint_io.py index c779f4c17355..3d0e8e31b33c 100644 --- a/colossalai/checkpoint_io/general_checkpoint_io.py +++ b/colossalai/checkpoint_io/general_checkpoint_io.py @@ -2,37 +2,29 @@ import torch.nn as nn from torch.optim import Optimizer +import logging +import os +import json +import gc from .checkpoint_io_base import CheckpointIO from .index_file import CheckpointIndexFile -from .utils import has_index_file, load_state_dict, save_state_dict +from .utils import ( + has_index_file, + load_state_dict, + save_state_dict, + is_safetensors_available, + shard_checkpoint, + load_shard_state_dict, + load_state_dict_into_model + ) +from .utils import SAFE_WEIGHTS_NAME, WEIGHTS_NAME, SAFE_WEIGHTS_INDEX_NAME, WEIGHTS_INDEX_NAME __all__ = ['GeneralCheckpointIO'] class GeneralCheckpointIO(CheckpointIO): - - def load_sharded_model(self, model: nn.Module, index_file_path: Path, strict: bool): - # load the index file - index_file = CheckpointIndexFile.from_file(index_file_path) - - # iterate over the shard checkpoint files - # and load each - index_file.assert_no_dtensor_checkpoint() - checkpoint_file_list, _ = index_file.get_checkpoint_fileanames() - for shard_file in checkpoint_file_list: - shard_checkpoint = load_state_dict(shard_file) - model.load_state_dict(shard_checkpoint, strict=strict) - - def load_unsharded_model(self, model: nn.Module, checkpoint: str, strict: bool): - checkpoint = load_state_dict(checkpoint) - model.load_state_dict(checkpoint, strict=strict) - - def save_sharded_model(self, model: nn.Module, checkpoint: Path, gather_dtensor: bool, prefix: str, - size_per_shard: int, use_safetensors: bool): - # TODO(FrankLeeeee): implement this method as it can be supported by Huggingface model - raise NotImplementedError("Sharded model checkpoint is not supported yet.") - + """Checkpoint IO""" def save_unsharded_model(self, model: nn.Module, checkpoint: str, gather_dtensor: bool, use_safetensors: bool): state_dict = model.state_dict() @@ -68,3 +60,72 @@ def save_unsharded_optimizer( ): # TODO(FrankLeeeee): handle distributed tensors save_state_dict(optimizer.state_dict(), checkpoint, use_safetensors=False) + + + def save_sharded_model(self, model: nn.Module, checkpoint_path: str, gather_dtensor:bool = False, \ + prefix: str = "", max_shard_size: int = 1024, use_safetensors: bool = False): + """ + implement this method as it can be supported by Huggingface model, + save shard model, save model to multiple files + """ + if os.path.isfile(checkpoint_path): + logging.error(f"Provided path ({checkpoint_path}) should be a directory, not a file") + return + + Path(checkpoint_path).mkdir(parents=True, exist_ok=True) + + # shard checkpoint + state_dict = model.state_dict() + weights_name = SAFE_WEIGHTS_NAME if use_safetensors else WEIGHTS_NAME + shards, index = shard_checkpoint(state_dict, max_shard_size=max_shard_size, weights_name=weights_name) + + # Save the model + for shard_file, shard in shards.items(): + checkpoint_file_path = os.path.join(checkpoint_path, shard_file) + save_state_dict(shard, checkpoint_file_path, use_safetensors) + + # save index file + save_index_file = SAFE_WEIGHTS_INDEX_NAME if use_safetensors else WEIGHTS_INDEX_NAME + save_index_file = os.path.join(checkpoint_path, save_index_file) + with open(save_index_file, "w", encoding="utf-8") as f: + content = json.dumps(index, indent=2, sort_keys=True) + "\n" + f.write(content) + logging.info( + f"The model is bigger than the maximum size per checkpoint ({max_shard_size}) and is going to be " + f"split in {len(shards)} checkpoint shards. You can find where each parameters has been saved in the " + f"index located at {save_index_file}." + ) + + + def load_sharded_model(self, model: nn.Module, checkpoint_index_file: Path, strict: bool = False, use_safetensors: bool = False): + """ + load shard model, load model from multiple files + """ + use_safetensors = False + if "safetensors" in checkpoint_index_file.name: + use_safetensors = True + checkpoint_path = checkpoint_index_file.parent + + if use_safetensors and not is_safetensors_available(): + raise ImportError("`safe_serialization` requires the `safetensors library: `pip install safetensors`.") + + # read checkpoint index file + ckpt_index_file = CheckpointIndexFile() + ckpt_index_file.load(checkpoint_index_file) + checkpoint_files, _ = ckpt_index_file.get_checkpoint_fileanames() + sharded_metadata = ckpt_index_file.get_checkpoint_metadata() + # checkpoint_files, sharded_metadata = get_checkpoint_shard_files(checkpoint_path, checkpoint_index_file) + missing_keys = sharded_metadata["all_checkpoint_keys"] + + for shard_file in checkpoint_files: + state_dict = load_shard_state_dict(Path(shard_file), use_safetensors) + load_state_dict_into_model(model, state_dict, missing_keys, strict) + del state_dict + gc.collect() + + if strict and len(missing_keys) > 0: + error_msgs = 'Missing key(s) in state_dict: {}. '.format( + ', '.join('"{}"'.format(k) for k in missing_keys)) + raise RuntimeError('Error(s) in loading state_dict for {}:\n\t{}'.format( + self.__class__.__name__, "\n\t".join(error_msgs))) + diff --git a/colossalai/checkpoint_io/index_file.py b/colossalai/checkpoint_io/index_file.py index 32ff1b762e88..4cabae5b3b23 100644 --- a/colossalai/checkpoint_io/index_file.py +++ b/colossalai/checkpoint_io/index_file.py @@ -148,3 +148,18 @@ def get_checkpoint_file(self, param_name: str) -> str: """ ckpt_path = self.weight_map[param_name] return ckpt_path + + + def get_checkpoint_metadata(self) -> str: + """ + Get the checkpoint index metadata. + + Args: + + Returns: + str: checkpoint file name. + """ + ckpt_metadata = self.metadata + ckpt_metadata["all_checkpoint_keys"] = list(self.weight_map.keys()) + ckpt_metadata["weight_map"] = self.weight_map.copy() + return ckpt_metadata diff --git a/colossalai/checkpoint_io/utils.py b/colossalai/checkpoint_io/utils.py index 76c9db0afaff..24fdcf38f642 100644 --- a/colossalai/checkpoint_io/utils.py +++ b/colossalai/checkpoint_io/utils.py @@ -1,12 +1,59 @@ +# coding=utf-8 +import importlib.util +import sys +from packaging import version from pathlib import Path -from typing import List, Optional, Tuple - import torch +import torch.nn as nn +import json +from typing import List, Dict, Mapping, OrderedDict, Optional, Tuple +from colossalai.tensor.d_tensor.d_tensor import DTensor +import logging # ====================================== # General helper functions # ====================================== +SAFE_WEIGHTS_NAME = "model.safetensors" +WEIGHTS_NAME = "model.bin" +SAFE_WEIGHTS_INDEX_NAME = "model.safetensors.index.json" +WEIGHTS_INDEX_NAME = "model.bin.index.json" + +# ====================================== +# Function for check if safetensor available +# ====================================== + +# if sys.version_info < (3, 8): +# import importlib_metadata +# else: +# import importlib.metadata as importlib_metadata + +# def is_safetensors_available(): +# torch_available, torch_version = is_torch_available() +# if torch_available: +# if version.parse(torch_version) >= version.parse("1.10"): +# return importlib.util.find_spec("safetensors") is not None +# else: +# return False +# else: +# return importlib.util.find_spec("safetensors") is not None + + +# def is_torch_available(): +# torch_version = "N/A" +# torch_available = importlib.util.find_spec("torch") is not None +# if torch_available: +# try: +# torch_version = importlib_metadata.version("torch") +# except importlib_metadata.PackageNotFoundError: +# torch_available = False + +# return torch_available, torch_version + +# if is_safetensors_available(): +# from safetensors import safe_open +# from safetensors.torch import load_file as safe_load_file +# from safetensors.torch import save_file as safe_save_file def calculate_tensor_size(tensor: torch.Tensor) -> float: """ @@ -68,6 +115,144 @@ def is_safetensor_checkpoint(checkpoint_file_path: str) -> bool: return False +# ====================================== +# Helper functions for saving shard file +# ====================================== +def shard_checkpoint(state_dict: torch.Tensor, max_shard_size: int = 1024, weights_name: str = WEIGHTS_NAME): + + """ + Splits a model state dictionary in sub-checkpoints so that the final size of each sub-checkpoint does not exceed a + given size. + """ + sharded_state_dicts = [] + current_block = {} + current_block_size = 0 + total_size = 0 + + for key, weight in state_dict.items(): + if type(weight) != DTensor: + weight_size = calculate_tensor_size(weight) + + # If this weight is going to tip up over the maximal size, we split. + if current_block_size + weight_size > max_shard_size: + sharded_state_dicts.append(current_block) + current_block = {} + current_block_size = 0 + + current_block[key] = weight + current_block_size += weight_size + total_size += weight_size + + # Add the last block + sharded_state_dicts.append(current_block) + + # If we only have one shard, we return it + if len(sharded_state_dicts) == 1: + return {weights_name: sharded_state_dicts[0]}, None + + # Otherwise, let's build the index + weight_map = {} + shards = {} + + for idx, shard in enumerate(sharded_state_dicts): + shard_file = weights_name.replace(".bin", f"-{idx+1:05d}-of-{len(sharded_state_dicts):05d}.bin") + shard_file = shard_file.replace( + ".safetensors", f"-{idx + 1:05d}-of-{len(sharded_state_dicts):05d}.safetensors" + ) + shards[shard_file] = shard + for key in shard.keys(): + weight_map[key] = shard_file + + # Add the metadata + metadata = {"total_size": total_size} + index = {"metadata": metadata, "weight_map": weight_map} + return shards, index + +# def get_checkpoint_shard_files(checkpoint_root: Path, checkpoint_index_file: Path): +# """get checkpoint shard files""" +# with checkpoint_index_file.open("r") as f: +# index = json.loads(f.read()) +# logging.info(index) +# shard_filenames = sorted(set(index["weight_map"].values())) +# sharded_metadata = index["metadata"] +# sharded_metadata["all_checkpoint_keys"] = list(index["weight_map"].keys()) +# sharded_metadata["weight_map"] = index["weight_map"].copy() + +# shard_filenames = [Path.joinpath(checkpoint_root, f) for f in shard_filenames] + +# return shard_filenames, sharded_metadata + +def load_shard_state_dict(checkpoint_file: Path, use_safetensors: bool =False): + """ + load shard state dict into model + """ + if use_safetensors and not checkpoint_file.suffix == ".safetensors": + raise Exception("load the model using `safetensors`, but no file endwith .safetensors") + if use_safetensors: + from safetensors.torch import safe_open + from safetensors.torch import load_file as safe_load_file + with safe_open(checkpoint_file, framework="pt") as f: + metadata = f.metadata() + if metadata["format"] != "pt": + raise NotImplementedError( + f"Conversion from a {metadata['format']} safetensors archive to PyTorch is not implemented yet." + ) + return safe_load_file(checkpoint_file) + else: + return torch.load(checkpoint_file) + +def load_state_dict_into_model(model: nn.Module, state_dict: torch.Tensor, missing_keys: List, strict: bool = False): + r"""Copies parameters and buffers from :attr:`state_dict` into + this module and its descendants. + + Args: + state_dict (dict): a dict containing parameters and + persistent buffers. + """ + if not isinstance(state_dict, Mapping): + raise TypeError("Expected state_dict to be dict-like, got {}.".format(type(state_dict))) + + unexpected_keys: List[str] = [] + sub_missing_keys: List[str] = [] + error_msgs: List[str] = [] + + # copy state_dict so _load_from_state_dict can modify it + metadata = getattr(state_dict, '_metadata', None) + state_dict = OrderedDict(state_dict) + if metadata is not None: + state_dict._metadata = metadata + + def load(module: nn.Module, state_dict, prefix=""): + local_metadata = {} if metadata is None else metadata.get(prefix[:-1], {}) + args = (state_dict, prefix, local_metadata, True, [], [], error_msgs) + # Parameters of module and children will start with prefix. We can exit early if there are none in this + # state_dict + if len([key for key in state_dict if key.startswith(prefix)]) > 0: + module._load_from_state_dict(*args) + + for name, child in module._modules.items(): + if child is not None: + load(child, state_dict, prefix + name + ".") + + load(model, state_dict, "") + del load + + # deal with missing key + if len(missing_keys) > 0: + delete_key = [] + for key in missing_keys: + if key not in sub_missing_keys: + delete_key.append(key) + for key in delete_key: + missing_keys.remove(key) + + if strict: + if len(unexpected_keys) > 0: + error_msgs = 'Unexpected key(s) in state_dict: {}. '.format( + ', '.join('"{}"'.format(k) for k in unexpected_keys)) + raise RuntimeError('Error(s) in loading state_dict for {}:\n\t{}'.format( + model.__class__.__name__, "\n\t".join(error_msgs))) + # ====================================== # Helper functions for saving state dict # ====================================== @@ -86,8 +271,8 @@ def save_state_dict(state_dict: dict, checkpoint_file_path: str, use_safetensors assert is_safetensors_available(), "safetensors is not available." assert checkpoint_file_path.endswith('.safetensors'), \ "safetensors only supports .safetensors suffix for checkpoint file." - from safetensors.torch import save_file - save_file(state_dict, checkpoint_file_path) + from safetensors.torch import save_file as safe_save_file + safe_save_file(state_dict, checkpoint_file_path, metadata={"format": "pt"}) else: torch.save(state_dict, checkpoint_file_path) diff --git a/tests/test_checkpoint_io/test_general_checkpoint_io.py b/tests/test_checkpoint_io/test_general_checkpoint_io.py index dfbb16af4ec6..dcc26f17e295 100644 --- a/tests/test_checkpoint_io/test_general_checkpoint_io.py +++ b/tests/test_checkpoint_io/test_general_checkpoint_io.py @@ -1,9 +1,12 @@ import tempfile - import pytest import torch +import logging from torch.optim import Adam from torchvision.models import resnet18 +from pathlib import Path +import os +import subprocess from colossalai.checkpoint_io import GeneralCheckpointIO @@ -11,7 +14,7 @@ # Note: # 1. due to checkpoint IO can be quite slow if tested with all models, we will only test on resnet for now # 2. we will test on both sharded and unsharded checkpoints -# 3. TODO(FrankLeeeee): implement sharded checkpoint and test it +# 3. implement sharded checkpoint and test it # ======== @@ -51,27 +54,72 @@ def test_unsharded_checkpoint(use_safetensors: bool): ckpt_io.load_model(new_model, model_ckpt_tempfile.name) ckpt_io.load_optimizer(new_optimizer, optimizer_ckpt_tempfile.name) - # do recursive check for the optimizer state dict - # if the value is a dict, compare its values - # if the value is a list, comapre all elements one-by-one - # if the value is a torch.Tensor, use torch.equal - # otherwise use assertEqual - def recursive_check(d1, d2): - for k, v in d1.items(): - if isinstance(v, dict): - recursive_check(v, d2[k]) - elif isinstance(v, list): - for i in range(len(v)): - if isinstance(v[i], torch.Tensor): - assert torch.equal(v[i], d2[k][i]) - else: - assert v[i] == d2[k][i] - elif isinstance(v, torch.Tensor): - assert torch.equal(v, d2[k]) - else: - assert v == d2[k] # check for model and optimizer state dict recursively recursive_check(model.state_dict(), new_model.state_dict()) recursive_check(optimizer.state_dict(), new_optimizer.state_dict()) + +@pytest.mark.parametrize('use_safetensors', [True, False]) +def test_sharded_checkpoint(use_safetensors: bool): + # create a model and optimizer + model = resnet18() + optimizer = Adam(model.parameters(), lr=0.001) + # create test data sample + x = torch.randn(1, 3, 224, 224) + + # run fwd and bwd + y = model(x) + loss = y.sum() + loss.backward() + optimizer.step() + + # create a temp file for checkpoint + if use_safetensors: + suffix = ".safetensors" + SAFE_WEIGHTS_INDEX_NAME = "model.safetensors.index.json" + else: + suffix = ".bin" + WEIGHTS_INDEX_NAME = "model.bin.index.json" + + # model_ckpt_dir = tempfile.TemporaryDirectory(suffix=suffix) + model_ckpt_dir = tempfile.TemporaryDirectory(dir='/tmp') + optimizer_ckpt_tempfile = tempfile.NamedTemporaryFile() + + # save the model and optimizer + ckpt_io = GeneralCheckpointIO() + print(model_ckpt_dir.name) + print(os.listdir(model_ckpt_dir.name)) + ckpt_io.save_model(model, model_ckpt_dir.name, True, True, "", 10, use_safetensors=use_safetensors) + ckpt_io.save_optimizer(optimizer, optimizer_ckpt_tempfile.name, shard=False) + + # create new model + new_model = resnet18() + new_optimizer = Adam(new_model.parameters(), lr=0.001) + + ckpt_io.load_model(new_model, str(model_ckpt_dir.name), strict=True) + ckpt_io.load_optimizer(new_optimizer, optimizer_ckpt_tempfile.name) + + # check for model and optimizer state dict recursively + recursive_check(model.state_dict(), new_model.state_dict()) recursive_check(optimizer.state_dict(), new_optimizer.state_dict()) + + +# do recursive check for the optimizer state dict +# if the value is a dict, compare its values +# if the value is a list, comapre all elements one-by-one +# if the value is a torch.Tensor, use torch.equal +# otherwise use assertEqual +def recursive_check(d1, d2): + for k, v in d1.items(): + if isinstance(v, dict): + recursive_check(v, d2[k]) + elif isinstance(v, list): + for i in range(len(v)): + if isinstance(v[i], torch.Tensor): + assert torch.equal(v[i], d2[k][i]) + else: + assert v[i] == d2[k][i] + elif isinstance(v, torch.Tensor): + assert torch.equal(v, d2[k]) + else: + assert v == d2[k] From ab7a8790ab4a4340e2769ea26400316e60139f71 Mon Sep 17 00:00:00 2001 From: luchen Date: Thu, 6 Apr 2023 13:38:53 +0800 Subject: [PATCH 2/5] [checkpoint] support huggingface style sharded checkpoint --- colossalai/checkpoint_io/utils.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/colossalai/checkpoint_io/utils.py b/colossalai/checkpoint_io/utils.py index 24fdcf38f642..17c14c36f17c 100644 --- a/colossalai/checkpoint_io/utils.py +++ b/colossalai/checkpoint_io/utils.py @@ -239,11 +239,11 @@ def load(module: nn.Module, state_dict, prefix=""): # deal with missing key if len(missing_keys) > 0: - delete_key = [] + deleted_keys = [] for key in missing_keys: if key not in sub_missing_keys: - delete_key.append(key) - for key in delete_key: + deleted_keys.append(key) + for key in deleted_keys: missing_keys.remove(key) if strict: From 2e8b8817a26166dbd26d8ef17282bc7f39f5207d Mon Sep 17 00:00:00 2001 From: luchen Date: Thu, 6 Apr 2023 14:29:34 +0800 Subject: [PATCH 3/5] [checkpoint] support huggingface style sharded checkpoint --- .../checkpoint_io/general_checkpoint_io.py | 11 ++-- colossalai/checkpoint_io/index_file.py | 3 +- colossalai/checkpoint_io/utils.py | 52 +------------------ 3 files changed, 11 insertions(+), 55 deletions(-) diff --git a/colossalai/checkpoint_io/general_checkpoint_io.py b/colossalai/checkpoint_io/general_checkpoint_io.py index 3d0e8e31b33c..5ed6817808dd 100644 --- a/colossalai/checkpoint_io/general_checkpoint_io.py +++ b/colossalai/checkpoint_io/general_checkpoint_io.py @@ -24,7 +24,13 @@ class GeneralCheckpointIO(CheckpointIO): - """Checkpoint IO""" + """ + Checkpoint IO + """ + def load_unsharded_model(self, model: nn.Module, checkpoint: str, strict: bool): + checkpoint = load_state_dict(checkpoint) + model.load_state_dict(checkpoint, strict=strict) + def save_unsharded_model(self, model: nn.Module, checkpoint: str, gather_dtensor: bool, use_safetensors: bool): state_dict = model.state_dict() @@ -107,14 +113,13 @@ def load_sharded_model(self, model: nn.Module, checkpoint_index_file: Path, stri checkpoint_path = checkpoint_index_file.parent if use_safetensors and not is_safetensors_available(): - raise ImportError("`safe_serialization` requires the `safetensors library: `pip install safetensors`.") + raise ImportError("`safe_serialization` requires the `safetensors` library: `pip install safetensors`.") # read checkpoint index file ckpt_index_file = CheckpointIndexFile() ckpt_index_file.load(checkpoint_index_file) checkpoint_files, _ = ckpt_index_file.get_checkpoint_fileanames() sharded_metadata = ckpt_index_file.get_checkpoint_metadata() - # checkpoint_files, sharded_metadata = get_checkpoint_shard_files(checkpoint_path, checkpoint_index_file) missing_keys = sharded_metadata["all_checkpoint_keys"] for shard_file in checkpoint_files: diff --git a/colossalai/checkpoint_io/index_file.py b/colossalai/checkpoint_io/index_file.py index 4cabae5b3b23..13e50c5267fd 100644 --- a/colossalai/checkpoint_io/index_file.py +++ b/colossalai/checkpoint_io/index_file.py @@ -159,7 +159,8 @@ def get_checkpoint_metadata(self) -> str: Returns: str: checkpoint file name. """ - ckpt_metadata = self.metadata + ckpt_metadata = {} + ckpt_metadata["metadata"] = self.metadata ckpt_metadata["all_checkpoint_keys"] = list(self.weight_map.keys()) ckpt_metadata["weight_map"] = self.weight_map.copy() return ckpt_metadata diff --git a/colossalai/checkpoint_io/utils.py b/colossalai/checkpoint_io/utils.py index 17c14c36f17c..7f55b29fc240 100644 --- a/colossalai/checkpoint_io/utils.py +++ b/colossalai/checkpoint_io/utils.py @@ -10,51 +10,15 @@ from colossalai.tensor.d_tensor.d_tensor import DTensor import logging -# ====================================== -# General helper functions -# ====================================== - SAFE_WEIGHTS_NAME = "model.safetensors" WEIGHTS_NAME = "model.bin" SAFE_WEIGHTS_INDEX_NAME = "model.safetensors.index.json" WEIGHTS_INDEX_NAME = "model.bin.index.json" # ====================================== -# Function for check if safetensor available +# General helper functions # ====================================== -# if sys.version_info < (3, 8): -# import importlib_metadata -# else: -# import importlib.metadata as importlib_metadata - -# def is_safetensors_available(): -# torch_available, torch_version = is_torch_available() -# if torch_available: -# if version.parse(torch_version) >= version.parse("1.10"): -# return importlib.util.find_spec("safetensors") is not None -# else: -# return False -# else: -# return importlib.util.find_spec("safetensors") is not None - - -# def is_torch_available(): -# torch_version = "N/A" -# torch_available = importlib.util.find_spec("torch") is not None -# if torch_available: -# try: -# torch_version = importlib_metadata.version("torch") -# except importlib_metadata.PackageNotFoundError: -# torch_available = False - -# return torch_available, torch_version - -# if is_safetensors_available(): -# from safetensors import safe_open -# from safetensors.torch import load_file as safe_load_file -# from safetensors.torch import save_file as safe_save_file - def calculate_tensor_size(tensor: torch.Tensor) -> float: """ Calculate the size of a parameter in MB. Used to compute whether a group of params exceed the shard size. @@ -168,20 +132,6 @@ def shard_checkpoint(state_dict: torch.Tensor, max_shard_size: int = 1024, weigh index = {"metadata": metadata, "weight_map": weight_map} return shards, index -# def get_checkpoint_shard_files(checkpoint_root: Path, checkpoint_index_file: Path): -# """get checkpoint shard files""" -# with checkpoint_index_file.open("r") as f: -# index = json.loads(f.read()) -# logging.info(index) -# shard_filenames = sorted(set(index["weight_map"].values())) -# sharded_metadata = index["metadata"] -# sharded_metadata["all_checkpoint_keys"] = list(index["weight_map"].keys()) -# sharded_metadata["weight_map"] = index["weight_map"].copy() - -# shard_filenames = [Path.joinpath(checkpoint_root, f) for f in shard_filenames] - -# return shard_filenames, sharded_metadata - def load_shard_state_dict(checkpoint_file: Path, use_safetensors: bool =False): """ load shard state dict into model From c0045a7901a3b16c8b8a9a2d07a407cc840c8138 Mon Sep 17 00:00:00 2001 From: luchen Date: Thu, 6 Apr 2023 14:53:35 +0800 Subject: [PATCH 4/5] [checkpoint] support huggingface style sharded checkpoint --- .../checkpoint_io/general_checkpoint_io.py | 4 +--- colossalai/checkpoint_io/index_file.py | 22 ++++++++----------- colossalai/checkpoint_io/utils.py | 5 ----- 3 files changed, 10 insertions(+), 21 deletions(-) diff --git a/colossalai/checkpoint_io/general_checkpoint_io.py b/colossalai/checkpoint_io/general_checkpoint_io.py index 5ed6817808dd..5b6847d27faa 100644 --- a/colossalai/checkpoint_io/general_checkpoint_io.py +++ b/colossalai/checkpoint_io/general_checkpoint_io.py @@ -110,7 +110,6 @@ def load_sharded_model(self, model: nn.Module, checkpoint_index_file: Path, stri use_safetensors = False if "safetensors" in checkpoint_index_file.name: use_safetensors = True - checkpoint_path = checkpoint_index_file.parent if use_safetensors and not is_safetensors_available(): raise ImportError("`safe_serialization` requires the `safetensors` library: `pip install safetensors`.") @@ -119,8 +118,7 @@ def load_sharded_model(self, model: nn.Module, checkpoint_index_file: Path, stri ckpt_index_file = CheckpointIndexFile() ckpt_index_file.load(checkpoint_index_file) checkpoint_files, _ = ckpt_index_file.get_checkpoint_fileanames() - sharded_metadata = ckpt_index_file.get_checkpoint_metadata() - missing_keys = sharded_metadata["all_checkpoint_keys"] + missing_keys = ckpt_index_file.get_all_param_names() for shard_file in checkpoint_files: state_dict = load_shard_state_dict(Path(shard_file), use_safetensors) diff --git a/colossalai/checkpoint_io/index_file.py b/colossalai/checkpoint_io/index_file.py index 13e50c5267fd..caae3c028b3c 100644 --- a/colossalai/checkpoint_io/index_file.py +++ b/colossalai/checkpoint_io/index_file.py @@ -148,19 +148,15 @@ def get_checkpoint_file(self, param_name: str) -> str: """ ckpt_path = self.weight_map[param_name] return ckpt_path - - - def get_checkpoint_metadata(self) -> str: + + def get_checkpoint_metadata(self) -> str: """ Get the checkpoint index metadata. - - Args: - - Returns: - str: checkpoint file name. """ - ckpt_metadata = {} - ckpt_metadata["metadata"] = self.metadata - ckpt_metadata["all_checkpoint_keys"] = list(self.weight_map.keys()) - ckpt_metadata["weight_map"] = self.weight_map.copy() - return ckpt_metadata + return self.metadata + + def get_all_param_names(self): + """ + Get all the weight keys. + """ + return list(self.weight_map.keys()) diff --git a/colossalai/checkpoint_io/utils.py b/colossalai/checkpoint_io/utils.py index 7f55b29fc240..81b666da5c78 100644 --- a/colossalai/checkpoint_io/utils.py +++ b/colossalai/checkpoint_io/utils.py @@ -1,14 +1,9 @@ # coding=utf-8 -import importlib.util -import sys -from packaging import version from pathlib import Path import torch import torch.nn as nn -import json from typing import List, Dict, Mapping, OrderedDict, Optional, Tuple from colossalai.tensor.d_tensor.d_tensor import DTensor -import logging SAFE_WEIGHTS_NAME = "model.safetensors" WEIGHTS_NAME = "model.bin" From 52d5a68205dda6f43bc1bc3d582ccea486465d7c Mon Sep 17 00:00:00 2001 From: luchen Date: Thu, 6 Apr 2023 15:41:09 +0800 Subject: [PATCH 5/5] [checkpoint] support huggingface style sharded checkpoint --- colossalai/checkpoint_io/general_checkpoint_io.py | 5 ++--- colossalai/checkpoint_io/index_file.py | 6 ------ tests/test_checkpoint_io/test_general_checkpoint_io.py | 5 ++--- 3 files changed, 4 insertions(+), 12 deletions(-) diff --git a/colossalai/checkpoint_io/general_checkpoint_io.py b/colossalai/checkpoint_io/general_checkpoint_io.py index 5b6847d27faa..2a76f1718469 100644 --- a/colossalai/checkpoint_io/general_checkpoint_io.py +++ b/colossalai/checkpoint_io/general_checkpoint_io.py @@ -68,7 +68,7 @@ def save_unsharded_optimizer( save_state_dict(optimizer.state_dict(), checkpoint, use_safetensors=False) - def save_sharded_model(self, model: nn.Module, checkpoint_path: str, gather_dtensor:bool = False, \ + def save_sharded_model(self, model: nn.Module, checkpoint_path: str, gather_dtensor:bool = False, prefix: str = "", max_shard_size: int = 1024, use_safetensors: bool = False): """ implement this method as it can be supported by Huggingface model, @@ -115,8 +115,7 @@ def load_sharded_model(self, model: nn.Module, checkpoint_index_file: Path, stri raise ImportError("`safe_serialization` requires the `safetensors` library: `pip install safetensors`.") # read checkpoint index file - ckpt_index_file = CheckpointIndexFile() - ckpt_index_file.load(checkpoint_index_file) + ckpt_index_file = CheckpointIndexFile.from_file(checkpoint_index_file) checkpoint_files, _ = ckpt_index_file.get_checkpoint_fileanames() missing_keys = ckpt_index_file.get_all_param_names() diff --git a/colossalai/checkpoint_io/index_file.py b/colossalai/checkpoint_io/index_file.py index caae3c028b3c..89224787a91b 100644 --- a/colossalai/checkpoint_io/index_file.py +++ b/colossalai/checkpoint_io/index_file.py @@ -149,12 +149,6 @@ def get_checkpoint_file(self, param_name: str) -> str: ckpt_path = self.weight_map[param_name] return ckpt_path - def get_checkpoint_metadata(self) -> str: - """ - Get the checkpoint index metadata. - """ - return self.metadata - def get_all_param_names(self): """ Get all the weight keys. diff --git a/tests/test_checkpoint_io/test_general_checkpoint_io.py b/tests/test_checkpoint_io/test_general_checkpoint_io.py index dcc26f17e295..0b16465b53e3 100644 --- a/tests/test_checkpoint_io/test_general_checkpoint_io.py +++ b/tests/test_checkpoint_io/test_general_checkpoint_io.py @@ -82,13 +82,12 @@ def test_sharded_checkpoint(use_safetensors: bool): WEIGHTS_INDEX_NAME = "model.bin.index.json" # model_ckpt_dir = tempfile.TemporaryDirectory(suffix=suffix) - model_ckpt_dir = tempfile.TemporaryDirectory(dir='/tmp') + model_ckpt_dir = tempfile.TemporaryDirectory() optimizer_ckpt_tempfile = tempfile.NamedTemporaryFile() # save the model and optimizer ckpt_io = GeneralCheckpointIO() - print(model_ckpt_dir.name) - print(os.listdir(model_ckpt_dir.name)) + ckpt_io.save_model(model, model_ckpt_dir.name, True, True, "", 10, use_safetensors=use_safetensors) ckpt_io.save_optimizer(optimizer, optimizer_ckpt_tempfile.name, shard=False)