Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions colossalai/booster/plugin/gemini_plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
from torch.utils.data import DataLoader

from colossalai.checkpoint_io import CheckpointIndexFile, CheckpointIO, GeneralCheckpointIO
from colossalai.checkpoint_io.utils import get_base_filenames, get_shard_filename, save_state_dict
from colossalai.checkpoint_io.utils import get_model_base_filenames, get_shard_filename, save_state_dict
from colossalai.cluster import DistCoordinator
from colossalai.interface import ModelWrapper, OptimizerWrapper
from colossalai.utils import get_current_device
Expand Down Expand Up @@ -76,14 +76,14 @@ def save_sharded_model(self,
model: GeminiDDP,
checkpoint_path: str,
gather_dtensor: bool = False,
variant: Optional[str] = None,
prefix: Optional[str] = None,
max_shard_size: int = 1024,
use_safetensors: bool = False):
"""
Save sharded model
"""
state_dict_shard = model.state_dict_shard(max_shard_size=max_shard_size, only_rank_0=True, dtype=torch.float32)
weights_name, save_index_file = get_base_filenames(variant, use_safetensors)
weights_name, save_index_file = get_model_base_filenames(prefix, use_safetensors)
total_size = 0
index_file = CheckpointIndexFile(checkpoint_path)
for idx, shard_pair in enumerate(state_dict_shard):
Expand Down
16 changes: 13 additions & 3 deletions colossalai/booster/plugin/torch_ddp_plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,6 @@ def save_unsharded_model(self, model: nn.Module, checkpoint: str, gather_dtensor
"""
Save model to checkpoint but only on master process.
"""
# the model should be unwrapped in self.load_model via ModelWrapper.unwrap
if self.coordinator.is_master():
super().save_unsharded_model(model, checkpoint, gather_dtensor, use_safetensors)

Expand All @@ -54,11 +53,22 @@ def save_sharded_model(self,
model: nn.Module,
checkpoint_path: str,
gather_dtensor: bool = False,
variant: Optional[str] = None,
prefix: Optional[str] = None,
max_shard_size: int = 1024,
use_safetensors: bool = False):
"""
Save model to checkpoint but only on master process.
"""
if self.coordinator.is_master():
super().save_sharded_model(model, checkpoint_path, gather_dtensor, prefix, max_shard_size, use_safetensors)

def save_sharded_optimier(self, optimizer: Optimizer, checkpoint: str, gather_dtensor: bool, prefix: str,
size_per_shard: int):
"""
Save optimizer to checkpoint but only on master process.
"""
if self.coordinator.is_master():
super().save_sharded_model(model, checkpoint_path, gather_dtensor, variant, max_shard_size, use_safetensors)
super().save_sharded_optimizer(optimizer, checkpoint, gather_dtensor, prefix, size_per_shard)


class TorchDDPModel(ModelWrapper):
Expand Down
9 changes: 5 additions & 4 deletions colossalai/booster/plugin/torch_fsdp_plugin.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
import warnings
from pathlib import Path
from typing import Callable, Iterable, Iterator, List, Optional, Tuple, Union

import torch
import torch.nn as nn
import warnings
from packaging import version
from torch.distributed import ProcessGroup

Expand Down Expand Up @@ -69,7 +69,7 @@ def save_unsharded_optimizer(self, optimizer: Optimizer, checkpoint: str, gather
full_optimizer_state = FSDP.full_optim_state_dict(fsdp_model, optim=optimizer, rank0_only=True)
utils.save_state_dict(full_optimizer_state, checkpoint_file_path=checkpoint, use_safetensors=False)

def save_sharded_model(self, model: nn.Module, checkpoint: str, gather_dtensor: bool, variant: Optional[str],
def save_sharded_model(self, model: nn.Module, checkpoint: str, gather_dtensor: bool, prefix: Optional[str],
size_per_shard: int, use_safetensors: bool):
"""
Save model to checkpoint but only on master process.
Expand All @@ -87,13 +87,14 @@ def load_sharded_model(self,
"""
raise NotImplementedError("Sharded model checkpoint is not supported yet.")

def save_sharded_optimizer(self, optimizer: Optimizer, checkpoint: str, gather_dtensor: bool):
def save_sharded_optimizer(self, optimizer: Optimizer, checkpoint: str, gather_dtensor: bool, prefix: str,
size_per_shard: int):
"""
Save optimizer to checkpoint but only on master process.
"""
raise NotImplementedError("Sharded optimizer checkpoint is not supported yet.")

def load_sharded_optimizer(self, optimizer: Optimizer, index_file_path: str, prefix: str, size_per_shard: int):
def load_sharded_optimizer(self, optimizer: Optimizer, index_file_path: str, size_per_shard: int):
"""
Load optimizer to checkpoint but only on master process.
"""
Expand Down
12 changes: 6 additions & 6 deletions colossalai/checkpoint_io/checkpoint_io_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,7 +103,7 @@ def save_model(self,
checkpoint: str,
shard: bool = False,
gather_dtensor: bool = True,
variant: str = None,
prefix: str = None,
size_per_shard: int = 1024,
use_safetensors: bool = False):
"""
Expand All @@ -128,7 +128,7 @@ def save_model(self,
multiple files. The model shards will be specified by a `model.index.json` file. When shard = True, please ensure
that the checkpoint path is a directory path instead of a file path.
gather_dtensor (bool): whether to gather the distributed tensor to the first device. Default: True.
variant (str): If specified, weights are saved in the format pytorch_model.<variant>.bin. Default: None.
prefix (str): If specified, weights are saved in the format pytorch_model.<prefix>.bin. Default: None.
size_per_shard (int): size per shard in MB. Default: 1024. This value is only used when shard = True.
use_safetensors (bool): whether to use safe tensors. Default: False. If set to True, the checkpoint will be saved
"""
Expand All @@ -137,11 +137,11 @@ def save_model(self,
model = model.unwrap()

if shard:
self.save_sharded_model(model, checkpoint, gather_dtensor, variant, size_per_shard, use_safetensors)
self.save_sharded_model(model, checkpoint, gather_dtensor, prefix, size_per_shard, use_safetensors)
else:
self.save_unsharded_model(model, checkpoint, gather_dtensor, use_safetensors)

def load_optimizer(self, optimizer: Optimizer, checkpoint: str):
def load_optimizer(self, optimizer: Optimizer, checkpoint: str, prefix: str = None, size_per_shard: int = 1024):
"""
Load optimizer from checkpoint.

Expand All @@ -157,7 +157,7 @@ def load_optimizer(self, optimizer: Optimizer, checkpoint: str):

if index_file_exists:
# the existence of index file means it is a sharded checkpoint
self.load_sharded_optimizer(optimizer, index_file_path)
self.load_sharded_optimizer(optimizer, index_file_path, prefix, size_per_shard)
else:
self.load_unsharded_optimizer(optimizer, checkpoint)

Expand Down Expand Up @@ -218,7 +218,7 @@ def load_unsharded_model(self, model: nn.Module, checkpoint: str, strict: bool):
pass

@abstractmethod
def save_sharded_model(self, model: nn.Module, checkpoint: str, gather_dtensor: bool, variant: Optional[str],
def save_sharded_model(self, model: nn.Module, checkpoint: str, gather_dtensor: bool, prefix: Optional[str],
size_per_shard: int, use_safetensors: bool):
"""
Save model to sharded checkpoint.
Expand Down
95 changes: 83 additions & 12 deletions colossalai/checkpoint_io/general_checkpoint_io.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,15 +11,21 @@
from .checkpoint_io_base import CheckpointIO
from .index_file import CheckpointIndexFile
from .utils import (
get_base_filenames,
get_model_base_filenames,
get_optimizer_base_filenames,
get_shard_filename,
has_index_file,
is_safetensors_available,
load_param_groups_into_optimizer,
load_shard_state_dict,
load_state_dict,
load_state_dict_into_model,
load_states_into_optimizer,
save_param_groups,
save_state_dict,
shard_checkpoint,
shard_model_checkpoint,
shard_optimizer_checkpoint,
sharded_optimizer_loading_epilogue,
)

__all__ = ['GeneralCheckpointIO']
Expand All @@ -44,12 +50,30 @@ def save_unsharded_model(self, model: nn.Module, checkpoint: str, gather_dtensor
# save the checkpoint
save_state_dict(state_dict, checkpoint, use_safetensors)

def load_sharded_optimizer(self, optimizer: Optimizer, checkpoint: Path, prefix: str, size_per_shard: int):
raise NotImplementedError("Sharded optimizer checkpoint is not supported yet.")
def load_sharded_optimizer(self, optimizer: Optimizer, index_file_path: str, prefix: str, size_per_shard: int):
"""
Load sharded optimizer with the given path to index file.
"""
optimizer.load_state_dict
# Read checkpoint index file.
ckpt_index_file = CheckpointIndexFile.from_file(index_file_path)

def load_unsharded_optimizer(self, optimizer: Optimizer, checkpoint: Path):
checkpoint = load_state_dict(checkpoint)
optimizer.load_state_dict(checkpoint)
# Load param_groups
param_group_path = ckpt_index_file.get_param_group_filename()
if param_group_path is None:
raise RuntimeError(f'Invalid index file path {index_file_path} for an optimizer. \
Lacking param group file under current directory.')
id_map = load_param_groups_into_optimizer(optimizer, param_group_path)

checkpoint_files, _ = ckpt_index_file.get_checkpoint_filenames()

for shard_file in checkpoint_files:
state_dict = load_shard_state_dict(Path(shard_file), use_safetensors=False)
load_states_into_optimizer(optimizer, state_dict, id_map)
del state_dict
gc.collect()

sharded_optimizer_loading_epilogue(optimizer)

def save_sharded_optimizer(
self,
Expand All @@ -59,7 +83,54 @@ def save_sharded_optimizer(
prefix: str,
size_per_shard: int,
):
raise NotImplementedError("Sharded optimizer checkpoint is not supported yet.")
"""
Save sharded optimizer checkpoint under the given checkpointing path.
The following files will be created under the path:
- An index file (pytorch_optim.bin.index.json) containing a map between optimizer states and file names
- A group file (pytorch_optim_group.bin) recording information of param_groups
- Multiple files (pytorch_optim-000XX.bin) that store state tensors of optimizer in a sharding way
"""
if os.path.isfile(checkpoint):
logging.error(f"Provided path ({checkpoint}) should be a directory, not a file")
return

Path(checkpoint).mkdir(parents=True, exist_ok=True)

# Offload optimizer states. States are broken into shards within max_shard_size.
state_dict = optimizer.state_dict()
sharded_state = shard_optimizer_checkpoint(state_dict, max_shard_size=size_per_shard)

# Preparing file paths and index file.
states_name, save_index_file, param_group_file = get_optimizer_base_filenames(prefix)
index_file = CheckpointIndexFile(checkpoint)

# Store the information of param groups to param_group_file.
index_file.append_meta_data("param_groups", param_group_file)
group_file_path = os.path.join(checkpoint, param_group_file)
save_param_groups(state_dict, group_file_path)

# Save shards of optimizer states.
total_size = 0
for idx, shard_pair in enumerate(sharded_state):
shard, current_size = shard_pair
shard_file = get_shard_filename(states_name, idx)
total_size = total_size + current_size
for param_id in shard.keys():
index_file.append_weight_map(str(param_id), shard_file)

checkpoint_file_path = os.path.join(checkpoint, shard_file)
save_state_dict(shard, checkpoint_file_path, use_safetensors=False)

# Wrap up index file.
index_file.append_meta_data("total_size", total_size)
index_file.write_index_file(save_index_file)
logging.info(f"The optimizer is going to be split to checkpoint shards. "
f"You can find where each parameters has been saved in the "
f"index located at {save_index_file}.")

def load_unsharded_optimizer(self, optimizer: Optimizer, checkpoint: Path):
checkpoint = load_state_dict(checkpoint)
optimizer.load_state_dict(checkpoint)

def save_unsharded_optimizer(
self,
Expand All @@ -74,7 +145,7 @@ def save_sharded_model(self,
model: nn.Module,
checkpoint_path: str,
gather_dtensor: bool = False,
variant: Optional[str] = None,
prefix: Optional[str] = None,
max_shard_size: int = 1024,
use_safetensors: bool = False):
"""
Expand All @@ -89,9 +160,9 @@ def save_sharded_model(self,

# shard checkpoint
state_dict = model.state_dict()
state_dict_shard = shard_checkpoint(state_dict, max_shard_size=max_shard_size)
state_dict_shard = shard_model_checkpoint(state_dict, max_shard_size=max_shard_size)

weights_name, save_index_file = get_base_filenames(variant, use_safetensors)
weights_name, save_index_file = get_model_base_filenames(prefix, use_safetensors)
total_size = 0
index_file = CheckpointIndexFile(checkpoint_path)
for idx, shard_pair in enumerate(state_dict_shard):
Expand Down Expand Up @@ -128,7 +199,7 @@ def load_sharded_model(self,

# read checkpoint index file
ckpt_index_file = CheckpointIndexFile.from_file(checkpoint_index_file)
checkpoint_files, _ = ckpt_index_file.get_checkpoint_fileanames()
checkpoint_files, _ = ckpt_index_file.get_checkpoint_filenames()
missing_keys = []

for shard_file in checkpoint_files:
Expand Down
14 changes: 13 additions & 1 deletion colossalai/checkpoint_io/index_file.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,7 +111,7 @@ def contains_dtensor(self):
return True
return False

def get_checkpoint_fileanames(self) -> List[str]:
def get_checkpoint_filenames(self) -> List[str]:
"""
Get the set of checkpoint filenames in the weight map.

Expand Down Expand Up @@ -159,6 +159,18 @@ def get_all_param_names(self):
"""
return list(self.weight_map.keys())

def get_param_group_filename(self) -> Union[str, None]:
"""
Get the file name of param_group file if this is a checkpoint for optimizer.
Returns:
str: param_group file name
"""
filename = self.metadata.get("param_groups", None)
if filename:
return str(self.root_path.joinpath(filename))
else:
return None

def write_index_file(self, save_index_file):
"""
Write index file.
Expand Down
Loading