Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
28 commits
Select commit Hold shift + click to select a range
5c1c06c
gemini plugin add shard checkpoint save/load
Apr 19, 2023
dd2579d
gemini plugin add shard checkpoint save/load
Apr 19, 2023
351d7eb
gemini plugin add shard checkpoint save/load
Apr 19, 2023
a43fae8
gemini plugin add shard checkpoint save/load
Apr 19, 2023
d0ab0a0
gemini plugin add shard checkpoint save/load
Apr 19, 2023
a636b46
gemini plugin add shard checkpoint save/load
Apr 19, 2023
4f9f603
gemini plugin add shard checkpoint save/load
Apr 19, 2023
9d67750
gemini plugin add shard checkpoint save/load
Apr 19, 2023
53bc248
gemini plugin add shard checkpoint save/load
flybird11111 Apr 19, 2023
777ac89
gemini plugin add shard checkpoint save/load
flybird11111 Apr 19, 2023
327c9a3
gemini plugin add shard checkpoint save/load
flybird11111 Apr 20, 2023
a75cc86
gemini plugin add shard checkpoint save/load
flybird11111 Apr 20, 2023
83c5740
gemini plugin add shard checkpoint save/load
flybird11111 Apr 20, 2023
f90afe4
gemini plugin add shard checkpoint save/load
flybird11111 Apr 20, 2023
dd7d03f
gemini plugin support shard checkpoint
flybird11111 Apr 24, 2023
a310915
[API Refactoring]gemini plugin support shard checkpoint
flybird11111 Apr 24, 2023
5f863ef
[API Refactoring]gemini plugin support shard checkpoint
flybird11111 Apr 24, 2023
29184cf
[API Refactoring]gemini plugin support shard checkpoint
flybird11111 Apr 24, 2023
f3f1dca
[API Refactoring]gemini plugin support shard checkpoint
flybird11111 Apr 24, 2023
617756d
[API Refactoring]gemini plugin support shard checkpoint
flybird11111 Apr 24, 2023
5cecee6
[API Refactoring]gemini plugin support shard checkpoint
flybird11111 Apr 26, 2023
4d59978
[API Refactoring]gemini plugin support shard checkpoint
flybird11111 Apr 26, 2023
7aaa096
[API Refactoring]gemini plugin support shard checkpoint
flybird11111 May 4, 2023
1cabb7b
[API Refactoring]gemini plugin support shard checkpoint
flybird11111 May 4, 2023
3cdbda5
[API Refactoring]gemini plugin support shard checkpoint
flybird11111 May 4, 2023
9be4e9c
[API Refactoring]gemini plugin support shard checkpoint
flybird11111 May 5, 2023
fca924c
[API Refactoring]gemini plugin support shard checkpoint
flybird11111 May 5, 2023
6d8270e
[API Refactoring]gemini plugin support shard checkpoint
flybird11111 May 5, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
44 changes: 44 additions & 0 deletions colossalai/booster/plugin/gemini_plugin.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
import random
import warnings
from typing import Callable, List, Optional, Tuple, Union
from pathlib import Path
import os
import logging

import numpy as np
import torch
Expand All @@ -20,6 +23,13 @@
from colossalai.zero import GeminiDDP, zero_model_wrapper, zero_optim_wrapper
from colossalai.zero.gemini.memory_tracer import MemStats

from colossalai.checkpoint_io.utils import (
get_base_filenames,
get_shard_filename
)

from colossalai.checkpoint_io import CheckpointIndexFile

from .plugin_base import Plugin

__all__ = ['GeminiPlugin']
Expand Down Expand Up @@ -62,6 +72,40 @@ def save_lr_scheduler(self, lr_scheduler: LRScheduler, checkpoint: str):
if self.coordinator.is_master():
super().save_lr_scheduler(lr_scheduler, checkpoint)

def save_sharded_model(self, model: GeminiDDP, checkpoint_path: str, gather_dtensor: bool = False, variant: Optional[str] = None, max_shard_size: int = 1024, use_safetensors: bool = False):
"""
Save sharded model
"""
state_dict_shard = model.state_dict_shard(max_shard_size=max_shard_size, only_rank_0=True, dtype=torch.float32)
weights_name, save_index_file = get_base_filenames(variant, use_safetensors)
total_size = 0
index_file = CheckpointIndexFile(checkpoint_path)
for idx, shard_pair in enumerate(state_dict_shard):
if not self.coordinator.is_master():
continue
shard = shard_pair[0]
shard_file = get_shard_filename(weights_name, idx)
total_size = total_size + shard_pair[1]
for key in shard.keys():
index_file.append_weight_map(key, shard_file)

checkpoint_file_path = os.path.join(checkpoint_path, shard_file)
save_state_dict(shard, checkpoint_file_path, use_safetensors)

index_file.append_meta_data("total_size", total_size)
index_file.write_index_file(save_index_file)
logging.info(
f"The model is going to be split to checkpoint shards. "
f"You can find where each parameters has been saved in the "
f"index located at {save_index_file}."
)


def load_sharded_model(self, model: GeminiDDP, checkpoint_index_file: Path, strict: bool = False, use_safetensors: bool = False):
"""
load shard model, load model from multiple files
"""
return super().load_sharded_model(model, checkpoint_index_file, strict, use_safetensors, load_sub_module=False)

class GeminiModel(ModelWrapper):

Expand Down
2 changes: 1 addition & 1 deletion colossalai/checkpoint_io/checkpoint_io_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,7 @@ def load_model(self,
# the existence of index file means it is a sharded checkpoint
ckpt_path = Path(checkpoint)
index_file_exists, index_file_path = has_index_file(checkpoint)

# return the origin model instead of the unwrapped model
origin_model = model

Expand Down
63 changes: 34 additions & 29 deletions colossalai/checkpoint_io/general_checkpoint_io.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,12 @@
from pathlib import Path
from functools import reduce

import torch.nn as nn
from torch.optim import Optimizer
import logging
import os
import json
import gc
from typing import Optional
from typing import Optional, Iterator, OrderedDict, Tuple

from .checkpoint_io_base import CheckpointIO
from .index_file import CheckpointIndexFile
Expand All @@ -18,10 +18,9 @@
shard_checkpoint,
load_shard_state_dict,
load_state_dict_into_model,
add_variant
get_shard_filename,
get_base_filenames
)
from .utils import SAFE_WEIGHTS_NAME, WEIGHTS_NAME, SAFE_WEIGHTS_INDEX_NAME, WEIGHTS_INDEX_NAME


__all__ = ['GeneralCheckpointIO']

Expand Down Expand Up @@ -85,30 +84,32 @@ def save_sharded_model(self, model: nn.Module, checkpoint_path: str, gather_dten

# shard checkpoint
state_dict = model.state_dict()
weights_name = SAFE_WEIGHTS_NAME if use_safetensors else WEIGHTS_NAME
weights_name = add_variant(weights_name, variant)
shards, index = shard_checkpoint(state_dict, max_shard_size=max_shard_size, weights_name=weights_name)

# Save the model
for shard_file, shard in shards.items():
state_dict_shard = shard_checkpoint(state_dict, max_shard_size=max_shard_size)

weights_name, save_index_file = get_base_filenames(variant, use_safetensors)
total_size = 0
index_file = CheckpointIndexFile(checkpoint_path)
for idx, shard_pair in enumerate(state_dict_shard):
shard = shard_pair[0]
shard_file = get_shard_filename(weights_name, idx)
total_size = total_size + shard_pair[1]
for key in shard.keys():
index_file.append_weight_map(key, shard_file)

checkpoint_file_path = os.path.join(checkpoint_path, shard_file)
save_state_dict(shard, checkpoint_file_path, use_safetensors)

# save index file
save_index_file = SAFE_WEIGHTS_INDEX_NAME if use_safetensors else WEIGHTS_INDEX_NAME

save_index_file = os.path.join(checkpoint_path, add_variant(save_index_file, variant))
with open(save_index_file, "w", encoding="utf-8") as f:
content = json.dumps(index, indent=2, sort_keys=True) + "\n"
f.write(content)

index_file.append_meta_data("total_size", total_size)
index_file.write_index_file(save_index_file)
logging.info(
f"The model is bigger than the maximum size per checkpoint ({max_shard_size}) and is going to be "
f"split in {len(shards)} checkpoint shards. You can find where each parameters has been saved in the "
f"The model is going to be split to checkpoint shards. "
f"You can find where each parameters has been saved in the "
f"index located at {save_index_file}."
)


def load_sharded_model(self, model: nn.Module, checkpoint_index_file: Path, strict: bool = False, use_safetensors: bool = False):
def load_sharded_model(self, model: nn.Module, checkpoint_index_file: Path, strict: bool = False,
use_safetensors: bool = False, load_sub_module: bool = True):
"""
load shard model, load model from multiple files
"""
Expand All @@ -122,17 +123,21 @@ def load_sharded_model(self, model: nn.Module, checkpoint_index_file: Path, stri
# read checkpoint index file
ckpt_index_file = CheckpointIndexFile.from_file(checkpoint_index_file)
checkpoint_files, _ = ckpt_index_file.get_checkpoint_fileanames()
missing_keys = ckpt_index_file.get_all_param_names()
missing_keys = []

for shard_file in checkpoint_files:
state_dict = load_shard_state_dict(Path(shard_file), use_safetensors)
load_state_dict_into_model(model, state_dict, missing_keys, strict)
load_state_dict_into_model(model, state_dict, missing_keys, strict, load_sub_module)
del state_dict
gc.collect()

if strict and len(missing_keys) > 0:
error_msgs = 'Missing key(s) in state_dict: {}. '.format(
', '.join('"{}"'.format(k) for k in missing_keys))
raise RuntimeError('Error(s) in loading state_dict for {}:\n\t{}'.format(
self.__class__.__name__, "\n\t".join(error_msgs)))
if strict:
remain_keys = reduce(lambda a, b: a & b, map(set, missing_keys))
if len(remain_keys) > 0:
error_msgs = 'Missing key(s) in state_dict: {}. '.format(
', '.join('"{}"'.format(k) for k in missing_keys))
raise RuntimeError('Error(s) in loading state_dict for {}:\n\t{}'.format(
self.__class__.__name__, "\n\t".join(error_msgs)))



16 changes: 14 additions & 2 deletions colossalai/checkpoint_io/index_file.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
import json
from pathlib import Path
from typing import Any, List, Union
import os
import json

from .utils import is_dtensor_checkpoint

Expand All @@ -18,8 +20,8 @@ class CheckpointIndexFile:
>>> index.export('new_index.json')
"""

def __init__(self) -> None:
self.root_path = None
def __init__(self, root_path=None) -> None:
self.root_path = root_path
self.metadata: dict = dict()
self.weight_map: dict = dict()

Expand Down Expand Up @@ -154,3 +156,13 @@ def get_all_param_names(self):
Get all the weight keys.
"""
return list(self.weight_map.keys())

def write_index_file(self, save_index_file):
"""
Wriete index file.
"""
save_index_file = os.path.join(self.root_path, save_index_file)
index = {"metadata": self.metadata, "weight_map": self.weight_map}
with open(save_index_file, "w", encoding="utf-8") as f:
content = json.dumps(index, indent=2, sort_keys=True) + "\n"
f.write(content)
86 changes: 40 additions & 46 deletions colossalai/checkpoint_io/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from pathlib import Path
import torch
import torch.nn as nn
from typing import List, Dict, Mapping, OrderedDict, Optional, Tuple
from typing import List, Mapping, OrderedDict, Optional, Tuple, Iterator
from colossalai.tensor.d_tensor.d_tensor import DTensor
import re

Expand Down Expand Up @@ -77,55 +77,35 @@ def is_safetensor_checkpoint(checkpoint_file_path: str) -> bool:
# ======================================
# Helper functions for saving shard file
# ======================================
def shard_checkpoint(state_dict: torch.Tensor, max_shard_size: int = 1024, weights_name: str = WEIGHTS_NAME):
def shard_checkpoint(state_dict: torch.Tensor, max_shard_size: int = 1024) -> Iterator[Tuple[OrderedDict, int]]:

"""
Splits a model state dictionary in sub-checkpoints so that the final size of each sub-checkpoint does not exceed a
given size.
"""
sharded_state_dicts = []
current_block = {}
current_block_size = 0
total_size = 0

for key, weight in state_dict.items():
ret_block = None
ret_block_size = 0
if type(weight) != DTensor:
weight_size = calculate_tensor_size(weight)

# If this weight is going to tip up over the maximal size, we split.
if current_block_size + weight_size > max_shard_size:
sharded_state_dicts.append(current_block)
ret_block = current_block
ret_block_size = current_block_size
current_block = {}
current_block_size = 0

current_block[key] = weight
current_block_size += weight_size
total_size += weight_size

if ret_block != None:
yield ret_block, ret_block_size

# Add the last block
sharded_state_dicts.append(current_block)
yield current_block, current_block_size

# If we only have one shard, we return it
if len(sharded_state_dicts) == 1:
return {weights_name: sharded_state_dicts[0]}, None

# Otherwise, let's build the index
weight_map = {}
shards = {}

for idx, shard in enumerate(sharded_state_dicts):
shard_file = weights_name.replace(".bin", f"-{idx+1:05d}-of-{len(sharded_state_dicts):05d}.bin")
shard_file = shard_file.replace(
".safetensors", f"-{idx + 1:05d}-of-{len(sharded_state_dicts):05d}.safetensors"
)
shards[shard_file] = shard
for key in shard.keys():
weight_map[key] = shard_file

# Add the metadata
metadata = {"total_size": total_size}
index = {"metadata": metadata, "weight_map": weight_map}
return shards, index

def load_shard_state_dict(checkpoint_file: Path, use_safetensors: bool =False):
"""
Expand All @@ -146,7 +126,7 @@ def load_shard_state_dict(checkpoint_file: Path, use_safetensors: bool =False):
else:
return torch.load(checkpoint_file)

def load_state_dict_into_model(model: nn.Module, state_dict: torch.Tensor, missing_keys: List, strict: bool = False):
def load_state_dict_into_model(model: nn.Module, state_dict: torch.Tensor, missing_keys: List, strict: bool = False, load_sub_module: bool = True):
r"""Copies parameters and buffers from :attr:`state_dict` into
this module and its descendants.

Expand All @@ -167,29 +147,22 @@ def load_state_dict_into_model(model: nn.Module, state_dict: torch.Tensor, missi
if metadata is not None:
state_dict._metadata = metadata

def load(module: nn.Module, state_dict, prefix=""):
def load(module: nn.Module, state_dict, prefix="", load_sub_module: bool = True):
local_metadata = {} if metadata is None else metadata.get(prefix[:-1], {})
args = (state_dict, prefix, local_metadata, True, [], [], error_msgs)
args = (state_dict, prefix, local_metadata, True, sub_missing_keys, [], error_msgs)
# Parameters of module and children will start with prefix. We can exit early if there are none in this
# state_dict
if len([key for key in state_dict if key.startswith(prefix)]) > 0:
module._load_from_state_dict(*args)
if load_sub_module:
for name, child in module._modules.items():
if child is not None:
load(child, state_dict, prefix + name + ".")

for name, child in module._modules.items():
if child is not None:
load(child, state_dict, prefix + name + ".")

load(model, state_dict, "")
load(model, state_dict, "", load_sub_module)
del load

# deal with missing key
if len(missing_keys) > 0:
deleted_keys = []
for key in missing_keys:
if key not in sub_missing_keys:
deleted_keys.append(key)
for key in deleted_keys:
missing_keys.remove(key)
missing_keys = missing_keys.append(sub_missing_keys)
Comment thread
ver217 marked this conversation as resolved.

if strict:
if len(unexpected_keys) > 0:
Expand Down Expand Up @@ -417,3 +390,24 @@ def add_variant(weights_name: str, variant: Optional[str] = None) -> str:
weights_name = ".".join(splits)

return weights_name


def get_base_filenames(variant: str=None, use_safetensors: bool=False):
"""
generate base weight filenames
"""
weights_name = SAFE_WEIGHTS_NAME if use_safetensors else WEIGHTS_NAME
weights_name = add_variant(weights_name, variant)

save_index_file = SAFE_WEIGHTS_INDEX_NAME if use_safetensors else WEIGHTS_INDEX_NAME
save_index_file = add_variant(save_index_file, variant)

return weights_name, save_index_file

def get_shard_filename(weights_name: str, idx: int):
"""
get shard file name
"""
shard_file = weights_name.replace(".bin", f"-{idx+1:05d}.bin")
shard_file = shard_file.replace(".safetensors", f"-{idx + 1:05d}.safetensors")
return shard_file
Loading