Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions colossalai/booster/booster.py
Original file line number Diff line number Diff line change
Expand Up @@ -310,6 +310,7 @@ def save_model(
prefix: Optional[str] = None,
size_per_shard: int = 1024,
use_safetensors: bool = False,
use_async: bool = False,
) -> None:
"""Save model to checkpoint.

Expand All @@ -333,6 +334,7 @@ def save_model(
prefix=prefix,
size_per_shard=size_per_shard,
use_safetensors=use_safetensors,
use_async=use_async,
)

def load_optimizer(self, optimizer: Optimizer, checkpoint: str) -> None:
Expand Down
9 changes: 6 additions & 3 deletions colossalai/booster/plugin/low_level_zero_plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -259,10 +259,12 @@ def load_sharded_model(
super().load_sharded_model(model, checkpoint_index_file, strict, use_safetensors, load_sub_module)
model.update_master_params()

def save_unsharded_model(self, model: ModelWrapper, checkpoint: str, gather_dtensor: bool, use_safetensors: bool):
def save_unsharded_model(
self, model: ModelWrapper, checkpoint: str, gather_dtensor: bool, use_safetensors: bool, use_async: bool = False
):
assert isinstance(model, LowLevelZeroModel), "Please boost the model before loading!"
model._force_wait_all_gather()
return super().save_unsharded_model(model, checkpoint, gather_dtensor, use_safetensors)
return super().save_unsharded_model(model, checkpoint, gather_dtensor, use_safetensors, use_async=use_async)

def save_sharded_model(
self,
Expand All @@ -272,11 +274,12 @@ def save_sharded_model(
prefix: Optional[str] = None,
max_shard_size: int = 1024,
use_safetensors: bool = False,
use_async: bool = False,
):
assert isinstance(model, LowLevelZeroModel), "Please boost the model before loading!"
model._force_wait_all_gather()
return super().save_sharded_model(
model, checkpoint_path, gather_dtensor, prefix, max_shard_size, use_safetensors
model, checkpoint_path, gather_dtensor, prefix, max_shard_size, use_safetensors, use_async=use_async
)

def save_lora_as_pretrained(self, model, checkpoint, use_safetensors):
Expand Down
17 changes: 14 additions & 3 deletions colossalai/booster/plugin/torch_ddp_plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,13 +33,17 @@ def load_unsharded_model(self, model: ModelWrapper, checkpoint: str, strict: boo
assert isinstance(model, ModelWrapper), "Please boost the model before loading!"
super().load_unsharded_model(model.unwrap(), checkpoint, strict=strict)

def save_unsharded_model(self, model: ModelWrapper, checkpoint: str, gather_dtensor: bool, use_safetensors: bool):
def save_unsharded_model(
self, model: ModelWrapper, checkpoint: str, gather_dtensor: bool, use_safetensors: bool, use_async: bool = False
):
"""
Save model to checkpoint but only on master process.
"""
assert isinstance(model, ModelWrapper), "Please boost the model before saving!"
if self.coordinator.is_master():
super().save_unsharded_model(model.unwrap(), checkpoint, gather_dtensor, use_safetensors)
super().save_unsharded_model(
model.unwrap(), checkpoint, gather_dtensor, use_safetensors, use_async=use_async
)

def load_unsharded_optimizer(self, optimizer: OptimizerWrapper, checkpoint: str):
"""
Expand Down Expand Up @@ -71,14 +75,21 @@ def save_sharded_model(
prefix: Optional[str] = None,
max_shard_size: int = 1024,
use_safetensors: bool = False,
use_async: bool = False,
):
"""
Save model to checkpoint but only on master process.
"""
assert isinstance(model, ModelWrapper), "Please boost the model before saving!"
if self.coordinator.is_master():
super().save_sharded_model(
model.unwrap(), checkpoint_path, gather_dtensor, prefix, max_shard_size, use_safetensors
model.unwrap(),
checkpoint_path,
gather_dtensor,
prefix,
max_shard_size,
use_safetensors,
use_async=use_async,
)

def load_sharded_model(
Expand Down
48 changes: 44 additions & 4 deletions colossalai/checkpoint_io/checkpoint_io_base.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,14 @@
from abc import ABC, abstractmethod
from pathlib import Path
from typing import Optional, Union
from typing import Dict, Optional, Union

import torch
import torch.nn as nn
from torch.optim import Optimizer
from torch.optim.lr_scheduler import _LRScheduler as LRScheduler

from colossalai.interface import ModelWrapper
from colossalai.logging import get_dist_logger

from .utils import SAFE_WEIGHTS_NAME, WEIGHTS_NAME, has_index_file

Expand Down Expand Up @@ -58,9 +59,34 @@ class CheckpointIO(ABC):
>>> checkpoint_io.save_optimizer(optimizer, 'optimizer.pt')
"""

N_WRITE_ENTRIES: int = 32

# ======================================
# Public methods
# ======================================
def __init__(self):
super().__init__()
self.pinned_state_dicts: Dict[int, dict] = {}
self.async_writers = []

def _sync_io(self):
for writer in self.async_writers:
writer.synchronize()
writer.fp.close()
self.async_writers.clear()

def _sync_d2h(self):
for writer in self.async_writers:
writer.sync_before_step()

def synchronize(self):
"""This method must be called before updating the model weights."""
self._sync_d2h()

def __del__(self):
self._sync_d2h()
self._sync_io()

def load_model(
self, model: Union[nn.Module, ModelWrapper], checkpoint: str, strict: bool = True
) -> Union[nn.Module, ModelWrapper]:
Expand Down Expand Up @@ -111,6 +137,7 @@ def save_model(
prefix: str = None,
size_per_shard: int = 1024,
use_safetensors: bool = False,
use_async: bool = False,
):
"""
Save model to checkpoint.
Expand Down Expand Up @@ -138,11 +165,21 @@ def save_model(
size_per_shard (int): size per shard in MB. Default: 1024. This value is only used when shard = True.
use_safetensors (bool): whether to use safe tensors. Default: False. If set to True, the checkpoint will be saved
"""
self._sync_io()
if use_async and not use_safetensors:
logger = get_dist_logger()
logger.warning(
"Async save is only supported when use_safetensors is set to True. "
"Setting use_safetensors to True for async save."
)
use_safetensors = True

if shard:
self.save_sharded_model(model, checkpoint, gather_dtensor, prefix, size_per_shard, use_safetensors)
self.save_sharded_model(
model, checkpoint, gather_dtensor, prefix, size_per_shard, use_safetensors, use_async=use_async
)
else:
self.save_unsharded_model(model, checkpoint, gather_dtensor, use_safetensors)
self.save_unsharded_model(model, checkpoint, gather_dtensor, use_safetensors, use_async=use_async)

def load_optimizer(self, optimizer: Optimizer, checkpoint: str, prefix: str = None, size_per_shard: int = 1024):
"""
Expand Down Expand Up @@ -234,6 +271,7 @@ def save_sharded_model(
prefix: Optional[str],
size_per_shard: int,
use_safetensors: bool,
use_async: bool = False,
):
"""
Save model to sharded checkpoint.
Expand All @@ -248,7 +286,9 @@ def save_sharded_model(
"""

@abstractmethod
def save_unsharded_model(self, model: nn.Module, checkpoint: str, gather_dtensor: bool, use_safetensors: bool):
def save_unsharded_model(
self, model: nn.Module, checkpoint: str, gather_dtensor: bool, use_safetensors: bool, use_async: bool = False
):
"""
Save model to unsharded checkpoint.

Expand Down
57 changes: 44 additions & 13 deletions colossalai/checkpoint_io/general_checkpoint_io.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,13 @@
import torch.nn as nn
from torch.optim import Optimizer

from colossalai.utils.safetensors import move_and_save

from .checkpoint_io_base import CheckpointIO
from .index_file import CheckpointIndexFile
from .utils import (
async_save_state_dict_shards,
create_pinned_state_dict,
get_model_base_filenames,
get_optimizer_base_filenames,
is_safetensors_available,
Expand Down Expand Up @@ -40,15 +44,27 @@ def load_unsharded_model(self, model: nn.Module, checkpoint: str, strict: bool):
checkpoint = load_state_dict(checkpoint)
model.load_state_dict(checkpoint, strict=strict)

def save_unsharded_model(self, model: nn.Module, checkpoint: str, gather_dtensor: bool, use_safetensors: bool):
def save_unsharded_model(
self, model: nn.Module, checkpoint: str, gather_dtensor: bool, use_safetensors: bool, use_async: bool = False
):
state_dict = model.state_dict()

# TODO(FrankLeeeee): add support for gather_dtensor
if gather_dtensor:
pass

# save the checkpoint
save_state_dict(state_dict, checkpoint, use_safetensors)
if use_async:
from tensornvme.async_file_io import AsyncFileWriter

writer = AsyncFileWriter(open(checkpoint, "wb"), self.N_WRITE_ENTRIES, backend="pthread")
if id(model) not in self.pinned_state_dicts:
self.pinned_state_dicts[id(model)] = create_pinned_state_dict(state_dict)
self.async_writers.append(writer)
move_and_save(writer, state_dict, self.pinned_state_dicts[id(model)])
else:

# save the checkpoint
save_state_dict(state_dict, checkpoint, use_safetensors)

def load_sharded_optimizer(self, optimizer: Optimizer, index_file_path: str, prefix: str):
"""
Expand Down Expand Up @@ -151,6 +167,7 @@ def save_sharded_model(
prefix: Optional[str] = None,
max_shard_size: int = 1024,
use_safetensors: bool = False,
use_async: bool = False,
):
"""
implement this method as it can be supported by Huggingface model,
Expand All @@ -168,16 +185,30 @@ def save_sharded_model(
weights_name, save_index_file = get_model_base_filenames(prefix, use_safetensors)
index_file = CheckpointIndexFile(checkpoint_path)

# Save shards of optimizer states.
# In general cases, is_master is set to True to get the right behavior.
total_size = save_state_dict_shards(
sharded_state_dict=state_dict_shard,
checkpoint=checkpoint_path,
index_file=index_file,
base_filename=weights_name,
is_master=True,
use_safetensors=use_safetensors,
)
if use_async:
pinned_state_dict = self.pinned_state_dicts.get(id(model), None)
total_size, new_pinned_state_dict, writers = async_save_state_dict_shards(
sharded_state_dict=state_dict_shard,
checkpoint=checkpoint_path,
index_file=index_file,
base_filename=weights_name,
is_master=True,
pinned_state_dict=pinned_state_dict,
n_write_entries=self.N_WRITE_ENTRIES,
)
self.pinned_state_dicts[id(model)] = new_pinned_state_dict
self.async_writers.extend(writers)
else:
# Save shards of optimizer states.
# In general cases, is_master is set to True to get the right behavior.
total_size = save_state_dict_shards(
sharded_state_dict=state_dict_shard,
checkpoint=checkpoint_path,
index_file=index_file,
base_filename=weights_name,
is_master=True,
use_safetensors=use_safetensors,
)

index_file.append_meta_data("total_size", total_size)
index_file.write_index_file(save_index_file)
Expand Down
75 changes: 74 additions & 1 deletion colossalai/checkpoint_io/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from collections import defaultdict
from itertools import chain
from pathlib import Path
from typing import Iterator, List, Mapping, Optional, OrderedDict, Tuple
from typing import Dict, Iterator, List, Mapping, Optional, OrderedDict, Tuple

import torch
import torch.nn as nn
Expand All @@ -19,6 +19,7 @@
to_global,
to_global_for_customized_distributed_tensor,
)
from colossalai.utils.safetensors import move_and_save

SAFE_WEIGHTS_NAME = "model.safetensors"
WEIGHTS_NAME = "pytorch_model.bin"
Expand Down Expand Up @@ -263,6 +264,71 @@ def save_state_dict_shards(
return total_size


def async_save_state_dict_shards(
sharded_state_dict: Iterator[Tuple[OrderedDict, int]],
checkpoint: str,
index_file: "CheckpointIndexFile",
base_filename: str,
is_master: bool,
pinned_state_dict: Optional[Dict[str, torch.Tensor]],
n_write_entries: int,
use_pp_format: bool = False,
) -> Tuple[int, Dict[str, torch.Tensor], list]:
"""
Save sharded state dict only on master rank, this method can be used by both model and optimizer states.
Args:
sharded_state_dict (Iterator[Tuple[OrderedDict, int]]): a generator of shards, each shard contains state dict and shard size.
checkpoint (str): The path of checkpoint directory as string.
index_file (CheckpointIndexFile): The index file object to be updated.
base_filename (str): Decides the prefix of filenames of shards.
is_master (bool): Whether current rank is main process.
use_safetensors (bool, optional): Whether to use safetensors to save checkpoint. Defaults to False.
use_pp_format: (bool, optional): Whether to save the files in pipeline format including stage information. Defaults to False.

Returns:
int: the total size of shards
"""
from tensornvme.async_file_io import AsyncFileWriter

total_size = 0
shard_filenames = []
if pinned_state_dict is None:
returned_state_dict = {}
else:
returned_state_dict = pinned_state_dict
writers = []
for idx, shard_pair in enumerate(sharded_state_dict):
shard, current_size = shard_pair
# Just loop over the sharder and gather to other ranks if not master
if not is_master:
del shard
continue
shard_file = get_shard_filename(base_filename, idx)
total_size = total_size + current_size
for key in shard.keys():
index_file.append_weight_map(key, shard_file)
checkpoint_file_path = os.path.join(checkpoint, shard_file)

writer = AsyncFileWriter(open(checkpoint_file_path, "wb"), n_write_entries, backend="pthread")
writers.append(writer)

if pinned_state_dict is not None:
sub_pinned_state_dict = {k: pinned_state_dict[k] for k in shard.keys()}
else:
sub_pinned_state_dict = create_pinned_state_dict(shard)
returned_state_dict.update(sub_pinned_state_dict)

# Only save on master rank.
move_and_save(writer, shard, sub_pinned_state_dict)
shard_filenames.append(shard_file)
del shard

# Clean folder, deleted unneeded files.
clean_folder(checkpoint, base_filename, shard_filenames, is_master=is_master, use_pp_format=use_pp_format)

return total_size, returned_state_dict, writers


def shard_model_checkpoint(state_dict: torch.Tensor, max_shard_size: int = 1024) -> Iterator[Tuple[OrderedDict, int]]:
"""
Splits a model state dictionary in sub-checkpoints so that the final size of each sub-checkpoint does not exceed a
Expand Down Expand Up @@ -799,3 +865,10 @@ def get_shard_filename(weights_name: str, idx: int):
shard_file = weights_name.replace(".bin", f"-{idx+1:05d}.bin")
shard_file = shard_file.replace(".safetensors", f"-{idx+1:05d}.safetensors")
return shard_file


def create_pinned_state_dict(state_dict: Dict[str, torch.Tensor]):
pin_mem = dict()
for name, tensor in state_dict.items():
pin_mem[name] = torch.empty_like(tensor, pin_memory=True, device="cpu")
return pin_mem
Loading