From 5c1c06c670930632d0b4e11234011ab9d8d39d25 Mon Sep 17 00:00:00 2001 From: luchen Date: Wed, 19 Apr 2023 13:26:36 +0800 Subject: [PATCH 01/28] gemini plugin add shard checkpoint save/load --- colossalai/booster/plugin/gemini_plugin.py | 8 ++++ .../checkpoint_io/general_checkpoint_io.py | 42 +++++++--------- colossalai/checkpoint_io/utils.py | 46 ++++++++++++++++-- colossalai/zero/gemini/gemini_ddp.py | 26 ++++++---- .../test_general_checkpoint_io.py | 48 ++++++++++++++++++- .../test_zeroddp_state_dict_shard.py | 6 ++- 6 files changed, 134 insertions(+), 42 deletions(-) diff --git a/colossalai/booster/plugin/gemini_plugin.py b/colossalai/booster/plugin/gemini_plugin.py index deda00d8a7b3..e08976b5b589 100644 --- a/colossalai/booster/plugin/gemini_plugin.py +++ b/colossalai/booster/plugin/gemini_plugin.py @@ -1,6 +1,7 @@ import random import warnings from typing import Callable, List, Optional, Tuple, Union +from pathlib import Path import numpy as np import torch @@ -62,6 +63,13 @@ def save_lr_scheduler(self, lr_scheduler: LRScheduler, checkpoint: str): if self.coordinator.is_master(): super().save_lr_scheduler(lr_scheduler, checkpoint) + def save_sharded_model(self, model: GeminiDDP, checkpoint_path: str, gather_dtensor: bool = False, variant: Optional[str] = None, max_shard_size: int = 1024, use_safetensors: bool = False): + state_dict_shard = model.state_dict_shard(max_shard_size=max_shard_size, only_rank_0=False) + if self.coordinator.is_master(): + super().save_gemini_shard_ckp(state_dict_shard, checkpoint_path, gather_dtensor, variant, use_safetensors) + + def load_sharded_model(self, model: nn.Module, checkpoint_index_file: Path, strict: bool = False, use_safetensors: bool = False): + return super().load_sharded_model(model, checkpoint_index_file, strict, use_safetensors) class GeminiModel(ModelWrapper): diff --git a/colossalai/checkpoint_io/general_checkpoint_io.py b/colossalai/checkpoint_io/general_checkpoint_io.py index bf584f45d045..0c31c3e25ad7 100644 --- a/colossalai/checkpoint_io/general_checkpoint_io.py +++ b/colossalai/checkpoint_io/general_checkpoint_io.py @@ -6,7 +6,7 @@ import os import json import gc -from typing import Optional +from typing import Optional, Iterator, OrderedDict from .checkpoint_io_base import CheckpointIO from .index_file import CheckpointIndexFile @@ -18,9 +18,9 @@ shard_checkpoint, load_shard_state_dict, load_state_dict_into_model, - add_variant + build_index, + write_model_files ) -from .utils import SAFE_WEIGHTS_NAME, WEIGHTS_NAME, SAFE_WEIGHTS_INDEX_NAME, WEIGHTS_INDEX_NAME __all__ = ['GeneralCheckpointIO'] @@ -85,27 +85,10 @@ def save_sharded_model(self, model: nn.Module, checkpoint_path: str, gather_dten # shard checkpoint state_dict = model.state_dict() - weights_name = SAFE_WEIGHTS_NAME if use_safetensors else WEIGHTS_NAME - weights_name = add_variant(weights_name, variant) - shards, index = shard_checkpoint(state_dict, max_shard_size=max_shard_size, weights_name=weights_name) - - # Save the model - for shard_file, shard in shards.items(): - checkpoint_file_path = os.path.join(checkpoint_path, shard_file) - save_state_dict(shard, checkpoint_file_path, use_safetensors) - - # save index file - save_index_file = SAFE_WEIGHTS_INDEX_NAME if use_safetensors else WEIGHTS_INDEX_NAME - - save_index_file = os.path.join(checkpoint_path, add_variant(save_index_file, variant)) - with open(save_index_file, "w", encoding="utf-8") as f: - content = json.dumps(index, indent=2, sort_keys=True) + "\n" - f.write(content) - logging.info( - f"The model is bigger than the maximum size per checkpoint ({max_shard_size}) and is going to be " - f"split in {len(shards)} checkpoint shards. You can find where each parameters has been saved in the " - f"index located at {save_index_file}." - ) + sharded_state_dicts, total_size = shard_checkpoint(state_dict, max_shard_size=max_shard_size) + # let's build the index + shards, shards_index = build_index(sharded_state_dicts, total_size, use_safetensors, variant) + write_model_files(shards, shards_index, checkpoint_path, use_safetensors) def load_sharded_model(self, model: nn.Module, checkpoint_index_file: Path, strict: bool = False, use_safetensors: bool = False): @@ -136,3 +119,14 @@ def load_sharded_model(self, model: nn.Module, checkpoint_index_file: Path, stri raise RuntimeError('Error(s) in loading state_dict for {}:\n\t{}'.format( self.__class__.__name__, "\n\t".join(error_msgs))) + def save_gemini_shard_ckp(self, state_dict_shard: Iterator[OrderedDict], checkpoint_path: str, gather_dtensor: bool = False, variant: Optional[str] = None, use_safetensors: bool = False): + # gather all shards + sharded_state_dicts = [] + total_size = 0 + for shard, s_size in state_dict_shard: + sharded_state_dicts = sharded_state_dicts.append(shard) + total_size = total_size + s_size + + shards, shards_index = build_index(sharded_state_dicts, total_size, use_safetensors, variant) + write_model_files(shards, shards_index, checkpoint_path, use_safetensors) + diff --git a/colossalai/checkpoint_io/utils.py b/colossalai/checkpoint_io/utils.py index 37d22d08df40..9b779ea7ac46 100644 --- a/colossalai/checkpoint_io/utils.py +++ b/colossalai/checkpoint_io/utils.py @@ -2,9 +2,12 @@ from pathlib import Path import torch import torch.nn as nn -from typing import List, Dict, Mapping, OrderedDict, Optional, Tuple +from typing import List, Mapping, OrderedDict, Optional, Tuple from colossalai.tensor.d_tensor.d_tensor import DTensor import re +import os +import json +import logging SAFE_WEIGHTS_NAME = "model.safetensors" WEIGHTS_NAME = "pytorch_model.bin" @@ -77,7 +80,7 @@ def is_safetensor_checkpoint(checkpoint_file_path: str) -> bool: # ====================================== # Helper functions for saving shard file # ====================================== -def shard_checkpoint(state_dict: torch.Tensor, max_shard_size: int = 1024, weights_name: str = WEIGHTS_NAME): +def shard_checkpoint(state_dict: torch.Tensor, max_shard_size: int = 1024): """ Splits a model state dictionary in sub-checkpoints so that the final size of each sub-checkpoint does not exceed a @@ -105,14 +108,22 @@ def shard_checkpoint(state_dict: torch.Tensor, max_shard_size: int = 1024, weigh # Add the last block sharded_state_dicts.append(current_block) + return sharded_state_dicts, total_size + + +def build_index(sharded_state_dicts: List[OrderedDict], total_size: int, use_safetensors: bool, variant: str): # If we only have one shard, we return it + weights_name = SAFE_WEIGHTS_NAME if use_safetensors else WEIGHTS_NAME + weights_name = add_variant(weights_name, variant) + + save_index_file = SAFE_WEIGHTS_INDEX_NAME if use_safetensors else WEIGHTS_INDEX_NAME + save_index_file = add_variant(save_index_file, variant) + if len(sharded_state_dicts) == 1: return {weights_name: sharded_state_dicts[0]}, None - # Otherwise, let's build the index weight_map = {} shards = {} - for idx, shard in enumerate(sharded_state_dicts): shard_file = weights_name.replace(".bin", f"-{idx+1:05d}-of-{len(sharded_state_dicts):05d}.bin") shard_file = shard_file.replace( @@ -125,7 +136,9 @@ def shard_checkpoint(state_dict: torch.Tensor, max_shard_size: int = 1024, weigh # Add the metadata metadata = {"total_size": total_size} index = {"metadata": metadata, "weight_map": weight_map} - return shards, index + shards_index = {save_index_file: index} + return shards, shards_index + def load_shard_state_dict(checkpoint_file: Path, use_safetensors: bool =False): """ @@ -417,3 +430,26 @@ def add_variant(weights_name: str, variant: Optional[str] = None) -> str: weights_name = ".".join(splits) return weights_name + + +def write_model_files(shards: dict, shards_index: dict, checkpoint_path: str, use_safetensors: bool = False): + # Save the model + for shard_file, shard in shards.items(): + checkpoint_file_path = os.path.join(checkpoint_path, shard_file) + save_state_dict(shard, checkpoint_file_path, use_safetensors) + + # when it only has one shard, index is None + if shards_index == None: + return + + save_index_file = next(iter(shards_index)) + index = shards_index[save_index_file] + save_index_file = os.path.join(checkpoint_path, save_index_file) + with open(save_index_file, "w", encoding="utf-8") as f: + content = json.dumps(index, indent=2, sort_keys=True) + "\n" + f.write(content) + logging.info( + f"The model is going to be split in {len(shards)} checkpoint shards. " + f"You can find where each parameters has been saved in the " + f"index located at {save_index_file}." + ) \ No newline at end of file diff --git a/colossalai/zero/gemini/gemini_ddp.py b/colossalai/zero/gemini/gemini_ddp.py index e151f1aefb2d..5a76d43a1baa 100644 --- a/colossalai/zero/gemini/gemini_ddp.py +++ b/colossalai/zero/gemini/gemini_ddp.py @@ -1,7 +1,7 @@ import itertools from collections import OrderedDict from functools import partial -from typing import Dict, Iterator, List, Optional, Union +from typing import Dict, Iterator, List, Optional, Union, Tuple import torch import torch.distributed as dist @@ -583,8 +583,12 @@ def state_dict_shard(self, prefix: str = '', keep_vars: bool = False, max_shard_size: int = 1024, +<<<<<<< Updated upstream only_rank_0: bool = True, dtype: torch.dtype = torch.float16) -> Iterator[OrderedDict]: +======= + only_rank_0: bool = True) -> Iterator[Tuple[OrderedDict, int]]: +>>>>>>> Stashed changes """Returns dictionaries containing a whole state of the module one by one. The max size of dictionary shard is specified by ``max_shard_size``. Both parameters and persistent buffers (e.g. running averages) are included. @@ -624,9 +628,9 @@ def state_dict_shard(self, gathered_param_buffer.update(self._get_chunk_to_save_data(chunk, only_rank_0, dtype)) gathered_param = gathered_param_buffer.pop(fp32_param) - block = sharder.append(prefix + name, gathered_param) + block, block_size = sharder.append(prefix + name, gathered_param) if block is not None: - yield block + yield block, block_size del fp16_to_fp32 del gathered_param_buffer @@ -635,19 +639,19 @@ def state_dict_shard(self, for name, buf in self.named_buffers(): if buf is not None and name not in self._non_persistent_buffers_set: buffer = buf if keep_vars else buf.detach() - block = sharder.append(prefix + name, buffer) + block, block_size = sharder.append(prefix + name, buffer) if block is not None: - yield block + yield block, block_size # save extra states extra_state_key = prefix + _EXTRA_STATE_KEY_SUFFIX if getattr(self.__class__, "get_extra_state", torch.nn.Module.get_extra_state) is not torch.nn.Module.get_extra_state: extra_state = self.get_extra_state() - block = sharder.append(extra_state_key, extra_state) + block, block_size = sharder.append(extra_state_key, extra_state) if block is not None: - yield block + yield block, block_size - yield sharder.current_block + yield sharder.current_block, sharder.current_block_size class _StateDictSharder: @@ -657,16 +661,18 @@ def __init__(self, max_shard_size: int) -> None: self.current_block = OrderedDict() self.current_block_size = 0 - def append(self, name: str, tensor: torch.Tensor) -> Optional[OrderedDict]: + def append(self, name: str, tensor: torch.Tensor) -> Tuple[Optional[OrderedDict], int]: tensor_size = calculate_tensor_size(tensor) ret_block = None + ret_block_size = 0 if self.current_block_size + tensor_size > self.max_shard_size: ret_block = self.current_block + ret_block_size = self.current_block_size self.current_block = OrderedDict() self.current_block_size = 0 self.current_block[name] = tensor self.current_block_size += tensor_size - return ret_block + return ret_block, ret_block_size class GeminiDDP(ZeroDDP): diff --git a/tests/test_checkpoint_io/test_general_checkpoint_io.py b/tests/test_checkpoint_io/test_general_checkpoint_io.py index ca5ce10054f7..947f906f72f4 100644 --- a/tests/test_checkpoint_io/test_general_checkpoint_io.py +++ b/tests/test_checkpoint_io/test_general_checkpoint_io.py @@ -11,6 +11,14 @@ from colossalai.checkpoint_io import GeneralCheckpointIO from colossalai.testing import clear_cache_before_run, parameterize +import colossalai +from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn +from colossalai.utils.cuda import get_current_device +from colossalai.zero import ColoInitContext, ZeroDDP +from colossalai.zero.gemini.chunk import ChunkManager, search_chunk_configuration +from colossalai.zero.gemini.gemini_mgr import GeminiManager +from tests.components_to_test.registry import non_distributed_component_funcs + # ======== # Note: # 1. due to checkpoint IO can be quite slow if tested with all models, we will only test on resnet for now @@ -83,7 +91,6 @@ def test_sharded_checkpoint(use_safetensors: bool): suffix = ".bin" WEIGHTS_INDEX_NAME = "model.bin.index.json" - # model_ckpt_dir = tempfile.TemporaryDirectory(suffix=suffix) model_ckpt_dir = tempfile.TemporaryDirectory() optimizer_ckpt_tempfile = tempfile.NamedTemporaryFile() @@ -105,6 +112,45 @@ def test_sharded_checkpoint(use_safetensors: bool): recursive_check(optimizer.state_dict(), new_optimizer.state_dict()) +@parameterize('placement_policy', ['cuda', 'cpu']) +@parameterize('model_name', ['bert']) +@parameterize('use_safetensors', [True, False]) +def exam_state_dict(placement_policy, model_name: str, use_safetensors: bool): + get_components_func = non_distributed_component_funcs.get_callable(model_name) + model_builder, *_ = get_components_func() + + with ColoInitContext(device=get_current_device()): + model = model_builder() + new_model = model_builder() + + config_dict, *_ = search_chunk_configuration(model, search_range_mb=1, search_interval_byte=100) + chunk_manager = ChunkManager(config_dict) + gemini_manager = GeminiManager(placement_policy, chunk_manager) + model = ZeroDDP(model, gemini_manager) + model.train() + + #save model + model_ckpt_dir = tempfile.TemporaryDirectory() + ckpt_io = GeneralCheckpointIO() + ckpt_io.save_model(model, model_ckpt_dir.name, True, True, "", 10, use_safetensors=use_safetensors) + + # load model + new_chunk_manager = ChunkManager(config_dict) + new_gemini_manager = GeminiManager(placement_policy, new_chunk_manager) + new_model = ZeroDDP(new_model, new_gemini_manager) + ckpt_io.load_model(new_model, str(model_ckpt_dir.name), strict=True) + + model_dict, _ = model.state_dict_shard(max_shard_size=10, only_rank_0=False) + accumulated_keys = set() + # ensure number of shards > 1 + for shard, _ in new_model.state_dict_shard(max_shard_size=10, only_rank_0=False): + for key, value in shard.items(): + assert key not in accumulated_keys, f"key `{key}` is duplicated." + accumulated_keys.add(key) + assert key in model_dict, f"{key} not in ZeRO dictionary." + assert torch.equal(value, model_dict[key]), f"{key} not equal." + + # do recursive check for the optimizer state dict # if the value is a dict, compare its values # if the value is a list, comapre all elements one-by-one diff --git a/tests/test_zero/test_gemini/test_zeroddp_state_dict_shard.py b/tests/test_zero/test_gemini/test_zeroddp_state_dict_shard.py index 96c26a1de4df..ff17edca8994 100644 --- a/tests/test_zero/test_gemini/test_zeroddp_state_dict_shard.py +++ b/tests/test_zero/test_gemini/test_zeroddp_state_dict_shard.py @@ -30,14 +30,16 @@ def exam_state_dict(placement_policy, model_name: str): zero_dict = model.state_dict(only_rank_0=False) accumulated_keys = set() + total_size = 0 # ensure number of shards > 1 - for shard in model.state_dict_shard(max_shard_size=(model_size / 3), only_rank_0=False): + for shard, s_size in model.state_dict_shard(max_shard_size=(model_size / 3), only_rank_0=False): + total_size = total_size + s_size for key, value in shard.items(): assert key not in accumulated_keys, f"key `{key}` is duplicated." accumulated_keys.add(key) assert key in zero_dict, f"{key} not in ZeRO dictionary." assert torch.equal(value, zero_dict[key]), f"{key} not equal." - + assert total_size == model_size def run_dist(rank, world_size, port): config = {} From dd2579d63836512f898e58ef3c547469362247a0 Mon Sep 17 00:00:00 2001 From: luchen Date: Wed, 19 Apr 2023 13:29:15 +0800 Subject: [PATCH 02/28] gemini plugin add shard checkpoint save/load --- colossalai/zero/gemini/gemini_ddp.py | 6 +----- 1 file changed, 1 insertion(+), 5 deletions(-) diff --git a/colossalai/zero/gemini/gemini_ddp.py b/colossalai/zero/gemini/gemini_ddp.py index 5a76d43a1baa..1acef4f64ddf 100644 --- a/colossalai/zero/gemini/gemini_ddp.py +++ b/colossalai/zero/gemini/gemini_ddp.py @@ -583,12 +583,8 @@ def state_dict_shard(self, prefix: str = '', keep_vars: bool = False, max_shard_size: int = 1024, -<<<<<<< Updated upstream only_rank_0: bool = True, - dtype: torch.dtype = torch.float16) -> Iterator[OrderedDict]: -======= - only_rank_0: bool = True) -> Iterator[Tuple[OrderedDict, int]]: ->>>>>>> Stashed changes + dtype: torch.dtype = torch.float16) -> Iterator[Tuple[OrderedDict, int]]: """Returns dictionaries containing a whole state of the module one by one. The max size of dictionary shard is specified by ``max_shard_size``. Both parameters and persistent buffers (e.g. running averages) are included. From 351d7ebefa5dee95b9f97085717abe66eaf69b8a Mon Sep 17 00:00:00 2001 From: luchen Date: Wed, 19 Apr 2023 15:02:52 +0800 Subject: [PATCH 03/28] gemini plugin add shard checkpoint save/load --- .../checkpoint_io/general_checkpoint_io.py | 18 ++++++++++++++++-- colossalai/checkpoint_io/index_file.py | 1 + 2 files changed, 17 insertions(+), 2 deletions(-) diff --git a/colossalai/checkpoint_io/general_checkpoint_io.py b/colossalai/checkpoint_io/general_checkpoint_io.py index 0c31c3e25ad7..2036f028fd25 100644 --- a/colossalai/checkpoint_io/general_checkpoint_io.py +++ b/colossalai/checkpoint_io/general_checkpoint_io.py @@ -18,9 +18,23 @@ shard_checkpoint, load_shard_state_dict, load_state_dict_into_model, - build_index, - write_model_files + add_variant ) +from .utils import SAFE_WEIGHTS_NAME, WEIGHTS_NAME, SAFE_WEIGHTS_INDEX_NAME, WEIGHTS_INDEX_NAME + +# from checkpoint_io_base import CheckpointIO +# from index_file import CheckpointIndexFile +# from utils import ( +# has_index_file, +# load_state_dict, +# save_state_dict, +# is_safetensors_available, +# shard_checkpoint, +# load_shard_state_dict, +# load_state_dict_into_model, +# build_index, +# write_model_files +# ) __all__ = ['GeneralCheckpointIO'] diff --git a/colossalai/checkpoint_io/index_file.py b/colossalai/checkpoint_io/index_file.py index 89224787a91b..828000e709e4 100644 --- a/colossalai/checkpoint_io/index_file.py +++ b/colossalai/checkpoint_io/index_file.py @@ -3,6 +3,7 @@ from typing import Any, List, Union from .utils import is_dtensor_checkpoint +# from utils import is_dtensor_checkpoint __all__ = ['CheckpointIndexFile'] From a43fae85a37d1fd6353ed3aa0c8c1764fa0eb76e Mon Sep 17 00:00:00 2001 From: luchen Date: Wed, 19 Apr 2023 16:36:17 +0800 Subject: [PATCH 04/28] gemini plugin add shard checkpoint save/load --- .../checkpoint_io/general_checkpoint_io.py | 37 ++++++------------- colossalai/checkpoint_io/index_file.py | 1 - 2 files changed, 11 insertions(+), 27 deletions(-) diff --git a/colossalai/checkpoint_io/general_checkpoint_io.py b/colossalai/checkpoint_io/general_checkpoint_io.py index 2036f028fd25..ed7af16e7b83 100644 --- a/colossalai/checkpoint_io/general_checkpoint_io.py +++ b/colossalai/checkpoint_io/general_checkpoint_io.py @@ -18,24 +18,9 @@ shard_checkpoint, load_shard_state_dict, load_state_dict_into_model, - add_variant + build_index, + write_model_files ) -from .utils import SAFE_WEIGHTS_NAME, WEIGHTS_NAME, SAFE_WEIGHTS_INDEX_NAME, WEIGHTS_INDEX_NAME - -# from checkpoint_io_base import CheckpointIO -# from index_file import CheckpointIndexFile -# from utils import ( -# has_index_file, -# load_state_dict, -# save_state_dict, -# is_safetensors_available, -# shard_checkpoint, -# load_shard_state_dict, -# load_state_dict_into_model, -# build_index, -# write_model_files -# ) - __all__ = ['GeneralCheckpointIO'] @@ -133,14 +118,14 @@ def load_sharded_model(self, model: nn.Module, checkpoint_index_file: Path, stri raise RuntimeError('Error(s) in loading state_dict for {}:\n\t{}'.format( self.__class__.__name__, "\n\t".join(error_msgs))) - def save_gemini_shard_ckp(self, state_dict_shard: Iterator[OrderedDict], checkpoint_path: str, gather_dtensor: bool = False, variant: Optional[str] = None, use_safetensors: bool = False): - # gather all shards - sharded_state_dicts = [] - total_size = 0 - for shard, s_size in state_dict_shard: - sharded_state_dicts = sharded_state_dicts.append(shard) - total_size = total_size + s_size + # def save_gemini_shard_ckp(self, state_dict_shard: Iterator[OrderedDict], checkpoint_path: str, gather_dtensor: bool = False, variant: Optional[str] = None, use_safetensors: bool = False): + # # gather all shards + # sharded_state_dicts = [] + # total_size = 0 + # for shard, s_size in state_dict_shard: + # sharded_state_dicts = sharded_state_dicts.append(shard) + # total_size = total_size + s_size - shards, shards_index = build_index(sharded_state_dicts, total_size, use_safetensors, variant) - write_model_files(shards, shards_index, checkpoint_path, use_safetensors) + # shards, shards_index = build_index(sharded_state_dicts, total_size, use_safetensors, variant) + # write_model_files(shards, shards_index, checkpoint_path, use_safetensors) diff --git a/colossalai/checkpoint_io/index_file.py b/colossalai/checkpoint_io/index_file.py index 828000e709e4..89224787a91b 100644 --- a/colossalai/checkpoint_io/index_file.py +++ b/colossalai/checkpoint_io/index_file.py @@ -3,7 +3,6 @@ from typing import Any, List, Union from .utils import is_dtensor_checkpoint -# from utils import is_dtensor_checkpoint __all__ = ['CheckpointIndexFile'] From d0ab0a02e7e0846d9890f2221ac92e75b41b9112 Mon Sep 17 00:00:00 2001 From: luchen Date: Wed, 19 Apr 2023 17:20:44 +0800 Subject: [PATCH 05/28] gemini plugin add shard checkpoint save/load --- .../checkpoint_io/general_checkpoint_io.py | 28 ++++++------ colossalai/checkpoint_io/utils.py | 43 ++++++++++++------- 2 files changed, 43 insertions(+), 28 deletions(-) diff --git a/colossalai/checkpoint_io/general_checkpoint_io.py b/colossalai/checkpoint_io/general_checkpoint_io.py index ed7af16e7b83..c24143f3e0b2 100644 --- a/colossalai/checkpoint_io/general_checkpoint_io.py +++ b/colossalai/checkpoint_io/general_checkpoint_io.py @@ -6,7 +6,8 @@ import os import json import gc -from typing import Optional, Iterator, OrderedDict +from typing import Optional, Iterator, OrderedDict, Tuple +import itertools from .checkpoint_io_base import CheckpointIO from .index_file import CheckpointIndexFile @@ -84,9 +85,15 @@ def save_sharded_model(self, model: nn.Module, checkpoint_path: str, gather_dten # shard checkpoint state_dict = model.state_dict() - sharded_state_dicts, total_size = shard_checkpoint(state_dict, max_shard_size=max_shard_size) + state_dict_shard = shard_checkpoint(state_dict, max_shard_size=max_shard_size) + + + # copy a duplicated iterator to get the total number of shards + state_dict_shard_tee = itertools.tee(state_dict_shard, 2) + shards_total_num = sum(1 for _ in state_dict_shard_tee[0]) + # let's build the index - shards, shards_index = build_index(sharded_state_dicts, total_size, use_safetensors, variant) + shards, shards_index = build_index(state_dict_shard_tee[1], shards_total_num, use_safetensors, variant) write_model_files(shards, shards_index, checkpoint_path, use_safetensors) @@ -118,14 +125,11 @@ def load_sharded_model(self, model: nn.Module, checkpoint_index_file: Path, stri raise RuntimeError('Error(s) in loading state_dict for {}:\n\t{}'.format( self.__class__.__name__, "\n\t".join(error_msgs))) - # def save_gemini_shard_ckp(self, state_dict_shard: Iterator[OrderedDict], checkpoint_path: str, gather_dtensor: bool = False, variant: Optional[str] = None, use_safetensors: bool = False): - # # gather all shards - # sharded_state_dicts = [] - # total_size = 0 - # for shard, s_size in state_dict_shard: - # sharded_state_dicts = sharded_state_dicts.append(shard) - # total_size = total_size + s_size + def save_gemini_shard_ckp(self, state_dict_shard: Iterator[Tuple[OrderedDict, int]], checkpoint_path: str, gather_dtensor: bool = False, variant: Optional[str] = None, use_safetensors: bool = False): - # shards, shards_index = build_index(sharded_state_dicts, total_size, use_safetensors, variant) - # write_model_files(shards, shards_index, checkpoint_path, use_safetensors) + # copy a duplicated iterator to get the total number of shards + state_dict_shard_tee = itertools.tee(state_dict_shard, 2) + shards_total_num = sum(1 for _ in state_dict_shard_tee[0]) + shards, shards_index = build_index(state_dict_shard[1], shards_total_num, use_safetensors, variant) + write_model_files(shards, shards_index, checkpoint_path, use_safetensors) diff --git a/colossalai/checkpoint_io/utils.py b/colossalai/checkpoint_io/utils.py index 9b779ea7ac46..58a027445588 100644 --- a/colossalai/checkpoint_io/utils.py +++ b/colossalai/checkpoint_io/utils.py @@ -2,7 +2,7 @@ from pathlib import Path import torch import torch.nn as nn -from typing import List, Mapping, OrderedDict, Optional, Tuple +from typing import List, Mapping, OrderedDict, Optional, Tuple, Iterator from colossalai.tensor.d_tensor.d_tensor import DTensor import re import os @@ -80,38 +80,44 @@ def is_safetensor_checkpoint(checkpoint_file_path: str) -> bool: # ====================================== # Helper functions for saving shard file # ====================================== -def shard_checkpoint(state_dict: torch.Tensor, max_shard_size: int = 1024): +def shard_checkpoint(state_dict: torch.Tensor, max_shard_size: int = 1024) -> Iterator[Tuple[OrderedDict, int]]: """ Splits a model state dictionary in sub-checkpoints so that the final size of each sub-checkpoint does not exceed a given size. """ - sharded_state_dicts = [] + # sharded_state_dicts = [] current_block = {} current_block_size = 0 - total_size = 0 + # total_size = 0 for key, weight in state_dict.items(): + ret_block = None + ret_block_size = 0 if type(weight) != DTensor: weight_size = calculate_tensor_size(weight) # If this weight is going to tip up over the maximal size, we split. if current_block_size + weight_size > max_shard_size: - sharded_state_dicts.append(current_block) + ret_block = current_block + ret_block_size = current_block_size + # sharded_state_dicts.append(current_block) current_block = {} current_block_size = 0 current_block[key] = weight current_block_size += weight_size - total_size += weight_size + # total_size += weight_size + if ret_block != None: + yield ret_block, ret_block_size # Add the last block - sharded_state_dicts.append(current_block) + # sharded_state_dicts.append(current_block) - return sharded_state_dicts, total_size + yield current_block, current_block_size -def build_index(sharded_state_dicts: List[OrderedDict], total_size: int, use_safetensors: bool, variant: str): +def build_index(state_dict_shard: Iterator[Tuple[OrderedDict, int]], shards_total_num: int, use_safetensors: bool, variant: str): # If we only have one shard, we return it weights_name = SAFE_WEIGHTS_NAME if use_safetensors else WEIGHTS_NAME weights_name = add_variant(weights_name, variant) @@ -119,19 +125,24 @@ def build_index(sharded_state_dicts: List[OrderedDict], total_size: int, use_saf save_index_file = SAFE_WEIGHTS_INDEX_NAME if use_safetensors else WEIGHTS_INDEX_NAME save_index_file = add_variant(save_index_file, variant) - if len(sharded_state_dicts) == 1: - return {weights_name: sharded_state_dicts[0]}, None + if shards_total_num == 1: + # return {weights_name: sharded_state_dicts[0]}, None + # print("bbbb", next(state_dict_shard)) + return {weights_name: next(state_dict_shard)[0]}, None weight_map = {} shards = {} - for idx, shard in enumerate(sharded_state_dicts): - shard_file = weights_name.replace(".bin", f"-{idx+1:05d}-of-{len(sharded_state_dicts):05d}.bin") + total_size = 0 + # shard_pair is like (shard, shard_size) + for idx, shard_pair in enumerate(state_dict_shard): + shard_file = weights_name.replace(".bin", f"-{idx+1:05d}-of-{shards_total_num:05d}.bin") shard_file = shard_file.replace( - ".safetensors", f"-{idx + 1:05d}-of-{len(sharded_state_dicts):05d}.safetensors" + ".safetensors", f"-{idx + 1:05d}-of-{shards_total_num:05d}.safetensors" ) - shards[shard_file] = shard - for key in shard.keys(): + shards[shard_file] = shard_pair[0] + for key in shard_pair[0].keys(): weight_map[key] = shard_file + total_size = total_size + shard_pair[1] # Add the metadata metadata = {"total_size": total_size} From a636b469617a3287867f0dc5297a469d49f05f54 Mon Sep 17 00:00:00 2001 From: luchen Date: Wed, 19 Apr 2023 17:24:53 +0800 Subject: [PATCH 06/28] gemini plugin add shard checkpoint save/load --- colossalai/checkpoint_io/utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/colossalai/checkpoint_io/utils.py b/colossalai/checkpoint_io/utils.py index 58a027445588..1e3fd353f626 100644 --- a/colossalai/checkpoint_io/utils.py +++ b/colossalai/checkpoint_io/utils.py @@ -133,7 +133,7 @@ def build_index(state_dict_shard: Iterator[Tuple[OrderedDict, int]], shards_tota weight_map = {} shards = {} total_size = 0 - # shard_pair is like (shard, shard_size) + # shard_pair like (shard, shard_size) for idx, shard_pair in enumerate(state_dict_shard): shard_file = weights_name.replace(".bin", f"-{idx+1:05d}-of-{shards_total_num:05d}.bin") shard_file = shard_file.replace( From 4f9f603693cf9568f957ed6aef4a3baf4c7b35f6 Mon Sep 17 00:00:00 2001 From: luchen Date: Wed, 19 Apr 2023 17:26:29 +0800 Subject: [PATCH 07/28] gemini plugin add shard checkpoint save/load --- colossalai/checkpoint_io/utils.py | 11 +---------- 1 file changed, 1 insertion(+), 10 deletions(-) diff --git a/colossalai/checkpoint_io/utils.py b/colossalai/checkpoint_io/utils.py index 1e3fd353f626..4d783d40bb63 100644 --- a/colossalai/checkpoint_io/utils.py +++ b/colossalai/checkpoint_io/utils.py @@ -86,10 +86,8 @@ def shard_checkpoint(state_dict: torch.Tensor, max_shard_size: int = 1024) -> It Splits a model state dictionary in sub-checkpoints so that the final size of each sub-checkpoint does not exceed a given size. """ - # sharded_state_dicts = [] current_block = {} current_block_size = 0 - # total_size = 0 for key, weight in state_dict.items(): ret_block = None @@ -101,18 +99,13 @@ def shard_checkpoint(state_dict: torch.Tensor, max_shard_size: int = 1024) -> It if current_block_size + weight_size > max_shard_size: ret_block = current_block ret_block_size = current_block_size - # sharded_state_dicts.append(current_block) current_block = {} current_block_size = 0 - current_block[key] = weight current_block_size += weight_size - # total_size += weight_size - + if ret_block != None: yield ret_block, ret_block_size - # Add the last block - # sharded_state_dicts.append(current_block) yield current_block, current_block_size @@ -126,8 +119,6 @@ def build_index(state_dict_shard: Iterator[Tuple[OrderedDict, int]], shards_tota save_index_file = add_variant(save_index_file, variant) if shards_total_num == 1: - # return {weights_name: sharded_state_dicts[0]}, None - # print("bbbb", next(state_dict_shard)) return {weights_name: next(state_dict_shard)[0]}, None weight_map = {} From 9d677506118205fc338efdfa89ed1db52495770c Mon Sep 17 00:00:00 2001 From: luchen Date: Wed, 19 Apr 2023 17:27:34 +0800 Subject: [PATCH 08/28] gemini plugin add shard checkpoint save/load --- tests/test_checkpoint_io/test_general_checkpoint_io.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/test_checkpoint_io/test_general_checkpoint_io.py b/tests/test_checkpoint_io/test_general_checkpoint_io.py index 947f906f72f4..2a9abdc633b8 100644 --- a/tests/test_checkpoint_io/test_general_checkpoint_io.py +++ b/tests/test_checkpoint_io/test_general_checkpoint_io.py @@ -142,7 +142,7 @@ def exam_state_dict(placement_policy, model_name: str, use_safetensors: bool): model_dict, _ = model.state_dict_shard(max_shard_size=10, only_rank_0=False) accumulated_keys = set() - # ensure number of shards > 1 + for shard, _ in new_model.state_dict_shard(max_shard_size=10, only_rank_0=False): for key, value in shard.items(): assert key not in accumulated_keys, f"key `{key}` is duplicated." From 53bc2482289812b2abcd88f75c625c20344de018 Mon Sep 17 00:00:00 2001 From: Mingyan Jiang <1829166702@qq.com> Date: Wed, 19 Apr 2023 22:47:33 +0800 Subject: [PATCH 09/28] gemini plugin add shard checkpoint save/load --- .../checkpoint_io/checkpoint_io_base.py | 2 +- .../checkpoint_io/general_checkpoint_io.py | 3 +- pytest.ini | 6 ++- .../test_general_checkpoint_io.py | 51 +++++++++++++------ .../test_zeroddp_state_dict_shard.py | 1 - 5 files changed, 44 insertions(+), 19 deletions(-) diff --git a/colossalai/checkpoint_io/checkpoint_io_base.py b/colossalai/checkpoint_io/checkpoint_io_base.py index 3f8b0b0a6b47..1b1d1dc88b2d 100644 --- a/colossalai/checkpoint_io/checkpoint_io_base.py +++ b/colossalai/checkpoint_io/checkpoint_io_base.py @@ -86,7 +86,7 @@ def load_model(self, # the existence of index file means it is a sharded checkpoint ckpt_path = Path(checkpoint) index_file_exists, index_file_path = has_index_file(checkpoint) - + # assert index_file_exists == True # return the origin model instead of the unwrapped model origin_model = model diff --git a/colossalai/checkpoint_io/general_checkpoint_io.py b/colossalai/checkpoint_io/general_checkpoint_io.py index c24143f3e0b2..a434a6606aba 100644 --- a/colossalai/checkpoint_io/general_checkpoint_io.py +++ b/colossalai/checkpoint_io/general_checkpoint_io.py @@ -130,6 +130,7 @@ def save_gemini_shard_ckp(self, state_dict_shard: Iterator[Tuple[OrderedDict, in # copy a duplicated iterator to get the total number of shards state_dict_shard_tee = itertools.tee(state_dict_shard, 2) shards_total_num = sum(1 for _ in state_dict_shard_tee[0]) - shards, shards_index = build_index(state_dict_shard[1], shards_total_num, use_safetensors, variant) + logging.warning("shards_total_num {shards_total_num}") + shards, shards_index = build_index(state_dict_shard_tee[1], shards_total_num, use_safetensors, variant) write_model_files(shards, shards_index, checkpoint_path, use_safetensors) diff --git a/pytest.ini b/pytest.ini index ac31ace4bfae..86997db62667 100644 --- a/pytest.ini +++ b/pytest.ini @@ -1,6 +1,10 @@ [pytest] +log_cli = true +log_cli_level=DEBUG +log_cli_format= %(asctime)s %(levelname)s %(message)s +log_cli_date_format = %Y-%m-%d %H:%M:%S markers = cpu: tests which can run on CPU gpu: tests which requires a single GPU dist: tests which are run in a multi-GPU or multi-machine environment - experiment: tests for experimental features \ No newline at end of file + experiment: tests for experimental features diff --git a/tests/test_checkpoint_io/test_general_checkpoint_io.py b/tests/test_checkpoint_io/test_general_checkpoint_io.py index 2a9abdc633b8..351fd430690d 100644 --- a/tests/test_checkpoint_io/test_general_checkpoint_io.py +++ b/tests/test_checkpoint_io/test_general_checkpoint_io.py @@ -7,8 +7,12 @@ from pathlib import Path import os import subprocess +import logging +import pathlib +import shutil from colossalai.checkpoint_io import GeneralCheckpointIO +from colossalai.booster.plugin.gemini_plugin import GeminiCheckpointIO from colossalai.testing import clear_cache_before_run, parameterize import colossalai @@ -116,6 +120,7 @@ def test_sharded_checkpoint(use_safetensors: bool): @parameterize('model_name', ['bert']) @parameterize('use_safetensors', [True, False]) def exam_state_dict(placement_policy, model_name: str, use_safetensors: bool): + logging.info("aaaa") get_components_func = non_distributed_component_funcs.get_callable(model_name) model_builder, *_ = get_components_func() @@ -129,26 +134,34 @@ def exam_state_dict(placement_policy, model_name: str, use_safetensors: bool): model = ZeroDDP(model, gemini_manager) model.train() - #save model model_ckpt_dir = tempfile.TemporaryDirectory() - ckpt_io = GeneralCheckpointIO() - ckpt_io.save_model(model, model_ckpt_dir.name, True, True, "", 10, use_safetensors=use_safetensors) + ckpt_io = GeminiCheckpointIO() + + model_size = sum(p.numel() * p.element_size() for p in model.parameters()) / 1024**2 + ckpt_io.save_model(model, model_ckpt_dir.name, True, True, "epoch", (model_size / 3), use_safetensors=use_safetensors) # load model - new_chunk_manager = ChunkManager(config_dict) - new_gemini_manager = GeminiManager(placement_policy, new_chunk_manager) - new_model = ZeroDDP(new_model, new_gemini_manager) - ckpt_io.load_model(new_model, str(model_ckpt_dir.name), strict=True) + if ckpt_io.coordinator.is_master(): + ckpt_io.load_model(new_model, model_ckpt_dir.name, strict=True) + model.to('cpu') + model_dict = model.state_dict(only_rank_0=True) + new_model.to('cpu') + new_model_dict = new_model.state_dict() + # recursive_check(model_dict, new_model_dict) + model_ckpt_dir.cleanup() + + +def run_dist(rank, world_size, port): + config = {} + colossalai.launch(config=config, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') + exam_state_dict() - model_dict, _ = model.state_dict_shard(max_shard_size=10, only_rank_0=False) - accumulated_keys = set() - for shard, _ in new_model.state_dict_shard(max_shard_size=10, only_rank_0=False): - for key, value in shard.items(): - assert key not in accumulated_keys, f"key `{key}` is duplicated." - accumulated_keys.add(key) - assert key in model_dict, f"{key} not in ZeRO dictionary." - assert torch.equal(value, model_dict[key]), f"{key} not equal." +@pytest.mark.dist +@pytest.mark.parametrize('world_size', [1, 4]) +@rerun_if_address_is_in_use() +def test_gemini_ckpIO(world_size): + spawn(run_dist, world_size) # do recursive check for the optimizer state dict @@ -163,10 +176,18 @@ def recursive_check(d1, d2): elif isinstance(v, list): for i in range(len(v)): if isinstance(v[i], torch.Tensor): + v[i].to("cpu") + d2[k][i].to("cpu") + # assert v[i].device == "cpu" + # assert d2[k][i].device == "cpu" assert torch.equal(v[i], d2[k][i]) else: assert v[i] == d2[k][i] elif isinstance(v, torch.Tensor): + v.to("cpu") + d2[k].to("cpu") + # assert v.device == "cpu" + # assert d2[k].device == "cpu" assert torch.equal(v, d2[k]) else: assert v == d2[k] diff --git a/tests/test_zero/test_gemini/test_zeroddp_state_dict_shard.py b/tests/test_zero/test_gemini/test_zeroddp_state_dict_shard.py index ff17edca8994..eb09eb2af340 100644 --- a/tests/test_zero/test_gemini/test_zeroddp_state_dict_shard.py +++ b/tests/test_zero/test_gemini/test_zeroddp_state_dict_shard.py @@ -39,7 +39,6 @@ def exam_state_dict(placement_policy, model_name: str): accumulated_keys.add(key) assert key in zero_dict, f"{key} not in ZeRO dictionary." assert torch.equal(value, zero_dict[key]), f"{key} not equal." - assert total_size == model_size def run_dist(rank, world_size, port): config = {} From 777ac8989d4f00e8762b69747b1d47b469f5019b Mon Sep 17 00:00:00 2001 From: Mingyan Jiang <1829166702@qq.com> Date: Wed, 19 Apr 2023 22:58:27 +0800 Subject: [PATCH 10/28] gemini plugin add shard checkpoint save/load --- colossalai/checkpoint_io/checkpoint_io_base.py | 1 - colossalai/checkpoint_io/general_checkpoint_io.py | 1 - pytest.ini | 4 ---- tests/test_checkpoint_io/test_general_checkpoint_io.py | 5 ----- 4 files changed, 11 deletions(-) diff --git a/colossalai/checkpoint_io/checkpoint_io_base.py b/colossalai/checkpoint_io/checkpoint_io_base.py index 1b1d1dc88b2d..e5bfcc68ac71 100644 --- a/colossalai/checkpoint_io/checkpoint_io_base.py +++ b/colossalai/checkpoint_io/checkpoint_io_base.py @@ -86,7 +86,6 @@ def load_model(self, # the existence of index file means it is a sharded checkpoint ckpt_path = Path(checkpoint) index_file_exists, index_file_path = has_index_file(checkpoint) - # assert index_file_exists == True # return the origin model instead of the unwrapped model origin_model = model diff --git a/colossalai/checkpoint_io/general_checkpoint_io.py b/colossalai/checkpoint_io/general_checkpoint_io.py index a434a6606aba..d6974a888e87 100644 --- a/colossalai/checkpoint_io/general_checkpoint_io.py +++ b/colossalai/checkpoint_io/general_checkpoint_io.py @@ -130,7 +130,6 @@ def save_gemini_shard_ckp(self, state_dict_shard: Iterator[Tuple[OrderedDict, in # copy a duplicated iterator to get the total number of shards state_dict_shard_tee = itertools.tee(state_dict_shard, 2) shards_total_num = sum(1 for _ in state_dict_shard_tee[0]) - logging.warning("shards_total_num {shards_total_num}") shards, shards_index = build_index(state_dict_shard_tee[1], shards_total_num, use_safetensors, variant) write_model_files(shards, shards_index, checkpoint_path, use_safetensors) diff --git a/pytest.ini b/pytest.ini index 86997db62667..01e5cd217c5d 100644 --- a/pytest.ini +++ b/pytest.ini @@ -1,8 +1,4 @@ [pytest] -log_cli = true -log_cli_level=DEBUG -log_cli_format= %(asctime)s %(levelname)s %(message)s -log_cli_date_format = %Y-%m-%d %H:%M:%S markers = cpu: tests which can run on CPU gpu: tests which requires a single GPU diff --git a/tests/test_checkpoint_io/test_general_checkpoint_io.py b/tests/test_checkpoint_io/test_general_checkpoint_io.py index 351fd430690d..e63a679d5542 100644 --- a/tests/test_checkpoint_io/test_general_checkpoint_io.py +++ b/tests/test_checkpoint_io/test_general_checkpoint_io.py @@ -120,7 +120,6 @@ def test_sharded_checkpoint(use_safetensors: bool): @parameterize('model_name', ['bert']) @parameterize('use_safetensors', [True, False]) def exam_state_dict(placement_policy, model_name: str, use_safetensors: bool): - logging.info("aaaa") get_components_func = non_distributed_component_funcs.get_callable(model_name) model_builder, *_ = get_components_func() @@ -178,16 +177,12 @@ def recursive_check(d1, d2): if isinstance(v[i], torch.Tensor): v[i].to("cpu") d2[k][i].to("cpu") - # assert v[i].device == "cpu" - # assert d2[k][i].device == "cpu" assert torch.equal(v[i], d2[k][i]) else: assert v[i] == d2[k][i] elif isinstance(v, torch.Tensor): v.to("cpu") d2[k].to("cpu") - # assert v.device == "cpu" - # assert d2[k].device == "cpu" assert torch.equal(v, d2[k]) else: assert v == d2[k] From 327c9a3a05f060e999a41da443b7d1ccda35233d Mon Sep 17 00:00:00 2001 From: Mingyan Jiang <1829166702@qq.com> Date: Thu, 20 Apr 2023 10:55:49 +0800 Subject: [PATCH 11/28] gemini plugin add shard checkpoint save/load --- .../test_general_checkpoint_io.py | 17 +++++++++-------- 1 file changed, 9 insertions(+), 8 deletions(-) diff --git a/tests/test_checkpoint_io/test_general_checkpoint_io.py b/tests/test_checkpoint_io/test_general_checkpoint_io.py index e63a679d5542..6293edb9e817 100644 --- a/tests/test_checkpoint_io/test_general_checkpoint_io.py +++ b/tests/test_checkpoint_io/test_general_checkpoint_io.py @@ -8,7 +8,6 @@ import os import subprocess import logging -import pathlib import shutil from colossalai.checkpoint_io import GeneralCheckpointIO @@ -117,7 +116,7 @@ def test_sharded_checkpoint(use_safetensors: bool): @parameterize('placement_policy', ['cuda', 'cpu']) -@parameterize('model_name', ['bert']) +@parameterize('model_name', ['gpt2', 'bert']) @parameterize('use_safetensors', [True, False]) def exam_state_dict(placement_policy, model_name: str, use_safetensors: bool): get_components_func = non_distributed_component_funcs.get_callable(model_name) @@ -131,6 +130,12 @@ def exam_state_dict(placement_policy, model_name: str, use_safetensors: bool): chunk_manager = ChunkManager(config_dict) gemini_manager = GeminiManager(placement_policy, chunk_manager) model = ZeroDDP(model, gemini_manager) + + new_config_dict, *_ = search_chunk_configuration(new_model, search_range_mb=1, search_interval_byte=100) + new_chunk_manager = ChunkManager(new_config_dict) + new_gemini_manager = GeminiManager(placement_policy, new_chunk_manager) + new_model = ZeroDDP(new_model, new_gemini_manager) + model.train() model_ckpt_dir = tempfile.TemporaryDirectory() @@ -145,8 +150,8 @@ def exam_state_dict(placement_policy, model_name: str, use_safetensors: bool): model.to('cpu') model_dict = model.state_dict(only_rank_0=True) new_model.to('cpu') - new_model_dict = new_model.state_dict() - # recursive_check(model_dict, new_model_dict) + new_model_dict = new_model.state_dict(only_rank_0=True) + recursive_check(model_dict, new_model_dict) model_ckpt_dir.cleanup() @@ -175,14 +180,10 @@ def recursive_check(d1, d2): elif isinstance(v, list): for i in range(len(v)): if isinstance(v[i], torch.Tensor): - v[i].to("cpu") - d2[k][i].to("cpu") assert torch.equal(v[i], d2[k][i]) else: assert v[i] == d2[k][i] elif isinstance(v, torch.Tensor): - v.to("cpu") - d2[k].to("cpu") assert torch.equal(v, d2[k]) else: assert v == d2[k] From a75cc8650f6ed76d5edc688cf1d340586016ba1e Mon Sep 17 00:00:00 2001 From: Mingyan Jiang <1829166702@qq.com> Date: Thu, 20 Apr 2023 11:03:55 +0800 Subject: [PATCH 12/28] gemini plugin add shard checkpoint save/load --- tests/test_checkpoint_io/test_general_checkpoint_io.py | 6 ------ 1 file changed, 6 deletions(-) diff --git a/tests/test_checkpoint_io/test_general_checkpoint_io.py b/tests/test_checkpoint_io/test_general_checkpoint_io.py index 6293edb9e817..aac6551b1710 100644 --- a/tests/test_checkpoint_io/test_general_checkpoint_io.py +++ b/tests/test_checkpoint_io/test_general_checkpoint_io.py @@ -1,14 +1,8 @@ import tempfile import pytest import torch -import logging from torch.optim import Adam from torchvision.models import resnet18 -from pathlib import Path -import os -import subprocess -import logging -import shutil from colossalai.checkpoint_io import GeneralCheckpointIO from colossalai.booster.plugin.gemini_plugin import GeminiCheckpointIO From 83c5740825a816beeeb3912e29d0dfb9b1e54991 Mon Sep 17 00:00:00 2001 From: Mingyan Jiang <1829166702@qq.com> Date: Thu, 20 Apr 2023 13:42:39 +0800 Subject: [PATCH 13/28] gemini plugin add shard checkpoint save/load --- colossalai/booster/plugin/gemini_plugin.py | 4 ++-- colossalai/checkpoint_io/general_checkpoint_io.py | 12 ++---------- .../test_gemini/test_zeroddp_state_dict_shard.py | 4 +--- 3 files changed, 5 insertions(+), 15 deletions(-) diff --git a/colossalai/booster/plugin/gemini_plugin.py b/colossalai/booster/plugin/gemini_plugin.py index e08976b5b589..26e1114d0c5c 100644 --- a/colossalai/booster/plugin/gemini_plugin.py +++ b/colossalai/booster/plugin/gemini_plugin.py @@ -64,9 +64,9 @@ def save_lr_scheduler(self, lr_scheduler: LRScheduler, checkpoint: str): super().save_lr_scheduler(lr_scheduler, checkpoint) def save_sharded_model(self, model: GeminiDDP, checkpoint_path: str, gather_dtensor: bool = False, variant: Optional[str] = None, max_shard_size: int = 1024, use_safetensors: bool = False): - state_dict_shard = model.state_dict_shard(max_shard_size=max_shard_size, only_rank_0=False) + state_dict_shard = model.state_dict_shard(max_shard_size=max_shard_size, only_rank_0=True) if self.coordinator.is_master(): - super().save_gemini_shard_ckp(state_dict_shard, checkpoint_path, gather_dtensor, variant, use_safetensors) + super().save_shards(state_dict_shard, checkpoint_path, variant, use_safetensors) def load_sharded_model(self, model: nn.Module, checkpoint_index_file: Path, strict: bool = False, use_safetensors: bool = False): return super().load_sharded_model(model, checkpoint_index_file, strict, use_safetensors) diff --git a/colossalai/checkpoint_io/general_checkpoint_io.py b/colossalai/checkpoint_io/general_checkpoint_io.py index d6974a888e87..dd8089c32103 100644 --- a/colossalai/checkpoint_io/general_checkpoint_io.py +++ b/colossalai/checkpoint_io/general_checkpoint_io.py @@ -87,14 +87,7 @@ def save_sharded_model(self, model: nn.Module, checkpoint_path: str, gather_dten state_dict = model.state_dict() state_dict_shard = shard_checkpoint(state_dict, max_shard_size=max_shard_size) - - # copy a duplicated iterator to get the total number of shards - state_dict_shard_tee = itertools.tee(state_dict_shard, 2) - shards_total_num = sum(1 for _ in state_dict_shard_tee[0]) - - # let's build the index - shards, shards_index = build_index(state_dict_shard_tee[1], shards_total_num, use_safetensors, variant) - write_model_files(shards, shards_index, checkpoint_path, use_safetensors) + self.save_shards(state_dict_shard, checkpoint_path, variant, use_safetensors) def load_sharded_model(self, model: nn.Module, checkpoint_index_file: Path, strict: bool = False, use_safetensors: bool = False): @@ -125,8 +118,7 @@ def load_sharded_model(self, model: nn.Module, checkpoint_index_file: Path, stri raise RuntimeError('Error(s) in loading state_dict for {}:\n\t{}'.format( self.__class__.__name__, "\n\t".join(error_msgs))) - def save_gemini_shard_ckp(self, state_dict_shard: Iterator[Tuple[OrderedDict, int]], checkpoint_path: str, gather_dtensor: bool = False, variant: Optional[str] = None, use_safetensors: bool = False): - + def save_shards(self, state_dict_shard: Iterator[Tuple[OrderedDict, int]], checkpoint_path: str, variant: Optional[str] = None, use_safetensors: bool = False): # copy a duplicated iterator to get the total number of shards state_dict_shard_tee = itertools.tee(state_dict_shard, 2) shards_total_num = sum(1 for _ in state_dict_shard_tee[0]) diff --git a/tests/test_zero/test_gemini/test_zeroddp_state_dict_shard.py b/tests/test_zero/test_gemini/test_zeroddp_state_dict_shard.py index eb09eb2af340..ad7d3a5a4859 100644 --- a/tests/test_zero/test_gemini/test_zeroddp_state_dict_shard.py +++ b/tests/test_zero/test_gemini/test_zeroddp_state_dict_shard.py @@ -30,10 +30,8 @@ def exam_state_dict(placement_policy, model_name: str): zero_dict = model.state_dict(only_rank_0=False) accumulated_keys = set() - total_size = 0 # ensure number of shards > 1 - for shard, s_size in model.state_dict_shard(max_shard_size=(model_size / 3), only_rank_0=False): - total_size = total_size + s_size + for shard, _ in model.state_dict_shard(max_shard_size=(model_size / 3), only_rank_0=False): for key, value in shard.items(): assert key not in accumulated_keys, f"key `{key}` is duplicated." accumulated_keys.add(key) From f90afe40a099fef8404598990fc8933f8ce81af7 Mon Sep 17 00:00:00 2001 From: Mingyan Jiang <1829166702@qq.com> Date: Thu, 20 Apr 2023 18:58:35 +0800 Subject: [PATCH 14/28] gemini plugin add shard checkpoint save/load --- .../checkpoint_io/general_checkpoint_io.py | 4 +--- colossalai/checkpoint_io/utils.py | 20 ++++++++++--------- .../test_general_checkpoint_io.py | 17 ++++++++++++++++ 3 files changed, 29 insertions(+), 12 deletions(-) diff --git a/colossalai/checkpoint_io/general_checkpoint_io.py b/colossalai/checkpoint_io/general_checkpoint_io.py index dd8089c32103..87ccf65c6e77 100644 --- a/colossalai/checkpoint_io/general_checkpoint_io.py +++ b/colossalai/checkpoint_io/general_checkpoint_io.py @@ -120,8 +120,6 @@ def load_sharded_model(self, model: nn.Module, checkpoint_index_file: Path, stri def save_shards(self, state_dict_shard: Iterator[Tuple[OrderedDict, int]], checkpoint_path: str, variant: Optional[str] = None, use_safetensors: bool = False): # copy a duplicated iterator to get the total number of shards - state_dict_shard_tee = itertools.tee(state_dict_shard, 2) - shards_total_num = sum(1 for _ in state_dict_shard_tee[0]) - shards, shards_index = build_index(state_dict_shard_tee[1], shards_total_num, use_safetensors, variant) + shards, shards_index = build_index(state_dict_shard, use_safetensors, variant) write_model_files(shards, shards_index, checkpoint_path, use_safetensors) diff --git a/colossalai/checkpoint_io/utils.py b/colossalai/checkpoint_io/utils.py index 4d783d40bb63..f8880f8ab0a9 100644 --- a/colossalai/checkpoint_io/utils.py +++ b/colossalai/checkpoint_io/utils.py @@ -110,7 +110,7 @@ def shard_checkpoint(state_dict: torch.Tensor, max_shard_size: int = 1024) -> It yield current_block, current_block_size -def build_index(state_dict_shard: Iterator[Tuple[OrderedDict, int]], shards_total_num: int, use_safetensors: bool, variant: str): +def build_index(state_dict_shard: Iterator[Tuple[OrderedDict, int]], use_safetensors: bool, variant: str): # If we only have one shard, we return it weights_name = SAFE_WEIGHTS_NAME if use_safetensors else WEIGHTS_NAME weights_name = add_variant(weights_name, variant) @@ -118,23 +118,25 @@ def build_index(state_dict_shard: Iterator[Tuple[OrderedDict, int]], shards_tota save_index_file = SAFE_WEIGHTS_INDEX_NAME if use_safetensors else WEIGHTS_INDEX_NAME save_index_file = add_variant(save_index_file, variant) - if shards_total_num == 1: - return {weights_name: next(state_dict_shard)[0]}, None - weight_map = {} shards = {} total_size = 0 + shards_total_num = 0 # shard_pair like (shard, shard_size) + single_shard = None for idx, shard_pair in enumerate(state_dict_shard): - shard_file = weights_name.replace(".bin", f"-{idx+1:05d}-of-{shards_total_num:05d}.bin") - shard_file = shard_file.replace( - ".safetensors", f"-{idx + 1:05d}-of-{shards_total_num:05d}.safetensors" - ) + if idx == 0: + single_shard = shard_pair[0] + shard_file = weights_name.replace(".bin", f"-{idx+1:05d}.bin") + shard_file = shard_file.replace(".safetensors", f"-{idx + 1:05d}.safetensors") shards[shard_file] = shard_pair[0] for key in shard_pair[0].keys(): weight_map[key] = shard_file + shards_total_num = shards_total_num + 1 total_size = total_size + shard_pair[1] - + # if only one shard, then we don't build index + if shards_total_num == 1: + return {weights_name: single_shard}, None # Add the metadata metadata = {"total_size": total_size} index = {"metadata": metadata, "weight_map": weight_map} diff --git a/tests/test_checkpoint_io/test_general_checkpoint_io.py b/tests/test_checkpoint_io/test_general_checkpoint_io.py index aac6551b1710..376790a59a11 100644 --- a/tests/test_checkpoint_io/test_general_checkpoint_io.py +++ b/tests/test_checkpoint_io/test_general_checkpoint_io.py @@ -109,6 +109,23 @@ def test_sharded_checkpoint(use_safetensors: bool): recursive_check(optimizer.state_dict(), new_optimizer.state_dict()) +@pytest.mark.parametrize('use_safetensors', [True, False]) +def test_hf_load_colossalai_checkpoint(use_safetensors: bool): + from transformers import BertTokenizer, BertModel, BertForMaskedLM, BertConfig + + ckpt_io = GeneralCheckpointIO() + bert_model = BertModel.from_pretrained('bert-base-chinese') + model_ckpt_dir = tempfile.TemporaryDirectory() + bert_model.config.save_pretrained(save_directory=model_ckpt_dir.name) + ckpt_io.save_model(bert_model, model_ckpt_dir.name, True, True, "", 10, use_safetensors=False) + + new_bert_config = bert_model.config + new_bert_model = BertModel(config=new_bert_config) + new_bert_model = BertModel.from_pretrained(model_ckpt_dir.name) + + recursive_check(bert_model.state_dict(), new_bert_model.state_dict()) + + @parameterize('placement_policy', ['cuda', 'cpu']) @parameterize('model_name', ['gpt2', 'bert']) @parameterize('use_safetensors', [True, False]) From dd7d03ffb60e8d715f66235af59e6e74e3309283 Mon Sep 17 00:00:00 2001 From: Mingyan Jiang <1829166702@qq.com> Date: Mon, 24 Apr 2023 16:48:20 +0800 Subject: [PATCH 15/28] gemini plugin support shard checkpoint --- colossalai/booster/plugin/gemini_plugin.py | 57 ++++++++++++++- .../checkpoint_io/general_checkpoint_io.py | 49 +++++++++++-- colossalai/checkpoint_io/index_file.py | 17 ++++- colossalai/checkpoint_io/utils.py | 72 +++++-------------- .../test_general_checkpoint_io.py | 8 +++ 5 files changed, 137 insertions(+), 66 deletions(-) diff --git a/colossalai/booster/plugin/gemini_plugin.py b/colossalai/booster/plugin/gemini_plugin.py index 26e1114d0c5c..4a9937a324dc 100644 --- a/colossalai/booster/plugin/gemini_plugin.py +++ b/colossalai/booster/plugin/gemini_plugin.py @@ -2,6 +2,9 @@ import warnings from typing import Callable, List, Optional, Tuple, Union from pathlib import Path +import os +import json +import logging import numpy as np import torch @@ -21,6 +24,13 @@ from colossalai.zero import GeminiDDP, zero_model_wrapper, zero_optim_wrapper from colossalai.zero.gemini.memory_tracer import MemStats +from colossalai.checkpoint_io.utils import ( + get_base_filenames, + get_shard_filename + ) + +from colossalai.checkpoint_io import CheckpointIndexFile + from .plugin_base import Plugin __all__ = ['GeminiPlugin'] @@ -64,9 +74,50 @@ def save_lr_scheduler(self, lr_scheduler: LRScheduler, checkpoint: str): super().save_lr_scheduler(lr_scheduler, checkpoint) def save_sharded_model(self, model: GeminiDDP, checkpoint_path: str, gather_dtensor: bool = False, variant: Optional[str] = None, max_shard_size: int = 1024, use_safetensors: bool = False): - state_dict_shard = model.state_dict_shard(max_shard_size=max_shard_size, only_rank_0=True) - if self.coordinator.is_master(): - super().save_shards(state_dict_shard, checkpoint_path, variant, use_safetensors) + """ + Save sharded model + """ + state_dict_shard = model.state_dict_shard(max_shard_size=max_shard_size, only_rank_0=False) + weights_name, save_index_file = get_base_filenames(variant, use_safetensors) + total_size = 0 + single_shard = None + single_shard_file = None + index_file = CheckpointIndexFile(checkpoint_path) + for idx, shard_pair in enumerate(state_dict_shard): + if not self.coordinator.is_master(): + continue + shard = shard_pair[0] + shard_file = get_shard_filename(weights_name, idx) + total_size = total_size + shard_pair[1] + for key in shard.keys(): + index_file.append_weight_map(key, shard_file) + if idx == 0: + single_shard = shard + single_shard_file = get_shard_filename(weights_name, idx) + continue + if idx == 1: + checkpoint_file_path = os.path.join(checkpoint_path, single_shard_file) + save_state_dict(single_shard, checkpoint_file_path, use_safetensors) + single_shard = None + single_shard_file = None + + total_size = total_size + shard_pair[1] + checkpoint_file_path = os.path.join(checkpoint_path, shard_file) + save_state_dict(shard, checkpoint_file_path, use_safetensors) + + if single_shard is not None: + checkpoint_file_path = os.path.join(checkpoint_path, weights_name) + save_state_dict(single_shard, checkpoint_file_path, use_safetensors) + return + + index_file.append_meta_data("total_size", total_size) + index_file.write_index_file(save_index_file) + logging.info( + f"The model is going to be split to checkpoint shards. " + f"You can find where each parameters has been saved in the " + f"index located at {save_index_file}." + ) + def load_sharded_model(self, model: nn.Module, checkpoint_index_file: Path, strict: bool = False, use_safetensors: bool = False): return super().load_sharded_model(model, checkpoint_index_file, strict, use_safetensors) diff --git a/colossalai/checkpoint_io/general_checkpoint_io.py b/colossalai/checkpoint_io/general_checkpoint_io.py index 87ccf65c6e77..dd008fbeb90e 100644 --- a/colossalai/checkpoint_io/general_checkpoint_io.py +++ b/colossalai/checkpoint_io/general_checkpoint_io.py @@ -19,8 +19,8 @@ shard_checkpoint, load_shard_state_dict, load_state_dict_into_model, - build_index, - write_model_files + get_shard_filename, + get_base_filenames ) __all__ = ['GeneralCheckpointIO'] @@ -87,7 +87,45 @@ def save_sharded_model(self, model: nn.Module, checkpoint_path: str, gather_dten state_dict = model.state_dict() state_dict_shard = shard_checkpoint(state_dict, max_shard_size=max_shard_size) - self.save_shards(state_dict_shard, checkpoint_path, variant, use_safetensors) + weights_name, save_index_file = get_base_filenames(variant, use_safetensors) + total_size = 0 + single_shard = None + single_shard_file = None + index_file = CheckpointIndexFile(checkpoint_path) + for idx, shard_pair in enumerate(state_dict_shard): + shard = shard_pair[0] + shard_file = get_shard_filename(weights_name, idx) + total_size = total_size + shard_pair[1] + for key in shard.keys(): + index_file.append_weight_map(key, shard_file) + if idx == 0: + single_shard = shard + single_shard_file = get_shard_filename(weights_name, idx) + continue + if idx == 1: + print(single_shard) + print("single_shard_file", single_shard_file) + checkpoint_file_path = os.path.join(checkpoint_path, single_shard_file) + save_state_dict(single_shard, checkpoint_file_path, use_safetensors) + single_shard = None + single_shard_file = None + + total_size = total_size + shard_pair[1] + checkpoint_file_path = os.path.join(checkpoint_path, shard_file) + save_state_dict(shard, checkpoint_file_path, use_safetensors) + + if single_shard is not None: + checkpoint_file_path = os.path.join(checkpoint_path, weights_name) + save_state_dict(single_shard, checkpoint_file_path, use_safetensors) + return + + index_file.append_meta_data("total_size", total_size) + index_file.write_index_file(save_index_file) + logging.info( + f"The model is going to be split to checkpoint shards. " + f"You can find where each parameters has been saved in the " + f"index located at {save_index_file}." + ) def load_sharded_model(self, model: nn.Module, checkpoint_index_file: Path, strict: bool = False, use_safetensors: bool = False): @@ -118,8 +156,5 @@ def load_sharded_model(self, model: nn.Module, checkpoint_index_file: Path, stri raise RuntimeError('Error(s) in loading state_dict for {}:\n\t{}'.format( self.__class__.__name__, "\n\t".join(error_msgs))) - def save_shards(self, state_dict_shard: Iterator[Tuple[OrderedDict, int]], checkpoint_path: str, variant: Optional[str] = None, use_safetensors: bool = False): - # copy a duplicated iterator to get the total number of shards - shards, shards_index = build_index(state_dict_shard, use_safetensors, variant) - write_model_files(shards, shards_index, checkpoint_path, use_safetensors) + diff --git a/colossalai/checkpoint_io/index_file.py b/colossalai/checkpoint_io/index_file.py index 89224787a91b..b2fb03ba3016 100644 --- a/colossalai/checkpoint_io/index_file.py +++ b/colossalai/checkpoint_io/index_file.py @@ -1,6 +1,9 @@ import json from pathlib import Path from typing import Any, List, Union +import os +import json +import logging from .utils import is_dtensor_checkpoint @@ -18,8 +21,8 @@ class CheckpointIndexFile: >>> index.export('new_index.json') """ - def __init__(self) -> None: - self.root_path = None + def __init__(self, root_path=None) -> None: + self.root_path = root_path self.metadata: dict = dict() self.weight_map: dict = dict() @@ -154,3 +157,13 @@ def get_all_param_names(self): Get all the weight keys. """ return list(self.weight_map.keys()) + + def write_index_file(self, save_index_file): + """ + Wriete index file. + """ + save_index_file = os.path.join(self.root_path, save_index_file) + index = {"metadata": self.metadata, "weight_map": self.weight_map} + with open(save_index_file, "w", encoding="utf-8") as f: + content = json.dumps(index, indent=2, sort_keys=True) + "\n" + f.write(content) diff --git a/colossalai/checkpoint_io/utils.py b/colossalai/checkpoint_io/utils.py index f8880f8ab0a9..0a08587fc74f 100644 --- a/colossalai/checkpoint_io/utils.py +++ b/colossalai/checkpoint_io/utils.py @@ -108,40 +108,6 @@ def shard_checkpoint(state_dict: torch.Tensor, max_shard_size: int = 1024) -> It yield ret_block, ret_block_size yield current_block, current_block_size - - -def build_index(state_dict_shard: Iterator[Tuple[OrderedDict, int]], use_safetensors: bool, variant: str): - # If we only have one shard, we return it - weights_name = SAFE_WEIGHTS_NAME if use_safetensors else WEIGHTS_NAME - weights_name = add_variant(weights_name, variant) - - save_index_file = SAFE_WEIGHTS_INDEX_NAME if use_safetensors else WEIGHTS_INDEX_NAME - save_index_file = add_variant(save_index_file, variant) - - weight_map = {} - shards = {} - total_size = 0 - shards_total_num = 0 - # shard_pair like (shard, shard_size) - single_shard = None - for idx, shard_pair in enumerate(state_dict_shard): - if idx == 0: - single_shard = shard_pair[0] - shard_file = weights_name.replace(".bin", f"-{idx+1:05d}.bin") - shard_file = shard_file.replace(".safetensors", f"-{idx + 1:05d}.safetensors") - shards[shard_file] = shard_pair[0] - for key in shard_pair[0].keys(): - weight_map[key] = shard_file - shards_total_num = shards_total_num + 1 - total_size = total_size + shard_pair[1] - # if only one shard, then we don't build index - if shards_total_num == 1: - return {weights_name: single_shard}, None - # Add the metadata - metadata = {"total_size": total_size} - index = {"metadata": metadata, "weight_map": weight_map} - shards_index = {save_index_file: index} - return shards, shards_index def load_shard_state_dict(checkpoint_file: Path, use_safetensors: bool =False): @@ -436,24 +402,22 @@ def add_variant(weights_name: str, variant: Optional[str] = None) -> str: return weights_name -def write_model_files(shards: dict, shards_index: dict, checkpoint_path: str, use_safetensors: bool = False): - # Save the model - for shard_file, shard in shards.items(): - checkpoint_file_path = os.path.join(checkpoint_path, shard_file) - save_state_dict(shard, checkpoint_file_path, use_safetensors) +def get_base_filenames(variant: str=None, use_safetensors: bool=False): + """ + generate base weight filenames + """ + weights_name = SAFE_WEIGHTS_NAME if use_safetensors else WEIGHTS_NAME + weights_name = add_variant(weights_name, variant) - # when it only has one shard, index is None - if shards_index == None: - return - - save_index_file = next(iter(shards_index)) - index = shards_index[save_index_file] - save_index_file = os.path.join(checkpoint_path, save_index_file) - with open(save_index_file, "w", encoding="utf-8") as f: - content = json.dumps(index, indent=2, sort_keys=True) + "\n" - f.write(content) - logging.info( - f"The model is going to be split in {len(shards)} checkpoint shards. " - f"You can find where each parameters has been saved in the " - f"index located at {save_index_file}." - ) \ No newline at end of file + save_index_file = SAFE_WEIGHTS_INDEX_NAME if use_safetensors else WEIGHTS_INDEX_NAME + save_index_file = add_variant(save_index_file, variant) + + return weights_name, save_index_file + +def get_shard_filename(weights_name: str, idx: int): + """ + get shard file name + """ + shard_file = weights_name.replace(".bin", f"-{idx+1:05d}.bin") + shard_file = shard_file.replace(".safetensors", f"-{idx + 1:05d}.safetensors") + return shard_file \ No newline at end of file diff --git a/tests/test_checkpoint_io/test_general_checkpoint_io.py b/tests/test_checkpoint_io/test_general_checkpoint_io.py index 376790a59a11..03c09f0dc13a 100644 --- a/tests/test_checkpoint_io/test_general_checkpoint_io.py +++ b/tests/test_checkpoint_io/test_general_checkpoint_io.py @@ -3,6 +3,7 @@ import torch from torch.optim import Adam from torchvision.models import resnet18 +import pathlib from colossalai.checkpoint_io import GeneralCheckpointIO from colossalai.booster.plugin.gemini_plugin import GeminiCheckpointIO @@ -136,6 +137,7 @@ def exam_state_dict(placement_policy, model_name: str, use_safetensors: bool): with ColoInitContext(device=get_current_device()): model = model_builder() new_model = model_builder() + temp_model = model_builder() config_dict, *_ = search_chunk_configuration(model, search_range_mb=1, search_interval_byte=100) chunk_manager = ChunkManager(config_dict) @@ -163,6 +165,12 @@ def exam_state_dict(placement_policy, model_name: str, use_safetensors: bool): new_model.to('cpu') new_model_dict = new_model.state_dict(only_rank_0=True) recursive_check(model_dict, new_model_dict) + + temp_path = pathlib.Path("/home/lcjmy/code/ColossalAI/tests/test_checkpoint_io/a/test") + pathlib.Path(temp_path).mkdir(parents=True, exist_ok=True) + ckpt_io.load_model(temp_model, model_ckpt_dir.name, strict=True) + pathlib.Path("/home/lcjmy/code/ColossalAI/tests/test_checkpoint_io/a/test/model_state_dict.txt").write_text(str(temp_model.state_dict())) + model_ckpt_dir.cleanup() From a310915dd9e72c00801569de59f9bce1316b0cba Mon Sep 17 00:00:00 2001 From: Mingyan Jiang <1829166702@qq.com> Date: Mon, 24 Apr 2023 18:03:35 +0800 Subject: [PATCH 16/28] [API Refactoring]gemini plugin support shard checkpoint --- colossalai/booster/plugin/gemini_plugin.py | 17 ----------------- .../checkpoint_io/general_checkpoint_io.py | 19 ------------------- 2 files changed, 36 deletions(-) diff --git a/colossalai/booster/plugin/gemini_plugin.py b/colossalai/booster/plugin/gemini_plugin.py index 4a9937a324dc..59b34e4762ac 100644 --- a/colossalai/booster/plugin/gemini_plugin.py +++ b/colossalai/booster/plugin/gemini_plugin.py @@ -80,8 +80,6 @@ def save_sharded_model(self, model: GeminiDDP, checkpoint_path: str, gather_dten state_dict_shard = model.state_dict_shard(max_shard_size=max_shard_size, only_rank_0=False) weights_name, save_index_file = get_base_filenames(variant, use_safetensors) total_size = 0 - single_shard = None - single_shard_file = None index_file = CheckpointIndexFile(checkpoint_path) for idx, shard_pair in enumerate(state_dict_shard): if not self.coordinator.is_master(): @@ -91,24 +89,9 @@ def save_sharded_model(self, model: GeminiDDP, checkpoint_path: str, gather_dten total_size = total_size + shard_pair[1] for key in shard.keys(): index_file.append_weight_map(key, shard_file) - if idx == 0: - single_shard = shard - single_shard_file = get_shard_filename(weights_name, idx) - continue - if idx == 1: - checkpoint_file_path = os.path.join(checkpoint_path, single_shard_file) - save_state_dict(single_shard, checkpoint_file_path, use_safetensors) - single_shard = None - single_shard_file = None - total_size = total_size + shard_pair[1] checkpoint_file_path = os.path.join(checkpoint_path, shard_file) save_state_dict(shard, checkpoint_file_path, use_safetensors) - - if single_shard is not None: - checkpoint_file_path = os.path.join(checkpoint_path, weights_name) - save_state_dict(single_shard, checkpoint_file_path, use_safetensors) - return index_file.append_meta_data("total_size", total_size) index_file.write_index_file(save_index_file) diff --git a/colossalai/checkpoint_io/general_checkpoint_io.py b/colossalai/checkpoint_io/general_checkpoint_io.py index dd008fbeb90e..fe1cf81ffa90 100644 --- a/colossalai/checkpoint_io/general_checkpoint_io.py +++ b/colossalai/checkpoint_io/general_checkpoint_io.py @@ -89,8 +89,6 @@ def save_sharded_model(self, model: nn.Module, checkpoint_path: str, gather_dten weights_name, save_index_file = get_base_filenames(variant, use_safetensors) total_size = 0 - single_shard = None - single_shard_file = None index_file = CheckpointIndexFile(checkpoint_path) for idx, shard_pair in enumerate(state_dict_shard): shard = shard_pair[0] @@ -98,26 +96,9 @@ def save_sharded_model(self, model: nn.Module, checkpoint_path: str, gather_dten total_size = total_size + shard_pair[1] for key in shard.keys(): index_file.append_weight_map(key, shard_file) - if idx == 0: - single_shard = shard - single_shard_file = get_shard_filename(weights_name, idx) - continue - if idx == 1: - print(single_shard) - print("single_shard_file", single_shard_file) - checkpoint_file_path = os.path.join(checkpoint_path, single_shard_file) - save_state_dict(single_shard, checkpoint_file_path, use_safetensors) - single_shard = None - single_shard_file = None - total_size = total_size + shard_pair[1] checkpoint_file_path = os.path.join(checkpoint_path, shard_file) save_state_dict(shard, checkpoint_file_path, use_safetensors) - - if single_shard is not None: - checkpoint_file_path = os.path.join(checkpoint_path, weights_name) - save_state_dict(single_shard, checkpoint_file_path, use_safetensors) - return index_file.append_meta_data("total_size", total_size) index_file.write_index_file(save_index_file) From 5f863eff5ea321d15d2de239942f204bd39621b2 Mon Sep 17 00:00:00 2001 From: Mingyan Jiang <1829166702@qq.com> Date: Mon, 24 Apr 2023 19:37:02 +0800 Subject: [PATCH 17/28] [API Refactoring]gemini plugin support shard checkpoint --- colossalai/booster/plugin/gemini_plugin.py | 12 ++++++--- .../checkpoint_io/general_checkpoint_io.py | 24 ++++++++++++------ colossalai/checkpoint_io/utils.py | 25 +++++++------------ .../test_general_checkpoint_io.py | 2 +- 4 files changed, 35 insertions(+), 28 deletions(-) diff --git a/colossalai/booster/plugin/gemini_plugin.py b/colossalai/booster/plugin/gemini_plugin.py index 59b34e4762ac..b1178d4aa475 100644 --- a/colossalai/booster/plugin/gemini_plugin.py +++ b/colossalai/booster/plugin/gemini_plugin.py @@ -5,6 +5,7 @@ import os import json import logging +import gc import numpy as np import torch @@ -26,7 +27,9 @@ from colossalai.checkpoint_io.utils import ( get_base_filenames, - get_shard_filename + get_shard_filename, + is_safetensors_available, + load_shard_state_dict ) from colossalai.checkpoint_io import CheckpointIndexFile @@ -102,8 +105,11 @@ def save_sharded_model(self, model: GeminiDDP, checkpoint_path: str, gather_dten ) - def load_sharded_model(self, model: nn.Module, checkpoint_index_file: Path, strict: bool = False, use_safetensors: bool = False): - return super().load_sharded_model(model, checkpoint_index_file, strict, use_safetensors) + def load_sharded_model(self, model: GeminiDDP, checkpoint_index_file: Path, strict: bool = False, use_safetensors: bool = False): + """ + load shard model, load model from multiple files + """ + return super().load_sharded_model(model, checkpoint_index_file, strict, use_safetensors, load_sub_module=False) class GeminiModel(ModelWrapper): diff --git a/colossalai/checkpoint_io/general_checkpoint_io.py b/colossalai/checkpoint_io/general_checkpoint_io.py index fe1cf81ffa90..e7353e5a4846 100644 --- a/colossalai/checkpoint_io/general_checkpoint_io.py +++ b/colossalai/checkpoint_io/general_checkpoint_io.py @@ -109,7 +109,8 @@ def save_sharded_model(self, model: nn.Module, checkpoint_path: str, gather_dten ) - def load_sharded_model(self, model: nn.Module, checkpoint_index_file: Path, strict: bool = False, use_safetensors: bool = False): + def load_sharded_model(self, model: nn.Module, checkpoint_index_file: Path, strict: bool = False, + use_safetensors: bool = False, load_sub_module: bool = True): """ load shard model, load model from multiple files """ @@ -123,19 +124,26 @@ def load_sharded_model(self, model: nn.Module, checkpoint_index_file: Path, stri # read checkpoint index file ckpt_index_file = CheckpointIndexFile.from_file(checkpoint_index_file) checkpoint_files, _ = ckpt_index_file.get_checkpoint_fileanames() - missing_keys = ckpt_index_file.get_all_param_names() + missing_keys = [] for shard_file in checkpoint_files: state_dict = load_shard_state_dict(Path(shard_file), use_safetensors) - load_state_dict_into_model(model, state_dict, missing_keys, strict) + load_state_dict_into_model(model, state_dict, missing_keys, strict, load_sub_module) del state_dict gc.collect() - if strict and len(missing_keys) > 0: - error_msgs = 'Missing key(s) in state_dict: {}. '.format( - ', '.join('"{}"'.format(k) for k in missing_keys)) - raise RuntimeError('Error(s) in loading state_dict for {}:\n\t{}'.format( - self.__class__.__name__, "\n\t".join(error_msgs))) + if strict: + remain_keys = set() + for i, sub_missing_keys in enumerate(missing_keys): + if i == 0: + remain_keys = set(sub_missing_keys) + else: + remain_keys = remain_keys & set(sub_missing_keys) + if len(remain_keys) > 0: + error_msgs = 'Missing key(s) in state_dict: {}. '.format( + ', '.join('"{}"'.format(k) for k in missing_keys)) + raise RuntimeError('Error(s) in loading state_dict for {}:\n\t{}'.format( + self.__class__.__name__, "\n\t".join(error_msgs))) diff --git a/colossalai/checkpoint_io/utils.py b/colossalai/checkpoint_io/utils.py index 0a08587fc74f..c59d8627abde 100644 --- a/colossalai/checkpoint_io/utils.py +++ b/colossalai/checkpoint_io/utils.py @@ -129,7 +129,7 @@ def load_shard_state_dict(checkpoint_file: Path, use_safetensors: bool =False): else: return torch.load(checkpoint_file) -def load_state_dict_into_model(model: nn.Module, state_dict: torch.Tensor, missing_keys: List, strict: bool = False): +def load_state_dict_into_model(model: nn.Module, state_dict: torch.Tensor, missing_keys: List, strict: bool = False, load_sub_module: bool = True): r"""Copies parameters and buffers from :attr:`state_dict` into this module and its descendants. @@ -150,29 +150,22 @@ def load_state_dict_into_model(model: nn.Module, state_dict: torch.Tensor, missi if metadata is not None: state_dict._metadata = metadata - def load(module: nn.Module, state_dict, prefix=""): + def load(module: nn.Module, state_dict, prefix="", load_sub_module: bool = True): local_metadata = {} if metadata is None else metadata.get(prefix[:-1], {}) - args = (state_dict, prefix, local_metadata, True, [], [], error_msgs) + args = (state_dict, prefix, local_metadata, True, sub_missing_keys, [], error_msgs) # Parameters of module and children will start with prefix. We can exit early if there are none in this # state_dict if len([key for key in state_dict if key.startswith(prefix)]) > 0: module._load_from_state_dict(*args) + if load_sub_module: + for name, child in module._modules.items(): + if child is not None: + load(child, state_dict, prefix + name + ".") - for name, child in module._modules.items(): - if child is not None: - load(child, state_dict, prefix + name + ".") - - load(model, state_dict, "") + load(model, state_dict, "", load_sub_module) del load - # deal with missing key - if len(missing_keys) > 0: - deleted_keys = [] - for key in missing_keys: - if key not in sub_missing_keys: - deleted_keys.append(key) - for key in deleted_keys: - missing_keys.remove(key) + missing_keys = missing_keys.append(sub_missing_keys) if strict: if len(unexpected_keys) > 0: diff --git a/tests/test_checkpoint_io/test_general_checkpoint_io.py b/tests/test_checkpoint_io/test_general_checkpoint_io.py index 03c09f0dc13a..bf285f7957de 100644 --- a/tests/test_checkpoint_io/test_general_checkpoint_io.py +++ b/tests/test_checkpoint_io/test_general_checkpoint_io.py @@ -143,13 +143,13 @@ def exam_state_dict(placement_policy, model_name: str, use_safetensors: bool): chunk_manager = ChunkManager(config_dict) gemini_manager = GeminiManager(placement_policy, chunk_manager) model = ZeroDDP(model, gemini_manager) + model.train() new_config_dict, *_ = search_chunk_configuration(new_model, search_range_mb=1, search_interval_byte=100) new_chunk_manager = ChunkManager(new_config_dict) new_gemini_manager = GeminiManager(placement_policy, new_chunk_manager) new_model = ZeroDDP(new_model, new_gemini_manager) - model.train() model_ckpt_dir = tempfile.TemporaryDirectory() ckpt_io = GeminiCheckpointIO() From 29184cfe1e4f4b480c76271d107764fa27eba53e Mon Sep 17 00:00:00 2001 From: Mingyan Jiang <1829166702@qq.com> Date: Mon, 24 Apr 2023 19:40:51 +0800 Subject: [PATCH 18/28] [API Refactoring]gemini plugin support shard checkpoint --- tests/test_checkpoint_io/test_general_checkpoint_io.py | 6 ------ 1 file changed, 6 deletions(-) diff --git a/tests/test_checkpoint_io/test_general_checkpoint_io.py b/tests/test_checkpoint_io/test_general_checkpoint_io.py index bf285f7957de..4330e8c8c63c 100644 --- a/tests/test_checkpoint_io/test_general_checkpoint_io.py +++ b/tests/test_checkpoint_io/test_general_checkpoint_io.py @@ -137,7 +137,6 @@ def exam_state_dict(placement_policy, model_name: str, use_safetensors: bool): with ColoInitContext(device=get_current_device()): model = model_builder() new_model = model_builder() - temp_model = model_builder() config_dict, *_ = search_chunk_configuration(model, search_range_mb=1, search_interval_byte=100) chunk_manager = ChunkManager(config_dict) @@ -165,11 +164,6 @@ def exam_state_dict(placement_policy, model_name: str, use_safetensors: bool): new_model.to('cpu') new_model_dict = new_model.state_dict(only_rank_0=True) recursive_check(model_dict, new_model_dict) - - temp_path = pathlib.Path("/home/lcjmy/code/ColossalAI/tests/test_checkpoint_io/a/test") - pathlib.Path(temp_path).mkdir(parents=True, exist_ok=True) - ckpt_io.load_model(temp_model, model_ckpt_dir.name, strict=True) - pathlib.Path("/home/lcjmy/code/ColossalAI/tests/test_checkpoint_io/a/test/model_state_dict.txt").write_text(str(temp_model.state_dict())) model_ckpt_dir.cleanup() From f3f1dca4b4724a84caae26980918115f1dcc13ed Mon Sep 17 00:00:00 2001 From: Mingyan Jiang <1829166702@qq.com> Date: Mon, 24 Apr 2023 19:52:41 +0800 Subject: [PATCH 19/28] [API Refactoring]gemini plugin support shard checkpoint --- tests/test_checkpoint_io/test_general_checkpoint_io.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/test_checkpoint_io/test_general_checkpoint_io.py b/tests/test_checkpoint_io/test_general_checkpoint_io.py index 4330e8c8c63c..0ae7c4b9e589 100644 --- a/tests/test_checkpoint_io/test_general_checkpoint_io.py +++ b/tests/test_checkpoint_io/test_general_checkpoint_io.py @@ -175,7 +175,7 @@ def run_dist(rank, world_size, port): @pytest.mark.dist -@pytest.mark.parametrize('world_size', [1, 4]) +@pytest.mark.parametrize('world_size', [4, 4]) @rerun_if_address_is_in_use() def test_gemini_ckpIO(world_size): spawn(run_dist, world_size) From 617756d4b73da3ab589380e603f7b51bdfc6e87e Mon Sep 17 00:00:00 2001 From: Mingyan Jiang <1829166702@qq.com> Date: Mon, 24 Apr 2023 19:54:34 +0800 Subject: [PATCH 20/28] [API Refactoring]gemini plugin support shard checkpoint --- colossalai/booster/plugin/gemini_plugin.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/colossalai/booster/plugin/gemini_plugin.py b/colossalai/booster/plugin/gemini_plugin.py index b1178d4aa475..274e8ba57fbc 100644 --- a/colossalai/booster/plugin/gemini_plugin.py +++ b/colossalai/booster/plugin/gemini_plugin.py @@ -3,9 +3,7 @@ from typing import Callable, List, Optional, Tuple, Union from pathlib import Path import os -import json import logging -import gc import numpy as np import torch From 5cecee664ff62a1f0837986a77787d44ced27a2a Mon Sep 17 00:00:00 2001 From: Mingyan Jiang <1829166702@qq.com> Date: Wed, 26 Apr 2023 13:52:14 +0800 Subject: [PATCH 21/28] [API Refactoring]gemini plugin support shard checkpoint --- .../checkpoint_io/general_checkpoint_io.py | 16 +++++++++------- .../test_general_checkpoint_io.py | 2 -- 2 files changed, 9 insertions(+), 9 deletions(-) diff --git a/colossalai/checkpoint_io/general_checkpoint_io.py b/colossalai/checkpoint_io/general_checkpoint_io.py index e7353e5a4846..9c6a592cef8a 100644 --- a/colossalai/checkpoint_io/general_checkpoint_io.py +++ b/colossalai/checkpoint_io/general_checkpoint_io.py @@ -133,13 +133,15 @@ def load_sharded_model(self, model: nn.Module, checkpoint_index_file: Path, stri gc.collect() if strict: - remain_keys = set() - for i, sub_missing_keys in enumerate(missing_keys): - if i == 0: - remain_keys = set(sub_missing_keys) - else: - remain_keys = remain_keys & set(sub_missing_keys) - if len(remain_keys) > 0: + def intersection(ll): + if len(ll) == 1: + return set(ll[0]) + if len(ll) == 2: + return set(ll[0]) & set(ll[1]) + n = len(ll) + return intersection(ll[:n//2]) & intersection(ll[n//2:]) + remian_keys = intersection(missing_keys) + if len(remian_keys) > 0: error_msgs = 'Missing key(s) in state_dict: {}. '.format( ', '.join('"{}"'.format(k) for k in missing_keys)) raise RuntimeError('Error(s) in loading state_dict for {}:\n\t{}'.format( diff --git a/tests/test_checkpoint_io/test_general_checkpoint_io.py b/tests/test_checkpoint_io/test_general_checkpoint_io.py index 0ae7c4b9e589..df250963433a 100644 --- a/tests/test_checkpoint_io/test_general_checkpoint_io.py +++ b/tests/test_checkpoint_io/test_general_checkpoint_io.py @@ -159,9 +159,7 @@ def exam_state_dict(placement_policy, model_name: str, use_safetensors: bool): # load model if ckpt_io.coordinator.is_master(): ckpt_io.load_model(new_model, model_ckpt_dir.name, strict=True) - model.to('cpu') model_dict = model.state_dict(only_rank_0=True) - new_model.to('cpu') new_model_dict = new_model.state_dict(only_rank_0=True) recursive_check(model_dict, new_model_dict) From 4d59978d22a9ff168f22f06a0d10f57f7eb8b810 Mon Sep 17 00:00:00 2001 From: Mingyan Jiang <1829166702@qq.com> Date: Wed, 26 Apr 2023 14:03:21 +0800 Subject: [PATCH 22/28] [API Refactoring]gemini plugin support shard checkpoint --- colossalai/booster/plugin/gemini_plugin.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/colossalai/booster/plugin/gemini_plugin.py b/colossalai/booster/plugin/gemini_plugin.py index 274e8ba57fbc..f307660f5dc9 100644 --- a/colossalai/booster/plugin/gemini_plugin.py +++ b/colossalai/booster/plugin/gemini_plugin.py @@ -78,7 +78,7 @@ def save_sharded_model(self, model: GeminiDDP, checkpoint_path: str, gather_dten """ Save sharded model """ - state_dict_shard = model.state_dict_shard(max_shard_size=max_shard_size, only_rank_0=False) + state_dict_shard = model.state_dict_shard(max_shard_size=max_shard_size, only_rank_0=True) weights_name, save_index_file = get_base_filenames(variant, use_safetensors) total_size = 0 index_file = CheckpointIndexFile(checkpoint_path) From 7aaa096be85c5a3301d14c57c72776823cbeff86 Mon Sep 17 00:00:00 2001 From: Mingyan Jiang <1829166702@qq.com> Date: Thu, 4 May 2023 18:07:28 +0800 Subject: [PATCH 23/28] [API Refactoring]gemini plugin support shard checkpoint --- colossalai/booster/plugin/gemini_plugin.py | 2 +- .../checkpoint_io/general_checkpoint_io.py | 10 +--- colossalai/zero/gemini/gemini_ddp.py | 32 ++++++++++- .../language/gpt/gemini/train_gpt_demo.py | 3 + .../test_general_checkpoint_io.py | 55 ++++++++++++++----- 5 files changed, 76 insertions(+), 26 deletions(-) diff --git a/colossalai/booster/plugin/gemini_plugin.py b/colossalai/booster/plugin/gemini_plugin.py index f307660f5dc9..503b8509cedc 100644 --- a/colossalai/booster/plugin/gemini_plugin.py +++ b/colossalai/booster/plugin/gemini_plugin.py @@ -78,7 +78,7 @@ def save_sharded_model(self, model: GeminiDDP, checkpoint_path: str, gather_dten """ Save sharded model """ - state_dict_shard = model.state_dict_shard(max_shard_size=max_shard_size, only_rank_0=True) + state_dict_shard = model.state_dict_shard(max_shard_size=max_shard_size, only_rank_0=True, dtype=torch.float32) weights_name, save_index_file = get_base_filenames(variant, use_safetensors) total_size = 0 index_file = CheckpointIndexFile(checkpoint_path) diff --git a/colossalai/checkpoint_io/general_checkpoint_io.py b/colossalai/checkpoint_io/general_checkpoint_io.py index 9c6a592cef8a..1285a7c12b2b 100644 --- a/colossalai/checkpoint_io/general_checkpoint_io.py +++ b/colossalai/checkpoint_io/general_checkpoint_io.py @@ -1,4 +1,5 @@ from pathlib import Path +from functools import reduce import torch.nn as nn from torch.optim import Optimizer @@ -133,14 +134,7 @@ def load_sharded_model(self, model: nn.Module, checkpoint_index_file: Path, stri gc.collect() if strict: - def intersection(ll): - if len(ll) == 1: - return set(ll[0]) - if len(ll) == 2: - return set(ll[0]) & set(ll[1]) - n = len(ll) - return intersection(ll[:n//2]) & intersection(ll[n//2:]) - remian_keys = intersection(missing_keys) + remian_keys = reduce(lambda a, b: a & b, map(set, missing_keys)) if len(remian_keys) > 0: error_msgs = 'Missing key(s) in state_dict: {}. '.format( ', '.join('"{}"'.format(k) for k in missing_keys)) diff --git a/colossalai/zero/gemini/gemini_ddp.py b/colossalai/zero/gemini/gemini_ddp.py index 1acef4f64ddf..1f14e18baf7b 100644 --- a/colossalai/zero/gemini/gemini_ddp.py +++ b/colossalai/zero/gemini/gemini_ddp.py @@ -1,7 +1,7 @@ import itertools from collections import OrderedDict from functools import partial -from typing import Dict, Iterator, List, Optional, Union, Tuple +from typing import Dict, Iterator, List, Optional, Union, Tuple, Set import torch import torch.distributed as dist @@ -21,7 +21,7 @@ from .gemini_hook import GeminiZeROHook from .gemini_mgr import GeminiManager from .memory_tracer import MemStats, OrderedParamGenerator -from .utils import get_temp_total_chunk_on_cuda +from .utils import get_temp_total_chunk_on_cuda, get_static_torch_model try: from torch.nn.modules.module import _EXTRA_STATE_KEY_SUFFIX, _IncompatibleKeys @@ -92,8 +92,36 @@ def __init__(self, param_name = m_name + '.' + p_name if m_name else p_name self.name2param[param_name] = p_var super().__init__(module, process_group=ColoProcessGroup()) + self._non_persistent_buffers_set=self._get_non_persistent_buffers_set(module) self._cast_buffers() + def _get_non_persistent_buffers_set(self, module, memo: Optional[Set[nn.Module]] = None, prefix: str = '', remove_duplicate: bool = True): + + r""" + Args: + memo: a memo to store the set of modules already added to the result + prefix: a prefix that will be added to the name of the module + remove_duplicate: whether to remove the duplicated module instances in the result + or not + """ + + if memo is None: + memo = set() + if module not in memo: + if remove_duplicate: + memo.add(module) + self_non_persistent_set = set(map(lambda key: prefix + ('.' if prefix else '') + key, module._non_persistent_buffers_set)) + # non_persistent_buffers_set.extend(sub_set) + # set.union(non_persistent_buffers_set, self_non_persistent_set) + for name, sub_module in module._modules.items(): + if sub_module is None: + continue + submodule_prefix = prefix + ('.' if prefix else '') + name + child_self_non_persistent_set = self._get_non_persistent_buffers_set(sub_module, memo, submodule_prefix, remove_duplicate) + self_non_persistent_set = set.union(self_non_persistent_set, child_self_non_persistent_set) + return self_non_persistent_set + + def _post_forward(self): """This function is only triggered for inference. """ diff --git a/examples/language/gpt/gemini/train_gpt_demo.py b/examples/language/gpt/gemini/train_gpt_demo.py index b2a7fa36d021..656c9b41a660 100644 --- a/examples/language/gpt/gemini/train_gpt_demo.py +++ b/examples/language/gpt/gemini/train_gpt_demo.py @@ -9,6 +9,7 @@ from commons.utils import get_data, get_profile_context, get_tflops, get_time_stamp from packaging import version from torch.nn.parallel import DistributedDataParallel as DDP +from colossalai.zero.gemini.utils import get_static_torch_model import colossalai from colossalai.logging import disable_existing_loggers, get_dist_logger @@ -346,6 +347,8 @@ def train_step(): median_index = ((NUM_STEPS - WARMUP_STEPS) >> 1) + WARMUP_STEPS logger.info(f"Median TFLOPS is {tflops_list[median_index]:.3f}") torch.cuda.synchronize() + model_to_save = get_static_torch_model(model, dtype=torch.half, only_rank_0=True) + model_to_save.model.save_pretrained('./tmp') if __name__ == '__main__': diff --git a/tests/test_checkpoint_io/test_general_checkpoint_io.py b/tests/test_checkpoint_io/test_general_checkpoint_io.py index df250963433a..14ce2fa8f92d 100644 --- a/tests/test_checkpoint_io/test_general_checkpoint_io.py +++ b/tests/test_checkpoint_io/test_general_checkpoint_io.py @@ -109,22 +109,43 @@ def test_sharded_checkpoint(use_safetensors: bool): recursive_check(model.state_dict(), new_model.state_dict()) recursive_check(optimizer.state_dict(), new_optimizer.state_dict()) +@parameterize('placement_policy', ['cuda', 'cpu']) +@parameterize('model_name', ['bert']) +@parameterize('use_safetensors', [True, False]) +def hf_load_colossalai_checkpoint(placement_policy, model_name, use_safetensors: bool): + from transformers import BertTokenizer, BertModel, BertForMaskedLM, BertConfig, BertForSequenceClassification -@pytest.mark.parametrize('use_safetensors', [True, False]) -def test_hf_load_colossalai_checkpoint(use_safetensors: bool): - from transformers import BertTokenizer, BertModel, BertForMaskedLM, BertConfig - - ckpt_io = GeneralCheckpointIO() - bert_model = BertModel.from_pretrained('bert-base-chinese') model_ckpt_dir = tempfile.TemporaryDirectory() - bert_model.config.save_pretrained(save_directory=model_ckpt_dir.name) - ckpt_io.save_model(bert_model, model_ckpt_dir.name, True, True, "", 10, use_safetensors=False) + # class model_ckpt_dir: + # name = pathlib.Path("/home/lcjmy/code/ColossalAI/tests/test_checkpoint_io/b") + # if use_safetensors == True: + # name = pathlib.Path("/home/lcjmy/code/ColossalAI/tests/test_checkpoint_io/b") + # name.mkdir(parents=True, exist_ok=True) + # @classmethod + # def cleanup(cls): + # pass - new_bert_config = bert_model.config - new_bert_model = BertModel(config=new_bert_config) - new_bert_model = BertModel.from_pretrained(model_ckpt_dir.name) + get_components_func = non_distributed_component_funcs.get_callable(model_name) + model_builder, *_ = get_components_func() + + with ColoInitContext(device=get_current_device()): + bert_model = model_builder() + bert_model.config.save_pretrained(save_directory=model_ckpt_dir.name) + config_dict, *_ = search_chunk_configuration(bert_model, search_range_mb=1, search_interval_byte=100) + chunk_manager = ChunkManager(config_dict) + gemini_manager = GeminiManager(placement_policy, chunk_manager) + bert_model = ZeroDDP(bert_model, gemini_manager) + bert_model.train() - recursive_check(bert_model.state_dict(), new_bert_model.state_dict()) + ckpt_io = GeminiCheckpointIO() + if ckpt_io.coordinator.is_master(): + model_size = sum(p.numel() * p.element_size() for p in bert_model.parameters()) / 1024**2 + ckpt_io.save_model(bert_model, model_ckpt_dir.name, True, True, "", (model_size / 3), use_safetensors=False) + new_bert_model = BertForSequenceClassification.from_pretrained(pathlib.Path(model_ckpt_dir.name)) + recursive_check(bert_model.state_dict(only_rank_0=True, dtype=torch.float32), new_bert_model.state_dict()) + + model_ckpt_dir.cleanup() + @parameterize('placement_policy', ['cuda', 'cpu']) @@ -149,10 +170,9 @@ def exam_state_dict(placement_policy, model_name: str, use_safetensors: bool): new_gemini_manager = GeminiManager(placement_policy, new_chunk_manager) new_model = ZeroDDP(new_model, new_gemini_manager) - model_ckpt_dir = tempfile.TemporaryDirectory() - ckpt_io = GeminiCheckpointIO() + ckpt_io = GeminiCheckpointIO() model_size = sum(p.numel() * p.element_size() for p in model.parameters()) / 1024**2 ckpt_io.save_model(model, model_ckpt_dir.name, True, True, "epoch", (model_size / 3), use_safetensors=use_safetensors) @@ -169,7 +189,8 @@ def exam_state_dict(placement_policy, model_name: str, use_safetensors: bool): def run_dist(rank, world_size, port): config = {} colossalai.launch(config=config, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') - exam_state_dict() + # exam_state_dict() + hf_load_colossalai_checkpoint() @pytest.mark.dist @@ -191,10 +212,14 @@ def recursive_check(d1, d2): elif isinstance(v, list): for i in range(len(v)): if isinstance(v[i], torch.Tensor): + v[i] = v[i].to("cpu") + d2[k][i] = d2[k][i].to("cpu") assert torch.equal(v[i], d2[k][i]) else: assert v[i] == d2[k][i] elif isinstance(v, torch.Tensor): + v = v.to("cpu") + d2[k] = d2[k].to("cpu") assert torch.equal(v, d2[k]) else: assert v == d2[k] From 1cabb7bc3bfe4d3453f9c5af6b6ca7741320f3ba Mon Sep 17 00:00:00 2001 From: Mingyan Jiang <1829166702@qq.com> Date: Thu, 4 May 2023 18:22:10 +0800 Subject: [PATCH 24/28] [API Refactoring]gemini plugin support shard checkpoint --- colossalai/zero/gemini/gemini_ddp.py | 4 +--- examples/language/gpt/gemini/train_gpt_demo.py | 2 -- .../test_general_checkpoint_io.py | 16 +++------------- 3 files changed, 4 insertions(+), 18 deletions(-) diff --git a/colossalai/zero/gemini/gemini_ddp.py b/colossalai/zero/gemini/gemini_ddp.py index 1f14e18baf7b..f254dffc2e25 100644 --- a/colossalai/zero/gemini/gemini_ddp.py +++ b/colossalai/zero/gemini/gemini_ddp.py @@ -21,7 +21,7 @@ from .gemini_hook import GeminiZeROHook from .gemini_mgr import GeminiManager from .memory_tracer import MemStats, OrderedParamGenerator -from .utils import get_temp_total_chunk_on_cuda, get_static_torch_model +from .utils import get_temp_total_chunk_on_cuda try: from torch.nn.modules.module import _EXTRA_STATE_KEY_SUFFIX, _IncompatibleKeys @@ -111,8 +111,6 @@ def _get_non_persistent_buffers_set(self, module, memo: Optional[Set[nn.Module]] if remove_duplicate: memo.add(module) self_non_persistent_set = set(map(lambda key: prefix + ('.' if prefix else '') + key, module._non_persistent_buffers_set)) - # non_persistent_buffers_set.extend(sub_set) - # set.union(non_persistent_buffers_set, self_non_persistent_set) for name, sub_module in module._modules.items(): if sub_module is None: continue diff --git a/examples/language/gpt/gemini/train_gpt_demo.py b/examples/language/gpt/gemini/train_gpt_demo.py index 656c9b41a660..afe0920b7a68 100644 --- a/examples/language/gpt/gemini/train_gpt_demo.py +++ b/examples/language/gpt/gemini/train_gpt_demo.py @@ -347,8 +347,6 @@ def train_step(): median_index = ((NUM_STEPS - WARMUP_STEPS) >> 1) + WARMUP_STEPS logger.info(f"Median TFLOPS is {tflops_list[median_index]:.3f}") torch.cuda.synchronize() - model_to_save = get_static_torch_model(model, dtype=torch.half, only_rank_0=True) - model_to_save.model.save_pretrained('./tmp') if __name__ == '__main__': diff --git a/tests/test_checkpoint_io/test_general_checkpoint_io.py b/tests/test_checkpoint_io/test_general_checkpoint_io.py index 14ce2fa8f92d..752ca706bfd4 100644 --- a/tests/test_checkpoint_io/test_general_checkpoint_io.py +++ b/tests/test_checkpoint_io/test_general_checkpoint_io.py @@ -3,7 +3,6 @@ import torch from torch.optim import Adam from torchvision.models import resnet18 -import pathlib from colossalai.checkpoint_io import GeneralCheckpointIO from colossalai.booster.plugin.gemini_plugin import GeminiCheckpointIO @@ -116,15 +115,6 @@ def hf_load_colossalai_checkpoint(placement_policy, model_name, use_safetensors: from transformers import BertTokenizer, BertModel, BertForMaskedLM, BertConfig, BertForSequenceClassification model_ckpt_dir = tempfile.TemporaryDirectory() - # class model_ckpt_dir: - # name = pathlib.Path("/home/lcjmy/code/ColossalAI/tests/test_checkpoint_io/b") - # if use_safetensors == True: - # name = pathlib.Path("/home/lcjmy/code/ColossalAI/tests/test_checkpoint_io/b") - # name.mkdir(parents=True, exist_ok=True) - # @classmethod - # def cleanup(cls): - # pass - get_components_func = non_distributed_component_funcs.get_callable(model_name) model_builder, *_ = get_components_func() @@ -140,8 +130,8 @@ def hf_load_colossalai_checkpoint(placement_policy, model_name, use_safetensors: ckpt_io = GeminiCheckpointIO() if ckpt_io.coordinator.is_master(): model_size = sum(p.numel() * p.element_size() for p in bert_model.parameters()) / 1024**2 - ckpt_io.save_model(bert_model, model_ckpt_dir.name, True, True, "", (model_size / 3), use_safetensors=False) - new_bert_model = BertForSequenceClassification.from_pretrained(pathlib.Path(model_ckpt_dir.name)) + ckpt_io.save_model(bert_model, model_ckpt_dir.name, True, True, "", (model_size / 3), use_safetensors=use_safetensors) + new_bert_model = BertForSequenceClassification.from_pretrained(model_ckpt_dir.name) recursive_check(bert_model.state_dict(only_rank_0=True, dtype=torch.float32), new_bert_model.state_dict()) model_ckpt_dir.cleanup() @@ -189,7 +179,7 @@ def exam_state_dict(placement_policy, model_name: str, use_safetensors: bool): def run_dist(rank, world_size, port): config = {} colossalai.launch(config=config, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') - # exam_state_dict() + exam_state_dict() hf_load_colossalai_checkpoint() From 3cdbda5f7373d94c3d3dbfdeb6adfd5c9af101cf Mon Sep 17 00:00:00 2001 From: Mingyan Jiang <1829166702@qq.com> Date: Thu, 4 May 2023 18:33:58 +0800 Subject: [PATCH 25/28] [API Refactoring]gemini plugin support shard checkpoint --- colossalai/booster/plugin/gemini_plugin.py | 4 +--- colossalai/checkpoint_io/checkpoint_io_base.py | 1 + colossalai/checkpoint_io/general_checkpoint_io.py | 2 -- colossalai/checkpoint_io/index_file.py | 1 - colossalai/checkpoint_io/utils.py | 3 --- 5 files changed, 2 insertions(+), 9 deletions(-) diff --git a/colossalai/booster/plugin/gemini_plugin.py b/colossalai/booster/plugin/gemini_plugin.py index 503b8509cedc..dfdd7be26eaa 100644 --- a/colossalai/booster/plugin/gemini_plugin.py +++ b/colossalai/booster/plugin/gemini_plugin.py @@ -25,9 +25,7 @@ from colossalai.checkpoint_io.utils import ( get_base_filenames, - get_shard_filename, - is_safetensors_available, - load_shard_state_dict + get_shard_filename ) from colossalai.checkpoint_io import CheckpointIndexFile diff --git a/colossalai/checkpoint_io/checkpoint_io_base.py b/colossalai/checkpoint_io/checkpoint_io_base.py index e5bfcc68ac71..e5dadab1e56e 100644 --- a/colossalai/checkpoint_io/checkpoint_io_base.py +++ b/colossalai/checkpoint_io/checkpoint_io_base.py @@ -86,6 +86,7 @@ def load_model(self, # the existence of index file means it is a sharded checkpoint ckpt_path = Path(checkpoint) index_file_exists, index_file_path = has_index_file(checkpoint) + # return the origin model instead of the unwrapped model origin_model = model diff --git a/colossalai/checkpoint_io/general_checkpoint_io.py b/colossalai/checkpoint_io/general_checkpoint_io.py index 1285a7c12b2b..f2a98cd9d838 100644 --- a/colossalai/checkpoint_io/general_checkpoint_io.py +++ b/colossalai/checkpoint_io/general_checkpoint_io.py @@ -5,10 +5,8 @@ from torch.optim import Optimizer import logging import os -import json import gc from typing import Optional, Iterator, OrderedDict, Tuple -import itertools from .checkpoint_io_base import CheckpointIO from .index_file import CheckpointIndexFile diff --git a/colossalai/checkpoint_io/index_file.py b/colossalai/checkpoint_io/index_file.py index b2fb03ba3016..15a6d09f3b5e 100644 --- a/colossalai/checkpoint_io/index_file.py +++ b/colossalai/checkpoint_io/index_file.py @@ -3,7 +3,6 @@ from typing import Any, List, Union import os import json -import logging from .utils import is_dtensor_checkpoint diff --git a/colossalai/checkpoint_io/utils.py b/colossalai/checkpoint_io/utils.py index c59d8627abde..16e41631f0d5 100644 --- a/colossalai/checkpoint_io/utils.py +++ b/colossalai/checkpoint_io/utils.py @@ -5,9 +5,6 @@ from typing import List, Mapping, OrderedDict, Optional, Tuple, Iterator from colossalai.tensor.d_tensor.d_tensor import DTensor import re -import os -import json -import logging SAFE_WEIGHTS_NAME = "model.safetensors" WEIGHTS_NAME = "pytorch_model.bin" From 9be4e9c254e0d346c0917bb252a8b6ab382b07ee Mon Sep 17 00:00:00 2001 From: Mingyan Jiang <1829166702@qq.com> Date: Fri, 5 May 2023 11:01:21 +0800 Subject: [PATCH 26/28] [API Refactoring]gemini plugin support shard checkpoint --- colossalai/checkpoint_io/general_checkpoint_io.py | 4 ++-- examples/language/gpt/gemini/train_gpt_demo.py | 1 - 2 files changed, 2 insertions(+), 3 deletions(-) diff --git a/colossalai/checkpoint_io/general_checkpoint_io.py b/colossalai/checkpoint_io/general_checkpoint_io.py index f2a98cd9d838..96a883fdb42a 100644 --- a/colossalai/checkpoint_io/general_checkpoint_io.py +++ b/colossalai/checkpoint_io/general_checkpoint_io.py @@ -132,8 +132,8 @@ def load_sharded_model(self, model: nn.Module, checkpoint_index_file: Path, stri gc.collect() if strict: - remian_keys = reduce(lambda a, b: a & b, map(set, missing_keys)) - if len(remian_keys) > 0: + remain_keys = reduce(lambda a, b: a & b, map(set, missing_keys)) + if len(remain_keys) > 0: error_msgs = 'Missing key(s) in state_dict: {}. '.format( ', '.join('"{}"'.format(k) for k in missing_keys)) raise RuntimeError('Error(s) in loading state_dict for {}:\n\t{}'.format( diff --git a/examples/language/gpt/gemini/train_gpt_demo.py b/examples/language/gpt/gemini/train_gpt_demo.py index afe0920b7a68..b2a7fa36d021 100644 --- a/examples/language/gpt/gemini/train_gpt_demo.py +++ b/examples/language/gpt/gemini/train_gpt_demo.py @@ -9,7 +9,6 @@ from commons.utils import get_data, get_profile_context, get_tflops, get_time_stamp from packaging import version from torch.nn.parallel import DistributedDataParallel as DDP -from colossalai.zero.gemini.utils import get_static_torch_model import colossalai from colossalai.logging import disable_existing_loggers, get_dist_logger From fca924c930e49db0f089abd4131a8b2db1dbd4b8 Mon Sep 17 00:00:00 2001 From: Mingyan Jiang <1829166702@qq.com> Date: Fri, 5 May 2023 11:43:28 +0800 Subject: [PATCH 27/28] [API Refactoring]gemini plugin support shard checkpoint --- colossalai/zero/gemini/gemini_ddp.py | 1 + 1 file changed, 1 insertion(+) diff --git a/colossalai/zero/gemini/gemini_ddp.py b/colossalai/zero/gemini/gemini_ddp.py index f254dffc2e25..361dd9bd4d23 100644 --- a/colossalai/zero/gemini/gemini_ddp.py +++ b/colossalai/zero/gemini/gemini_ddp.py @@ -107,6 +107,7 @@ def _get_non_persistent_buffers_set(self, module, memo: Optional[Set[nn.Module]] if memo is None: memo = set() + self_non_persistent_set = set() if module not in memo: if remove_duplicate: memo.add(module) From 6d8270e8eb8712a57d0aabced02a6841c46a15f8 Mon Sep 17 00:00:00 2001 From: Mingyan Jiang <1829166702@qq.com> Date: Fri, 5 May 2023 11:45:22 +0800 Subject: [PATCH 28/28] [API Refactoring]gemini plugin support shard checkpoint --- colossalai/zero/gemini/gemini_ddp.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/colossalai/zero/gemini/gemini_ddp.py b/colossalai/zero/gemini/gemini_ddp.py index 361dd9bd4d23..3802568151d3 100644 --- a/colossalai/zero/gemini/gemini_ddp.py +++ b/colossalai/zero/gemini/gemini_ddp.py @@ -116,8 +116,8 @@ def _get_non_persistent_buffers_set(self, module, memo: Optional[Set[nn.Module]] if sub_module is None: continue submodule_prefix = prefix + ('.' if prefix else '') + name - child_self_non_persistent_set = self._get_non_persistent_buffers_set(sub_module, memo, submodule_prefix, remove_duplicate) - self_non_persistent_set = set.union(self_non_persistent_set, child_self_non_persistent_set) + child_non_persistent_set = self._get_non_persistent_buffers_set(sub_module, memo, submodule_prefix, remove_duplicate) + self_non_persistent_set = set.union(self_non_persistent_set, child_non_persistent_set) return self_non_persistent_set