Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
101 changes: 82 additions & 19 deletions colossalai/checkpoint_io/general_checkpoint_io.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,37 +2,35 @@

import torch.nn as nn
from torch.optim import Optimizer
import logging
import os
import json
import gc

from .checkpoint_io_base import CheckpointIO
from .index_file import CheckpointIndexFile
from .utils import has_index_file, load_state_dict, save_state_dict
from .utils import (
has_index_file,
load_state_dict,
save_state_dict,
is_safetensors_available,
shard_checkpoint,
load_shard_state_dict,
load_state_dict_into_model
)
from .utils import SAFE_WEIGHTS_NAME, WEIGHTS_NAME, SAFE_WEIGHTS_INDEX_NAME, WEIGHTS_INDEX_NAME

__all__ = ['GeneralCheckpointIO']


class GeneralCheckpointIO(CheckpointIO):

def load_sharded_model(self, model: nn.Module, index_file_path: Path, strict: bool):
# load the index file
index_file = CheckpointIndexFile.from_file(index_file_path)

# iterate over the shard checkpoint files
# and load each
index_file.assert_no_dtensor_checkpoint()
checkpoint_file_list, _ = index_file.get_checkpoint_fileanames()
for shard_file in checkpoint_file_list:
shard_checkpoint = load_state_dict(shard_file)
model.load_state_dict(shard_checkpoint, strict=strict)

"""
Checkpoint IO
"""
def load_unsharded_model(self, model: nn.Module, checkpoint: str, strict: bool):
checkpoint = load_state_dict(checkpoint)
model.load_state_dict(checkpoint, strict=strict)

def save_sharded_model(self, model: nn.Module, checkpoint: Path, gather_dtensor: bool, prefix: str,
size_per_shard: int, use_safetensors: bool):
# TODO(FrankLeeeee): implement this method as it can be supported by Huggingface model
raise NotImplementedError("Sharded model checkpoint is not supported yet.")

def save_unsharded_model(self, model: nn.Module, checkpoint: str, gather_dtensor: bool, use_safetensors: bool):
state_dict = model.state_dict()

Expand Down Expand Up @@ -68,3 +66,68 @@ def save_unsharded_optimizer(
):
# TODO(FrankLeeeee): handle distributed tensors
save_state_dict(optimizer.state_dict(), checkpoint, use_safetensors=False)


def save_sharded_model(self, model: nn.Module, checkpoint_path: str, gather_dtensor:bool = False,
prefix: str = "", max_shard_size: int = 1024, use_safetensors: bool = False):
"""
implement this method as it can be supported by Huggingface model,
save shard model, save model to multiple files
"""
if os.path.isfile(checkpoint_path):
logging.error(f"Provided path ({checkpoint_path}) should be a directory, not a file")
return

Path(checkpoint_path).mkdir(parents=True, exist_ok=True)

# shard checkpoint
state_dict = model.state_dict()
weights_name = SAFE_WEIGHTS_NAME if use_safetensors else WEIGHTS_NAME
shards, index = shard_checkpoint(state_dict, max_shard_size=max_shard_size, weights_name=weights_name)

# Save the model
for shard_file, shard in shards.items():
checkpoint_file_path = os.path.join(checkpoint_path, shard_file)
save_state_dict(shard, checkpoint_file_path, use_safetensors)

# save index file
save_index_file = SAFE_WEIGHTS_INDEX_NAME if use_safetensors else WEIGHTS_INDEX_NAME
save_index_file = os.path.join(checkpoint_path, save_index_file)
with open(save_index_file, "w", encoding="utf-8") as f:
content = json.dumps(index, indent=2, sort_keys=True) + "\n"
f.write(content)
logging.info(
f"The model is bigger than the maximum size per checkpoint ({max_shard_size}) and is going to be "
f"split in {len(shards)} checkpoint shards. You can find where each parameters has been saved in the "
f"index located at {save_index_file}."
)


def load_sharded_model(self, model: nn.Module, checkpoint_index_file: Path, strict: bool = False, use_safetensors: bool = False):
"""
load shard model, load model from multiple files
"""
use_safetensors = False
if "safetensors" in checkpoint_index_file.name:
use_safetensors = True

if use_safetensors and not is_safetensors_available():
raise ImportError("`safe_serialization` requires the `safetensors` library: `pip install safetensors`.")

# read checkpoint index file
ckpt_index_file = CheckpointIndexFile.from_file(checkpoint_index_file)
checkpoint_files, _ = ckpt_index_file.get_checkpoint_fileanames()
missing_keys = ckpt_index_file.get_all_param_names()

for shard_file in checkpoint_files:
state_dict = load_shard_state_dict(Path(shard_file), use_safetensors)
load_state_dict_into_model(model, state_dict, missing_keys, strict)
del state_dict
gc.collect()

if strict and len(missing_keys) > 0:
error_msgs = 'Missing key(s) in state_dict: {}. '.format(
', '.join('"{}"'.format(k) for k in missing_keys))
raise RuntimeError('Error(s) in loading state_dict for {}:\n\t{}'.format(
self.__class__.__name__, "\n\t".join(error_msgs)))

6 changes: 6 additions & 0 deletions colossalai/checkpoint_io/index_file.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,3 +148,9 @@ def get_checkpoint_file(self, param_name: str) -> str:
"""
ckpt_path = self.weight_map[param_name]
return ckpt_path

def get_all_param_names(self):
"""
Get all the weight keys.
"""
return list(self.weight_map.keys())
140 changes: 135 additions & 5 deletions colossalai/checkpoint_io/utils.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,19 @@
# coding=utf-8
from pathlib import Path
from typing import List, Optional, Tuple

import torch
import torch.nn as nn
from typing import List, Dict, Mapping, OrderedDict, Optional, Tuple
from colossalai.tensor.d_tensor.d_tensor import DTensor

SAFE_WEIGHTS_NAME = "model.safetensors"
WEIGHTS_NAME = "model.bin"
SAFE_WEIGHTS_INDEX_NAME = "model.safetensors.index.json"
WEIGHTS_INDEX_NAME = "model.bin.index.json"

# ======================================
# General helper functions
# ======================================


def calculate_tensor_size(tensor: torch.Tensor) -> float:
"""
Calculate the size of a parameter in MB. Used to compute whether a group of params exceed the shard size.
Expand Down Expand Up @@ -68,6 +74,130 @@ def is_safetensor_checkpoint(checkpoint_file_path: str) -> bool:
return False


# ======================================
# Helper functions for saving shard file
# ======================================
def shard_checkpoint(state_dict: torch.Tensor, max_shard_size: int = 1024, weights_name: str = WEIGHTS_NAME):

"""
Splits a model state dictionary in sub-checkpoints so that the final size of each sub-checkpoint does not exceed a
given size.
"""
sharded_state_dicts = []
current_block = {}
current_block_size = 0
total_size = 0

for key, weight in state_dict.items():
if type(weight) != DTensor:
weight_size = calculate_tensor_size(weight)

# If this weight is going to tip up over the maximal size, we split.
if current_block_size + weight_size > max_shard_size:
sharded_state_dicts.append(current_block)
current_block = {}
current_block_size = 0

current_block[key] = weight
current_block_size += weight_size
total_size += weight_size

# Add the last block
sharded_state_dicts.append(current_block)

# If we only have one shard, we return it
if len(sharded_state_dicts) == 1:
return {weights_name: sharded_state_dicts[0]}, None

# Otherwise, let's build the index
weight_map = {}
shards = {}

for idx, shard in enumerate(sharded_state_dicts):
shard_file = weights_name.replace(".bin", f"-{idx+1:05d}-of-{len(sharded_state_dicts):05d}.bin")
shard_file = shard_file.replace(
".safetensors", f"-{idx + 1:05d}-of-{len(sharded_state_dicts):05d}.safetensors"
)
shards[shard_file] = shard
for key in shard.keys():
weight_map[key] = shard_file

# Add the metadata
metadata = {"total_size": total_size}
index = {"metadata": metadata, "weight_map": weight_map}
return shards, index

def load_shard_state_dict(checkpoint_file: Path, use_safetensors: bool =False):
"""
load shard state dict into model
"""
if use_safetensors and not checkpoint_file.suffix == ".safetensors":
raise Exception("load the model using `safetensors`, but no file endwith .safetensors")
if use_safetensors:
from safetensors.torch import safe_open
from safetensors.torch import load_file as safe_load_file
with safe_open(checkpoint_file, framework="pt") as f:
metadata = f.metadata()
if metadata["format"] != "pt":
raise NotImplementedError(
f"Conversion from a {metadata['format']} safetensors archive to PyTorch is not implemented yet."
)
return safe_load_file(checkpoint_file)
else:
return torch.load(checkpoint_file)

def load_state_dict_into_model(model: nn.Module, state_dict: torch.Tensor, missing_keys: List, strict: bool = False):
r"""Copies parameters and buffers from :attr:`state_dict` into
this module and its descendants.

Args:
state_dict (dict): a dict containing parameters and
persistent buffers.
"""
if not isinstance(state_dict, Mapping):
raise TypeError("Expected state_dict to be dict-like, got {}.".format(type(state_dict)))

unexpected_keys: List[str] = []
sub_missing_keys: List[str] = []
error_msgs: List[str] = []

# copy state_dict so _load_from_state_dict can modify it
metadata = getattr(state_dict, '_metadata', None)
state_dict = OrderedDict(state_dict)
if metadata is not None:
state_dict._metadata = metadata

def load(module: nn.Module, state_dict, prefix=""):
local_metadata = {} if metadata is None else metadata.get(prefix[:-1], {})
args = (state_dict, prefix, local_metadata, True, [], [], error_msgs)
# Parameters of module and children will start with prefix. We can exit early if there are none in this
# state_dict
if len([key for key in state_dict if key.startswith(prefix)]) > 0:
module._load_from_state_dict(*args)

for name, child in module._modules.items():
if child is not None:
load(child, state_dict, prefix + name + ".")

load(model, state_dict, "")
del load

# deal with missing key
if len(missing_keys) > 0:
deleted_keys = []
for key in missing_keys:
if key not in sub_missing_keys:
deleted_keys.append(key)
for key in deleted_keys:
missing_keys.remove(key)

if strict:
if len(unexpected_keys) > 0:
error_msgs = 'Unexpected key(s) in state_dict: {}. '.format(
', '.join('"{}"'.format(k) for k in unexpected_keys))
raise RuntimeError('Error(s) in loading state_dict for {}:\n\t{}'.format(
model.__class__.__name__, "\n\t".join(error_msgs)))

# ======================================
# Helper functions for saving state dict
# ======================================
Expand All @@ -86,8 +216,8 @@ def save_state_dict(state_dict: dict, checkpoint_file_path: str, use_safetensors
assert is_safetensors_available(), "safetensors is not available."
assert checkpoint_file_path.endswith('.safetensors'), \
"safetensors only supports .safetensors suffix for checkpoint file."
from safetensors.torch import save_file
save_file(state_dict, checkpoint_file_path)
from safetensors.torch import save_file as safe_save_file
safe_save_file(state_dict, checkpoint_file_path, metadata={"format": "pt"})
else:
torch.save(state_dict, checkpoint_file_path)

Expand Down
Loading