Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
75 changes: 47 additions & 28 deletions colossalai/booster/plugin/gemini_plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,44 +33,40 @@ def __init__(self) -> None:
super().__init__()
self.coordinator = DistCoordinator()

def load_unsharded_model(self, model: GeminiDDP, checkpoint: str, strict: bool = True):
"""
Load model from checkpoint with automatic unwrapping.
"""
# the model should be unwrapped in self.load_model via ModelWrapper.unwrap
return super().load_unsharded_model(model, checkpoint, strict=strict)

def save_unsharded_model(self, model: GeminiDDP, checkpoint: str, gather_dtensor: bool, use_safetensors: bool):
"""
Save model to checkpoint but only on master process.
Save sharded model to checkpoint but only on master process.
The model should be unwrapped in self.load_model via ModelWrapper.unwrap.
As there is communication when getting state dict, this must be called on all processes.
"""
# the model should be unwrapped in self.load_model via ModelWrapper.unwrap
# as there is communication when get state dict, this must be called on all processes
state_dict = model.state_dict(only_rank_0=True)
if self.coordinator.is_master():
save_state_dict(state_dict, checkpoint, use_safetensors)

def save_unsharded_optimizer(self, optimizer: Optimizer, checkpoint: str, gather_dtensor: bool):
def load_unsharded_model(self, model: GeminiDDP, checkpoint: str, strict: bool = True):
"""
Save optimizer to checkpoint but only on master process.
Load model from checkpoint with automatic unwrapping.
The model should be unwrapped in self.load_model via ModelWrapper.unwrap.
"""
# TODO(ver217): optimizer state dict is sharded
warnings.warn('GeminiPlugin does not support save full optimizer checkpoint now. Save it on every process.')
checkpoint = f'{checkpoint}.rank{self.coordinator.rank}'
super().save_unsharded_optimizer(optimizer, checkpoint, gather_dtensor)

def load_optimizer(self, optimizer: Optimizer, checkpoint: str):
warnings.warn(
'GeminiPlugin can only load optimizer checkpoint saved by itself with the same number of processes.')
checkpoint = f'{checkpoint}.rank{self.coordinator.rank}'
super().load_optimizer(optimizer, checkpoint)

def save_lr_scheduler(self, lr_scheduler: LRScheduler, checkpoint: str):
super().load_unsharded_model(model, checkpoint, strict=strict)

def save_unsharded_optimizer(self, optimizer: Optimizer, checkpoint: str, gather_dtensor: bool):
"""
Save model to checkpoint but only on master process.
Save unsharded optimizer state dict to checkpoint.
After calling optimizer.state_dict(), the complete optimizer states will be collected on master rank.
As there is communication when getting state dict, this must be called on all processes.
The saving process will only be executed by master rank.
"""
state_dict = optimizer.state_dict()
if self.coordinator.is_master():
super().save_lr_scheduler(lr_scheduler, checkpoint)
save_state_dict(state_dict, checkpoint, use_safetensors=False)

def load_unsharded_optimizer(self, optimizer: Optimizer, checkpoint: str):
"""
Loading unsharded optimizer from checkpoint file.
For each process, only loading optimizer states of parameters it controls.
"""
super().load_unsharded_optimizer(optimizer, checkpoint)

def save_sharded_model(self,
model: GeminiDDP,
Expand All @@ -82,6 +78,12 @@ def save_sharded_model(self,
"""
Save sharded model
"""
if os.path.isfile(checkpoint_path):
logging.error(f"Provided path ({checkpoint_path}) should be a directory, not a file")
return

Path(checkpoint_path).mkdir(parents=True, exist_ok=True)

state_dict_shard = model.state_dict_shard(max_shard_size=max_shard_size, only_rank_0=True, dtype=torch.float32)
weights_name, save_index_file = get_model_base_filenames(prefix, use_safetensors)
total_size = 0
Expand Down Expand Up @@ -117,6 +119,23 @@ def load_sharded_model(self,
"""
return super().load_sharded_model(model, checkpoint_index_file, strict, use_safetensors, load_sub_module=False)

def save_sharded_optimizer(self, optimizer: Optimizer, checkpoint: Path, gather_dtensor: bool, prefix: str,
size_per_shard: int):
"""
Save sharded optimizer state dict to checkpoint folder.
As there is communication when getting state dict, this must be called on all processes.
"""
Path(checkpoint).mkdir(parents=True, exist_ok=True)
super().save_sharded_optimizer(optimizer, checkpoint, gather_dtensor, prefix, size_per_shard)

def load_sharded_optimizer(self, optimizer: Optimizer, checkpoint_index_file: Path, prefix: str):
"""
Loading sharded optimizer from checkpoint folder, with index file given.
For each process, only loading optimizer states of parameters it controls.
"""
# TODO(Baizhou): To be implemented.
pass


class GeminiModel(ModelWrapper):

Expand Down Expand Up @@ -193,7 +212,7 @@ class GeminiPlugin(DPPluginBase):
which will be used when using hybrid CPU optimizer.
This argument is meaningless when `placement_policy` of `GeminiManager` is not "auto".
Defaults to 0.0.
initial_scale (float, optional): Initial scale used by DynamicGradScaler. Defaults to 2**32.
initial_scale (float, optional): Initial scale used by DynamicGradScaler. Defaults to 2**16.
min_scale (float, optional): Min scale used by DynamicGradScaler. Defaults to 1.
growth_factor (float, optional): growth_factor used by DynamicGradScaler. Defaults to 2.
backoff_factor (float, optional): backoff_factor used by DynamicGradScaler. Defaults to 0.5.
Expand All @@ -219,7 +238,7 @@ def __init__(
min_chunk_size_m: float = 32,
memstats: Optional[MemStats] = None,
gpu_margin_mem_ratio: float = 0.0,
initial_scale: float = 2**32,
initial_scale: float = 2**16,
Comment thread
Fridge003 marked this conversation as resolved.
min_scale: float = 1,
growth_factor: float = 2,
backoff_factor: float = 0.5,
Expand Down
2 changes: 2 additions & 0 deletions colossalai/checkpoint_io/checkpoint_io_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -152,6 +152,7 @@ def load_optimizer(self, optimizer: Optimizer, checkpoint: str, prefix: str = No
names to compose the keys in state_dict. Defaults to None.
size_per_shard (int, optional): Maximum size of checkpoint shard file in MB. This is useful only when ``shard=True``. Defaults to 1024.
"""

index_file_exists, index_file_path = has_index_file(checkpoint)

if Path(checkpoint).is_dir() and not index_file_exists:
Expand Down Expand Up @@ -186,6 +187,7 @@ def save_optimizer(self,
prefix (str): prefix for the optimizer checkpoint when shard = True. Default: None.
size_per_shard (int): size per shard in MB. Default: 1024. This value is only used when shard is set to True.
"""

if shard:
self.save_sharded_optimizer(optimizer, checkpoint, gather_dtensor, prefix, size_per_shard)
else:
Expand Down
14 changes: 9 additions & 5 deletions colossalai/checkpoint_io/general_checkpoint_io.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
shard_model_checkpoint,
shard_optimizer_checkpoint,
sharded_optimizer_loading_epilogue,
unwrap_optimizer,
)

__all__ = ['GeneralCheckpointIO']
Expand Down Expand Up @@ -59,7 +60,7 @@ def load_sharded_optimizer(self, optimizer: Optimizer, index_file_path: str, pre

# If optimizer is wrapped, unwrap it.
if isinstance(optimizer, OptimizerWrapper):
optimizer = optimizer.optim
optimizer = unwrap_optimizer(optimizer)

# Read checkpoint index file.
ckpt_index_file = CheckpointIndexFile.from_file(index_file_path)
Expand Down Expand Up @@ -96,6 +97,11 @@ def save_sharded_optimizer(
- A group file (pytorch_optim_group.bin) recording information of param_groups
- Multiple files (pytorch_optim-000XX.bin) that store state tensors of optimizer in a sharding way
"""

# If optimizer is wrapped, unwrap it.
if isinstance(optimizer, OptimizerWrapper):
optimizer = unwrap_optimizer(optimizer)

if os.path.isfile(checkpoint):
logging.error(f"Provided path ({checkpoint}) should be a directory, not a file")
return
Expand All @@ -121,9 +127,8 @@ def save_sharded_optimizer(
shard, current_size = shard_pair
shard_file = get_shard_filename(states_name, idx)
total_size = total_size + current_size
for param_id in shard.keys():
index_file.append_weight_map(str(param_id), shard_file)

for key in shard.keys():
index_file.append_weight_map(key, shard_file)
checkpoint_file_path = os.path.join(checkpoint, shard_file)
save_state_dict(shard, checkpoint_file_path, use_safetensors=False)

Expand Down Expand Up @@ -177,7 +182,6 @@ def save_sharded_model(self,
total_size = total_size + shard_pair[1]
for key in shard.keys():
index_file.append_weight_map(key, shard_file)

checkpoint_file_path = os.path.join(checkpoint_path, shard_file)
save_state_dict(shard, checkpoint_file_path, use_safetensors)

Expand Down
24 changes: 20 additions & 4 deletions colossalai/checkpoint_io/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@
import torch.nn as nn
from torch.optim import Optimizer

from colossalai.interface import OptimizerWrapper
from colossalai.nn.optimizer import ColossalaiOptimizer
from colossalai.tensor.d_tensor import is_distributed_tensor

SAFE_WEIGHTS_NAME = "model.safetensors"
Expand Down Expand Up @@ -88,6 +90,19 @@ def is_safetensor_checkpoint(checkpoint_file_path: str) -> bool:
# ======================================
# Helper functions for saving shard file
# ======================================
def unwrap_optimizer(optimizer: OptimizerWrapper):
'''
Unwrap a wrapped optimizer.
This method should be used before saving/loading it to/from sharded checkpoints.
'''
Comment thread
Fridge003 marked this conversation as resolved.

# TODO(Baizhou): ColossalaiOptimizer will be replaced with OptimizerWrapper in the future
unwrapped_optim = optimizer.optim
if isinstance(unwrapped_optim, ColossalaiOptimizer):
unwrapped_optim = unwrapped_optim.optim
return unwrapped_optim


def shard_model_checkpoint(state_dict: torch.Tensor, max_shard_size: int = 1024) -> Iterator[Tuple[OrderedDict, int]]:
"""
Splits a model state dictionary in sub-checkpoints so that the final size of each sub-checkpoint does not exceed a
Expand All @@ -103,7 +118,7 @@ def shard_model_checkpoint(state_dict: torch.Tensor, max_shard_size: int = 1024)
weight_size = calculate_tensor_size(weight)

# If this weight is going to tip up over the maximal size, we split.
if current_block_size + weight_size > max_shard_size:
if current_block_size + weight_size > max_shard_size and current_block_size > 0:
ret_block = current_block
ret_block_size = current_block_size
current_block = {}
Expand Down Expand Up @@ -140,9 +155,10 @@ def shard_optimizer_checkpoint(state_dict: dict, max_shard_size: int = 1024) ->
isDTensor = False
for state_tensor in state.values():

# When state_tensor is None (e.g., a SGD optimizer with momentum set to 0),
# When state_tensor is not of Tensor class,
# e.g., a SGD optimizer with momentum set to 0 can have None as state
# The calculation of tensor size should be skipped to avoid error.
if state_tensor is None:
if not isinstance(state_tensor, torch.Tensor):
continue

# If the states are stored as DTensors, mark isDTensor as true.
Expand All @@ -152,7 +168,7 @@ def shard_optimizer_checkpoint(state_dict: dict, max_shard_size: int = 1024) ->

if not isDTensor:

if current_block_size + state_size > max_shard_size:
if current_block_size + state_size > max_shard_size and current_block_size > 0:
ret_block = current_block
ret_block_size = current_block_size
current_block = {}
Expand Down
6 changes: 6 additions & 0 deletions colossalai/interface/optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,3 +119,9 @@ def unscale_grad(self):
"""
raise NotImplementedError(
"The method unscale_grad is only available for optimizers with mixed precision training")

def unwrap(self):
"""
Unwrap the optimizer for checkpoint saving/loading.
"""
return self.optim
64 changes: 48 additions & 16 deletions colossalai/testing/comparison.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from torch import Tensor
from torch.distributed import ProcessGroup
from torch.testing import assert_close
from torch.utils._pytree import tree_flatten


def assert_equal(a: Tensor, b: Tensor):
Expand All @@ -16,7 +17,12 @@ def assert_not_equal(a: Tensor, b: Tensor):


def assert_close_loose(a: Tensor, b: Tensor, rtol: float = 1e-3, atol: float = 1e-3):
assert_close(a, b, rtol=rtol, atol=atol)
assert_close(a,
b,
rtol=rtol,
atol=atol,
msg=f"Tensor not close, shape: {a.shape} vs {b.shape}, \
dtype: {a.dtype} vs {b.dtype}")


def assert_equal_in_group(tensor: Tensor, process_group: ProcessGroup = None):
Expand All @@ -33,25 +39,51 @@ def assert_equal_in_group(tensor: Tensor, process_group: ProcessGroup = None):


def check_state_dict_equal(d1: OrderedDict, d2: OrderedDict, ignore_device: bool = True):
for k, v in d1.items():
if isinstance(v, dict):
check_state_dict_equal(v, d2[k])
elif isinstance(v, list):
for i in range(len(v)):
if isinstance(v[i], torch.Tensor):
assert len(list(d1.keys())) == len(list(d2.keys())), \
f"Number of keys unequal: {len(list(d1.keys()))} vs {len(list(d2.keys()))}"
for k, v1 in d1.items():
assert k in d2
v2 = d2[k]
if isinstance(v1, dict):
assert isinstance(v2, dict)
check_state_dict_equal(v1, v2, ignore_device)
elif isinstance(v1, list):
assert isinstance(v2, list)
for v1_i, v2_i in zip(v1, v2):
if isinstance(v1_i, torch.Tensor):
assert isinstance(v2_i, torch.Tensor)
if not ignore_device:
v[i] = v[i].to("cpu")
d2[k][i] = d2[k][i].to("cpu")
assert torch.equal(v[i], d2[k][i])
v1_i = v1_i.to("cpu")
v2_i = v2_i.to("cpu")
assert_close_loose(v1_i, v2_i)
elif isinstance(v1_i, dict):
assert isinstance(v2_i, dict)
check_state_dict_equal(v1_i, v2_i, ignore_device)
else:
assert v[i] == d2[k][i]
elif isinstance(v, torch.Tensor):
assert v1_i == v2_i, f"{v1_i} not equals to {v2_i}"
elif isinstance(v1, torch.Tensor):
assert isinstance(v2, torch.Tensor)
if not ignore_device:
v = v.to("cpu")
d2[k] = d2[k].to("cpu")
assert torch.equal(v, d2[k])
v1 = v1.to("cpu")
v2 = v2.to("cpu")
assert_close_loose(v1, v2)
else:
assert v == d2[k]
assert v1 == v2, f"{v1} not equals to {v2}"


def check_state_dict_equal_pytree(d1: OrderedDict, d2: OrderedDict, ignore_device: bool = True):
flat_d1, _ = tree_flatten(d1)
flat_d2, _ = tree_flatten(d2)
assert len(flat_d1) == len(flat_d2)
for v1, v2 in zip(flat_d1, flat_d2):
if isinstance(v1, torch.Tensor):
assert isinstance(v2, torch.Tensor)
if not ignore_device:
v1 = v1.to("cpu")
v2 = v2.to("cpu")
assert_close_loose(v1, v2)
else:
assert v1 == v2, f"{v1} not equals to {v2}"


def assert_hf_output_close(out1: Any,
Expand Down
Loading