Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 25 additions & 7 deletions colossalai/booster/booster.py
Original file line number Diff line number Diff line change
Expand Up @@ -288,7 +288,14 @@ def enable_lora(

return self.plugin.enable_lora(model, pretrained_dir, lora_config, bnb_quantization_config)

def load_model(self, model: Union[nn.Module, ModelWrapper], checkpoint: str, strict: bool = True) -> None:
def load_model(
self,
model: Union[nn.Module, ModelWrapper],
checkpoint: str,
strict: bool = True,
low_cpu_mem_mode: bool = True,
num_threads: int = 1,
) -> None:
"""Load model from checkpoint.

Args:
Expand All @@ -298,8 +305,12 @@ def load_model(self, model: Union[nn.Module, ModelWrapper], checkpoint: str, str
strict (bool, optional): whether to strictly enforce that the keys
in :attr:`state_dict` match the keys returned by this module's
:meth:`~torch.nn.Module.state_dict` function. Defaults to True.
low_cpu_mem_mode (bool): whether to load the model in low cpu memory mode. If false, it will use RAM cache to accelerate loading. Default: True.
num_threads (int): number of threads to use when loading the model. Only useful when disabling low cpu mem mode. Default: 1.
"""
self.checkpoint_io.load_model(model, checkpoint, strict)
self.checkpoint_io.load_model(
model, checkpoint, strict, low_cpu_mem_mode=low_cpu_mem_mode, num_threads=num_threads
)

def save_model(
self,
Expand Down Expand Up @@ -338,18 +349,25 @@ def save_model(
use_async=use_async,
)

def load_optimizer(self, optimizer: Optimizer, checkpoint: str) -> None:
def load_optimizer(
self,
optimizer: Optimizer,
checkpoint: str,
low_cpu_mem_mode: bool = True,
num_threads: int = 1,
) -> None:
"""Load optimizer from checkpoint.

Args:
optimizer (Optimizer): An optimizer boosted by Booster.
checkpoint (str): Path to the checkpoint. It must be a local path.
It should be a directory path if the checkpoint is sharded. Otherwise, it should be a file path.
prefix (str, optional): A prefix added to parameter and buffer
names to compose the keys in state_dict. Defaults to None.
size_per_shard (int, optional): Maximum size of checkpoint shard file in MB. This is useful only when ``shard=True``. Defaults to 1024.
low_cpu_mem_mode (bool): whether to load the model in low cpu memory mode. If false, it will use RAM cache to accelerate loading. Default: True.
num_threads (int): number of threads to use when loading the model. Only useful when disabling low cpu mem mode. Default: 1.
"""
self.checkpoint_io.load_optimizer(optimizer, checkpoint)
self.checkpoint_io.load_optimizer(
optimizer, checkpoint, low_cpu_mem_mode=low_cpu_mem_mode, num_threads=num_threads
)

def save_optimizer(
self,
Expand Down
53 changes: 43 additions & 10 deletions colossalai/booster/plugin/gemini_plugin.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
import gc
import os
import random
from pathlib import Path
Expand Down Expand Up @@ -97,13 +96,22 @@ def save_unsharded_model(
else:
save_state_dict(state_dict, checkpoint, use_safetensors)

def load_unsharded_model(self, model: GeminiDDP, checkpoint: str, strict: bool = True):
def load_unsharded_model(
self,
model: GeminiDDP,
checkpoint: str,
strict: bool = True,
low_cpu_mem_mode: bool = True,
num_threads: int = 1,
):
"""
Load model from checkpoint with automatic unwrapping.
The model should be unwrapped in self.load_model via ModelWrapper.unwrap.
"""
assert isinstance(model, GeminiDDP), "Please boost the model before loading!"
super().load_unsharded_model(model, checkpoint, strict=strict)
super().load_unsharded_model(
model, checkpoint, strict=strict, low_cpu_mem_mode=low_cpu_mem_mode, num_threads=num_threads
)

def save_unsharded_optimizer(
self, optimizer: GeminiOptimizer, checkpoint: str, gather_dtensor: bool, use_async: bool = False
Expand Down Expand Up @@ -131,13 +139,17 @@ def save_unsharded_optimizer(
else:
save_state_dict(state_dict, checkpoint, use_safetensors=False)

def load_unsharded_optimizer(self, optimizer: GeminiOptimizer, checkpoint: str):
def load_unsharded_optimizer(
self, optimizer: GeminiOptimizer, checkpoint: str, low_cpu_mem_mode: bool = True, num_threads: int = 1
):
"""
Loading unsharded optimizer from checkpoint file.
For each process, only loading optimizer states of parameters it controls.
"""
assert isinstance(optimizer, GeminiOptimizer), "Please boost the optimizer before loading!"
super().load_unsharded_optimizer(optimizer, checkpoint)
super().load_unsharded_optimizer(
optimizer, checkpoint, low_cpu_mem_mode=low_cpu_mem_mode, num_threads=num_threads
)

def save_sharded_model(
self,
Expand Down Expand Up @@ -206,13 +218,27 @@ def save_sharded_model(
)

def load_sharded_model(
self, model: GeminiDDP, checkpoint_index_file: Path, strict: bool = False, use_safetensors: bool = False
self,
model: GeminiDDP,
checkpoint_index_file: Path,
strict: bool = False,
use_safetensors: bool = False,
low_cpu_mem_mode: bool = True,
num_threads: int = 1,
):
"""
Load shard model, load model from multiple files.
"""
assert isinstance(model, GeminiDDP), "Please boost the model before loading!"
return super().load_sharded_model(model, checkpoint_index_file, strict, use_safetensors, load_sub_module=False)
return super().load_sharded_model(
model,
checkpoint_index_file,
strict,
use_safetensors,
load_sub_module=False,
low_cpu_mem_mode=low_cpu_mem_mode,
num_threads=num_threads,
)

def save_sharded_optimizer(
self,
Expand Down Expand Up @@ -289,7 +315,14 @@ def save_sharded_optimizer(
ranks=[0],
)

def load_sharded_optimizer(self, optimizer: GeminiOptimizer, checkpoint_index_file: Path, prefix: str):
def load_sharded_optimizer(
self,
optimizer: GeminiOptimizer,
checkpoint_index_file: Path,
prefix: str,
low_cpu_mem_mode: bool = True,
num_threads: int = 1,
):
"""
Loading sharded optimizer from checkpoint folder, with index file given.
For each process, only loading optimizer states of parameters it controls.
Expand Down Expand Up @@ -322,9 +355,9 @@ def load_sharded_optimizer(self, optimizer: GeminiOptimizer, checkpoint_index_fi
state_dict_shard = load_flat(shard_file)
else:
state_dict_shard = load_shard_state_dict(Path(shard_file), use_safetensors=False)
if not low_cpu_mem_mode:
state_dict_shard = create_pinned_state_dict(state_dict_shard, empty=False, num_threads=num_threads)
optimizer.load_param_states(state_dict_shard)
del state_dict_shard
gc.collect()

optimizer.optimizer_loading_epilogue()

Expand Down
48 changes: 42 additions & 6 deletions colossalai/booster/plugin/low_level_zero_plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
from colossalai.accelerator import get_accelerator
from colossalai.checkpoint_io import CheckpointIndexFile, CheckpointIO
from colossalai.checkpoint_io.utils import (
create_pinned_state_dict,
get_optimizer_base_filenames,
get_shard_filename,
load_param_groups_into_optimizer,
Expand Down Expand Up @@ -145,14 +146,18 @@ def save_unsharded_optimizer(
else:
save_state_dict(state_dict, checkpoint, use_safetensors=False)

def load_unsharded_optimizer(self, optimizer: OptimizerWrapper, checkpoint: str):
def load_unsharded_optimizer(
self, optimizer: OptimizerWrapper, checkpoint: str, low_cpu_mem_mode: bool = True, num_threads: int = 1
):
use_async = checkpoint.endswith(".safetensors")
if use_async:
from colossalai.utils.safetensors import load_flat

checkpoint = load_flat(checkpoint)
else:
checkpoint = load_state_dict(checkpoint)
if not low_cpu_mem_mode:
checkpoint = create_pinned_state_dict(checkpoint, empty=False, num_threads=num_threads)
optimizer.load_state_dict(checkpoint)

def save_sharded_optimizer(
Expand Down Expand Up @@ -239,7 +244,14 @@ def save_sharded_optimizer(
ranks=[0],
)

def load_sharded_optimizer(self, optimizer: OptimizerWrapper, index_file_path: str, prefix: str):
def load_sharded_optimizer(
self,
optimizer: OptimizerWrapper,
index_file_path: str,
prefix: str,
low_cpu_mem_mode: bool = True,
num_threads: int = 1,
):
"""Load sharded optimizer with the given path to index file.

Args:
Expand Down Expand Up @@ -283,14 +295,28 @@ def load_sharded_optimizer(self, optimizer: OptimizerWrapper, index_file_path: s
if padding_size > 0:
v = torch.nn.functional.pad(v, [0, padding_size])
v_list = v.split(v.numel() // self.coordinator.world_size)
state_dict[param_idx][k] = v_list[self.coordinator.rank].detach().clone()
state_dict[param_idx][k] = v_list[self.coordinator.rank].detach()
if low_cpu_mem_mode:
state_dict[param_idx][k] = state_dict[param_idx][k].clone()

if not low_cpu_mem_mode:
state_dict = create_pinned_state_dict(state_dict, empty=False, num_threads=num_threads)
load_states_into_optimizer(optimizer, state_dict, id_map)
sharded_optimizer_loading_epilogue(optimizer)

def load_unsharded_model(self, model: ModelWrapper, checkpoint: str, strict: bool = True):
def load_unsharded_model(
self,
model: ModelWrapper,
checkpoint: str,
strict: bool = True,
low_cpu_mem_mode: bool = True,
num_threads: int = 1,
):
assert isinstance(model, LowLevelZeroModel), "Please boost the model before loading!"
model._force_wait_all_gather()
super().load_unsharded_model(model, checkpoint, strict)
super().load_unsharded_model(
model, checkpoint, strict, low_cpu_mem_mode=low_cpu_mem_mode, num_threads=num_threads
)
model.update_master_params()

def load_sharded_model(
Expand All @@ -300,10 +326,20 @@ def load_sharded_model(
strict: bool = False,
use_safetensors: bool = False,
load_sub_module: bool = True,
low_cpu_mem_mode: bool = True,
num_threads: int = 1,
):
assert isinstance(model, LowLevelZeroModel), "Please boost the model before loading!"
model._force_wait_all_gather()
super().load_sharded_model(model, checkpoint_index_file, strict, use_safetensors, load_sub_module)
super().load_sharded_model(
model,
checkpoint_index_file,
strict,
use_safetensors,
load_sub_module,
low_cpu_mem_mode=low_cpu_mem_mode,
num_threads=num_threads,
)
model.update_master_params()

def save_unsharded_model(
Expand Down
39 changes: 33 additions & 6 deletions colossalai/booster/plugin/torch_ddp_plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,12 +26,21 @@ def __init__(self) -> None:
self.coordinator = DistCoordinator()
self.logger = get_dist_logger()

def load_unsharded_model(self, model: ModelWrapper, checkpoint: str, strict: bool = True):
def load_unsharded_model(
self,
model: ModelWrapper,
checkpoint: str,
strict: bool = True,
low_cpu_mem_mode: bool = True,
num_threads: int = 1,
):
"""
Load model from checkpoint.
"""
assert isinstance(model, ModelWrapper), "Please boost the model before loading!"
super().load_unsharded_model(model.unwrap(), checkpoint, strict=strict)
super().load_unsharded_model(
model.unwrap(), checkpoint, strict=strict, low_cpu_mem_mode=low_cpu_mem_mode, num_threads=num_threads
)

def save_unsharded_model(
self, model: ModelWrapper, checkpoint: str, gather_dtensor: bool, use_safetensors: bool, use_async: bool = False
Expand All @@ -45,12 +54,16 @@ def save_unsharded_model(
model.unwrap(), checkpoint, gather_dtensor, use_safetensors, use_async=use_async
)

def load_unsharded_optimizer(self, optimizer: OptimizerWrapper, checkpoint: str):
def load_unsharded_optimizer(
self, optimizer: OptimizerWrapper, checkpoint: str, low_cpu_mem_mode: bool = True, num_threads: int = 1
):
"""
Load optimizer from checkpoint.
"""
assert isinstance(optimizer, OptimizerWrapper), "Please boost the optimizer before loading!"
super().load_unsharded_optimizer(optimizer, checkpoint)
super().load_unsharded_optimizer(
optimizer, checkpoint, low_cpu_mem_mode=low_cpu_mem_mode, num_threads=num_threads
)

def save_unsharded_optimizer(
self, optimizer: OptimizerWrapper, checkpoint: str, gather_dtensor: bool, use_async: bool = False
Expand Down Expand Up @@ -101,12 +114,22 @@ def load_sharded_model(
strict: bool = False,
use_safetensors: bool = False,
load_sub_module: bool = True,
low_cpu_mem_mode: bool = True,
num_threads: int = 1,
):
"""
Load model from sharded checkpoint.
"""
assert isinstance(model, ModelWrapper), "Please boost the model before loading!"
super().load_sharded_model(model.unwrap(), checkpoint_index_file, strict, use_safetensors, load_sub_module)
super().load_sharded_model(
model.unwrap(),
checkpoint_index_file,
strict,
use_safetensors,
load_sub_module,
low_cpu_mem_mode=low_cpu_mem_mode,
num_threads=num_threads,
)

def save_sharded_optimizer(
self,
Expand All @@ -131,12 +154,16 @@ def load_sharded_optimizer(
optimizer: Optimizer,
index_file_path: str,
prefix: Optional[str] = None,
low_cpu_mem_mode: bool = True,
num_threads: int = 1,
):
"""
Load optimizer from sharded checkpoint.
"""
assert isinstance(optimizer, OptimizerWrapper), "Please boost the optimizer before loading!"
super().load_sharded_optimizer(optimizer.unwrap(), index_file_path, prefix)
super().load_sharded_optimizer(
optimizer.unwrap(), index_file_path, prefix, low_cpu_mem_mode=low_cpu_mem_mode, num_threads=num_threads
)

def save_lora_as_pretrained(
self, model: Union[nn.Module, ModelWrapper], checkpoint: str, use_safetensors: bool = False
Expand Down
19 changes: 16 additions & 3 deletions colossalai/booster/plugin/torch_fsdp_plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,13 +43,17 @@ def __init__(self) -> None:
self.coordinator = DistCoordinator()
self.logger = get_dist_logger()

def load_unsharded_model(self, model: ModelWrapper, checkpoint: str, strict: bool):
def load_unsharded_model(
self, model: ModelWrapper, checkpoint: str, strict: bool, low_cpu_mem_mode: bool = True, num_threads: int = 1
):
assert isinstance(model, TorchFSDPModel), "Please boost the model before loading!"
model = model.unwrap()
checkpoint = utils.load_state_dict(checkpoint)
model.load_state_dict(checkpoint)

def load_unsharded_optimizer(self, optimizer: OptimizerWrapper, checkpoint: Path):
def load_unsharded_optimizer(
self, optimizer: OptimizerWrapper, checkpoint: Path, low_cpu_mem_mode: bool = True, num_threads: int = 1
):
assert isinstance(optimizer, FSDPOptimizerWrapper), "Please boost the optimizer before loading!"
if checkpoint.endswith(".safetensors"):
checkpoint = load_flat(checkpoint, seperator=".")
Expand Down Expand Up @@ -232,6 +236,8 @@ def load_sharded_model(
strict: bool = False,
use_safetensors: bool = False,
load_sub_module: bool = True,
low_cpu_mem_mode: bool = True,
num_threads: int = 1,
):
"""
Load model to checkpoint but only on master process.
Expand Down Expand Up @@ -354,7 +360,14 @@ def pack_group(group: Dict[str, Any]) -> Dict[str, Any]:
f"index located at {save_index_file}."
)

def load_sharded_optimizer(self, optimizer: Optimizer, index_file_path: str, size_per_shard: int):
def load_sharded_optimizer(
self,
optimizer: Optimizer,
index_file_path: str,
size_per_shard: int,
low_cpu_mem_mode: bool = True,
num_threads: int = 1,
):
"""
Load optimizer to checkpoint but only on master process.
"""
Expand Down
Loading