Skip to content
12 changes: 11 additions & 1 deletion colossalai/booster/plugin/torch_ddp_plugin.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Callable, Iterator, List, Tuple, Union
from typing import Callable, Iterator, List, Optional, Tuple, Union

import torch.nn as nn
from torch.nn.parallel import DistributedDataParallel as DDP
Expand Down Expand Up @@ -50,6 +50,16 @@ def save_lr_scheduler(self, lr_scheduler: LRScheduler, checkpoint: str):
if self.coordinator.is_master():
super().save_lr_scheduler(lr_scheduler, checkpoint)

def save_sharded_model(self,
model: nn.Module,
checkpoint_path: str,
gather_dtensor: bool = False,
variant: Optional[str] = None,
max_shard_size: int = 1024,
use_safetensors: bool = False):
if self.coordinator.is_master():
super().save_sharded_model(model, checkpoint_path, gather_dtensor, variant, max_shard_size, use_safetensors)


class TorchDDPModel(ModelWrapper):

Expand Down
6 changes: 2 additions & 4 deletions colossalai/checkpoint_io/checkpoint_io_base.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
from abc import ABC, abstractmethod
from pathlib import Path
from typing import Union
from typing import Optional
from typing import Optional, Union

import torch
import torch.nn as nn
Expand Down Expand Up @@ -84,9 +83,8 @@ def load_model(self,
# containing no distributed tensors, dtensor -> full tensor conversion
# should be done offline via our CLI
# the existence of index file means it is a sharded checkpoint
ckpt_path = Path(checkpoint)
index_file_exists, index_file_path = has_index_file(checkpoint)

# return the origin model instead of the unwrapped model
origin_model = model

Expand Down
62 changes: 36 additions & 26 deletions colossalai/checkpoint_io/utils.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
# coding=utf-8
import re
from pathlib import Path
from typing import Iterator, List, Mapping, Optional, OrderedDict, Tuple

import torch
import torch.nn as nn
from typing import List, Mapping, OrderedDict, Optional, Tuple, Iterator

from colossalai.tensor.d_tensor.d_tensor import DTensor
import re

SAFE_WEIGHTS_NAME = "model.safetensors"
WEIGHTS_NAME = "pytorch_model.bin"
Expand All @@ -15,6 +17,7 @@
# General helper functions
# ======================================


def calculate_tensor_size(tensor: torch.Tensor) -> float:
"""
Calculate the size of a parameter in MB. Used to compute whether a group of params exceed the shard size.
Expand All @@ -28,6 +31,7 @@ def calculate_tensor_size(tensor: torch.Tensor) -> float:
"""
return tensor.numel() * tensor.element_size() / 1024 / 1024


def is_safetensors_available() -> bool:
"""
Check whether safetensors is available.
Expand Down Expand Up @@ -78,7 +82,6 @@ def is_safetensor_checkpoint(checkpoint_file_path: str) -> bool:
# Helper functions for saving shard file
# ======================================
def shard_checkpoint(state_dict: torch.Tensor, max_shard_size: int = 1024) -> Iterator[Tuple[OrderedDict, int]]:

"""
Splits a model state dictionary in sub-checkpoints so that the final size of each sub-checkpoint does not exceed a
given size.
Expand All @@ -100,35 +103,39 @@ def shard_checkpoint(state_dict: torch.Tensor, max_shard_size: int = 1024) -> It
current_block_size = 0
current_block[key] = weight
current_block_size += weight_size

if ret_block != None:
yield ret_block, ret_block_size

yield current_block, current_block_size


def load_shard_state_dict(checkpoint_file: Path, use_safetensors: bool =False):
def load_shard_state_dict(checkpoint_file: Path, use_safetensors: bool = False):
"""
load shard state dict into model
"""
if use_safetensors and not checkpoint_file.suffix == ".safetensors":
raise Exception("load the model using `safetensors`, but no file endwith .safetensors")
if use_safetensors:
from safetensors.torch import safe_open
from safetensors.torch import load_file as safe_load_file
from safetensors.torch import safe_open
with safe_open(checkpoint_file, framework="pt") as f:
metadata = f.metadata()
if metadata["format"] != "pt":
raise NotImplementedError(
f"Conversion from a {metadata['format']} safetensors archive to PyTorch is not implemented yet."
)
f"Conversion from a {metadata['format']} safetensors archive to PyTorch is not implemented yet.")
return safe_load_file(checkpoint_file)
else:
return torch.load(checkpoint_file)

def load_state_dict_into_model(model: nn.Module, state_dict: torch.Tensor, missing_keys: List, strict: bool = False, load_sub_module: bool = True):


def load_state_dict_into_model(model: nn.Module,
state_dict: torch.Tensor,
missing_keys: List,
strict: bool = False,
load_sub_module: bool = True):
r"""Copies parameters and buffers from :attr:`state_dict` into
this module and its descendants.
this module and its descendants.

Args:
state_dict (dict): a dict containing parameters and
Expand Down Expand Up @@ -166,11 +173,12 @@ def load(module: nn.Module, state_dict, prefix="", load_sub_module: bool = True)

if strict:
if len(unexpected_keys) > 0:
error_msgs = 'Unexpected key(s) in state_dict: {}. '.format(
', '.join('"{}"'.format(k) for k in unexpected_keys))
error_msgs = 'Unexpected key(s) in state_dict: {}. '.format(', '.join(
'"{}"'.format(k) for k in unexpected_keys))
raise RuntimeError('Error(s) in loading state_dict for {}:\n\t{}'.format(
model.__class__.__name__, "\n\t".join(error_msgs)))

model.__class__.__name__, "\n\t".join(error_msgs)))


# ======================================
# Helper functions for saving state dict
# ======================================
Expand Down Expand Up @@ -350,6 +358,8 @@ def has_index_file(checkpoint_path: str) -> Tuple[bool, Optional[Path]]:
return True, index_files[0]
else:
return False, None
else:
raise RuntimeError(f'Invalid checkpoint path {checkpoint_path}. Expected a file or a directory.')


def load_state_dict(checkpoint_file_path: Path):
Expand Down Expand Up @@ -380,7 +390,6 @@ def load_state_dict(checkpoint_file_path: Path):
else:
# load with torch
return torch.load(checkpoint_file_path)



def add_variant(weights_name: str, variant: Optional[str] = None) -> str:
Expand All @@ -392,22 +401,23 @@ def add_variant(weights_name: str, variant: Optional[str] = None) -> str:
return weights_name


def get_base_filenames(variant: str=None, use_safetensors: bool=False):
"""
generate base weight filenames
"""
weights_name = SAFE_WEIGHTS_NAME if use_safetensors else WEIGHTS_NAME
weights_name = add_variant(weights_name, variant)
def get_base_filenames(variant: str = None, use_safetensors: bool = False):
"""
generate base weight filenames
"""
weights_name = SAFE_WEIGHTS_NAME if use_safetensors else WEIGHTS_NAME
weights_name = add_variant(weights_name, variant)

save_index_file = SAFE_WEIGHTS_INDEX_NAME if use_safetensors else WEIGHTS_INDEX_NAME
save_index_file = add_variant(save_index_file, variant)

save_index_file = SAFE_WEIGHTS_INDEX_NAME if use_safetensors else WEIGHTS_INDEX_NAME
save_index_file = add_variant(save_index_file, variant)
return weights_name, save_index_file

return weights_name, save_index_file

def get_shard_filename(weights_name: str, idx: int):
"""
get shard file name
"""
shard_file = weights_name.replace(".bin", f"-{idx+1:05d}.bin")
shard_file = shard_file.replace(".safetensors", f"-{idx + 1:05d}.safetensors")
return shard_file
return shard_file
Original file line number Diff line number Diff line change
Expand Up @@ -11,9 +11,9 @@
from tests.kit.model_zoo import model_zoo

# These models are not compatible with AMP
_AMP_ERR_MODELS = ['timm_convit', 'dlrm', 'deepfm_interactionarch', 'deepfm_simpledeepfmnn`']
_AMP_ERR_MODELS = ['timm_convit', 'dlrm', 'deepfm_interactionarch', 'deepfm_simpledeepfmnn']
# These models have no parameters
_LOW_LEVEL_ZERO_ERR_MODELS = ['dlrm_interactionarch']
_LOW_LEVEL_ZERO_ERR_MODELS = ['dlrm_interactionarch', 'deepfm_overarch', 'deepfm_sparsearch', 'dlrm_sparsearch']
# These models will get stuck
_STUCK_MODELS = [
'diffusers_vq_model', 'transformers_albert', 'transformers_albert_for_pretraining', 'transformers_bert',
Expand Down Expand Up @@ -67,6 +67,7 @@ def check_low_level_zero_plugin(stage: int, early_stop: bool = True):
skipped_models.append(name)
continue
err = run_fn(stage, model_fn, data_gen_fn, output_transform_fn)

torch.cuda.empty_cache()

if err is None:
Expand All @@ -91,7 +92,7 @@ def run_dist(rank, world_size, port, early_stop: bool = True):

@rerun_if_address_is_in_use()
def test_low_level_zero_plugin(early_stop: bool = True):
spawn(run_dist, 2, early_stop=early_stop)
spawn(run_dist, 4, early_stop=early_stop)


if __name__ == '__main__':
Expand Down
50 changes: 33 additions & 17 deletions tests/test_checkpoint_io/test_torch_ddp_checkpoint_io.py
Original file line number Diff line number Diff line change
@@ -1,19 +1,20 @@
import tempfile

import torch
import torch.distributed as dist
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.optim import SGD
from torchvision.models import resnet18

import colossalai
from colossalai.booster import Booster
from colossalai.booster.plugin import TorchDDPPlugin
from colossalai.booster.plugin.torch_ddp_plugin import TorchDDPCheckpointIO
from colossalai.interface import OptimizerWrapper
from colossalai.testing import check_state_dict_equal, rerun_if_address_is_in_use, spawn
from colossalai.testing import check_state_dict_equal, parameterize, rerun_if_address_is_in_use, spawn


def check_torch_ddp_checkpointIO():
@parameterize('shard', [True, False])
def check_torch_ddp_checkpointIO(shard: bool):
plugin = TorchDDPPlugin()
booster = Booster(plugin=plugin)
model = resnet18()
Expand All @@ -34,23 +35,38 @@ def check_torch_ddp_checkpointIO():
optimizer.step()
scheduler.step()

optimizer_ckpt_tempfile = tempfile.NamedTemporaryFile()
lr_scheduler_ckpt_tempfile = tempfile.NamedTemporaryFile()
ckpt_io = TorchDDPCheckpointIO()
ckpt_io.save_optimizer(optimizer, optimizer_ckpt_tempfile.name)
ckpt_io.save_lr_scheduler(scheduler, lr_scheduler_ckpt_tempfile.name)
with tempfile.TemporaryDirectory() as tempdir:
obj = [tempdir]
dist.broadcast_object_list(obj, src=0)
tempdir = obj[0] # use the same directory on all ranks

new_model = resnet18()
new_optimizer = SGD((new_model.parameters()), lr=0.001)
new_scheduler = torch.optim.lr_scheduler.StepLR(new_optimizer, step_size=1, gamma=0.1)
_, new_optimizer, _, _, new_scheduler = booster.boost(new_model, new_optimizer, lr_scheduler=new_scheduler)
model_ckpt_path = f"{tempdir}/model"
optimizer_ckpt_path = f"{tempdir}/optimizer"
lr_scheduler_ckpt_path = f"{tempdir}/lr_scheduler"
booster.save_model(model, model_ckpt_path, shard=shard)
if not shard:
# TODO(ver217): optimizer checkpointing is not supported for sharded checkpoint
booster.save_optimizer(optimizer, optimizer_ckpt_path)
booster.save_lr_scheduler(scheduler, lr_scheduler_ckpt_path)
dist.barrier()

if ckpt_io.coordinator.is_master():
ckpt_io.load_optimizer(new_optimizer, optimizer_ckpt_tempfile.name)
check_state_dict_equal(optimizer.state_dict(), new_optimizer.state_dict(), False)
new_model = resnet18()
new_optimizer = SGD((new_model.parameters()), lr=0.001)
new_scheduler = torch.optim.lr_scheduler.StepLR(new_optimizer, step_size=1, gamma=0.1)
new_model, new_optimizer, _, _, new_scheduler = booster.boost(new_model,
new_optimizer,
lr_scheduler=new_scheduler)

ckpt_io.load_lr_scheduler(new_scheduler, lr_scheduler_ckpt_tempfile.name)
check_state_dict_equal(scheduler.state_dict(), new_scheduler.state_dict(), False)
booster.load_model(new_model, model_ckpt_path)
check_state_dict_equal(model.state_dict(), new_model.state_dict(), False)

if not shard:
booster.load_optimizer(new_optimizer, optimizer_ckpt_path)
check_state_dict_equal(optimizer.state_dict(), new_optimizer.state_dict(), False)
booster.load_lr_scheduler(new_scheduler, lr_scheduler_ckpt_path)
check_state_dict_equal(scheduler.state_dict(), new_scheduler.state_dict(), False)

dist.barrier()


def run_dist(rank, world_size, port):
Expand Down