-
Notifications
You must be signed in to change notification settings - Fork 4.5k
[api] implemented the checkpoint io module #3205
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
Changes from all commits
Commits
Show all changes
3 commits
Select commit
Hold shift + click to select a range
File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,4 @@ | ||
| from .checkpoint_io_base import CheckpointIO, ShardCheckpointIndexFile | ||
| from .general_checkpoint_io import GeneralCheckpointIO | ||
|
|
||
| __all__ = ['CheckpointIO', 'ShardCheckpointIndexFile', 'GeneralCheckpointIO'] |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,374 @@ | ||
| import json | ||
| from abc import ABC, abstractmethod | ||
| from pathlib import Path | ||
| from typing import Any | ||
|
|
||
| import torch | ||
| import torch.nn as nn | ||
| from torch.optim import Optimizer | ||
| from torch.optim.lr_scheduler import _LRScheduler as LRScheduler | ||
|
|
||
| __all__ = ['CheckpointIO', 'ShardCheckpointIndexFile'] | ||
|
|
||
|
|
||
| class CheckpointIO(ABC): | ||
| """ | ||
| CheckpointIO is the base class for all checkpoint IO classes. It defines the interface for checkpoint IO. | ||
|
|
||
|
|
||
| Examples: | ||
| >>> from colossalai.checkpoint_io import GeneralCheckpointIO | ||
| >>> checkpoint_io = CheckpointIO() | ||
| >>> | ||
| >>> # load model from checkpoint | ||
| >>> model = checkpoint_io.load_model(model, 'model.pt') | ||
| >>> | ||
| >>> # save model to checkpoint | ||
| >>> checkpoint_io.save_model(model, 'model.pt') | ||
| >>> | ||
| >>> # save model to sharded checkpoints | ||
| >>> checkpoint_io.save_model(model, './checkpoints/', shard=True) | ||
| >>> | ||
| >>> # load model from sharded checkpoints | ||
| >>> model = checkpoint_io.load_model(model, './checkpoints/') | ||
| >>> | ||
| >>> # load optimizer from checkpoint | ||
| >>> optimizer = checkpoint_io.load_optimizer(optimizer, 'optimizer.pt') | ||
| >>> | ||
| >>> # save optimizer to checkpoint | ||
| >>> checkpoint_io.save_optimizer(optimizer, 'optimizer.pt') | ||
|
|
||
| """ | ||
|
|
||
| # ====================================== | ||
| # Abstract methods for implementation | ||
| # ====================================== | ||
|
|
||
| @abstractmethod | ||
| def load_model(self, model: nn.Module, checkpoint: str, strict: bool = True): | ||
| """ | ||
| Load model from checkpoint. | ||
|
|
||
| Args: | ||
| model (nn.Module): model to be loaded. | ||
| checkpoint (str): checkpoint path. This value is made compatiblity with the model checkpoints in the | ||
| mainstream model zoos such as Hugging Face and TIMM. The checkpoint path can be: | ||
| 1. a file path, e.g. 'model.pt' | ||
| 2. a path to a json file which defines the index to the sharded checkpoint | ||
| 3. a path to a folder containing a unique .index.json file for sharded checkpoint | ||
| strict (bool): whether to strictly enforce that the param name in | ||
| the checkpoint match the keys returned by this module's. | ||
| """ | ||
| pass | ||
|
|
||
| @abstractmethod | ||
| def save_model(self, | ||
| model: nn.Module, | ||
| checkpoint: str, | ||
| prefix: str = None, | ||
| shard: bool = False, | ||
| size_per_shard: int = 1024): | ||
| """ | ||
| Save model to checkpoint. | ||
|
|
||
| Examples: | ||
| >>> from colossalai.checkpoint_io import GeneralCheckpointIO | ||
| >>> checkpoint_io = CheckpointIO() | ||
| >>> | ||
| >>> # save model to a single file | ||
| >>> save_model(model, 'model.pt') | ||
| >>> | ||
| >>> # save model to a sharded checkpoint | ||
| >>> save_model(model, './checkpoints/', shard=True) | ||
|
|
||
| Args: | ||
| model (nn.Module): model to be saved. | ||
| checkpoint: checkpoint path. The checkpoint path can be : | ||
| 1. a file path, e.g. 'model.pt' | ||
| 2. a directory path to save the sharded checkpoint, e.g. './checkpoints/' when shard = True. | ||
| shard: whether to shard the checkpoint. Default: False. If set to True, the checkpoint will be sharded into | ||
| multiple files. The model shards will be specificed by a `model.index.json` file. When shard = True, please ensure | ||
| that the checkpoint path is a directory path instead of a file path. | ||
| size_per_shard (int): size per shard in MB. Default: 1024. This value is only used when shard is set to True. | ||
| """ | ||
| pass | ||
|
|
||
| @abstractmethod | ||
| def load_optimizer(self, optimizer: Optimizer, checkpoint: str): | ||
| """ | ||
| Load optimizer from checkpoint. | ||
|
|
||
| Args: | ||
| optimizer (Optimizer): optimizer to be loaded. | ||
| checkpoint (str): checkpoint path. This value is made compatiblity with the model checkpoints in the | ||
| """ | ||
| pass | ||
|
|
||
| @abstractmethod | ||
| def save_optimizer(self, optimizer: Optimizer, checkpoint: str, shard: bool = False, size_per_shard: int = 1024): | ||
| """ | ||
| Save optimizer to checkpoint. | ||
|
|
||
| Args: | ||
| optimizer (Optimizer): optimizer to be saved. | ||
| checkpoint: checkpoint path. The checkpoint path can be : | ||
| 1. a file path, e.g. 'model.pt' | ||
| 2. a path to a json file which defines the index to the sharded checkpoint for the optimizer | ||
| 3. a path to a folder containing a unique .index.json file for sharded checkpoint | ||
| """ | ||
| pass | ||
|
|
||
| # ============================================ | ||
| # methods for loading and saving lr scheduler | ||
| # as this is quite standard, there is no need | ||
| # to make them abstract | ||
| # ============================================ | ||
|
|
||
| def save_lr_scheduler(self, lr_scheduler: LRScheduler, checkpoint: str): | ||
| """ | ||
| Save lr scheduler to checkpoint. | ||
|
|
||
| Args: | ||
| lr_scheduler (LRScheduler): lr scheduler to be saved. | ||
| checkpoint: checkpoint path. The checkpoint path can only be a file path. | ||
| """ | ||
| torch.save(lr_scheduler.state_dict(), checkpoint) | ||
|
|
||
| def load_lr_scheduler(self, lr_scheduler: LRScheduler, checkpoint: str): | ||
| """ | ||
| Load lr scheduler from checkpoint. | ||
|
|
||
| Args: | ||
| lr_scheduler (LRScheduler): lr scheduler to be loaded. | ||
| checkpoint (str): the path for a single checkpoint file. | ||
| """ | ||
| state_dict = torch.load(checkpoint) | ||
| lr_scheduler.load_state_dict(state_dict) | ||
|
|
||
| # ======================================== | ||
| # Helper functions for loading state dict | ||
| # ======================================== | ||
|
|
||
| def get_sharded_checkpoint_index_file(self, checkpoint_path: Path): | ||
| """ | ||
| Get the index file path for a sharded checkpoint. | ||
|
|
||
| Args: | ||
| checkpoint_path (Path): path to the checkpoint. | ||
|
|
||
| Returns: | ||
| Path: path to the index file. | ||
| """ | ||
| if checkpoint_path.is_file(): | ||
| # check if it is .index.json | ||
| if checkpoint_path.name.endswith('.index.json'): | ||
| return checkpoint_path | ||
| else: | ||
| raise ValueError(f'Invalid checkpoint path: {checkpoint_path}. ') | ||
| elif checkpoint_path.is_dir(): | ||
| # check if there is only one a file ending with .index.json in this directory | ||
| index_files = list(checkpoint_path.glob('*.index.json')) | ||
| if len(index_files) == 1: | ||
| return index_files[0] | ||
| else: | ||
| raise ValueError(f'Found {len(index_files)} index files in {checkpoint_path}. ') | ||
|
|
||
| def is_sharded_checkpoint(self, checkpoint_path: Path): | ||
| """ | ||
| Check whether the checkpoint is sharded. | ||
|
|
||
| Args: | ||
| checkpoint (str): checkpoint path. | ||
|
|
||
| Returns: | ||
| bool: whether the checkpoint is sharded. | ||
| """ | ||
| if checkpoint_path.is_file(): | ||
| # check if it is .index.json | ||
| if checkpoint_path.name.endswith('.index.json'): | ||
| return True | ||
| else: | ||
| return False | ||
| elif checkpoint_path.is_dir(): | ||
| # check if there is only one a file ending with .index.json in this directory | ||
| index_files = list(checkpoint_path.glob('*.index.json')) | ||
| if len(index_files) == 1: | ||
| return True | ||
| else: | ||
| raise ValueError(f'Found {len(index_files)} index files in {checkpoint_path}. ') | ||
|
|
||
| def get_checkpoint_shard_filenames(self, index_file_path: Path): | ||
| """ | ||
| Get checkpoint shard filenames from a json file. | ||
|
|
||
| Args: | ||
| index_file_path (Path): path to the json file. | ||
|
|
||
| Returns: | ||
| list: checkpoint shard filenames. | ||
| """ | ||
| with open(str(index_file_path), 'r') as f: | ||
| shard_filenames = json.load(f) | ||
|
|
||
| if "weight_map" in index: | ||
| index = index["weight_map"] | ||
|
|
||
| checkpoint_root_path = index_file_path.absolute().parent | ||
|
|
||
| # read the checkpoint file list from the json file and get a list of unique file names | ||
| checkpoint_files = sorted(list(set(index.values()))) | ||
|
|
||
| # get the absolute paths for all checkpoint files | ||
| checkpoint_files = [checkpoint_root_path.joinpath(f) for f in checkpoint_files] | ||
| return shard_filenames | ||
|
|
||
| def load_safetensors_state_dict(self, *args, **kwargs): | ||
| """ | ||
| Load safetensors state dict from checkpoint. | ||
| """ | ||
| # TODO(FrankLeeeee): support huggingface safetensors | ||
| raise NotImplementedError("This method is not implemented to support safe tensors") | ||
|
|
||
| def load_state_dict(self, checkpoint_file_path: Path): | ||
| """ | ||
| Load state dict from checkpoint. | ||
|
|
||
| Args: | ||
| checkpoint_file_path (Path): path to the checkpoint file. | ||
|
|
||
| Returns: | ||
| dict: state dict. | ||
| """ | ||
| return torch.load(str(checkpoint_file_path)) | ||
|
|
||
| # ====================================== | ||
| # Helper functions for saving state dict | ||
| # ====================================== | ||
|
|
||
| def save_safetensors_state_dict(self, *args, **kwargs): | ||
| """ | ||
| Save safetensors state dict to checkpoint. | ||
| """ | ||
| # TODO(FrankLeeeee): support huggingface safetensors | ||
| raise NotImplementedError("This method is not implemented to support safe tensors") | ||
|
|
||
| def generate_checkpoint_shard_file_name(self, index: int, total_number: int, prefix: str = None): | ||
| """ | ||
| Generate checkpoint shard file name. | ||
|
|
||
| Args: | ||
| index (int): index of the shard. | ||
| total_number (int): total number of shards. | ||
| prefix (str): prefix of the shard file name. Default: None. | ||
| """ | ||
| if prefix is None: | ||
| return f"{index}-of-{total_number}.bin" | ||
| else: | ||
| return f"{prefix}-{index}-of-{total_number}.bin" | ||
|
|
||
| def save_checkpoint(self, state_dict: dict, checkpoint_file_path: Path): | ||
| """ | ||
| Save state dict to checkpoint. | ||
|
|
||
| Args: | ||
| state_dict (dict): state dict. | ||
| checkpoint_file_path (Path): path to the checkpoint file. | ||
| """ | ||
| torch.save(state_dict, str(checkpoint_file_path)) | ||
|
|
||
| def save_state_dict_as_shard(self, state_dict: dict, index: int, total_number: int, prefix: str, | ||
| checkpoint_path: Path): | ||
| """ | ||
| Save state dict as shard. | ||
|
|
||
| Args: | ||
| state_dict (dict): state dict. | ||
| checkpoint_path (Path): path to the checkpoint file. | ||
| """ | ||
| # generate the shard name | ||
| shard_file_name = self.generate_checkpoint_shard_file_name(index, total_number, prefix) | ||
| shard_file_path = checkpoint_path.joinpath(shard_file_name) | ||
|
|
||
| # save the shard | ||
| self.save_checkpoint(state_dict, shard_file_path) | ||
|
|
||
| def calculate_param_size(self, param: torch.Tensor): | ||
| """ | ||
| Calculate the size of a parameter in MB. Used to compute whether a group of params exceed the shard size. | ||
| If so, a new shard should be created. | ||
|
|
||
| ArgsL | ||
| param (torch.Tensor): parameter tensor. | ||
| """ | ||
| # TODO(FrankLeeeee): check if this tensor is a DTensor, compute its global size if so | ||
| return param.numel() * param.element_size() / 1024 / 1024 | ||
|
|
||
|
|
||
| class ShardCheckpointIndexFile: | ||
| """ | ||
| This class is a data structure to keep the content in the index.json file for sharded checkpoint. | ||
|
|
||
| Example: | ||
| >>> index = ShardCheckpointIndexFile() | ||
| >>> index.load('index.json') | ||
| >>> index.append_metadata('model_type', 'bert') | ||
| >>> index.append_weight_map('bert.embeddings.word_embeddings.weight', 'bert.embeddings.word_embeddings.weight-0-of-2.bin') | ||
| >>> index.export('index.json') | ||
| """ | ||
|
|
||
| def __init__(self) -> None: | ||
| self.metadata: dict = dict() | ||
| self.weight_map: dict = dict() | ||
|
|
||
| def load(self, json_path: str): | ||
| """ | ||
| Load the index file from a json file. | ||
|
|
||
| Args: | ||
| json_path (str): path to the json file. | ||
| """ | ||
| # load the json file | ||
| with open(json_path, 'r') as f: | ||
| index = json.load(f) | ||
|
|
||
| # assign attributes if exists | ||
| if "metadata" in index: | ||
| self.metadata = index["metadata"] | ||
| if "weight_map" in index: | ||
| self.weight_map = index["weight_map"] | ||
|
|
||
| def export(self, json_path: str): | ||
| """ | ||
| Export the index file to a json file. | ||
|
|
||
| Args: | ||
| json_path (str): path to the json file. | ||
| """ | ||
| # create the index file | ||
| index = dict() | ||
| index["metadata"] = self.metadata | ||
| index["weight_map"] = self.weight_map | ||
|
|
||
| # export the index file | ||
| with open(json_path, 'w') as f: | ||
| json.dump(index, f, indent=4) | ||
|
|
||
| def append_weight_map(self, param_name: str, shard_file: str): | ||
| """ | ||
| Append a weight map entry to the index file. | ||
|
|
||
| Args: | ||
| param_name (str): name of the parameter. | ||
| shard_file (str): name of the shard file. | ||
| """ | ||
| self.weight_map[param_name] = shard_file | ||
|
|
||
| def append_meta_data(self, name: str, val: Any): | ||
| """ | ||
| Append a metadata entry to the index file. | ||
|
|
||
| Args: | ||
| name (str): name of the metadata. | ||
| val (Any): value of the metadata. | ||
| """ | ||
| self.metadata[name] = val | ||
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.