Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions colossalai/booster/plugin/torch_ddp_plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,15 +33,15 @@ def load_unsharded_model(self, model: nn.Module, checkpoint: str, strict: bool =
# the model should be unwrapped in self.load_model via ModelWrapper.unwrap
return super().load_unsharded_model(model, checkpoint, strict=strict)

def save_unsharded_model(self, model: nn.Module, checkpoint: str):
def save_unsharded_model(self, model: nn.Module, checkpoint: str, gather_dtensor: bool):
"""
Save model to checkpoint but only on master process.
"""
# the model should be unwrapped in self.load_model via ModelWrapper.unwrap
if self.coordinator.is_master():
super().save_unsharded_model(model, checkpoint)

def save_unsharded_optimizer(self, optimizer: Optimizer, checkpoint: str):
def save_unsharded_optimizer(self, optimizer: Optimizer, checkpoint: str, gather_dtensor: bool):
"""
Save optimizer to checkpoint but only on master process.
"""
Expand Down
5 changes: 3 additions & 2 deletions colossalai/checkpoint_io/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from .checkpoint_io_base import CheckpointIO, ShardCheckpointIndexFile
from .checkpoint_io_base import CheckpointIO
from .general_checkpoint_io import GeneralCheckpointIO
from .index_file import CheckpointIndexFile

__all__ = ['CheckpointIO', 'ShardCheckpointIndexFile', 'GeneralCheckpointIO']
__all__ = ['CheckpointIO', 'CheckpointIndexFile', 'GeneralCheckpointIO']
332 changes: 73 additions & 259 deletions colossalai/checkpoint_io/checkpoint_io_base.py

Large diffs are not rendered by default.

53 changes: 39 additions & 14 deletions colossalai/checkpoint_io/general_checkpoint_io.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,42 +4,67 @@
from torch.optim import Optimizer

from .checkpoint_io_base import CheckpointIO
from .index_file import CheckpointIndexFile
from .utils import has_index_file, load_state_dict, save_state_dict

__all__ = ['GeneralCheckpointIO']


class GeneralCheckpointIO(CheckpointIO):

def load_sharded_model(self, model: nn.Module, checkpoint: Path, strict: bool):
index_file_path = self.get_sharded_checkpoint_index_file(checkpoint)
def load_sharded_model(self, model: nn.Module, index_file_path: Path, strict: bool):
# load the index file
index_file = CheckpointIndexFile.from_file(index_file_path)

# iterate over the shard checkpoint files
# and load each
shard_files = self.get_checkpoint_shard_filenames(index_file_path)
for shard_file in shard_files:
shard_checkpoint = self.load_state_dict(shard_file)
index_file.assert_no_dtensor_checkpoint()
checkpoint_file_list, _ = index_file.get_checkpoint_fileanames()
for shard_file in checkpoint_file_list:
shard_checkpoint = load_state_dict(shard_file)
model.load_state_dict(shard_checkpoint, strict=strict)

def load_unsharded_model(self, model: nn.Module, checkpoint: Path, strict: bool):
checkpoint = self.load_state_dict(str(checkpoint))
def load_unsharded_model(self, model: nn.Module, checkpoint: str, strict: bool):
checkpoint = load_state_dict(checkpoint)
model.load_state_dict(checkpoint, strict=strict)

def save_sharded_model(self, model: nn.Module, checkpoint: Path, prefix: str, size_per_shard: int):
def save_sharded_model(self, model: nn.Module, checkpoint: Path, gather_dtensor: bool, prefix: str,
size_per_shard: int, use_safetensors: bool):
# TODO(FrankLeeeee): implement this method as it can be supported by Huggingface model
raise NotImplementedError("Sharded model checkpoint is not supported yet.")

def save_unsharded_model(self, model: nn.Module, checkpoint: Path):
self.save_checkpoint(model.state_dict(), checkpoint)
def save_unsharded_model(self, model: nn.Module, checkpoint: str, gather_dtensor: bool, use_safetensors: bool):
state_dict = model.state_dict()

# TODO(FrankLeeeee): add support for gather_dtensor
if gather_dtensor:
pass

# save the checkpoint
save_state_dict(state_dict, checkpoint, use_safetensors)

def load_sharded_optimizer(self, optimizer: Optimizer, checkpoint: Path, prefix: str, size_per_shard: int):
raise NotImplementedError("Sharded optimizer checkpoint is not supported yet.")

def load_unsharded_optimizer(self, optimizer: Optimizer, checkpoint: Path):
checkpoint = self.load_state_dict(checkpoint)
checkpoint = load_state_dict(checkpoint)
optimizer.load_state_dict(checkpoint)

def save_sharded_optimizer(self, optimizer: Optimizer, checkpoint: Path, prefix: str, size_per_shard: int):
def save_sharded_optimizer(
self,
optimizer: Optimizer,
checkpoint: Path,
gather_dtensor: bool,
prefix: str,
size_per_shard: int,
):
raise NotImplementedError("Sharded optimizer checkpoint is not supported yet.")

def save_unsharded_optimizer(self, optimizer: Optimizer, checkpoint: Path):
self.save_checkpoint(optimizer.state_dict(), checkpoint)
def save_unsharded_optimizer(
self,
optimizer: Optimizer,
checkpoint: Path,
gather_dtensor: bool,
):
# TODO(FrankLeeeee): handle distributed tensors
save_state_dict(optimizer.state_dict(), checkpoint, use_safetensors=False)
150 changes: 150 additions & 0 deletions colossalai/checkpoint_io/index_file.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,150 @@
import json
from pathlib import Path
from typing import Any, List, Union

from .utils import is_dtensor_checkpoint

__all__ = ['CheckpointIndexFile']


class CheckpointIndexFile:
"""
This class is a data structure to keep the content in the index.json file for sharded checkpoint.

Example:
>>> index = CheckpointIndexFile.from_file('model.index.json')
>>> index.append_metadata('model_type', 'bert')
>>> index.append_weight_map('bert.embeddings.word_embeddings.weight', 'model_0001-of-0002.bin')
>>> index.export('new_index.json')
"""

def __init__(self) -> None:
self.root_path = None
self.metadata: dict = dict()
self.weight_map: dict = dict()

@staticmethod
def from_file(index_path: Union[str, Path]):
"""
Create a CheckpointIndexFile object from a json file.

Args:
index_path (str): path to the json file.

Returns:
CheckpointIndexFile: CheckpointIndexFile object.
"""
index = CheckpointIndexFile()
index.load(index_path)
return index

def load(self, json_path: str):
"""
Load the index file from a json file.

Args:
json_path (str): path to the json file.
"""
# load the json file
with open(json_path, 'r') as f:
index = json.load(f)

# assign attributes if exists
if "metadata" in index:
self.metadata = index["metadata"]
if "weight_map" in index:
self.weight_map = index["weight_map"]

# assign the root directory for the index file
self.root_path = Path(json_path).absolute().parent

def export(self, json_path: str):
"""
Export the index file to a json file.

Args:
json_path (str): path to the json file.
"""
# create the index file
index = dict()
index["metadata"] = self.metadata
index["weight_map"] = self.weight_map

# export the index file
with open(json_path, 'w') as f:
json.dump(index, f, indent=4)

def append_weight_map(self, param_name: str, shard_file: str):
"""
Append a weight map entry to the index file.

Args:
param_name (str): name of the parameter.
shard_file (str): name of the shard file.
"""
self.weight_map[param_name] = shard_file

def append_meta_data(self, name: str, val: Any):
"""
Append a metadata entry to the index file.

Args:
name (str): name of the metadata.
val (Any): value of the metadata.
"""
self.metadata[name] = val

def contains_dtensor(self):
"""
Check if the index file contains any distributed tensor. The distributed tensors will be stored in
`dtensor/module.linear.weight.*.bin` or `dtensor/module.linear.weight.*.safetensors` in the weight map.

Returns:
bool: True if the index file contains any distributed tensor, False otherwise.
"""
for value in self.weight_map.values():
if value.endswith(".*.bin") or value.endswith(".*.safetensors"):
return True
return False

def get_checkpoint_fileanames(self) -> List[str]:
"""
Get the set of checkpoint filenames in the weight map.

Returns:
list: checkpoint shard filenames.
"""
# read the checkpoint file list from the json file and get a list of unique file names
checkpoint_files = sorted(list(set(self.weight_map.values())))

# get the absolute paths for all checkpoint files
checkpoint_files = [str(self.root_path.joinpath(f)) for f in checkpoint_files]

dtensor_list = []
checkpoint_list = []

for ckpt_file in checkpoint_files:
if is_dtensor_checkpoint(ckpt_file):
dtensor_list.append(ckpt_file)
else:
checkpoint_list.append(ckpt_file)

return checkpoint_list, dtensor_list

def assert_no_dtensor_checkpoint(self):
for val in self.weight_map.values():
if is_dtensor_checkpoint(val):
raise ValueError(f"Checkpoint file {val} contains distributed tensor")

def get_checkpoint_file(self, param_name: str) -> str:
"""
Get the checkpoint file name for a parameter.

Args:
param_name (str): name of the parameter.

Returns:
str: checkpoint file name.
"""
ckpt_path = self.weight_map[param_name]
return ckpt_path
Loading