Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion colossalai/booster/booster.py
Original file line number Diff line number Diff line change
Expand Up @@ -359,6 +359,7 @@ def save_optimizer(
gather_dtensor: bool = True,
prefix: Optional[str] = None,
size_per_shard: int = 1024,
use_async: bool = False,
) -> None:
"""
Save optimizer to checkpoint.
Expand All @@ -374,7 +375,9 @@ def save_optimizer(
names to compose the keys in state_dict. Defaults to None.
size_per_shard (int, optional): Maximum size of checkpoint shard file in MB. This is useful only when ``shard=True``. Defaults to 1024.
"""
self.checkpoint_io.save_optimizer(optimizer, checkpoint, shard, gather_dtensor, prefix, size_per_shard)
self.checkpoint_io.save_optimizer(
optimizer, checkpoint, shard, gather_dtensor, prefix, size_per_shard, use_async=use_async
)

def save_lr_scheduler(self, lr_scheduler: LRScheduler, checkpoint: str) -> None:
"""Save lr scheduler to checkpoint.
Expand Down
12 changes: 10 additions & 2 deletions colossalai/booster/plugin/gemini_plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,9 @@ def load_unsharded_model(self, model: GeminiDDP, checkpoint: str, strict: bool =
assert isinstance(model, GeminiDDP), "Please boost the model before loading!"
super().load_unsharded_model(model, checkpoint, strict=strict)

def save_unsharded_optimizer(self, optimizer: GeminiOptimizer, checkpoint: str, gather_dtensor: bool):
def save_unsharded_optimizer(
self, optimizer: GeminiOptimizer, checkpoint: str, gather_dtensor: bool, use_async: bool = False
):
"""
Save unsharded optimizer state dict to checkpoint.
After calling optimizer.state_dict(), the complete optimizer states will be collected on master rank.
Expand Down Expand Up @@ -178,7 +180,13 @@ def load_sharded_model(
return super().load_sharded_model(model, checkpoint_index_file, strict, use_safetensors, load_sub_module=False)

def save_sharded_optimizer(
self, optimizer: GeminiOptimizer, checkpoint: Path, gather_dtensor: bool, prefix: str, size_per_shard: int
self,
optimizer: GeminiOptimizer,
checkpoint: Path,
gather_dtensor: bool,
prefix: str,
size_per_shard: int,
use_async: bool = False,
):
"""
Save sharded optimizer state dict to checkpoint folder.
Expand Down
65 changes: 58 additions & 7 deletions colossalai/booster/plugin/low_level_zero_plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
get_shard_filename,
load_param_groups_into_optimizer,
load_shard_state_dict,
load_state_dict,
load_states_into_optimizer,
save_param_groups,
save_state_dict,
Expand Down Expand Up @@ -113,7 +114,9 @@ def _hook_context(self):


class LowLevelZeroCheckpointIO(TorchDDPCheckpointIO):
def save_unsharded_optimizer(self, optimizer: OptimizerWrapper, checkpoint: str, gather_dtensor: bool = False):
def save_unsharded_optimizer(
self, optimizer: OptimizerWrapper, checkpoint: str, gather_dtensor: bool = False, use_async: bool = False
):
"""Save optimizer to checkpoint but only on master process.

Args:
Expand All @@ -125,9 +128,34 @@ def save_unsharded_optimizer(self, optimizer: OptimizerWrapper, checkpoint: str,
# the `state_dict` in LowLevelZeroOptimizer has communication
# if only the master rank collect state_dict and save,
# the communication on each rank would not match
state_dict = optimizer.state_dict()
if use_async:
if id(optimizer) not in self.pinned_state_dicts:
self.pinned_state_dicts[id(optimizer)] = {}
pinned_state_dicts = self.pinned_state_dicts[id(optimizer)]
else:
pinned_state_dicts = None
state_dict = optimizer.state_dict(pinned_state_dicts)
if self.coordinator.is_master():
save_state_dict(state_dict, checkpoint, use_safetensors=False)
if use_async:
from tensornvme.async_file_io import AsyncFileWriter

from colossalai.utils.safetensors import save_nested

f_writer = AsyncFileWriter(fp=open(checkpoint, "wb"), n_entries=self.N_WRITE_ENTRIES, backend="pthread")
save_nested(f_writer, state_dict["state"], {"param_groups": state_dict["param_groups"]})
self.async_writers.append(f_writer)
else:
save_state_dict(state_dict, checkpoint, use_safetensors=False)

def load_unsharded_optimizer(self, optimizer: OptimizerWrapper, checkpoint: str):
use_async = checkpoint.endswith(".safetensors")
if use_async:
from colossalai.utils.safetensors import load_flat

checkpoint = load_flat(checkpoint)
else:
checkpoint = load_state_dict(checkpoint)
optimizer.load_state_dict(checkpoint)

def save_sharded_optimizer(
self,
Expand All @@ -136,6 +164,7 @@ def save_sharded_optimizer(
gather_dtensor: bool = False,
prefix: str = None,
size_per_shard: int = 1024,
use_async: bool = False,
):
"""
Save sharded Zero-optimizer checkpoint under the given checkpointing path.
Expand All @@ -161,10 +190,16 @@ def save_sharded_optimizer(
# state_dict only provide only 'param_groups'
state_dict = optimizer.optim.state_dict()
# state shard would be handled by the low-level zero optimizer
sharded_state = optimizer.state_dict_shard(max_shard_size=size_per_shard)
if use_async:
if id(optimizer) not in self.pinned_state_dicts:
self.pinned_state_dicts[id(optimizer)] = {}
pinned_state_dicts = self.pinned_state_dicts[id(optimizer)]
else:
pinned_state_dicts = None
sharded_state = optimizer.state_dict_shard(max_shard_size=size_per_shard, pinned_state_dicts=pinned_state_dicts)

# Preparing file paths and index file.
states_name, save_index_file, param_group_file = get_optimizer_base_filenames(prefix)
states_name, save_index_file, param_group_file = get_optimizer_base_filenames(prefix, use_safetensors=use_async)
index_file = CheckpointIndexFile(checkpoint)
index_file.append_meta_data("param_groups", param_group_file)

Expand All @@ -184,7 +219,18 @@ def save_sharded_optimizer(

checkpoint_file_path = os.path.join(checkpoint, shard_file)
if self.coordinator.is_master():
save_state_dict(shard, checkpoint_file_path, use_safetensors=False)
if use_async:
from tensornvme.async_file_io import AsyncFileWriter

from colossalai.utils.safetensors import save_nested

f_writer = AsyncFileWriter(
fp=open(checkpoint_file_path, "wb"), n_entries=self.N_WRITE_ENTRIES, backend="pthread"
)
save_nested(f_writer, shard)
self.async_writers.append(f_writer)
else:
save_state_dict(shard, checkpoint_file_path, use_safetensors=False)

# Wrap up index file.
index_file.append_meta_data("total_size", total_size)
Expand Down Expand Up @@ -223,7 +269,12 @@ def load_sharded_optimizer(self, optimizer: OptimizerWrapper, index_file_path: s
checkpoint_files, _ = ckpt_index_file.get_checkpoint_filenames()

for shard_file in checkpoint_files:
state_dict = load_shard_state_dict(Path(shard_file), use_safetensors=False)
if shard_file.endswith(".safetensors"):
from colossalai.utils.safetensors import load_flat

state_dict = load_flat(shard_file)
else:
state_dict = load_shard_state_dict(Path(shard_file), use_safetensors=False)
# shard state dict
for param_idx, state in state_dict.items():
for k, v in state.items():
Expand Down
9 changes: 7 additions & 2 deletions colossalai/booster/plugin/torch_ddp_plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,9 @@ def load_unsharded_optimizer(self, optimizer: OptimizerWrapper, checkpoint: str)
assert isinstance(optimizer, OptimizerWrapper), "Please boost the optimizer before loading!"
super().load_unsharded_optimizer(optimizer, checkpoint)

def save_unsharded_optimizer(self, optimizer: OptimizerWrapper, checkpoint: str, gather_dtensor: bool):
def save_unsharded_optimizer(
self, optimizer: OptimizerWrapper, checkpoint: str, gather_dtensor: bool, use_async: bool = False
):
"""
Save optimizer to checkpoint but only on master process.
"""
Expand Down Expand Up @@ -113,13 +115,16 @@ def save_sharded_optimizer(
gather_dtensor: bool = True,
prefix: Optional[str] = None,
size_per_shard: int = 1024,
use_async: bool = False,
):
"""
Save optimizer to sharded checkpoint but only on master process.
"""
assert isinstance(optimizer, OptimizerWrapper), "Please boost the optimizer before saving!"
if self.coordinator.is_master():
super().save_sharded_optimizer(optimizer.unwrap(), checkpoint, gather_dtensor, prefix, size_per_shard)
super().save_sharded_optimizer(
optimizer.unwrap(), checkpoint, gather_dtensor, prefix, size_per_shard, use_async=use_async
)

def load_sharded_optimizer(
self,
Expand Down
12 changes: 10 additions & 2 deletions colossalai/booster/plugin/torch_fsdp_plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,9 @@ def save_unsharded_model(
full_model_state = model.state_dict()
utils.save_state_dict(full_model_state, checkpoint_file_path=checkpoint, use_safetensors=use_safetensors)

def save_unsharded_optimizer(self, optimizer: OptimizerWrapper, checkpoint: str, gather_dtensor: bool):
def save_unsharded_optimizer(
self, optimizer: OptimizerWrapper, checkpoint: str, gather_dtensor: bool, use_async: bool = False
):
"""
Save optimizer to checkpoint but only on master process.
"""
Expand Down Expand Up @@ -157,7 +159,13 @@ def load_sharded_model(
model.unwrap().load_state_dict(fsdp_state_dict, strict=False)

def save_sharded_optimizer(
self, optimizer: Optimizer, checkpoint: str, gather_dtensor: bool, prefix: str, size_per_shard: int
self,
optimizer: Optimizer,
checkpoint: str,
gather_dtensor: bool,
prefix: str,
size_per_shard: int,
use_async: bool = False,
):
"""
Save optimizer to checkpoint but only on master process.
Expand Down
20 changes: 15 additions & 5 deletions colossalai/checkpoint_io/checkpoint_io_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -213,6 +213,7 @@ def save_optimizer(
gather_dtensor=True,
prefix: str = None,
size_per_shard: int = 1024,
use_async: bool = False,
):
"""
Save optimizer to checkpoint. Optimizer states saving is not compatible with safetensors.
Expand All @@ -229,11 +230,12 @@ def save_optimizer(
prefix (str): prefix for the optimizer checkpoint when shard = True. Default: None.
size_per_shard (int): size per shard in MB. Default: 1024. This value is only used when shard is set to True.
"""

if shard:
self.save_sharded_optimizer(optimizer, checkpoint, gather_dtensor, prefix, size_per_shard)
self.save_sharded_optimizer(
optimizer, checkpoint, gather_dtensor, prefix, size_per_shard, use_async=use_async
)
else:
self.save_unsharded_optimizer(optimizer, checkpoint, gather_dtensor)
self.save_unsharded_optimizer(optimizer, checkpoint, gather_dtensor, use_async=use_async)

# ========================================================
# Abstract methods for model loading/saving implementation
Expand Down Expand Up @@ -326,7 +328,13 @@ def load_unsharded_optimizer(self, optimizer: Optimizer, checkpoint: Path):

@abstractmethod
def save_sharded_optimizer(
self, optimizer: Optimizer, checkpoint: Path, gather_dtensor: bool, prefix: str, size_per_shard: int
self,
optimizer: Optimizer,
checkpoint: Path,
gather_dtensor: bool,
prefix: str,
size_per_shard: int,
use_async: bool = False,
):
"""
Save optimizer to sharded checkpoint.
Expand All @@ -340,7 +348,9 @@ def save_sharded_optimizer(
"""

@abstractmethod
def save_unsharded_optimizer(self, optimizer: Optimizer, checkpoint: Path, gather_dtensor: bool):
def save_unsharded_optimizer(
self, optimizer: Optimizer, checkpoint: Path, gather_dtensor: bool, use_async: bool = False
):
"""
Save optimizer to unsharded checkpoint.

Expand Down
2 changes: 2 additions & 0 deletions colossalai/checkpoint_io/general_checkpoint_io.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,7 @@ def save_sharded_optimizer(
gather_dtensor: bool,
prefix: str,
size_per_shard: int,
use_async: bool = False,
):
"""
Save sharded optimizer checkpoint under the given checkpointing path.
Expand Down Expand Up @@ -155,6 +156,7 @@ def save_unsharded_optimizer(
optimizer: Optimizer,
checkpoint: Path,
gather_dtensor: bool,
use_async: bool = False,
):
# TODO(FrankLeeeee): handle distributed tensors
save_state_dict(optimizer.state_dict(), checkpoint, use_safetensors=False)
Expand Down
5 changes: 4 additions & 1 deletion colossalai/checkpoint_io/hybrid_parallel_checkpoint_io.py
Original file line number Diff line number Diff line change
Expand Up @@ -416,6 +416,7 @@ def save_sharded_optimizer(
gather_dtensor: bool = True,
prefix: Optional[str] = None,
size_per_shard: int = 1024,
use_async: bool = False,
):
"""
Save sharded optimizer checkpoint under the given checkpointing path.
Expand Down Expand Up @@ -725,7 +726,9 @@ def load_unsharded_model(self, model: ModelWrapper, checkpoint: str, strict: boo
# Update master params if mixed-precision training is enabled.
model_before_wrapping.update_master_params()

def save_unsharded_optimizer(self, optimizer: OptimizerWrapper, checkpoint: str, gather_dtensor: bool):
def save_unsharded_optimizer(
self, optimizer: OptimizerWrapper, checkpoint: str, gather_dtensor: bool, use_async: bool = False
):
"""
Save optimizer state dict to a file with given path.

Expand Down
9 changes: 8 additions & 1 deletion colossalai/checkpoint_io/moe_checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -369,6 +369,7 @@ def save_sharded_optimizer(
gather_dtensor: bool = True,
prefix: Optional[str] = None,
size_per_shard: int = 1024,
use_async: bool = False,
):
"""
Save sharded optimizer checkpoint under the given checkpointing path.
Expand Down Expand Up @@ -729,7 +730,13 @@ def save_unsharded_model(
dist.barrier()

# Copied from colossalai.moe
def save_unsharded_optimizer(self, optimizer: OptimizerWrapper, checkpoint: str, gather_dtensor: bool):
def save_unsharded_optimizer(
self,
optimizer: OptimizerWrapper,
checkpoint: str,
gather_dtensor: bool,
use_async: bool = False,
):
"""
Save optimizer state dict to a file with given path.

Expand Down
8 changes: 5 additions & 3 deletions colossalai/checkpoint_io/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,9 +24,11 @@
SAFE_WEIGHTS_NAME = "model.safetensors"
WEIGHTS_NAME = "pytorch_model.bin"
STATES_NAME = "pytorch_optim.bin"
SAFE_STATE_NAME = "optimizer.safetensors"
SAFE_WEIGHTS_INDEX_NAME = "model.safetensors.index.json"
WEIGHTS_INDEX_NAME = "pytorch_model.bin.index.json"
STATES_INDEX_NAME = "pytorch_optim.bin.index.json"
SAFE_STATES_INDEX_NAME = "optimizer.safetensors.index.json"
GROUP_FILE_NAME = "pytorch_optim_group.bin"

# ======================================
Expand Down Expand Up @@ -839,14 +841,14 @@ def get_model_base_filenames(prefix: str = None, use_safetensors: bool = False):
return weights_name, save_index_file


def get_optimizer_base_filenames(prefix: str = None):
def get_optimizer_base_filenames(prefix: str = None, use_safetensors: bool = False):
"""
generate base optimizer state filenames
"""
states_name = STATES_NAME
states_name = SAFE_STATE_NAME if use_safetensors else STATES_NAME
states_name = add_prefix(states_name, prefix)

save_index_file = STATES_INDEX_NAME
save_index_file = SAFE_STATES_INDEX_NAME if use_safetensors else STATES_INDEX_NAME
save_index_file = add_prefix(save_index_file, prefix)

param_group_file = GROUP_FILE_NAME
Expand Down
6 changes: 4 additions & 2 deletions colossalai/testing/comparison.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Any, List, OrderedDict
from typing import Any, List, OrderedDict, Tuple

import torch
import torch.distributed as dist
Expand Down Expand Up @@ -78,7 +78,9 @@ def check_state_dict_equal(
v1 = v1.to(v2.dtype)
assert_close_loose(v1, v2)
else:
assert v1 == v2, f"{v1} not equals to {v2}"
if isinstance(v1, Tuple) and not isinstance(v2, Tuple):
v2 = tuple(v2)
assert v1 == v2, f"{v1} not equals to {v2}. {type(v1)}, {type(v2)}"


def check_state_dict_equal_pytree(d1: OrderedDict, d2: OrderedDict, ignore_device: bool = True):
Expand Down
Loading