From 0a3bb550fa076ebceef7a1e4f58bd161f7d491e4 Mon Sep 17 00:00:00 2001 From: Baizhou Zhang Date: Fri, 25 Aug 2023 20:48:08 +0800 Subject: [PATCH 01/15] implement sharded optimizer saving --- .../booster/plugin/hybrid_parallel_plugin.py | 29 ++- .../hybrid_parallel_checkpoint_io.py | 180 +++++++++++++++++- ...st_hybrid_parallel_plugin_checkpoint_io.py | 2 + 3 files changed, 202 insertions(+), 9 deletions(-) diff --git a/colossalai/booster/plugin/hybrid_parallel_plugin.py b/colossalai/booster/plugin/hybrid_parallel_plugin.py index c49b3e1823cd..4228296ce3ba 100644 --- a/colossalai/booster/plugin/hybrid_parallel_plugin.py +++ b/colossalai/booster/plugin/hybrid_parallel_plugin.py @@ -110,6 +110,30 @@ def unwrap(self): return module +def get_param_info(optim: Optimizer): + # Get a backup of necessary information of optimizer for future use, which includes: + # 1. A complete param_group, with params in the form of param_id + # 2. A mapping from param address to param_id + # 3. A mapping from param_id to param address + + param_info = {'param_groups': [], 'param2id': {}, 'id2param': {}} + start_index = 0 + for group in optim.param_groups: + + packed_group = {k: v for k, v in group.items() if k != 'params'} + packed_group['params'] = [] + + for param_id, param in enumerate(group['params'], start_index): + packed_group['params'].append(param_id) + param_info['param2id'][id(param)] = param_id + param_info['id2param'][param_id] = id(param) + + param_info['param_groups'].append(packed_group) + start_index += len(group['params']) + + return param_info + + def init_pipeline_optimizer(optim: Optimizer, model: Module): params = set(model.parameters()) new_param_groups = [] @@ -122,6 +146,7 @@ def init_pipeline_optimizer(optim: Optimizer, model: Module): class HybridParallelNaiveOptimizer(OptimizerWrapper): def __init__(self, optim: Optimizer, model: Module, use_pipeline: bool): + self.param_info = get_param_info(optim) if use_pipeline: init_pipeline_optimizer(optim, model) super().__init__(optim) @@ -142,6 +167,7 @@ def __init__(self, hysteresis: int = 2, max_scale: float = 2**32, max_norm: float = 0): + self.param_info = get_param_info(optim) if use_pipeline: init_pipeline_optimizer(optim, model) super().__init__(optim, precision, initial_scale, min_scale, growth_factor, backoff_factor, growth_interval, @@ -172,6 +198,7 @@ def __init__( dp_process_group: Optional[ProcessGroup] = None, # the dp pg for comm tp_process_group: Optional[ProcessGroup] = None, # if using tp forced_dtype: Optional[torch.dtype] = None): + self.param_info = get_param_info(optimizer) if use_pipeline: init_pipeline_optimizer(optimizer, model) super().__init__(optimizer, initial_scale, min_scale, growth_factor, backoff_factor, growth_interval, @@ -461,7 +488,7 @@ def seed_worker(worker_id): **_kwargs) def get_checkpoint_io(self) -> CheckpointIO: - return HypridParallelCheckpointIO(self.dp_group, self.pp_group, self.tp_group) + return HypridParallelCheckpointIO(self.dp_group, self.pp_group, self.tp_group, self.zero_stage) def no_sync(self, model: Module) -> Iterator[None]: raise NotImplementedError diff --git a/colossalai/checkpoint_io/hybrid_parallel_checkpoint_io.py b/colossalai/checkpoint_io/hybrid_parallel_checkpoint_io.py index 56a89bff75ca..6b608625b183 100644 --- a/colossalai/checkpoint_io/hybrid_parallel_checkpoint_io.py +++ b/colossalai/checkpoint_io/hybrid_parallel_checkpoint_io.py @@ -14,12 +14,14 @@ from torch.optim.lr_scheduler import _LRScheduler as LRScheduler from colossalai.cluster import ProcessGroupMesh +from colossalai.interface import OptimizerWrapper from colossalai.tensor.d_tensor import ( is_customized_distributed_tensor, is_distributed_tensor, to_global, to_global_for_customized_distributed_tensor, ) +from colossalai.zero.low_level import LowLevelZeroOptimizer from .general_checkpoint_io import GeneralCheckpointIO from .index_file import CheckpointIndexFile @@ -52,9 +54,10 @@ class HypridParallelCheckpointIO(GeneralCheckpointIO): dp_group (ProcessGroup): Process group along data parallel dimension. pp_group (ProcessGroup): Process group along pipeline parallel dimension. tp_group (ProcessGroup): Process group along tensor parallel dimension. + zero_stage (int): The zero stage of plugin. Should be in [0, 1, 2]. """ - def __init__(self, dp_group: ProcessGroup, pp_group: ProcessGroup, tp_group: ProcessGroup) -> None: + def __init__(self, dp_group: ProcessGroup, pp_group: ProcessGroup, tp_group: ProcessGroup, zero_stage: int) -> None: super().__init__() self.dp_group = dp_group self.pp_group = pp_group @@ -65,6 +68,7 @@ def __init__(self, dp_group: ProcessGroup, pp_group: ProcessGroup, tp_group: Pro self.dp_size = dist.get_world_size(dp_group) self.pp_size = dist.get_world_size(pp_group) self.tp_size = dist.get_world_size(tp_group) + self.zero_stage = zero_stage @staticmethod def _model_sharder(model: nn.Module, @@ -106,10 +110,54 @@ def _model_sharder(model: nn.Module, yield state_dict_sharder.current_block, state_dict_sharder.current_block_size @staticmethod - def _optimizer_sharder(optimizer: Optimizer, size_per_shard: int = 1024): + def _optimizer_sharder(optimizer: OptimizerWrapper, + use_zero: bool, + dp_group: ProcessGroup, + size_per_shard: int = 1024): + # An internel method that breaks state_dict of optimizer into shards within limited size. - # TODO (Baizhou): Implement sharding feature of optimizer. - pass + + state_dict_sharder = StateDictSharder(size_per_shard) + dp_size = dist.get_world_size(dp_group) + + for param, state in optimizer.optim.state.items(): + + if param is None: + continue + + state_ = copy.deepcopy(state) + + # First handle Zero shards. Working params in fp16 should first be obtained. + if use_zero: + working_param = optimizer._param_store.master_to_working_param[id(param)] + for k, v in state.items(): + if isinstance(v, torch.Tensor) and k != 'step': + v = v.cuda() + gather_tensor = [torch.zeros_like(v) for _ in range(dp_size)] + dist.all_gather(gather_tensor, v, group=dp_group) + param_state = torch.stack(gather_tensor).view(-1)[:working_param.numel()].reshape_as( + working_param) + state_[k] = param_state.detach().cpu() + param_id = optimizer.param_info['param2id'][id(working_param)] + else: + param_id = optimizer.param_info['param2id'][id(param)] + + # Then hanlde TP Shards + use_tp = is_distributed_tensor(param) or is_customized_distributed_tensor(param) + if use_tp: + # TODO: solve the tensor parallel case + # Either collect shape information from dtensor.ShardSpec + # Or collect shape information before sharder.optimize + # I think the second way is more reasonable + pass + + state_ = {k: v.detach().cpu() for k, v in state_.items()} + block, block_size = state_dict_sharder.append(param_id, state_) + if block is not None: + yield block, block_size + + # Return the last block in sharder. + yield state_dict_sharder.current_block, state_dict_sharder.current_block_size def save_sharded_model(self, model: nn.Module, @@ -148,7 +196,7 @@ def save_sharded_model(self, return # Then collect the sharded parameters & buffers along tp_group. - # Only devices with tp_size == 0 are responsible for model saving. + # Only devices with tp_rank == 0 are responsible for model saving. state_dict_shard = HypridParallelCheckpointIO._model_sharder(model, size_per_shard=size_per_shard) weights_name, save_index_file = get_model_base_filenames(prefix, use_safetensors) index_file = CheckpointIndexFile(checkpoint) @@ -282,14 +330,130 @@ def _load(name: str): _load(extra_state_key) def save_sharded_optimizer(self, - optimizer: Optimizer, + optimizer: OptimizerWrapper, checkpoint: str, gather_dtensor: bool = True, prefix: Optional[str] = None, size_per_shard: int = 1024): - pass + """ + Save sharded optimizer checkpoint under the given checkpointing path. + The following files will be created under the path: + - An index file (pytorch_optim.bin.index.json) containing a map between optimizer states and file names + - A group file (pytorch_optim_group.bin) recording information of param_groups + - Multiple files that store state tensors of optimizers. + If pipeline parallelism is used, the filenames are in the form of "pytorch_optim.-stage-000XX-shard-000XX.bin". + If pipeline parallelism is not used, "pytorch_optim.-000XX.bin" + + Args: + optimizer (OptimizerWrapper): Optimizer to save sharded state_dict + checkpoint (str): Path to save optimizer state_dict + gather_dtensor (bool): Whether to gather_dtensor, not used + prefix (str): Perfix of file to save + size_per_shard (int): Max file size of each file shard that store state tensors + """ + if os.path.isfile(checkpoint): + logging.error(f"Provided path ({checkpoint}) should be a directory, not a file") + return + + Path(checkpoint).mkdir(parents=True, exist_ok=True) + + # Devices along the same dp_group share the same copies of states when zero is not used. + # In this case only let the device with dp_rank == 0 save the model. + if self.zero_stage == 0 and self.dp_rank != 0: + return + + # Then collect the sharded states along dp_group(if using zero)/tp_group. + # Only devices with (dp_rank == 0 and tp_rank == 0) are responsible for states saving. + state_dict_shard = HypridParallelCheckpointIO._optimizer_sharder(optimizer, + use_zero=(self.zero_stage > 0), + dp_group=self.dp_group, + size_per_shard=size_per_shard) + states_name, save_index_file, param_group_file = get_optimizer_base_filenames(prefix) + index_file = CheckpointIndexFile(checkpoint) + control_saving = (self.dp_rank == 0 and self.tp_rank == 0) + + if self.pp_size == 1: + # When pipeline is not used, save the optimizer shards as in general checkpointIO + total_size = save_state_dict_shards(sharded_state_dict=state_dict_shard, + checkpoint=checkpoint, + index_file=index_file, + base_filename=states_name, + is_master=control_saving) + + if control_saving: + # Store param groups. + index_file.append_meta_data("param_groups", param_group_file) + group_file_path = os.path.join(checkpoint, param_group_file) + save_param_groups(optimizer.param_info, group_file_path) + # Store index file. + index_file.append_meta_data("total_size", total_size) + index_file.write_index_file(save_index_file) + logging.info(f"The optimizer is going to be split to checkpoint shards. " + f"You can find where each parameters has been saved in the " + f"index located at {save_index_file}.") + + else: + # When pipeline is used, each stage produces its own shard files and index files. + # Index files belonging to each stage are saved under a temporary folder ./tmp_index_files/ + # After all the state_dicts have been saved, the master rank integrates all the index files into one final index file and deletes the tmp folder. + + final_index_file_path = copy.deepcopy(save_index_file) + tmp_index_file_folder = os.path.join(checkpoint, "tmp_index_files") + Path(tmp_index_file_folder).mkdir(parents=True, exist_ok=True) + + # Manage filenames of sharded weights and index file for each pipeline stage. + states_name = states_name.replace(".bin", f"-stage-{self.pp_rank:05d}-shard.bin") + save_index_file = save_index_file.replace(".json", f"-stage-{self.pp_rank:05d}.json") + save_index_file = os.path.join("tmp_index_files", save_index_file) - def load_sharded_optimizer(self, optimizer: Optimizer, index_file_path: str, prefix: str): + total_size = save_state_dict_shards(sharded_state_dict=state_dict_shard, + checkpoint=checkpoint, + index_file=index_file, + base_filename=states_name, + is_master=control_saving) + + if control_saving: + assert self.dp_rank == 0 and self.tp_rank == 0, "The saving process should have both dp_rank and tp_rank as 0." + index_file.append_meta_data("total_size", total_size) + index_file.write_index_file(save_index_file) + else: + return + + dist.barrier(self.pp_group) + + # The global master rank integrates the index files and clean the folder. + if self.pp_rank == 0: + + final_index_file = CheckpointIndexFile(checkpoint) + final_index_file.append_meta_data("total_size", 0) + + for filename in os.listdir(tmp_index_file_folder): + stage_index_file = CheckpointIndexFile.from_file(os.path.join(tmp_index_file_folder, filename)) + final_index_file.metadata["total_size"] += stage_index_file.metadata["total_size"] + for states, states_filename in stage_index_file.weight_map.items(): + final_index_file.append_weight_map(states, states_filename) + + # Store param groups. + final_index_file.append_meta_data("param_groups", param_group_file) + group_file_path = os.path.join(checkpoint, param_group_file) + save_param_groups(optimizer.param_info, group_file_path) + + final_index_file.write_index_file(final_index_file_path) + + rmtree(tmp_index_file_folder) + logging.info(f"The model is split into checkpoint shards. " + f"You can find where each parameters has been saved in the " + f"index located at {final_index_file_path}.") + + def load_sharded_optimizer(self, optimizer: Optimizer, index_file_path: str, prefix: str = ""): + """ + Load sharded optimizer with the given path to index file of checkpoint folder. + + Args: + optimizer (OptimizerWrapper): The optimizer to be loaded. + index_file_path (str): Path to the index file of checkpointing folder. + prefix (str): Not used. + """ pass def load_unsharded_model(self, model: nn.Module, checkpoint: str, strict: bool = True): diff --git a/tests/test_checkpoint_io/test_hybrid_parallel_plugin_checkpoint_io.py b/tests/test_checkpoint_io/test_hybrid_parallel_plugin_checkpoint_io.py index 67d73c31f6e0..1fef389c58e2 100644 --- a/tests/test_checkpoint_io/test_hybrid_parallel_plugin_checkpoint_io.py +++ b/tests/test_checkpoint_io/test_hybrid_parallel_plugin_checkpoint_io.py @@ -18,6 +18,8 @@ ) from tests.kit.model_zoo import model_zoo +# TODO (Baizhou): Add test cases for: shard=False/PP+Zero + @clear_cache_before_run() @parameterize('shard', [True]) From 0cf7d68d2d02f109f2f17a0a397729fc426a1e54 Mon Sep 17 00:00:00 2001 From: Baizhou Zhang Date: Fri, 25 Aug 2023 22:21:46 +0800 Subject: [PATCH 02/15] add more param info --- colossalai/booster/plugin/hybrid_parallel_plugin.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/colossalai/booster/plugin/hybrid_parallel_plugin.py b/colossalai/booster/plugin/hybrid_parallel_plugin.py index 4228296ce3ba..38c478c66f9c 100644 --- a/colossalai/booster/plugin/hybrid_parallel_plugin.py +++ b/colossalai/booster/plugin/hybrid_parallel_plugin.py @@ -111,10 +111,10 @@ def unwrap(self): def get_param_info(optim: Optimizer): - # Get a backup of necessary information of optimizer for future use, which includes: + # Get a backup of necessary information of parameters for future use, which includes: # 1. A complete param_group, with params in the form of param_id # 2. A mapping from param address to param_id - # 3. A mapping from param_id to param address + # 3. A mapping from param_id to param address, as well as the original shape of parameter. param_info = {'param_groups': [], 'param2id': {}, 'id2param': {}} start_index = 0 @@ -124,9 +124,10 @@ def get_param_info(optim: Optimizer): packed_group['params'] = [] for param_id, param in enumerate(group['params'], start_index): + original_shape = param.shape if isinstance(param, torch.Tensor) else None packed_group['params'].append(param_id) param_info['param2id'][id(param)] = param_id - param_info['id2param'][param_id] = id(param) + param_info['id2param'][param_id] = {'id': id(param), 'original_shape': original_shape} param_info['param_groups'].append(packed_group) start_index += len(group['params']) From a896297ac6219aaf7b9eb8982da97e7027a09734 Mon Sep 17 00:00:00 2001 From: Baizhou Zhang Date: Sun, 27 Aug 2023 21:14:16 +0800 Subject: [PATCH 03/15] finish implementation of sharded optimizer saving --- .../booster/plugin/hybrid_parallel_plugin.py | 10 +- .../hybrid_parallel_checkpoint_io.py | 33 ++--- colossalai/checkpoint_io/utils.py | 131 +++++++++++++++--- 3 files changed, 128 insertions(+), 46 deletions(-) diff --git a/colossalai/booster/plugin/hybrid_parallel_plugin.py b/colossalai/booster/plugin/hybrid_parallel_plugin.py index 38c478c66f9c..cd0fea104b44 100644 --- a/colossalai/booster/plugin/hybrid_parallel_plugin.py +++ b/colossalai/booster/plugin/hybrid_parallel_plugin.py @@ -113,10 +113,11 @@ def unwrap(self): def get_param_info(optim: Optimizer): # Get a backup of necessary information of parameters for future use, which includes: # 1. A complete param_group, with params in the form of param_id - # 2. A mapping from param address to param_id - # 3. A mapping from param_id to param address, as well as the original shape of parameter. + # 2. A mapping from param address (obtained using id(param)) to integer param_id + # 3. A mapping from integer param_id to param address. + # 4. A mapping from param_address (obtained using id(param)) to the original shape of parameter before sharding. - param_info = {'param_groups': [], 'param2id': {}, 'id2param': {}} + param_info = {'param_groups': [], 'param2id': {}, 'id2param': {}, 'param2shape': {}} start_index = 0 for group in optim.param_groups: @@ -127,7 +128,8 @@ def get_param_info(optim: Optimizer): original_shape = param.shape if isinstance(param, torch.Tensor) else None packed_group['params'].append(param_id) param_info['param2id'][id(param)] = param_id - param_info['id2param'][param_id] = {'id': id(param), 'original_shape': original_shape} + param_info['id2param'][param_id] = id(param) + param_info['param2shape'][id(param)] = original_shape param_info['param_groups'].append(packed_group) start_index += len(group['params']) diff --git a/colossalai/checkpoint_io/hybrid_parallel_checkpoint_io.py b/colossalai/checkpoint_io/hybrid_parallel_checkpoint_io.py index 6b608625b183..d3178520571f 100644 --- a/colossalai/checkpoint_io/hybrid_parallel_checkpoint_io.py +++ b/colossalai/checkpoint_io/hybrid_parallel_checkpoint_io.py @@ -13,21 +13,13 @@ from torch.optim import Optimizer from torch.optim.lr_scheduler import _LRScheduler as LRScheduler -from colossalai.cluster import ProcessGroupMesh from colossalai.interface import OptimizerWrapper -from colossalai.tensor.d_tensor import ( - is_customized_distributed_tensor, - is_distributed_tensor, - to_global, - to_global_for_customized_distributed_tensor, -) -from colossalai.zero.low_level import LowLevelZeroOptimizer from .general_checkpoint_io import GeneralCheckpointIO from .index_file import CheckpointIndexFile from .utils import ( StateDictSharder, - calculate_tensor_size, + gather_distributed_optimizer_states, gather_distributed_param, get_model_base_filenames, get_optimizer_base_filenames, @@ -85,7 +77,7 @@ def _model_sharder(model: nn.Module, continue # Gather tensor pieces when using tensor parallel. param_ = gather_distributed_param(param, keep_vars=False) - block, block_size = state_dict_sharder.append(prefix + name, param_) + block, block_size = state_dict_sharder.append_param(prefix + name, param_) if block is not None: yield block, block_size @@ -93,7 +85,7 @@ def _model_sharder(model: nn.Module, for name, buf in model.named_buffers(): if buf is not None and name not in model._non_persistent_buffers_set: buffer = buf if keep_vars else buf.detach() - block, block_size = state_dict_sharder.append(prefix + name, buffer) + block, block_size = state_dict_sharder.append_param(prefix + name, buffer) if block is not None: yield block, block_size @@ -102,7 +94,7 @@ def _model_sharder(model: nn.Module, if getattr(model.__class__, "get_extra_state", torch.nn.Module.get_extra_state) is not torch.nn.Module.get_extra_state: extra_state = model.get_extra_state() - block, block_size = state_dict_sharder.append(extra_state_key, extra_state) + block, block_size = state_dict_sharder.append_param(extra_state_key, extra_state) if block is not None: yield block, block_size @@ -113,6 +105,7 @@ def _model_sharder(model: nn.Module, def _optimizer_sharder(optimizer: OptimizerWrapper, use_zero: bool, dp_group: ProcessGroup, + tp_group: ProcessGroup, size_per_shard: int = 1024): # An internel method that breaks state_dict of optimizer into shards within limited size. @@ -129,6 +122,8 @@ def _optimizer_sharder(optimizer: OptimizerWrapper, # First handle Zero shards. Working params in fp16 should first be obtained. if use_zero: + # Zero optimizer has replaced params in param_groups with sharded master params in fp32, + # so we have to first target working params in fp16 using stored mapping. working_param = optimizer._param_store.master_to_working_param[id(param)] for k, v in state.items(): if isinstance(v, torch.Tensor) and k != 'step': @@ -138,21 +133,16 @@ def _optimizer_sharder(optimizer: OptimizerWrapper, param_state = torch.stack(gather_tensor).view(-1)[:working_param.numel()].reshape_as( working_param) state_[k] = param_state.detach().cpu() + param = working_param param_id = optimizer.param_info['param2id'][id(working_param)] else: param_id = optimizer.param_info['param2id'][id(param)] # Then hanlde TP Shards - use_tp = is_distributed_tensor(param) or is_customized_distributed_tensor(param) - if use_tp: - # TODO: solve the tensor parallel case - # Either collect shape information from dtensor.ShardSpec - # Or collect shape information before sharder.optimize - # I think the second way is more reasonable - pass - + original_shape = optimizer.param_info['param2shape'][id(param)] + state_ = gather_distributed_optimizer_states(state_, param, original_shape, tp_group) state_ = {k: v.detach().cpu() for k, v in state_.items()} - block, block_size = state_dict_sharder.append(param_id, state_) + block, block_size = state_dict_sharder.append_optim_state(param_id, state_) if block is not None: yield block, block_size @@ -367,6 +357,7 @@ def save_sharded_optimizer(self, state_dict_shard = HypridParallelCheckpointIO._optimizer_sharder(optimizer, use_zero=(self.zero_stage > 0), dp_group=self.dp_group, + tp_group=self.tp_group, size_per_shard=size_per_shard) states_name, save_index_file, param_group_file = get_optimizer_base_filenames(prefix) index_file = CheckpointIndexFile(checkpoint) diff --git a/colossalai/checkpoint_io/utils.py b/colossalai/checkpoint_io/utils.py index d04159c54d5e..bf8f31b811e9 100644 --- a/colossalai/checkpoint_io/utils.py +++ b/colossalai/checkpoint_io/utils.py @@ -1,4 +1,5 @@ # coding=utf-8 +import copy import os import re from collections import abc as container_abcs @@ -8,7 +9,9 @@ from typing import Iterator, List, Mapping, Optional, OrderedDict, Tuple import torch +import torch.distributed as dist import torch.nn as nn +from torch.distributed import ProcessGroup from torch.optim import Optimizer from colossalai.interface import OptimizerWrapper @@ -93,26 +96,6 @@ def is_safetensor_checkpoint(checkpoint_file_path: str) -> bool: return False -def gather_distributed_param(param: torch.Tensor, keep_vars: bool = False): - """ - Gather the complete parameter for saving if passed in param is distributed. - - Args: - param (torch.Tensor): A model parameter, might be d_tensor. - keep_vars (bool, optional): Whether to return the parameter in calculation graph. Defaults to False. - - Returns: - torch.Tensor: the complete parameter - """ - param_ = param if keep_vars else param.detach() - if is_distributed_tensor(param_): - return to_global(param_) - elif is_customized_distributed_tensor(param_): - return to_global_for_customized_distributed_tensor(param_) - else: - return param_ - - # ====================================== # Helper classes and functions for saving shard file # ====================================== @@ -136,7 +119,8 @@ def __init__(self, size_per_shard: int) -> None: self.current_block = OrderedDict() self.current_block_size = 0 - def append(self, name: str, tensor: torch.Tensor) -> Tuple[Optional[OrderedDict], int]: + def append_param(self, name: str, tensor: torch.Tensor) -> Tuple[Optional[OrderedDict], int]: + tensor_size = calculate_tensor_size(tensor) ret_block = None ret_block_size = 0 @@ -153,6 +137,111 @@ def append(self, name: str, tensor: torch.Tensor) -> Tuple[Optional[OrderedDict] self.current_block_size += tensor_size return ret_block, ret_block_size + def append_optim_state(self, param_id: int, state: OrderedDict) -> Tuple[Optional[OrderedDict], int]: + + state_size = 0 + for k, v in state.items(): + if isinstance(v, torch.Tensor): + state_size += calculate_tensor_size(v) + ret_block = None + ret_block_size = 0 + + # before we return the current block and create a new block, + # we need to ensure that the current block is not empty + if self.current_block_size + state_size > self.max_shard_size and self.current_block_size > 0: + ret_block = self.current_block + ret_block_size = self.current_block_size + self.current_block = OrderedDict() + self.current_block_size = 0 + + self.current_block[param_id] = state + self.current_block_size += state_size + return ret_block, ret_block_size + + +def gather_distributed_param(param: torch.Tensor, keep_vars: bool = False) -> torch.Tensor: + """ + Gather the complete parameter for saving if passed in param is distributed under tp setting. + + Args: + param (torch.Tensor): A model parameter, might be d_tensor. + keep_vars (bool, optional): Whether to return the parameter in calculation graph. Defaults to False. + + Returns: + torch.Tensor: the complete parameter + """ + param_ = param if keep_vars else param.detach() + if is_distributed_tensor(param_): + return to_global(param_) + elif is_customized_distributed_tensor(param_): + return to_global_for_customized_distributed_tensor(param_) + else: + return param_ + + +def search_tp_partition_dim(current_shape: torch.Size, original_shape: torch.Size, tp_size: int) -> Optional[int]: + """ + Given the current shape of parameter and the shape of parameter before sharding, + return the dimension along which the parameter is sharded when using tensor parallel. + If tensor parallel is not used, return None. + + Args: + current_shape (torch.Size): The current shape of parameter after sharding. + original_shape (torch.Size): The shape of parameter before sharding. + tp_size (int): The size of tp group. + + Returns: + Optional[int]: The dimension along which parameter is partitioned. + """ + partition_dim = None + for dim, length in enumerate(original_shape): + if length > current_shape[dim]: + partition_dim = dim + break + if partition_dim is not None: + assert original_shape[partition_dim] == tp_size * current_shape[partition_dim], \ + f"The parameter isn't evenly distributed among tensor parallel group: \ + shape before sharding {original_shape}, shape after sharding {current_shape}" + + return partition_dim + + +def gather_distributed_optimizer_states(state: OrderedDict, + param: torch.Tensor, + original_shape: torch.Size, + tp_group: ProcessGroup, + inplace: bool = True) -> OrderedDict: + """ + With given parameter and its optimizer states, + gather the complete optimizer state for saving + if the passed in param is distributed under tp setting. + + Args: + state (OrderedDict): Optimizer states of given parameter, might be distributed among tp group if tp is used. + param (torch.Tensor): The given parameter, might be d_tensor. + original_shape (torch.Size): The size of parameter before sharding. + tp_group (ProcessGroup): The process group of tensor parallel. + inplace (bool, optional): If set to True, will update the values of passed in state dict to the gathered states. Defaults to True. + + Returns: + OrderedDict: The complete optimizer state of given parameter. + """ + state_ = state if inplace else copy.deepcopy(state) + if is_distributed_tensor(param) or is_customized_distributed_tensor(param): + tp_size = dist.get_world_size(tp_group) + partition_dim = search_tp_partition_dim(param.shape, original_shape, tp_size) + if partition_dim is not None: + for k, v in state.items(): + if isinstance(v, torch.Tensor) and k != 'step': + v = v.cuda() + gather_tensor = [torch.zeros_like(v) for _ in range(tp_size)] + dist.all_gather(gather_tensor, v, group=tp_group) + param_state = torch.stack(gather_tensor, dim=partition_dim) + state_[k] = param_state.detach().cpu() + if inplace: + del v + return state_ + def save_state_dict_shards(sharded_state_dict: Iterator[Tuple[OrderedDict, int]], checkpoint: str, From 6092b6ae118db478a11bbcc91cc662503f5cae77 Mon Sep 17 00:00:00 2001 From: Baizhou Zhang Date: Mon, 28 Aug 2023 11:18:13 +0800 Subject: [PATCH 04/15] fix bugs in optimizer sharded saving --- .../booster/plugin/hybrid_parallel_plugin.py | 21 ++++++++++++------- .../hybrid_parallel_checkpoint_io.py | 18 +++++++++------- 2 files changed, 25 insertions(+), 14 deletions(-) diff --git a/colossalai/booster/plugin/hybrid_parallel_plugin.py b/colossalai/booster/plugin/hybrid_parallel_plugin.py index cd0fea104b44..9c16826ba478 100644 --- a/colossalai/booster/plugin/hybrid_parallel_plugin.py +++ b/colossalai/booster/plugin/hybrid_parallel_plugin.py @@ -1,7 +1,7 @@ import random from contextlib import nullcontext from functools import partial -from typing import Any, Callable, Iterator, List, Optional, Tuple, Union +from typing import Any, Callable, Iterator, List, Optional, OrderedDict, Tuple, Union import numpy as np import torch @@ -116,7 +116,8 @@ def get_param_info(optim: Optimizer): # 2. A mapping from param address (obtained using id(param)) to integer param_id # 3. A mapping from integer param_id to param address. # 4. A mapping from param_address (obtained using id(param)) to the original shape of parameter before sharding. - + if optim is None: + return {} param_info = {'param_groups': [], 'param2id': {}, 'id2param': {}, 'param2shape': {}} start_index = 0 for group in optim.param_groups: @@ -148,8 +149,8 @@ def init_pipeline_optimizer(optim: Optimizer, model: Module): class HybridParallelNaiveOptimizer(OptimizerWrapper): - def __init__(self, optim: Optimizer, model: Module, use_pipeline: bool): - self.param_info = get_param_info(optim) + def __init__(self, optim: Optimizer, model: Module, use_pipeline: bool, param_info: OrderedDict): + self.param_info = param_info if use_pipeline: init_pipeline_optimizer(optim, model) super().__init__(optim) @@ -161,6 +162,7 @@ def __init__(self, optim: Optimizer, model: Module, use_pipeline: bool, + param_info: OrderedDict, precision: str = 'fp16', initial_scale: float = 2**16, min_scale: float = 1, @@ -170,7 +172,7 @@ def __init__(self, hysteresis: int = 2, max_scale: float = 2**32, max_norm: float = 0): - self.param_info = get_param_info(optim) + self.param_info = param_info if use_pipeline: init_pipeline_optimizer(optim, model) super().__init__(optim, precision, initial_scale, min_scale, growth_factor, backoff_factor, growth_interval, @@ -184,6 +186,7 @@ def __init__( optimizer: Optimizer, model: Module, use_pipeline: bool, + param_info: OrderedDict, initial_scale: int = 2**16, # grad scaler config min_scale: int = 1, growth_factor: float = 2., @@ -201,7 +204,7 @@ def __init__( dp_process_group: Optional[ProcessGroup] = None, # the dp pg for comm tp_process_group: Optional[ProcessGroup] = None, # if using tp forced_dtype: Optional[torch.dtype] = None): - self.param_info = get_param_info(optimizer) + self.param_info = param_info if use_pipeline: init_pipeline_optimizer(optimizer, model) super().__init__(optimizer, initial_scale, min_scale, growth_factor, backoff_factor, growth_interval, @@ -386,6 +389,7 @@ def configure( dataloader: Optional[DataLoader] = None, lr_scheduler: Optional[LRScheduler] = None, ) -> Tuple[Module, OptimizerWrapper, Callable, DataLoader, LRScheduler]: + param_info = get_param_info(optimizer) if not isinstance(model, ModelWrapper): use_ddp = self.dp_size > 1 and self.pp_size == 1 and self.zero_stage == 0 model = HybridParallelModule(model, self.precision, self.shard_config, self.dp_group, use_ddp, @@ -396,19 +400,22 @@ def configure( optimizer = HybridParallelAMPOptimizer(optimizer, model, use_pipeline=self.enable_pipeline_parallelism, + param_info=param_info, precision=self.precision, max_norm=self.max_norm, **self.amp_config) else: optimizer = HybridParallelNaiveOptimizer(optimizer, model, - use_pipeline=self.enable_pipeline_parallelism) + use_pipeline=self.enable_pipeline_parallelism, + param_info=param_info) else: assert self.dp_size > 1, "Please use Zero when data parallel size is greater than 1." assert self.precision != 'fp32', "Please set precision to 'fp16' or 'bf16' when using ZeRO." optimizer = HybridParallelZeroOptimizer(optimizer, model, use_pipeline=self.enable_pipeline_parallelism, + param_info=param_info, dp_process_group=self.dp_group, tp_process_group=self.tp_group, verbose=True, diff --git a/colossalai/checkpoint_io/hybrid_parallel_checkpoint_io.py b/colossalai/checkpoint_io/hybrid_parallel_checkpoint_io.py index d3178520571f..b1edcac39197 100644 --- a/colossalai/checkpoint_io/hybrid_parallel_checkpoint_io.py +++ b/colossalai/checkpoint_io/hybrid_parallel_checkpoint_io.py @@ -1,5 +1,4 @@ import copy -import gc import logging import os from pathlib import Path @@ -112,6 +111,7 @@ def _optimizer_sharder(optimizer: OptimizerWrapper, state_dict_sharder = StateDictSharder(size_per_shard) dp_size = dist.get_world_size(dp_group) + param_info = optimizer.param_info for param, state in optimizer.optim.state.items(): @@ -134,12 +134,11 @@ def _optimizer_sharder(optimizer: OptimizerWrapper, working_param) state_[k] = param_state.detach().cpu() param = working_param - param_id = optimizer.param_info['param2id'][id(working_param)] - else: - param_id = optimizer.param_info['param2id'][id(param)] + + param_id = param_info['param2id'][id(param)] # Then hanlde TP Shards - original_shape = optimizer.param_info['param2shape'][id(param)] + original_shape = param_info['param2shape'][id(param)] state_ = gather_distributed_optimizer_states(state_, param, original_shape, tp_group) state_ = {k: v.detach().cpu() for k, v in state_.items()} block, block_size = state_dict_sharder.append_optim_state(param_id, state_) @@ -421,8 +420,13 @@ def save_sharded_optimizer(self, for filename in os.listdir(tmp_index_file_folder): stage_index_file = CheckpointIndexFile.from_file(os.path.join(tmp_index_file_folder, filename)) final_index_file.metadata["total_size"] += stage_index_file.metadata["total_size"] - for states, states_filename in stage_index_file.weight_map.items(): - final_index_file.append_weight_map(states, states_filename) + for param_id, state_filename in stage_index_file.weight_map.items(): + if param_id not in final_index_file.weight_map: + final_index_file.append_weight_map(param_id, state_filename) + else: + # If parameter is shared with other stage, delete this copy of optimizer state. + # For example: only save one copy of optimizer states between embedding weight & lm_head. + os.remove(os.path.join(checkpoint, state_filename)) # Store param groups. final_index_file.append_meta_data("param_groups", param_group_file) From 99a21db6c8fc5e6a42daf51497ef421b224748e3 Mon Sep 17 00:00:00 2001 From: Baizhou Zhang Date: Mon, 28 Aug 2023 11:23:15 +0800 Subject: [PATCH 05/15] add pp+zero test --- ...st_hybrid_parallel_plugin_checkpoint_io.py | 21 ++++++++----------- 1 file changed, 9 insertions(+), 12 deletions(-) diff --git a/tests/test_checkpoint_io/test_hybrid_parallel_plugin_checkpoint_io.py b/tests/test_checkpoint_io/test_hybrid_parallel_plugin_checkpoint_io.py index 1fef389c58e2..051a099f6af1 100644 --- a/tests/test_checkpoint_io/test_hybrid_parallel_plugin_checkpoint_io.py +++ b/tests/test_checkpoint_io/test_hybrid_parallel_plugin_checkpoint_io.py @@ -18,37 +18,34 @@ ) from tests.kit.model_zoo import model_zoo -# TODO (Baizhou): Add test cases for: shard=False/PP+Zero - +# TODO (Baizhou): Add test cases for shard=False @clear_cache_before_run() @parameterize('shard', [True]) -@parameterize('model_name', ['transformers_gpt']) +@parameterize('model_name', ['transformers_gpt', 'transformers_bert']) @parameterize('size_per_shard', [32]) @parameterize('test_config', [{ 'tp_size': 2, 'pp_size': 2, 'num_microbatches': 4, 'precision': 'fp32', -}, { - 'tp_size': 1, - 'pp_size': 2, - 'num_microbatches': 4, - 'precision': 'fp32', }, { 'tp_size': 4, 'pp_size': 1, 'precision': 'fp32', -}, { - 'tp_size': 2, - 'pp_size': 1, - 'precision': 'fp32', }, { 'tp_size': 2, 'pp_size': 1, 'zero_stage': 2, 'precision': 'fp16', 'initial_scale': 1 +}, { + 'tp_size': 1, + 'pp_size': 2, + 'num_microbatches': 4, + 'zero_stage': 1, + 'precision': 'fp16', + 'initial_scale': 1 }]) def exam_state_dict(shard: bool, model_name: str, size_per_shard: int, test_config: dict): From 79ab97a68dc95659002c8cab29d37b12475a4d1b Mon Sep 17 00:00:00 2001 From: Baizhou Zhang Date: Mon, 28 Aug 2023 17:52:30 +0800 Subject: [PATCH 06/15] param group loading --- .../hybrid_parallel_checkpoint_io.py | 36 ++++++++++++++++--- 1 file changed, 32 insertions(+), 4 deletions(-) diff --git a/colossalai/checkpoint_io/hybrid_parallel_checkpoint_io.py b/colossalai/checkpoint_io/hybrid_parallel_checkpoint_io.py index b1edcac39197..8c502aa2a936 100644 --- a/colossalai/checkpoint_io/hybrid_parallel_checkpoint_io.py +++ b/colossalai/checkpoint_io/hybrid_parallel_checkpoint_io.py @@ -24,6 +24,7 @@ get_optimizer_base_filenames, get_shard_filename, is_safetensors_available, + load_param_groups_into_optimizer, load_shard_state_dict, load_state_dict_into_model, save_param_groups, @@ -259,7 +260,7 @@ def load_sharded_model(self, model: nn.Module, checkpoint_index_file: Path, stri Args: model (nn.Module): The model to be loaded. - index_file_path (str): Path to the index file of checkpointing folder. + checkpoint_index_file (str): Path to the index file of checkpointing folder. strict (bool, optional): For name matching during loading state_dict. Defaults to False. This argument should be manually set to False since params on same device might be stored in different files. """ @@ -440,16 +441,43 @@ def save_sharded_optimizer(self, f"You can find where each parameters has been saved in the " f"index located at {final_index_file_path}.") - def load_sharded_optimizer(self, optimizer: Optimizer, index_file_path: str, prefix: str = ""): + def load_sharded_optimizer(self, optimizer: OptimizerWrapper, checkpoint_index_file: str, prefix: str = ""): """ Load sharded optimizer with the given path to index file of checkpoint folder. Args: optimizer (OptimizerWrapper): The optimizer to be loaded. - index_file_path (str): Path to the index file of checkpointing folder. + checkpoint_index_file (str): Path to the index file of checkpointing folder. prefix (str): Not used. """ - pass + # Read checkpoint index file. + ckpt_index_file = CheckpointIndexFile.from_file(checkpoint_index_file) + + # Load param_groups + param_group_path = ckpt_index_file.get_param_group_filename() + if param_group_path is None: + raise RuntimeError(f'Invalid index file path {checkpoint_index_file} for an optimizer. \ + Lacking param group file under current directory.') + saved_groups = torch.load(param_group_path) + + updated_groups = [] + + # A mapping from integer id to parameter object. + # IDs should be obtained through saved param2id mapping in optimizer.param_info. + id_map = {} + + for old_pg, saved_pg in zip(optimizer.optim.param_groups, saved_groups): + + # update id_map + for param in old_pg['params']: + param_id = optimizer.param_info['param2id'][id(param)] + id_map[param_id] = param + + # obtain updated param group + new_pg = copy.deepcopy(saved_pg) + new_pg['params'] = old_pg['params'] # The parameters in the same group shouln't change. + updated_groups.append(new_pg) + optimizer.optim.__dict__.update({'param_groups': updated_groups}) def load_unsharded_model(self, model: nn.Module, checkpoint: str, strict: bool = True): # TODO(Baizhou): support this feature after implementing complete state_dict collection From 9e71f0ab5ff4151860634c3a20827cf51388b53d Mon Sep 17 00:00:00 2001 From: Baizhou Zhang Date: Tue, 29 Aug 2023 12:28:57 +0800 Subject: [PATCH 07/15] greedy loading of optimizer --- .../booster/plugin/hybrid_parallel_plugin.py | 2 + .../hybrid_parallel_checkpoint_io.py | 74 ++++++++++++++----- colossalai/checkpoint_io/utils.py | 9 ++- 3 files changed, 63 insertions(+), 22 deletions(-) diff --git a/colossalai/booster/plugin/hybrid_parallel_plugin.py b/colossalai/booster/plugin/hybrid_parallel_plugin.py index 9c16826ba478..a49ae44c27a2 100644 --- a/colossalai/booster/plugin/hybrid_parallel_plugin.py +++ b/colossalai/booster/plugin/hybrid_parallel_plugin.py @@ -116,6 +116,8 @@ def get_param_info(optim: Optimizer): # 2. A mapping from param address (obtained using id(param)) to integer param_id # 3. A mapping from integer param_id to param address. # 4. A mapping from param_address (obtained using id(param)) to the original shape of parameter before sharding. + # When Zero is used, the params here are fp16/bf16 model params rather than fp32 master params in optimizer. + if optim is None: return {} param_info = {'param_groups': [], 'param2id': {}, 'id2param': {}, 'param2shape': {}} diff --git a/colossalai/checkpoint_io/hybrid_parallel_checkpoint_io.py b/colossalai/checkpoint_io/hybrid_parallel_checkpoint_io.py index 8c502aa2a936..b006c088c239 100644 --- a/colossalai/checkpoint_io/hybrid_parallel_checkpoint_io.py +++ b/colossalai/checkpoint_io/hybrid_parallel_checkpoint_io.py @@ -1,4 +1,5 @@ import copy +import gc import logging import os from pathlib import Path @@ -22,14 +23,13 @@ gather_distributed_param, get_model_base_filenames, get_optimizer_base_filenames, - get_shard_filename, is_safetensors_available, - load_param_groups_into_optimizer, load_shard_state_dict, load_state_dict_into_model, + load_states_into_optimizer, save_param_groups, - save_state_dict, save_state_dict_shards, + sharded_optimizer_loading_epilogue, ) try: @@ -60,7 +60,7 @@ def __init__(self, dp_group: ProcessGroup, pp_group: ProcessGroup, tp_group: Pro self.dp_size = dist.get_world_size(dp_group) self.pp_size = dist.get_world_size(pp_group) self.tp_size = dist.get_world_size(tp_group) - self.zero_stage = zero_stage + self.use_zero = (zero_stage > 0) @staticmethod def _model_sharder(model: nn.Module, @@ -142,7 +142,7 @@ def _optimizer_sharder(optimizer: OptimizerWrapper, original_shape = param_info['param2shape'][id(param)] state_ = gather_distributed_optimizer_states(state_, param, original_shape, tp_group) state_ = {k: v.detach().cpu() for k, v in state_.items()} - block, block_size = state_dict_sharder.append_optim_state(param_id, state_) + block, block_size = state_dict_sharder.append_optim_state(str(param_id), state_) if block is not None: yield block, block_size @@ -302,6 +302,7 @@ def _load(name: str): strict=strict, load_sub_module=True) del state_dict + gc.collect() loaded_file.add(filename) # Load parameters. @@ -349,13 +350,13 @@ def save_sharded_optimizer(self, # Devices along the same dp_group share the same copies of states when zero is not used. # In this case only let the device with dp_rank == 0 save the model. - if self.zero_stage == 0 and self.dp_rank != 0: + if not self.use_zero and self.dp_rank != 0: return # Then collect the sharded states along dp_group(if using zero)/tp_group. # Only devices with (dp_rank == 0 and tp_rank == 0) are responsible for states saving. state_dict_shard = HypridParallelCheckpointIO._optimizer_sharder(optimizer, - use_zero=(self.zero_stage > 0), + use_zero=self.use_zero, dp_group=self.dp_group, tp_group=self.tp_group, size_per_shard=size_per_shard) @@ -450,8 +451,28 @@ def load_sharded_optimizer(self, optimizer: OptimizerWrapper, checkpoint_index_f checkpoint_index_file (str): Path to the index file of checkpointing folder. prefix (str): Not used. """ + + def _get_param_id_from_optimizer_param(param: torch.Tensor, use_zero: bool): + if use_zero: + working_param = optimizer._param_store.master_to_working_param[id(param)] + param_id = optimizer.param_info['param2id'][id(working_param)] + else: + param_id = optimizer.param_info['param2id'][id(param)] + return param_id + + # id_map is a mapping from param ids kept by current pipeline, to their corresponding parameter objects. + # When Zero is used, the mapped parameter objects should be fp32 master parameters. + # IDs should be obtained through saved param2id mapping earlier saved in optimizer.param_info. + id_map = {} + for pg in optimizer.optim.param_groups: + for param in pg['params']: + param_id = _get_param_id_from_optimizer_param(param, self.use_zero) + id_map[param_id] = param + # Read checkpoint index file. ckpt_index_file = CheckpointIndexFile.from_file(checkpoint_index_file) + ckpt_root_path = ckpt_index_file.root_path + weight_map = ckpt_index_file.weight_map # Load param_groups param_group_path = ckpt_index_file.get_param_group_filename() @@ -461,24 +482,41 @@ def load_sharded_optimizer(self, optimizer: OptimizerWrapper, checkpoint_index_f saved_groups = torch.load(param_group_path) updated_groups = [] - - # A mapping from integer id to parameter object. - # IDs should be obtained through saved param2id mapping in optimizer.param_info. - id_map = {} - for old_pg, saved_pg in zip(optimizer.optim.param_groups, saved_groups): - - # update id_map - for param in old_pg['params']: - param_id = optimizer.param_info['param2id'][id(param)] - id_map[param_id] = param - # obtain updated param group new_pg = copy.deepcopy(saved_pg) new_pg['params'] = old_pg['params'] # The parameters in the same group shouln't change. updated_groups.append(new_pg) optimizer.optim.__dict__.update({'param_groups': updated_groups}) + # Load saved states to optimizer. + # Keep a record of loaded files so that file will not be repeatedly loaded. + loaded_file = set() + for pg in optimizer.optim.param_groups: + for param in pg['params']: + param_id = str(_get_param_id_from_optimizer_param(param, self.use_zero)) + if param_id not in weight_map: + raise ValueError( + f"Parameter with ID:{param_id} is not stored in checkpoint, please check your checkpointing configuration!" + ) + filename = weight_map[param_id] + + # If this param's states has been loaded before, directly return. + if filename in loaded_file: + continue + + file_path = os.path.join(ckpt_root_path, filename) + state_dict = load_shard_state_dict(Path(file_path), use_safetensors=False) + load_states_into_optimizer(optimizer.optim, state_dict, id_map, strict=True) + del state_dict + gc.collect() + loaded_file.add(filename) + + # Shard the states along tensor parallel group. + + # Shard the states along data parallel group when Zero is used. + sharded_optimizer_loading_epilogue(optimizer.optim) + def load_unsharded_model(self, model: nn.Module, checkpoint: str, strict: bool = True): # TODO(Baizhou): support this feature after implementing complete state_dict collection raise NotImplementedError diff --git a/colossalai/checkpoint_io/utils.py b/colossalai/checkpoint_io/utils.py index bf8f31b811e9..d3b36d1803b5 100644 --- a/colossalai/checkpoint_io/utils.py +++ b/colossalai/checkpoint_io/utils.py @@ -472,15 +472,16 @@ def update_group(group, new_group): return id_map -def load_states_into_optimizer(optimizer: Optimizer, state_dict: dict, id_map: dict): +def load_states_into_optimizer(optimizer: Optimizer, state_dict: dict, id_map: dict, strict: bool = False): r"""Copies states from `state_dict` into an Optimizer object. Args: optimizer(Optimizer): An initialized Optimizer object to be loaded - state_dict(dict): a mapping from tensor index (an integer) + state_dict(dict): A mapping from tensor index (an integer) to its states to be loaded (a mapping from state name to a tensor). - id_map(dict): a mapping from tensor index (an integer) + id_map(dict): A mapping from tensor index (an integer) to its corresponding parameter (a tensor) whose states will be updated. + strict(bool, optional): If set to True, only load the parameters with its id in id_map. Defaults to False. """ def cast(param, value, key=None): @@ -509,7 +510,7 @@ def cast(param, value, key=None): if k in id_map: param = id_map[k] new_states[param] = cast(param, v) - else: + elif not strict: new_states[k] = v optimizer.state.update(new_states) From aa5fecddadc243a90727d146dc927f7f3ad48781 Mon Sep 17 00:00:00 2001 From: Baizhou Zhang Date: Tue, 29 Aug 2023 17:14:57 +0800 Subject: [PATCH 08/15] fix bug when loading --- .../hybrid_parallel_checkpoint_io.py | 15 +++-- colossalai/checkpoint_io/utils.py | 56 +++++++++---------- ...st_hybrid_parallel_plugin_checkpoint_io.py | 1 + 3 files changed, 39 insertions(+), 33 deletions(-) diff --git a/colossalai/checkpoint_io/hybrid_parallel_checkpoint_io.py b/colossalai/checkpoint_io/hybrid_parallel_checkpoint_io.py index b006c088c239..00cf83d30644 100644 --- a/colossalai/checkpoint_io/hybrid_parallel_checkpoint_io.py +++ b/colossalai/checkpoint_io/hybrid_parallel_checkpoint_io.py @@ -14,6 +14,7 @@ from torch.optim.lr_scheduler import _LRScheduler as LRScheduler from colossalai.interface import OptimizerWrapper +from colossalai.tensor.d_tensor import is_customized_distributed_tensor, is_distributed_tensor from .general_checkpoint_io import GeneralCheckpointIO from .index_file import CheckpointIndexFile @@ -473,6 +474,7 @@ def _get_param_id_from_optimizer_param(param: torch.Tensor, use_zero: bool): ckpt_index_file = CheckpointIndexFile.from_file(checkpoint_index_file) ckpt_root_path = ckpt_index_file.root_path weight_map = ckpt_index_file.weight_map + weight_map = {int(k): v for k, v in weight_map.items()} # convert saved id from str to int # Load param_groups param_group_path = ckpt_index_file.get_param_group_filename() @@ -494,11 +496,7 @@ def _get_param_id_from_optimizer_param(param: torch.Tensor, use_zero: bool): loaded_file = set() for pg in optimizer.optim.param_groups: for param in pg['params']: - param_id = str(_get_param_id_from_optimizer_param(param, self.use_zero)) - if param_id not in weight_map: - raise ValueError( - f"Parameter with ID:{param_id} is not stored in checkpoint, please check your checkpointing configuration!" - ) + param_id = _get_param_id_from_optimizer_param(param, self.use_zero) filename = weight_map[param_id] # If this param's states has been loaded before, directly return. @@ -507,14 +505,21 @@ def _get_param_id_from_optimizer_param(param: torch.Tensor, use_zero: bool): file_path = os.path.join(ckpt_root_path, filename) state_dict = load_shard_state_dict(Path(file_path), use_safetensors=False) + state_dict = {int(k): v for k, v in state_dict.items()} load_states_into_optimizer(optimizer.optim, state_dict, id_map, strict=True) del state_dict gc.collect() loaded_file.add(filename) + # for param, state in optimizer.optim.state.items(): + # param_id = _get_param_id_from_optimizer_param(param, self.use_zero) + # print(dist.get_rank(), param_id, param.shape, + # param.dtype, param.device, state['exp_avg'].shape, state['exp_avg'].dtype, state['exp_avg'].device) + # Shard the states along tensor parallel group. # Shard the states along data parallel group when Zero is used. + sharded_optimizer_loading_epilogue(optimizer.optim) def load_unsharded_model(self, model: nn.Module, checkpoint: str, strict: bool = True): diff --git a/colossalai/checkpoint_io/utils.py b/colossalai/checkpoint_io/utils.py index d3b36d1803b5..4eb21a17c322 100644 --- a/colossalai/checkpoint_io/utils.py +++ b/colossalai/checkpoint_io/utils.py @@ -96,6 +96,33 @@ def is_safetensor_checkpoint(checkpoint_file_path: str) -> bool: return False +def search_tp_partition_dim(current_shape: torch.Size, original_shape: torch.Size, tp_size: int) -> Optional[int]: + """ + Given the current shape of parameter and the shape of parameter before sharding, + return the dimension along which the parameter is sharded when using tensor parallel. + If tensor parallel is not used, return None. + + Args: + current_shape (torch.Size): The current shape of parameter after sharding. + original_shape (torch.Size): The shape of parameter before sharding. + tp_size (int): The size of tp group. + + Returns: + Optional[int]: The dimension along which parameter is partitioned. + """ + partition_dim = None + for dim, length in enumerate(original_shape): + if length > current_shape[dim]: + partition_dim = dim + break + if partition_dim is not None: + assert original_shape[partition_dim] == tp_size * current_shape[partition_dim], \ + f"The parameter isn't evenly distributed among tensor parallel group: \ + shape before sharding {original_shape}, shape after sharding {current_shape}" + + return partition_dim + + # ====================================== # Helper classes and functions for saving shard file # ====================================== @@ -179,33 +206,6 @@ def gather_distributed_param(param: torch.Tensor, keep_vars: bool = False) -> to return param_ -def search_tp_partition_dim(current_shape: torch.Size, original_shape: torch.Size, tp_size: int) -> Optional[int]: - """ - Given the current shape of parameter and the shape of parameter before sharding, - return the dimension along which the parameter is sharded when using tensor parallel. - If tensor parallel is not used, return None. - - Args: - current_shape (torch.Size): The current shape of parameter after sharding. - original_shape (torch.Size): The shape of parameter before sharding. - tp_size (int): The size of tp group. - - Returns: - Optional[int]: The dimension along which parameter is partitioned. - """ - partition_dim = None - for dim, length in enumerate(original_shape): - if length > current_shape[dim]: - partition_dim = dim - break - if partition_dim is not None: - assert original_shape[partition_dim] == tp_size * current_shape[partition_dim], \ - f"The parameter isn't evenly distributed among tensor parallel group: \ - shape before sharding {original_shape}, shape after sharding {current_shape}" - - return partition_dim - - def gather_distributed_optimizer_states(state: OrderedDict, param: torch.Tensor, original_shape: torch.Size, @@ -236,7 +236,7 @@ def gather_distributed_optimizer_states(state: OrderedDict, v = v.cuda() gather_tensor = [torch.zeros_like(v) for _ in range(tp_size)] dist.all_gather(gather_tensor, v, group=tp_group) - param_state = torch.stack(gather_tensor, dim=partition_dim) + param_state = torch.cat(gather_tensor, dim=partition_dim) state_[k] = param_state.detach().cpu() if inplace: del v diff --git a/tests/test_checkpoint_io/test_hybrid_parallel_plugin_checkpoint_io.py b/tests/test_checkpoint_io/test_hybrid_parallel_plugin_checkpoint_io.py index 051a099f6af1..9dcf71386e48 100644 --- a/tests/test_checkpoint_io/test_hybrid_parallel_plugin_checkpoint_io.py +++ b/tests/test_checkpoint_io/test_hybrid_parallel_plugin_checkpoint_io.py @@ -64,6 +64,7 @@ def _criterion(outputs, inputs): optimizer = Adam(model.parameters(), lr=1e-3) model, optimizer, criterion, _, _ = booster.boost(model, optimizer, criterion) + # TODO: If testing stepping accuracy, remember to copy new_model from model, not call model_fn() new_model = model_fn().cuda() new_optimizer = Adam(new_model.parameters(), lr=1e-3) new_model, new_optimizer, criterion, _, _ = booster.boost(new_model, new_optimizer, criterion) From 4b2ee33f4bed71c464bacabbe432bcc9a7c1b2ba Mon Sep 17 00:00:00 2001 From: Baizhou Zhang Date: Tue, 29 Aug 2023 21:10:05 +0800 Subject: [PATCH 09/15] implement optimizer sharded saving --- .../hybrid_parallel_checkpoint_io.py | 151 ++++++++++++++---- colossalai/checkpoint_io/utils.py | 40 +---- ...st_hybrid_parallel_plugin_checkpoint_io.py | 7 +- 3 files changed, 128 insertions(+), 70 deletions(-) diff --git a/colossalai/checkpoint_io/hybrid_parallel_checkpoint_io.py b/colossalai/checkpoint_io/hybrid_parallel_checkpoint_io.py index 00cf83d30644..a7811328d7be 100644 --- a/colossalai/checkpoint_io/hybrid_parallel_checkpoint_io.py +++ b/colossalai/checkpoint_io/hybrid_parallel_checkpoint_io.py @@ -14,13 +14,11 @@ from torch.optim.lr_scheduler import _LRScheduler as LRScheduler from colossalai.interface import OptimizerWrapper -from colossalai.tensor.d_tensor import is_customized_distributed_tensor, is_distributed_tensor from .general_checkpoint_io import GeneralCheckpointIO from .index_file import CheckpointIndexFile from .utils import ( StateDictSharder, - gather_distributed_optimizer_states, gather_distributed_param, get_model_base_filenames, get_optimizer_base_filenames, @@ -30,6 +28,7 @@ load_states_into_optimizer, save_param_groups, save_state_dict_shards, + search_tp_partition_dim, sharded_optimizer_loading_epilogue, ) @@ -120,30 +119,23 @@ def _optimizer_sharder(optimizer: OptimizerWrapper, if param is None: continue - state_ = copy.deepcopy(state) - - # First handle Zero shards. Working params in fp16 should first be obtained. + working_param = param if use_zero: # Zero optimizer has replaced params in param_groups with sharded master params in fp32, - # so we have to first target working params in fp16 using stored mapping. + # so we have to first target working params in fp16 using mapping stored in Zero Optimizer. working_param = optimizer._param_store.master_to_working_param[id(param)] - for k, v in state.items(): - if isinstance(v, torch.Tensor) and k != 'step': - v = v.cuda() - gather_tensor = [torch.zeros_like(v) for _ in range(dp_size)] - dist.all_gather(gather_tensor, v, group=dp_group) - param_state = torch.stack(gather_tensor).view(-1)[:working_param.numel()].reshape_as( - working_param) - state_[k] = param_state.detach().cpu() - param = working_param - - param_id = param_info['param2id'][id(param)] - - # Then hanlde TP Shards - original_shape = param_info['param2shape'][id(param)] - state_ = gather_distributed_optimizer_states(state_, param, original_shape, tp_group) - state_ = {k: v.detach().cpu() for k, v in state_.items()} - block, block_size = state_dict_sharder.append_optim_state(str(param_id), state_) + + param_id = param_info['param2id'][id(working_param)] + original_shape = param_info['param2shape'][id(working_param)] + state_ = HypridParallelCheckpointIO.gather_from_sharded_optimizer_state(state, + working_param, + original_shape=original_shape, + dp_group=dp_group, + tp_group=tp_group, + use_zero=use_zero, + inplace=False) + + block, block_size = state_dict_sharder.append_optim_state(param_id, state_) if block is not None: yield block, block_size @@ -505,20 +497,30 @@ def _get_param_id_from_optimizer_param(param: torch.Tensor, use_zero: bool): file_path = os.path.join(ckpt_root_path, filename) state_dict = load_shard_state_dict(Path(file_path), use_safetensors=False) - state_dict = {int(k): v for k, v in state_dict.items()} load_states_into_optimizer(optimizer.optim, state_dict, id_map, strict=True) del state_dict gc.collect() loaded_file.add(filename) - # for param, state in optimizer.optim.state.items(): - # param_id = _get_param_id_from_optimizer_param(param, self.use_zero) - # print(dist.get_rank(), param_id, param.shape, - # param.dtype, param.device, state['exp_avg'].shape, state['exp_avg'].dtype, state['exp_avg'].device) + # Then shard the loaded optimizer states if using tp/zero. + for param, state in optimizer.optim.state.items(): + device = param.device + working_param = param + if self.use_zero: + working_param = optimizer._param_store.master_to_working_param[id(param)] - # Shard the states along tensor parallel group. + original_shape = optimizer.param_info['param2shape'][id(working_param)] + sharded_state = self.shard_from_complete_optimizer_state(state, + current_shape=working_param.shape, + original_shape=original_shape, + device=device, + inplace=True) + optimizer.optim.state[param] = sharded_state - # Shard the states along data parallel group when Zero is used. + for param, state in optimizer.optim.state.items(): + param_id = _get_param_id_from_optimizer_param(param, self.use_zero) + print(dist.get_rank(), param_id, param.shape, param.dtype, param.device, state['exp_avg'].shape, + state['exp_avg'].dtype, state['exp_avg'].device) sharded_optimizer_loading_epilogue(optimizer.optim) @@ -544,3 +546,92 @@ def save_lr_scheduler(self, lr_scheduler: LRScheduler, checkpoint: str): """ if self.coordinator.is_master(): super().save_lr_scheduler(lr_scheduler, checkpoint) + + @staticmethod + def gather_from_sharded_optimizer_state(state: OrderedDict, param: torch.Tensor, original_shape: torch.Size, + dp_group: ProcessGroup, tp_group: ProcessGroup, use_zero: bool, + inplace: bool) -> OrderedDict: + """ + With given parameter and its optimizer states, gather the complete optimizer state for saving. + + Args: + state (OrderedDict): Optimizer states of given parameter, might be distributed among tp/dp group if using TP/Zero. + param (torch.Tensor): The given parameter. It should be working_param when using Zero. + original_shape (torch.Size): The size of parameter before sharding. + dp_group (ProcessGroup): The process group of data parallel. + tp_group (ProcessGroup): The process group of tensor parallel. + use_zero (bool): Whether Zero is used. + inplace (bool): If set to True, will update the values of argument 'state' in place. Else will make a copy of state. + + Returns: + OrderedDict: The complete optimizer state of given parameter. + """ + dp_size = dist.get_world_size(dp_group) + tp_size = dist.get_world_size(tp_group) + current_shape = param.shape + state_ = state if inplace else copy.deepcopy(state) + + for k, v in state_.items(): + if isinstance(v, torch.Tensor) and k != 'step': + + # First gather Zero shards. + if use_zero: + v = v.cuda() + gather_tensor = [torch.zeros_like(v) for _ in range(dp_size)] + dist.all_gather(gather_tensor, v, group=dp_group) + v = torch.stack(gather_tensor).view(-1)[:param.numel()].reshape_as(param) + + # Then gather TP shards. + partition_dim = search_tp_partition_dim(current_shape, original_shape, tp_size) + if partition_dim is not None: + gather_tensor = [torch.zeros_like(v) for _ in range(tp_size)] + dist.all_gather(gather_tensor, v, group=tp_group) + v = torch.cat(gather_tensor, dim=partition_dim) + + state_[k] = v.detach().clone().cpu() + del v + + return state_ + + def shard_from_complete_optimizer_state(self, state: OrderedDict, current_shape: torch.Size, + original_shape: torch.Size, device: torch.device, + inplace: bool) -> OrderedDict: + """ + With complete optimizer states of a specific parameter loaded from checkpoint, + slice out the sharded optimizer states kept by current device. + + Args: + state (OrderedDict): Complete optimizer states of a given parameter, loaded from checkpoint. + current_shape (torch.Size): The size of parameter after sharding. + original_shape (torch.Size): The size of parameter before sharding. + device (torch.device): The destination device of loaded optimizer states. + inplace (bool): If set to True, will update the values of argument 'state' in place. Else will make a copy of state. + + Returns: + OrderedDict: The sharded optimizer state of the given parameter. + """ + state_ = state if inplace else copy.deepcopy(state) + + for k, v in state_.items(): + if isinstance(v, torch.Tensor) and k != 'step': + + # Shard state along tensor parallel group. + partition_dim = search_tp_partition_dim(current_shape, original_shape, self.tp_size) + if partition_dim is not None: + slice_size = current_shape[partition_dim] + v = v.split(slice_size, dim=partition_dim)[self.tp_rank] + + # Shard state along data parallel group when using Zero. + if self.use_zero: + padding_size = (self.dp_size - v.numel() % self.dp_size) % self.dp_size + with torch.no_grad(): + v = v.flatten() + if padding_size > 0: + v = torch.nn.functional.pad(v, [0, padding_size]) + slice_size = v.numel() // self.dp_size + v = v.split(slice_size, dim=partition_dim)[self.tp_rank] + + state_[k] = v.detach().clone().to(device) + del v + + return state_ diff --git a/colossalai/checkpoint_io/utils.py b/colossalai/checkpoint_io/utils.py index 4eb21a17c322..1f8badb99bfe 100644 --- a/colossalai/checkpoint_io/utils.py +++ b/colossalai/checkpoint_io/utils.py @@ -206,43 +206,6 @@ def gather_distributed_param(param: torch.Tensor, keep_vars: bool = False) -> to return param_ -def gather_distributed_optimizer_states(state: OrderedDict, - param: torch.Tensor, - original_shape: torch.Size, - tp_group: ProcessGroup, - inplace: bool = True) -> OrderedDict: - """ - With given parameter and its optimizer states, - gather the complete optimizer state for saving - if the passed in param is distributed under tp setting. - - Args: - state (OrderedDict): Optimizer states of given parameter, might be distributed among tp group if tp is used. - param (torch.Tensor): The given parameter, might be d_tensor. - original_shape (torch.Size): The size of parameter before sharding. - tp_group (ProcessGroup): The process group of tensor parallel. - inplace (bool, optional): If set to True, will update the values of passed in state dict to the gathered states. Defaults to True. - - Returns: - OrderedDict: The complete optimizer state of given parameter. - """ - state_ = state if inplace else copy.deepcopy(state) - if is_distributed_tensor(param) or is_customized_distributed_tensor(param): - tp_size = dist.get_world_size(tp_group) - partition_dim = search_tp_partition_dim(param.shape, original_shape, tp_size) - if partition_dim is not None: - for k, v in state.items(): - if isinstance(v, torch.Tensor) and k != 'step': - v = v.cuda() - gather_tensor = [torch.zeros_like(v) for _ in range(tp_size)] - dist.all_gather(gather_tensor, v, group=tp_group) - param_state = torch.cat(gather_tensor, dim=partition_dim) - state_[k] = param_state.detach().cpu() - if inplace: - del v - return state_ - - def save_state_dict_shards(sharded_state_dict: Iterator[Tuple[OrderedDict, int]], checkpoint: str, index_file: "CheckpointIndexFile", @@ -484,6 +447,9 @@ def load_states_into_optimizer(optimizer: Optimizer, state_dict: dict, id_map: d strict(bool, optional): If set to True, only load the parameters with its id in id_map. Defaults to False. """ + # Ensure that the keys of state_dict are integers. + state_dict = {int(k): v for k, v in state_dict.items()} + def cast(param, value, key=None): r"""Make a deep copy of value, casting all tensors to device of param.""" if isinstance(value, torch.Tensor): diff --git a/tests/test_checkpoint_io/test_hybrid_parallel_plugin_checkpoint_io.py b/tests/test_checkpoint_io/test_hybrid_parallel_plugin_checkpoint_io.py index 9dcf71386e48..760ffc2544ba 100644 --- a/tests/test_checkpoint_io/test_hybrid_parallel_plugin_checkpoint_io.py +++ b/tests/test_checkpoint_io/test_hybrid_parallel_plugin_checkpoint_io.py @@ -64,7 +64,6 @@ def _criterion(outputs, inputs): optimizer = Adam(model.parameters(), lr=1e-3) model, optimizer, criterion, _, _ = booster.boost(model, optimizer, criterion) - # TODO: If testing stepping accuracy, remember to copy new_model from model, not call model_fn() new_model = model_fn().cuda() new_optimizer = Adam(new_model.parameters(), lr=1e-3) new_model, new_optimizer, criterion, _, _ = booster.boost(new_model, new_optimizer, criterion) @@ -93,11 +92,13 @@ def _criterion(outputs, inputs): optimizer.step() with shared_tempdir() as tempdir: + model_ckpt_path = f"{tempdir}/model" - # optimizer_ckpt_path = f"{tempdir}/optimizer" + optimizer_ckpt_path = f"{tempdir}/optimizer" booster.save_model(model, model_ckpt_path, shard=shard, size_per_shard=size_per_shard) - # booster.save_optimizer(optimizer, optimizer_ckpt_path, shard=shard, size_per_shard=size_per_shard) + booster.save_optimizer(optimizer, optimizer_ckpt_path, shard=shard, size_per_shard=size_per_shard) dist.barrier() + booster.load_model(new_model, model_ckpt_path) check_state_dict_equal(model.unwrap().state_dict(), new_model.unwrap().state_dict(), False) From 17f19bebb13a9b06d22086d1c1d50402ea2d2abe Mon Sep 17 00:00:00 2001 From: Baizhou Zhang Date: Tue, 29 Aug 2023 22:40:44 +0800 Subject: [PATCH 10/15] add optimizer test & arrange checkpointIO utils --- .../hybrid_parallel_checkpoint_io.py | 24 +- colossalai/checkpoint_io/utils.py | 365 ++++++++---------- ...st_hybrid_parallel_plugin_checkpoint_io.py | 43 ++- 3 files changed, 191 insertions(+), 241 deletions(-) diff --git a/colossalai/checkpoint_io/hybrid_parallel_checkpoint_io.py b/colossalai/checkpoint_io/hybrid_parallel_checkpoint_io.py index a7811328d7be..11509362a144 100644 --- a/colossalai/checkpoint_io/hybrid_parallel_checkpoint_io.py +++ b/colossalai/checkpoint_io/hybrid_parallel_checkpoint_io.py @@ -111,7 +111,6 @@ def _optimizer_sharder(optimizer: OptimizerWrapper, # An internel method that breaks state_dict of optimizer into shards within limited size. state_dict_sharder = StateDictSharder(size_per_shard) - dp_size = dist.get_world_size(dp_group) param_info = optimizer.param_info for param, state in optimizer.optim.state.items(): @@ -303,8 +302,11 @@ def _load(name: str): _load(name) # Load buffers. + non_persistent_buffers = set() + for n, m in model.named_modules(): + non_persistent_buffers |= set('.'.join((n, b)) for b in m._non_persistent_buffers_set) for name, buf in model.named_buffers(): - if buf is not None and name not in model._non_persistent_buffers_set: + if buf is not None and name not in non_persistent_buffers: _load(name) # Load extra states. @@ -416,12 +418,7 @@ def save_sharded_optimizer(self, stage_index_file = CheckpointIndexFile.from_file(os.path.join(tmp_index_file_folder, filename)) final_index_file.metadata["total_size"] += stage_index_file.metadata["total_size"] for param_id, state_filename in stage_index_file.weight_map.items(): - if param_id not in final_index_file.weight_map: - final_index_file.append_weight_map(param_id, state_filename) - else: - # If parameter is shared with other stage, delete this copy of optimizer state. - # For example: only save one copy of optimizer states between embedding weight & lm_head. - os.remove(os.path.join(checkpoint, state_filename)) + final_index_file.append_weight_map(param_id, state_filename) # Store param groups. final_index_file.append_meta_data("param_groups", param_group_file) @@ -488,7 +485,11 @@ def _get_param_id_from_optimizer_param(param: torch.Tensor, use_zero: bool): loaded_file = set() for pg in optimizer.optim.param_groups: for param in pg['params']: + if param is None: + continue param_id = _get_param_id_from_optimizer_param(param, self.use_zero) + if param_id not in weight_map: + continue filename = weight_map[param_id] # If this param's states has been loaded before, directly return. @@ -517,11 +518,6 @@ def _get_param_id_from_optimizer_param(param: torch.Tensor, use_zero: bool): inplace=True) optimizer.optim.state[param] = sharded_state - for param, state in optimizer.optim.state.items(): - param_id = _get_param_id_from_optimizer_param(param, self.use_zero) - print(dist.get_rank(), param_id, param.shape, param.dtype, param.device, state['exp_avg'].shape, - state['exp_avg'].dtype, state['exp_avg'].device) - sharded_optimizer_loading_epilogue(optimizer.optim) def load_unsharded_model(self, model: nn.Module, checkpoint: str, strict: bool = True): @@ -629,7 +625,7 @@ def shard_from_complete_optimizer_state(self, state: OrderedDict, current_shape: if padding_size > 0: v = torch.nn.functional.pad(v, [0, padding_size]) slice_size = v.numel() // self.dp_size - v = v.split(slice_size, dim=partition_dim)[self.tp_rank] + v = v.split(slice_size, dim=0)[self.dp_rank] state_[k] = v.detach().clone().to(device) del v diff --git a/colossalai/checkpoint_io/utils.py b/colossalai/checkpoint_io/utils.py index 1f8badb99bfe..0025d07dfc8e 100644 --- a/colossalai/checkpoint_io/utils.py +++ b/colossalai/checkpoint_io/utils.py @@ -166,13 +166,30 @@ def append_param(self, name: str, tensor: torch.Tensor) -> Tuple[Optional[Ordere def append_optim_state(self, param_id: int, state: OrderedDict) -> Tuple[Optional[OrderedDict], int]: + # A state might contain more than one tensors. + # e.g. each Adam state includes: 'step', 'exp_avg', 'exp_avg_sq' state_size = 0 - for k, v in state.items(): - if isinstance(v, torch.Tensor): - state_size += calculate_tensor_size(v) + isDTensor = False + for state_tensor in state.values(): + + # When state_tensor is not of Tensor class, + # e.g., a SGD optimizer with momentum set to 0 can have None as state + # The calculation of tensor size should be skipped to avoid error. + if not isinstance(state_tensor, torch.Tensor): + continue + + # If the states are stored as DTensors, mark isDTensor as true. + if is_distributed_tensor(state_tensor): + isDTensor = True + state_size += calculate_tensor_size(state_tensor) + ret_block = None ret_block_size = 0 + # directly return if state is stored as distributed tensor + if isDTensor: + return ret_block, ret_block_size + # before we return the current block and create a new block, # we need to ensure that the current block is not empty if self.current_block_size + state_size > self.max_shard_size and self.current_block_size > 0: @@ -250,28 +267,17 @@ def shard_model_checkpoint(state_dict: torch.Tensor, max_shard_size: int = 1024) Splits a model state dictionary in sub-checkpoints so that the final size of each sub-checkpoint does not exceed a given size. """ - current_block = {} - current_block_size = 0 + state_dict_sharder = StateDictSharder(max_shard_size) for key, weight in state_dict.items(): - ret_block = None - ret_block_size = 0 if not is_distributed_tensor(weight): - weight_size = calculate_tensor_size(weight) + block, block_size = state_dict_sharder.append_param(key, weight) - # If this weight is going to tip up over the maximal size, we split. - if current_block_size + weight_size > max_shard_size and current_block_size > 0: - ret_block = current_block - ret_block_size = current_block_size - current_block = {} - current_block_size = 0 - current_block[key] = weight - current_block_size += weight_size + if block != None: + yield block, block_size - if ret_block != None: - yield ret_block, ret_block_size - - yield current_block, current_block_size + # Return the last block in sharder. + yield state_dict_sharder.current_block, state_dict_sharder.current_block_size def shard_optimizer_checkpoint(state_dict: dict, max_shard_size: int = 1024) -> Iterator[Tuple[OrderedDict, int]]: @@ -282,47 +288,147 @@ def shard_optimizer_checkpoint(state_dict: dict, max_shard_size: int = 1024) -> # Only split state_dict['state']; state_dict['param_group'] is not considered in this function. states = state_dict['state'] - - current_block = {} - current_block_size = 0 + state_dict_sharder = StateDictSharder(max_shard_size) for param_id, state in states.items(): + block, block_size = state_dict_sharder.append_optim_state(param_id, state) + if block != None: + yield block, block_size - ret_block = None - ret_block_size = 0 + # Return the last block in sharder. + yield state_dict_sharder.current_block, state_dict_sharder.current_block_size - # A state might contain more than one tensors. - # e.g. each Adam state includes: 'step', 'exp_avg', 'exp_avg_sq' - state_size = 0 - isDTensor = False - for state_tensor in state.values(): - # When state_tensor is not of Tensor class, - # e.g., a SGD optimizer with momentum set to 0 can have None as state - # The calculation of tensor size should be skipped to avoid error. - if not isinstance(state_tensor, torch.Tensor): - continue +# ====================================== +# Helper functions for saving state dict +# ====================================== - # If the states are stored as DTensors, mark isDTensor as true. - if is_distributed_tensor(state_tensor): - isDTensor = True - state_size += calculate_tensor_size(state_tensor) - if not isDTensor: +def save_state_dict(state_dict: dict, checkpoint_file_path: str, use_safetensors: bool) -> None: + """ + Save state dict to checkpoint. + + Args: + state_dict (dict): state dict. + checkpoint_file_path (str): path to the checkpoint file. + use_safetensors (bool): whether to use safetensors to save the checkpoint. + """ + if use_safetensors: + assert is_safetensors_available(), "safetensors is not available." + assert checkpoint_file_path.endswith('.safetensors'), \ + "safetensors only supports .safetensors suffix for checkpoint file." + from safetensors.torch import save_file as safe_save_file + safe_save_file(state_dict, checkpoint_file_path, metadata={"format": "pt"}) + else: + torch.save(state_dict, checkpoint_file_path) + + +def save_param_groups(state_dict: dict, group_file_path: str) -> None: + """ + Save information of param_groups to given file path. + + Args: + state_dict (dict): state dict. + group_file_path (str): path to the group file. + """ + param_groups = state_dict["param_groups"] + torch.save(param_groups, group_file_path) + + +def save_dtensor(name: str, tensor: torch.Tensor, index_file: "CheckpointIndexFile", use_safetensors: bool) -> None: + """ + Save distributed tensor to checkpoint. This checkpoint will be a dictionary which contains + only one tensor. + + Args: + tensor (Tensor): tensor to be saved. + index_file (CheckpointIndexFile): path to the checkpoint file. + size_per_shard (int): size per shard in MB. + """ + root_path = index_file.root_path + output_root_path = root_path.joinpath('dtensor') + + # create directory + output_root_path.mkdir(exist_ok=True) + + # save tensor to this directory + # TODO(YuliangLiu): get index of the tensor shard + # e.g. index = + index = 0 + + # save tensor to file + ckpt_file_name = generate_dtensor_file_name(name, index, use_safetensors) + ckpt_file_path = output_root_path.joinpath(ckpt_file_name) + + # dtensor ckpt file always contains only one tensor + state_dict = {name: tensor} + save_state_dict(state_dict, str(ckpt_file_path), use_safetensors) + + # update the weight map + # * means all shards + ckpt_file_name_in_weight_map = 'dtensor/' + generate_dtensor_file_name(name, '*', use_safetensors) + index_file.append_weight_map(name, ckpt_file_name_in_weight_map) + + +def get_checkpoint_file_suffix(use_safetensors: bool) -> str: + """ + Get checkpoint file suffix. + + Args: + use_safetensors (bool): whether to use safetensors to save the checkpoint. + + Returns: + str: checkpoint file suffix. + """ + if use_safetensors: + return '.safetensors' + else: + return '.bin' + + +def generate_checkpoint_shard_file_name(index: int, + total_number: int, + use_safetensors: bool, + prefix: str = None) -> str: + """ + Generate checkpoint shard file name. + + Args: + index (int): index of the shard. + total_number (int): total number of shards. + use_safetensors (bool): whether to use safetensors to save the checkpoint. + prefix (str): prefix of the shard file name. Default: None. + + Returns: + str: checkpoint shard file name. + """ + suffix = get_checkpoint_file_suffix(use_safetensors) - if current_block_size + state_size > max_shard_size and current_block_size > 0: - ret_block = current_block - ret_block_size = current_block_size - current_block = {} - current_block_size = 0 + if prefix is None: + return f"{index:05d}-of-{total_number:05d}.{suffix}" + else: + return f"{prefix}-{index:05d}-of-{total_number:05d}.{suffix}" + + +def generate_dtensor_file_name(param_name: str, index: int, use_safetensors: bool) -> str: + """ + Generate dtensor file name. + + Args: + param_name (str): name of the distributed parameter. + index (int): index of the shard. + use_safetensors (bool): whether to use safetensors to save the checkpoint. - current_block[param_id] = state - current_block_size += state_size + Returns: + str: dtensor file name. + """ + suffix = get_checkpoint_file_suffix(use_safetensors) + return f'{param_name}.{index}.{suffix}' - if ret_block != None: - yield ret_block, ret_block_size - yield current_block, current_block_size +# ======================================== +# Helper functions for loading state dict +# ======================================== def load_shard_state_dict(checkpoint_file: Path, use_safetensors: bool = False): @@ -494,165 +600,6 @@ def sharded_optimizer_loading_epilogue(optimizer: Optimizer): optimizer.defaults.setdefault('differentiable', False) -# ====================================== -# Helper functions for saving state dict -# ====================================== - - -def save_state_dict(state_dict: dict, checkpoint_file_path: str, use_safetensors: bool) -> None: - """ - Save state dict to checkpoint. - - Args: - state_dict (dict): state dict. - checkpoint_file_path (str): path to the checkpoint file. - use_safetensors (bool): whether to use safetensors to save the checkpoint. - """ - if use_safetensors: - assert is_safetensors_available(), "safetensors is not available." - assert checkpoint_file_path.endswith('.safetensors'), \ - "safetensors only supports .safetensors suffix for checkpoint file." - from safetensors.torch import save_file as safe_save_file - safe_save_file(state_dict, checkpoint_file_path, metadata={"format": "pt"}) - else: - torch.save(state_dict, checkpoint_file_path) - - -def save_param_groups(state_dict: dict, group_file_path: str) -> None: - """ - Save information of param_groups to given file path. - - Args: - state_dict (dict): state dict. - group_file_path (str): path to the group file. - """ - param_groups = state_dict["param_groups"] - torch.save(param_groups, group_file_path) - - -def save_dtensor(name: str, tensor: torch.Tensor, index_file: "CheckpointIndexFile", use_safetensors: bool) -> None: - """ - Save distributed tensor to checkpoint. This checkpoint will be a dictionary which contains - only one tensor. - - Args: - tensor (Tensor): tensor to be saved. - index_file (CheckpointIndexFile): path to the checkpoint file. - size_per_shard (int): size per shard in MB. - """ - root_path = index_file.root_path - output_root_path = root_path.joinpath('dtensor') - - # create directory - output_root_path.mkdir(exist_ok=True) - - # save tensor to this directory - # TODO(YuliangLiu): get index of the tensor shard - # e.g. index = - index = 0 - - # save tensor to file - ckpt_file_name = generate_dtensor_file_name(name, index, use_safetensors) - ckpt_file_path = output_root_path.joinpath(ckpt_file_name) - - # dtensor ckpt file always contains only one tensor - state_dict = {name: tensor} - save_state_dict(state_dict, str(ckpt_file_path), use_safetensors) - - # update the weight map - # * means all shards - ckpt_file_name_in_weight_map = 'dtensor/' + generate_dtensor_file_name(name, '*', use_safetensors) - index_file.append_weight_map(name, ckpt_file_name_in_weight_map) - - -def get_checkpoint_file_suffix(use_safetensors: bool) -> str: - """ - Get checkpoint file suffix. - - Args: - use_safetensors (bool): whether to use safetensors to save the checkpoint. - - Returns: - str: checkpoint file suffix. - """ - if use_safetensors: - return '.safetensors' - else: - return '.bin' - - -def generate_checkpoint_shard_file_name(index: int, - total_number: int, - use_safetensors: bool, - prefix: str = None) -> str: - """ - Generate checkpoint shard file name. - - Args: - index (int): index of the shard. - total_number (int): total number of shards. - use_safetensors (bool): whether to use safetensors to save the checkpoint. - prefix (str): prefix of the shard file name. Default: None. - - Returns: - str: checkpoint shard file name. - """ - suffix = get_checkpoint_file_suffix(use_safetensors) - - if prefix is None: - return f"{index:05d}-of-{total_number:05d}.{suffix}" - else: - return f"{prefix}-{index:05d}-of-{total_number:05d}.{suffix}" - - -def generate_dtensor_file_name(param_name: str, index: int, use_safetensors: bool) -> str: - """ - Generate dtensor file name. - - Args: - param_name (str): name of the distributed parameter. - index (int): index of the shard. - use_safetensors (bool): whether to use safetensors to save the checkpoint. - - Returns: - str: dtensor file name. - """ - suffix = get_checkpoint_file_suffix(use_safetensors) - return f'{param_name}.{index}.{suffix}' - - -def save_state_dict_as_shard( - state_dict: dict, - checkpoint_path: str, - index: int, - total_number: int, - use_safetensors: bool, - prefix: str = None, -) -> None: - """ - Save state dict as shard. - - Args: - state_dict (dict): state dict. - checkpoint_path (str): path to the checkpoint file. - index (int): index of the shard. - total_number (int): total number of shards. - prefix (str): prefix of the shard file name. - use_safetensors (bool): whether to use safetensors to save the checkpoint. - """ - # generate the shard name - shard_file_name = generate_checkpoint_shard_file_name(index, total_number, use_safetensors, prefix) - shard_file_path = Path(checkpoint_path).joinpath(shard_file_name).absolute() - - # save the shard - save_state_dict(state_dict, str(shard_file_path), use_safetensors) - - -# ======================================== -# Helper functions for loading state dict -# ======================================== - - def has_index_file(checkpoint_path: str) -> Tuple[bool, Optional[Path]]: """ Check whether the checkpoint has an index file. diff --git a/tests/test_checkpoint_io/test_hybrid_parallel_plugin_checkpoint_io.py b/tests/test_checkpoint_io/test_hybrid_parallel_plugin_checkpoint_io.py index 760ffc2544ba..c96c97aa381b 100644 --- a/tests/test_checkpoint_io/test_hybrid_parallel_plugin_checkpoint_io.py +++ b/tests/test_checkpoint_io/test_hybrid_parallel_plugin_checkpoint_io.py @@ -25,14 +25,14 @@ @parameterize('model_name', ['transformers_gpt', 'transformers_bert']) @parameterize('size_per_shard', [32]) @parameterize('test_config', [{ + 'tp_size': 4, + 'pp_size': 1, + 'precision': 'fp32', +}, { 'tp_size': 2, 'pp_size': 2, 'num_microbatches': 4, 'precision': 'fp32', -}, { - 'tp_size': 4, - 'pp_size': 1, - 'precision': 'fp32', }, { 'tp_size': 2, 'pp_size': 1, @@ -60,6 +60,17 @@ def _criterion(outputs, inputs): loss = criterion(outputs) return loss + def _preprocess_data(data): + if booster.plugin.stage_manager is not None: + for k, v in data.items(): + if torch.is_tensor(v) or 'Tensor' in v.__class__.__name__: + new_shape = [1] * v.dim() + new_shape[0] = 4 + data[k] = v.to('cuda').repeat(*new_shape) + return iter([data]) + else: + return {k: v.cuda() for k, v in data.items()} + model = model_fn().cuda() optimizer = Adam(model.parameters(), lr=1e-3) model, optimizer, criterion, _, _ = booster.boost(model, optimizer, criterion) @@ -71,21 +82,14 @@ def _criterion(outputs, inputs): data = data_gen_fn() model.train() if booster.plugin.stage_manager is not None: - for k, v in data.items(): - if torch.is_tensor(v) or 'Tensor' in v.__class__.__name__: - new_shape = [1] * v.dim() - new_shape[0] = 4 - data[k] = v.to('cuda').repeat(*new_shape) - data_iter = iter([data]) - output = booster.execute_pipeline(data_iter, - model, - _criterion, - optimizer, - return_loss=True, - return_outputs=False) + booster.execute_pipeline(_preprocess_data(data), + model, + _criterion, + optimizer, + return_loss=True, + return_outputs=False) else: - data = {k: v.cuda() for k, v in data.items()} - output = model(**data) + output = model(**_preprocess_data(data)) loss = criterion(output) optimizer.backward(loss) @@ -101,6 +105,9 @@ def _criterion(outputs, inputs): booster.load_model(new_model, model_ckpt_path) check_state_dict_equal(model.unwrap().state_dict(), new_model.unwrap().state_dict(), False) + booster.load_optimizer(new_optimizer, optimizer_ckpt_path) + check_state_dict_equal(optimizer.unwrap().state_dict(), new_optimizer.unwrap().state_dict(), False) + dist.barrier() Randomizer.reset_index() clear_layout_converter() From 1cec8cdfca6c22b5f21b0a6fc0c4ace945118538 Mon Sep 17 00:00:00 2001 From: Baizhou Zhang Date: Wed, 30 Aug 2023 00:12:42 +0800 Subject: [PATCH 11/15] fix gemini sharding state_dict --- colossalai/zero/gemini/gemini_ddp.py | 6 +-- colossalai/zero/gemini/gemini_optimizer.py | 44 +++------------------- 2 files changed, 9 insertions(+), 41 deletions(-) diff --git a/colossalai/zero/gemini/gemini_ddp.py b/colossalai/zero/gemini/gemini_ddp.py index 5aff91f03153..1c19071feb67 100644 --- a/colossalai/zero/gemini/gemini_ddp.py +++ b/colossalai/zero/gemini/gemini_ddp.py @@ -679,7 +679,7 @@ def state_dict_shard(self, gathered_param_buffer.update(self._get_chunk_to_save_data(chunk, only_rank_0, dtype)) gathered_param = gathered_param_buffer.pop(fp32_param) - block, block_size = sharder.append(prefix + name, gathered_param) + block, block_size = sharder.append_param(prefix + name, gathered_param) if block is not None: yield block, block_size @@ -690,7 +690,7 @@ def state_dict_shard(self, for name, buf in self.named_buffers(): if buf is not None and name not in self._non_persistent_buffers_set: buffer = buf if keep_vars else buf.detach() - block, block_size = sharder.append(prefix + name, buffer) + block, block_size = sharder.append_param(prefix + name, buffer) if block is not None: yield block, block_size # save extra states @@ -698,7 +698,7 @@ def state_dict_shard(self, if getattr(self.__class__, "get_extra_state", torch.nn.Module.get_extra_state) is not torch.nn.Module.get_extra_state: extra_state = self.get_extra_state() - block, block_size = sharder.append(extra_state_key, extra_state) + block, block_size = sharder.append_param(extra_state_key, extra_state) if block is not None: yield block, block_size diff --git a/colossalai/zero/gemini/gemini_optimizer.py b/colossalai/zero/gemini/gemini_optimizer.py index a2085323f83e..58b0f33ab189 100644 --- a/colossalai/zero/gemini/gemini_optimizer.py +++ b/colossalai/zero/gemini/gemini_optimizer.py @@ -10,7 +10,7 @@ from torch.optim import Optimizer from colossalai.amp.naive_amp.mixed_precision_mixin import BF16MixedPrecisionMixin, FP16MixedPrecisionMixin -from colossalai.checkpoint_io.utils import calculate_tensor_size +from colossalai.checkpoint_io.utils import StateDictSharder from colossalai.logging import get_dist_logger from colossalai.nn.optimizer import ColossalaiOptimizer, CPUAdam, FusedAdam, HybridAdam from colossalai.tensor.d_tensor import is_distributed_tensor @@ -691,49 +691,17 @@ def state_shard(self, Iterator[OrderedDict]: A generator of state dict shard of optimizer states. """ - current_block = {} - current_block_size = 0 - + sharder = StateDictSharder(max_shard_size) for param_id in self.id_to_real_params.keys(): dist.barrier() state = self.collect_states(param_id=param_id, only_rank_0=only_rank_0) - ret_block = None - ret_block_size = 0 - - # A state might contain more than one tensors. - # e.g. each Adam state includes: 'step', 'exp_avg', 'exp_avg_sq' - state_size = 0 - isDTensor = False - for state_tensor in state.values(): - - # When state_tensor is not of Tensor class, - # e.g., a SGD optimizer with momentum set to 0 can have None as state - # The calculation of tensor size should be skipped to avoid error. - if not isinstance(state_tensor, torch.Tensor): - continue - - # If the states are stored as DTensors, mark isDTensor as true. - if is_distributed_tensor(state_tensor): - isDTensor = True - state_size += calculate_tensor_size(state_tensor) - - if not isDTensor: - - if current_block_size + state_size > max_shard_size and current_block_size > 0: - ret_block = current_block - ret_block_size = current_block_size - current_block = {} - current_block_size = 0 - - current_block[param_id] = state - current_block_size += state_size - - if ret_block != None: - yield ret_block, ret_block_size + block, block_size = sharder.append_optim_state(param_id, state) + if block is not None: + yield block, block_size - yield current_block, current_block_size + yield sharder.current_block, sharder.current_block_size class GeminiAdamOptimizer(ZeroOptimizer): From ad5c37b693eca34bf50b92a7a4f5af97932a1f48 Mon Sep 17 00:00:00 2001 From: Baizhou Zhang Date: Wed, 30 Aug 2023 16:33:18 +0800 Subject: [PATCH 12/15] add verbose option --- .../hybrid_parallel_checkpoint_io.py | 50 +++++++++++-------- 1 file changed, 30 insertions(+), 20 deletions(-) diff --git a/colossalai/checkpoint_io/hybrid_parallel_checkpoint_io.py b/colossalai/checkpoint_io/hybrid_parallel_checkpoint_io.py index 11509362a144..6674734b4911 100644 --- a/colossalai/checkpoint_io/hybrid_parallel_checkpoint_io.py +++ b/colossalai/checkpoint_io/hybrid_parallel_checkpoint_io.py @@ -47,9 +47,15 @@ class HypridParallelCheckpointIO(GeneralCheckpointIO): pp_group (ProcessGroup): Process group along pipeline parallel dimension. tp_group (ProcessGroup): Process group along tensor parallel dimension. zero_stage (int): The zero stage of plugin. Should be in [0, 1, 2]. + verbose (bool, optional): Whether to print logging massage when saving/loading has been succesfully executed. Defaults to True. """ - def __init__(self, dp_group: ProcessGroup, pp_group: ProcessGroup, tp_group: ProcessGroup, zero_stage: int) -> None: + def __init__(self, + dp_group: ProcessGroup, + pp_group: ProcessGroup, + tp_group: ProcessGroup, + zero_stage: int, + verbose: bool = True) -> None: super().__init__() self.dp_group = dp_group self.pp_group = pp_group @@ -61,6 +67,7 @@ def __init__(self, dp_group: ProcessGroup, pp_group: ProcessGroup, tp_group: Pro self.pp_size = dist.get_world_size(pp_group) self.tp_size = dist.get_world_size(tp_group) self.use_zero = (zero_stage > 0) + self.verbose = verbose @staticmethod def _model_sharder(model: nn.Module, @@ -195,9 +202,10 @@ def save_sharded_model(self, if control_saving: index_file.append_meta_data("total_size", total_size) index_file.write_index_file(save_index_file) - logging.info(f"The model is split into checkpoint shards. " - f"You can find where each parameters has been saved in the " - f"index located at {save_index_file}.") + if self.verbose: + logging.info(f"The model is split into checkpoint shards. " + f"You can find where each parameters has been saved in the " + f"index located at {save_index_file}.") else: # When pipeline is used, each stage produces its own shard files and index files. @@ -242,9 +250,10 @@ def save_sharded_model(self, final_index_file.write_index_file(final_index_file_path) rmtree(tmp_index_file_folder) - logging.info(f"The model is split into checkpoint shards. " - f"You can find where each parameters has been saved in the " - f"index located at {final_index_file_path}.") + if self.verbose: + logging.info(f"The model is split into checkpoint shards. " + f"You can find where each parameters has been saved in the " + f"index located at {final_index_file_path}.") def load_sharded_model(self, model: nn.Module, checkpoint_index_file: Path, strict: bool = False): """ @@ -293,8 +302,6 @@ def _load(name: str): missing_keys=missing_keys, strict=strict, load_sub_module=True) - del state_dict - gc.collect() loaded_file.add(filename) # Load parameters. @@ -315,6 +322,9 @@ def _load(name: str): torch.nn.Module.get_extra_state) is not torch.nn.Module.get_extra_state: _load(extra_state_key) + if self.verbose: + logging.info(f"The model has been successfully loaded from sharded checkpoint: {ckpt_root_path}.") + def save_sharded_optimizer(self, optimizer: OptimizerWrapper, checkpoint: str, @@ -375,9 +385,10 @@ def save_sharded_optimizer(self, # Store index file. index_file.append_meta_data("total_size", total_size) index_file.write_index_file(save_index_file) - logging.info(f"The optimizer is going to be split to checkpoint shards. " - f"You can find where each parameters has been saved in the " - f"index located at {save_index_file}.") + if self.verbose: + logging.info(f"The optimizer is going to be split to checkpoint shards. " + f"You can find where each parameters has been saved in the " + f"index located at {save_index_file}.") else: # When pipeline is used, each stage produces its own shard files and index files. @@ -426,11 +437,12 @@ def save_sharded_optimizer(self, save_param_groups(optimizer.param_info, group_file_path) final_index_file.write_index_file(final_index_file_path) - rmtree(tmp_index_file_folder) - logging.info(f"The model is split into checkpoint shards. " - f"You can find where each parameters has been saved in the " - f"index located at {final_index_file_path}.") + + if self.verbose: + logging.info(f"The model is split into checkpoint shards. " + f"You can find where each parameters has been saved in the " + f"index located at {final_index_file_path}.") def load_sharded_optimizer(self, optimizer: OptimizerWrapper, checkpoint_index_file: str, prefix: str = ""): """ @@ -499,8 +511,6 @@ def _get_param_id_from_optimizer_param(param: torch.Tensor, use_zero: bool): file_path = os.path.join(ckpt_root_path, filename) state_dict = load_shard_state_dict(Path(file_path), use_safetensors=False) load_states_into_optimizer(optimizer.optim, state_dict, id_map, strict=True) - del state_dict - gc.collect() loaded_file.add(filename) # Then shard the loaded optimizer states if using tp/zero. @@ -519,6 +529,8 @@ def _get_param_id_from_optimizer_param(param: torch.Tensor, use_zero: bool): optimizer.optim.state[param] = sharded_state sharded_optimizer_loading_epilogue(optimizer.optim) + if self.verbose: + logging.info(f"The optimizer has been successfully loaded from sharded checkpoint: {ckpt_root_path}.") def load_unsharded_model(self, model: nn.Module, checkpoint: str, strict: bool = True): # TODO(Baizhou): support this feature after implementing complete state_dict collection @@ -585,7 +597,6 @@ def gather_from_sharded_optimizer_state(state: OrderedDict, param: torch.Tensor, v = torch.cat(gather_tensor, dim=partition_dim) state_[k] = v.detach().clone().cpu() - del v return state_ @@ -628,6 +639,5 @@ def shard_from_complete_optimizer_state(self, state: OrderedDict, current_shape: v = v.split(slice_size, dim=0)[self.dp_rank] state_[k] = v.detach().clone().to(device) - del v return state_ From 78ccd225d61fea3de1284686c813c024c50280d9 Mon Sep 17 00:00:00 2001 From: Baizhou Zhang Date: Wed, 30 Aug 2023 18:57:59 +0800 Subject: [PATCH 13/15] add loading of master params --- .../booster/plugin/hybrid_parallel_plugin.py | 6 ++- .../hybrid_parallel_checkpoint_io.py | 41 ++++++++++++++++++- ...st_hybrid_parallel_plugin_checkpoint_io.py | 39 +++++++++++++++++- 3 files changed, 83 insertions(+), 3 deletions(-) diff --git a/colossalai/booster/plugin/hybrid_parallel_plugin.py b/colossalai/booster/plugin/hybrid_parallel_plugin.py index a49ae44c27a2..0fc63ebad3fb 100644 --- a/colossalai/booster/plugin/hybrid_parallel_plugin.py +++ b/colossalai/booster/plugin/hybrid_parallel_plugin.py @@ -406,6 +406,7 @@ def configure( precision=self.precision, max_norm=self.max_norm, **self.amp_config) + self.checkpoint_io.create_working_to_master_map(optimizer.working_to_master_map) else: optimizer = HybridParallelNaiveOptimizer(optimizer, model, @@ -424,6 +425,8 @@ def configure( clip_grad_norm=self.max_norm, **self.zero_config, **self.amp_config) + self.checkpoint_io.create_working_to_master_map(optimizer._param_store.working_to_master_param) + return model, optimizer, criterion, dataloader, lr_scheduler def execute_pipeline(self, @@ -500,7 +503,8 @@ def seed_worker(worker_id): **_kwargs) def get_checkpoint_io(self) -> CheckpointIO: - return HypridParallelCheckpointIO(self.dp_group, self.pp_group, self.tp_group, self.zero_stage) + self.checkpoint_io = HypridParallelCheckpointIO(self.dp_group, self.pp_group, self.tp_group, self.zero_stage) + return self.checkpoint_io def no_sync(self, model: Module) -> Iterator[None]: raise NotImplementedError diff --git a/colossalai/checkpoint_io/hybrid_parallel_checkpoint_io.py b/colossalai/checkpoint_io/hybrid_parallel_checkpoint_io.py index 6674734b4911..376b31a09aed 100644 --- a/colossalai/checkpoint_io/hybrid_parallel_checkpoint_io.py +++ b/colossalai/checkpoint_io/hybrid_parallel_checkpoint_io.py @@ -4,7 +4,7 @@ import os from pathlib import Path from shutil import rmtree -from typing import Any, Callable, Iterator, List, Optional, OrderedDict, Tuple, Union +from typing import Iterator, Mapping, Optional, OrderedDict, Tuple, Union import torch import torch.distributed as dist @@ -68,6 +68,7 @@ def __init__(self, self.tp_size = dist.get_world_size(tp_group) self.use_zero = (zero_stage > 0) self.verbose = verbose + self.working_to_master_map = None @staticmethod def _model_sharder(model: nn.Module, @@ -322,6 +323,25 @@ def _load(name: str): torch.nn.Module.get_extra_state) is not torch.nn.Module.get_extra_state: _load(extra_state_key) + # Update master params if mixed-precision training is enabled. + with torch.no_grad(): + if self.working_to_master_map is not None: + for param in model.parameters(): + if (param is None) or (id(param) not in self.working_to_master_map): + continue + master_param = self.working_to_master_map[id(param)] + if self.use_zero: + # master_param is sharded under Zero setting + padding_size = (self.dp_size - param.numel() % self.dp_size) % self.dp_size + if padding_size > 0: + padded_param = torch.nn.functional.pad(param.data.view(-1), [0, padding_size]) + else: + padded_param = param.data.view(-1) + sharded_param = padded_param.split(padded_param.numel() // self.dp_size)[self.dp_rank] + master_param.data.copy_(sharded_param.data) + else: + master_param.data.copy_(param.data) + if self.verbose: logging.info(f"The model has been successfully loaded from sharded checkpoint: {ckpt_root_path}.") @@ -555,6 +575,25 @@ def save_lr_scheduler(self, lr_scheduler: LRScheduler, checkpoint: str): if self.coordinator.is_master(): super().save_lr_scheduler(lr_scheduler, checkpoint) + def create_working_to_master_map(self, working_to_master_map: Mapping[Union[int, torch.Tensor], torch.Tensor]): + """ + Create current checkpoint IO's working_to_master_map with passed in mapping. + This mapping can only be created when mixied precision is used. + The created mapping should be a mapping from address of working parameters to master parameter objects. + + Args: + working_to_master_map (Mapping[Union[int, torch.Tensor], torch.Tensor]): A mapping from working parameters objects/addresses to master parameter objects. + """ + self.working_to_master_map = dict() + for k, v in working_to_master_map.items(): + if isinstance(k, torch.Tensor): + self.working_to_master_map[id(k)] = v + elif isinstance(k, int): + self.working_to_master_map[k] = v + else: + raise ValueError( + f"The passed in mapping should have keys of type 'int' or 'torch.Tensor', but got {type(k)}!") + @staticmethod def gather_from_sharded_optimizer_state(state: OrderedDict, param: torch.Tensor, original_shape: torch.Size, dp_group: ProcessGroup, tp_group: ProcessGroup, use_zero: bool, diff --git a/tests/test_checkpoint_io/test_hybrid_parallel_plugin_checkpoint_io.py b/tests/test_checkpoint_io/test_hybrid_parallel_plugin_checkpoint_io.py index c96c97aa381b..cdf74c2e9495 100644 --- a/tests/test_checkpoint_io/test_hybrid_parallel_plugin_checkpoint_io.py +++ b/tests/test_checkpoint_io/test_hybrid_parallel_plugin_checkpoint_io.py @@ -10,6 +10,7 @@ from colossalai.shardformer.layer.utils import Randomizer from colossalai.tensor.d_tensor.api import clear_layout_converter from colossalai.testing import ( + assert_close_loose, check_state_dict_equal, clear_cache_before_run, parameterize, @@ -22,7 +23,7 @@ # TODO (Baizhou): Add test cases for shard=False @clear_cache_before_run() @parameterize('shard', [True]) -@parameterize('model_name', ['transformers_gpt', 'transformers_bert']) +@parameterize('model_name', ['transformers_gpt']) @parameterize('size_per_shard', [32]) @parameterize('test_config', [{ 'tp_size': 4, @@ -109,6 +110,42 @@ def _preprocess_data(data): check_state_dict_equal(optimizer.unwrap().state_dict(), new_optimizer.unwrap().state_dict(), False) dist.barrier() + # Check whether the loaded model & optimizer works smoothly. + model.train() + new_model.train() + if booster.plugin.stage_manager is not None: + booster.execute_pipeline(_preprocess_data(data), + model, + _criterion, + optimizer, + return_loss=True, + return_outputs=False) + booster.execute_pipeline(_preprocess_data(data), + new_model, + _criterion, + new_optimizer, + return_loss=True, + return_outputs=False) + else: + old_model_loss = criterion(model(**_preprocess_data(data))) + optimizer.backward(old_model_loss) + new_model_loss = criterion(new_model(**_preprocess_data(data))) + new_optimizer.backward(new_model_loss) + + optimizer.step() + new_optimizer.step() + + # Check updated weights. + stage_manager = booster.plugin.stage_manager + + if stage_manager is None or stage_manager.is_first_stage(): + assert_close_loose(model.unwrap().wte.weight.data, new_model.unwrap().wte.weight.data, atol=5e-3, rtol=5e-3) + assert_close_loose(model.unwrap().h[0].mlp.c_fc.weight.data, + new_model.unwrap().h[0].mlp.c_fc.weight.data, + atol=5e-3, + rtol=5e-3) + + dist.barrier() Randomizer.reset_index() clear_layout_converter() From 52491b54d5f8eb151348c8fca6b130b170629c91 Mon Sep 17 00:00:00 2001 From: Baizhou Zhang Date: Wed, 30 Aug 2023 23:30:53 +0800 Subject: [PATCH 14/15] fix typehint --- colossalai/checkpoint_io/hybrid_parallel_checkpoint_io.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/colossalai/checkpoint_io/hybrid_parallel_checkpoint_io.py b/colossalai/checkpoint_io/hybrid_parallel_checkpoint_io.py index 376b31a09aed..0d8ca18b9c7d 100644 --- a/colossalai/checkpoint_io/hybrid_parallel_checkpoint_io.py +++ b/colossalai/checkpoint_io/hybrid_parallel_checkpoint_io.py @@ -4,7 +4,7 @@ import os from pathlib import Path from shutil import rmtree -from typing import Iterator, Mapping, Optional, OrderedDict, Tuple, Union +from typing import Dict, Iterator, Optional, OrderedDict, Tuple, Union import torch import torch.distributed as dist @@ -575,14 +575,14 @@ def save_lr_scheduler(self, lr_scheduler: LRScheduler, checkpoint: str): if self.coordinator.is_master(): super().save_lr_scheduler(lr_scheduler, checkpoint) - def create_working_to_master_map(self, working_to_master_map: Mapping[Union[int, torch.Tensor], torch.Tensor]): + def create_working_to_master_map(self, working_to_master_map: Dict[Union[int, torch.Tensor], torch.Tensor]): """ Create current checkpoint IO's working_to_master_map with passed in mapping. This mapping can only be created when mixied precision is used. The created mapping should be a mapping from address of working parameters to master parameter objects. Args: - working_to_master_map (Mapping[Union[int, torch.Tensor], torch.Tensor]): A mapping from working parameters objects/addresses to master parameter objects. + working_to_master_map (Dict[Union[int, torch.Tensor], torch.Tensor]): A mapping from working parameters objects/addresses to master parameter objects. """ self.working_to_master_map = dict() for k, v in working_to_master_map.items(): From ae93561d8d929a42b60ce845d08227715140c27e Mon Sep 17 00:00:00 2001 From: Baizhou Zhang Date: Thu, 31 Aug 2023 12:20:32 +0800 Subject: [PATCH 15/15] fix master/working mapping in fp16 amp --- .../booster/plugin/hybrid_parallel_plugin.py | 6 +- .../hybrid_parallel_checkpoint_io.py | 65 ++++++++++++------- ...st_hybrid_parallel_plugin_checkpoint_io.py | 11 ++-- 3 files changed, 50 insertions(+), 32 deletions(-) diff --git a/colossalai/booster/plugin/hybrid_parallel_plugin.py b/colossalai/booster/plugin/hybrid_parallel_plugin.py index 0fc63ebad3fb..277843b66568 100644 --- a/colossalai/booster/plugin/hybrid_parallel_plugin.py +++ b/colossalai/booster/plugin/hybrid_parallel_plugin.py @@ -406,7 +406,8 @@ def configure( precision=self.precision, max_norm=self.max_norm, **self.amp_config) - self.checkpoint_io.create_working_to_master_map(optimizer.working_to_master_map) + self.checkpoint_io.link_master_and_working_param(optimizer.working_to_master_map, + optimizer.master_to_working_map) else: optimizer = HybridParallelNaiveOptimizer(optimizer, model, @@ -425,7 +426,8 @@ def configure( clip_grad_norm=self.max_norm, **self.zero_config, **self.amp_config) - self.checkpoint_io.create_working_to_master_map(optimizer._param_store.working_to_master_param) + self.checkpoint_io.link_master_and_working_param(optimizer._param_store.working_to_master_param, + optimizer._param_store.master_to_working_param) return model, optimizer, criterion, dataloader, lr_scheduler diff --git a/colossalai/checkpoint_io/hybrid_parallel_checkpoint_io.py b/colossalai/checkpoint_io/hybrid_parallel_checkpoint_io.py index 0d8ca18b9c7d..c128858b1efe 100644 --- a/colossalai/checkpoint_io/hybrid_parallel_checkpoint_io.py +++ b/colossalai/checkpoint_io/hybrid_parallel_checkpoint_io.py @@ -69,6 +69,7 @@ def __init__(self, self.use_zero = (zero_stage > 0) self.verbose = verbose self.working_to_master_map = None + self.master_to_working_map = None @staticmethod def _model_sharder(model: nn.Module, @@ -114,6 +115,7 @@ def _optimizer_sharder(optimizer: OptimizerWrapper, use_zero: bool, dp_group: ProcessGroup, tp_group: ProcessGroup, + master_to_working_map: Optional[Dict[int, torch.Tensor]] = None, size_per_shard: int = 1024): # An internel method that breaks state_dict of optimizer into shards within limited size. @@ -126,11 +128,10 @@ def _optimizer_sharder(optimizer: OptimizerWrapper, if param is None: continue - working_param = param - if use_zero: - # Zero optimizer has replaced params in param_groups with sharded master params in fp32, - # so we have to first target working params in fp16 using mapping stored in Zero Optimizer. - working_param = optimizer._param_store.master_to_working_param[id(param)] + if master_to_working_map is not None: + working_param = master_to_working_map[id(param)] + else: + working_param = param param_id = param_info['param2id'][id(working_param)] original_shape = param_info['param2shape'][id(working_param)] @@ -380,11 +381,13 @@ def save_sharded_optimizer(self, # Then collect the sharded states along dp_group(if using zero)/tp_group. # Only devices with (dp_rank == 0 and tp_rank == 0) are responsible for states saving. - state_dict_shard = HypridParallelCheckpointIO._optimizer_sharder(optimizer, - use_zero=self.use_zero, - dp_group=self.dp_group, - tp_group=self.tp_group, - size_per_shard=size_per_shard) + state_dict_shard = HypridParallelCheckpointIO._optimizer_sharder( + optimizer, + use_zero=self.use_zero, + dp_group=self.dp_group, + tp_group=self.tp_group, + master_to_working_map=self.master_to_working_map, + size_per_shard=size_per_shard) states_name, save_index_file, param_group_file = get_optimizer_base_filenames(prefix) index_file = CheckpointIndexFile(checkpoint) control_saving = (self.dp_rank == 0 and self.tp_rank == 0) @@ -474,13 +477,13 @@ def load_sharded_optimizer(self, optimizer: OptimizerWrapper, checkpoint_index_f prefix (str): Not used. """ - def _get_param_id_from_optimizer_param(param: torch.Tensor, use_zero: bool): - if use_zero: - working_param = optimizer._param_store.master_to_working_param[id(param)] - param_id = optimizer.param_info['param2id'][id(working_param)] + def _get_param_id_from_optimizer_param(param: torch.Tensor, + master_to_working_map: Optional[Dict[int, torch.Tensor]] = None): + if master_to_working_map is not None: + working_param = master_to_working_map[id(param)] else: - param_id = optimizer.param_info['param2id'][id(param)] - return param_id + working_param = param + return optimizer.param_info['param2id'][id(working_param)] # id_map is a mapping from param ids kept by current pipeline, to their corresponding parameter objects. # When Zero is used, the mapped parameter objects should be fp32 master parameters. @@ -488,7 +491,7 @@ def _get_param_id_from_optimizer_param(param: torch.Tensor, use_zero: bool): id_map = {} for pg in optimizer.optim.param_groups: for param in pg['params']: - param_id = _get_param_id_from_optimizer_param(param, self.use_zero) + param_id = _get_param_id_from_optimizer_param(param, self.master_to_working_map) id_map[param_id] = param # Read checkpoint index file. @@ -519,7 +522,7 @@ def _get_param_id_from_optimizer_param(param: torch.Tensor, use_zero: bool): for param in pg['params']: if param is None: continue - param_id = _get_param_id_from_optimizer_param(param, self.use_zero) + param_id = _get_param_id_from_optimizer_param(param, self.master_to_working_map) if param_id not in weight_map: continue filename = weight_map[param_id] @@ -536,10 +539,10 @@ def _get_param_id_from_optimizer_param(param: torch.Tensor, use_zero: bool): # Then shard the loaded optimizer states if using tp/zero. for param, state in optimizer.optim.state.items(): device = param.device - working_param = param - if self.use_zero: - working_param = optimizer._param_store.master_to_working_param[id(param)] - + if self.master_to_working_map is not None: + working_param = self.master_to_working_map[id(param)] + else: + working_param = param original_shape = optimizer.param_info['param2shape'][id(working_param)] sharded_state = self.shard_from_complete_optimizer_state(state, current_shape=working_param.shape, @@ -575,14 +578,16 @@ def save_lr_scheduler(self, lr_scheduler: LRScheduler, checkpoint: str): if self.coordinator.is_master(): super().save_lr_scheduler(lr_scheduler, checkpoint) - def create_working_to_master_map(self, working_to_master_map: Dict[Union[int, torch.Tensor], torch.Tensor]): + def link_master_and_working_param(self, working_to_master_map: Dict[Union[int, torch.Tensor], torch.Tensor], + master_to_working_map: Dict[Union[int, torch.Tensor], torch.Tensor]): """ - Create current checkpoint IO's working_to_master_map with passed in mapping. + Create mappings between working params (for forward/backward) and master params (for optimizer update) with passed in mappings. This mapping can only be created when mixied precision is used. - The created mapping should be a mapping from address of working parameters to master parameter objects. + The created mappings should be mappings from integer parameter addresses to parameter objects. Args: working_to_master_map (Dict[Union[int, torch.Tensor], torch.Tensor]): A mapping from working parameters objects/addresses to master parameter objects. + master_to_working_map (Dict[Union[int, torch.Tensor], torch.Tensor]): A mapping from master parameters objects/addresses to working parameter objects. """ self.working_to_master_map = dict() for k, v in working_to_master_map.items(): @@ -594,6 +599,16 @@ def create_working_to_master_map(self, working_to_master_map: Dict[Union[int, to raise ValueError( f"The passed in mapping should have keys of type 'int' or 'torch.Tensor', but got {type(k)}!") + self.master_to_working_map = dict() + for k, v in master_to_working_map.items(): + if isinstance(k, torch.Tensor): + self.master_to_working_map[id(k)] = v + elif isinstance(k, int): + self.master_to_working_map[k] = v + else: + raise ValueError( + f"The passed in mapping should have keys of type 'int' or 'torch.Tensor', but got {type(k)}!") + @staticmethod def gather_from_sharded_optimizer_state(state: OrderedDict, param: torch.Tensor, original_shape: torch.Size, dp_group: ProcessGroup, tp_group: ProcessGroup, use_zero: bool, diff --git a/tests/test_checkpoint_io/test_hybrid_parallel_plugin_checkpoint_io.py b/tests/test_checkpoint_io/test_hybrid_parallel_plugin_checkpoint_io.py index cdf74c2e9495..e43908e0c651 100644 --- a/tests/test_checkpoint_io/test_hybrid_parallel_plugin_checkpoint_io.py +++ b/tests/test_checkpoint_io/test_hybrid_parallel_plugin_checkpoint_io.py @@ -33,7 +33,8 @@ 'tp_size': 2, 'pp_size': 2, 'num_microbatches': 4, - 'precision': 'fp32', + 'precision': 'fp16', + 'initial_scale': 1 }, { 'tp_size': 2, 'pp_size': 1, @@ -76,10 +77,6 @@ def _preprocess_data(data): optimizer = Adam(model.parameters(), lr=1e-3) model, optimizer, criterion, _, _ = booster.boost(model, optimizer, criterion) - new_model = model_fn().cuda() - new_optimizer = Adam(new_model.parameters(), lr=1e-3) - new_model, new_optimizer, criterion, _, _ = booster.boost(new_model, new_optimizer, criterion) - data = data_gen_fn() model.train() if booster.plugin.stage_manager is not None: @@ -104,6 +101,10 @@ def _preprocess_data(data): booster.save_optimizer(optimizer, optimizer_ckpt_path, shard=shard, size_per_shard=size_per_shard) dist.barrier() + new_model = model_fn().cuda() + new_optimizer = Adam(new_model.parameters(), lr=1e-3) + new_model, new_optimizer, criterion, _, _ = booster.boost(new_model, new_optimizer, criterion) + booster.load_model(new_model, model_ckpt_path) check_state_dict_equal(model.unwrap().state_dict(), new_model.unwrap().state_dict(), False) booster.load_optimizer(new_optimizer, optimizer_ckpt_path)