diff --git a/colossalai/booster/plugin/hybrid_parallel_plugin.py b/colossalai/booster/plugin/hybrid_parallel_plugin.py index c49b3e1823cd..277843b66568 100644 --- a/colossalai/booster/plugin/hybrid_parallel_plugin.py +++ b/colossalai/booster/plugin/hybrid_parallel_plugin.py @@ -1,7 +1,7 @@ import random from contextlib import nullcontext from functools import partial -from typing import Any, Callable, Iterator, List, Optional, Tuple, Union +from typing import Any, Callable, Iterator, List, Optional, OrderedDict, Tuple, Union import numpy as np import torch @@ -110,6 +110,36 @@ def unwrap(self): return module +def get_param_info(optim: Optimizer): + # Get a backup of necessary information of parameters for future use, which includes: + # 1. A complete param_group, with params in the form of param_id + # 2. A mapping from param address (obtained using id(param)) to integer param_id + # 3. A mapping from integer param_id to param address. + # 4. A mapping from param_address (obtained using id(param)) to the original shape of parameter before sharding. + # When Zero is used, the params here are fp16/bf16 model params rather than fp32 master params in optimizer. + + if optim is None: + return {} + param_info = {'param_groups': [], 'param2id': {}, 'id2param': {}, 'param2shape': {}} + start_index = 0 + for group in optim.param_groups: + + packed_group = {k: v for k, v in group.items() if k != 'params'} + packed_group['params'] = [] + + for param_id, param in enumerate(group['params'], start_index): + original_shape = param.shape if isinstance(param, torch.Tensor) else None + packed_group['params'].append(param_id) + param_info['param2id'][id(param)] = param_id + param_info['id2param'][param_id] = id(param) + param_info['param2shape'][id(param)] = original_shape + + param_info['param_groups'].append(packed_group) + start_index += len(group['params']) + + return param_info + + def init_pipeline_optimizer(optim: Optimizer, model: Module): params = set(model.parameters()) new_param_groups = [] @@ -121,7 +151,8 @@ def init_pipeline_optimizer(optim: Optimizer, model: Module): class HybridParallelNaiveOptimizer(OptimizerWrapper): - def __init__(self, optim: Optimizer, model: Module, use_pipeline: bool): + def __init__(self, optim: Optimizer, model: Module, use_pipeline: bool, param_info: OrderedDict): + self.param_info = param_info if use_pipeline: init_pipeline_optimizer(optim, model) super().__init__(optim) @@ -133,6 +164,7 @@ def __init__(self, optim: Optimizer, model: Module, use_pipeline: bool, + param_info: OrderedDict, precision: str = 'fp16', initial_scale: float = 2**16, min_scale: float = 1, @@ -142,6 +174,7 @@ def __init__(self, hysteresis: int = 2, max_scale: float = 2**32, max_norm: float = 0): + self.param_info = param_info if use_pipeline: init_pipeline_optimizer(optim, model) super().__init__(optim, precision, initial_scale, min_scale, growth_factor, backoff_factor, growth_interval, @@ -155,6 +188,7 @@ def __init__( optimizer: Optimizer, model: Module, use_pipeline: bool, + param_info: OrderedDict, initial_scale: int = 2**16, # grad scaler config min_scale: int = 1, growth_factor: float = 2., @@ -172,6 +206,7 @@ def __init__( dp_process_group: Optional[ProcessGroup] = None, # the dp pg for comm tp_process_group: Optional[ProcessGroup] = None, # if using tp forced_dtype: Optional[torch.dtype] = None): + self.param_info = param_info if use_pipeline: init_pipeline_optimizer(optimizer, model) super().__init__(optimizer, initial_scale, min_scale, growth_factor, backoff_factor, growth_interval, @@ -356,6 +391,7 @@ def configure( dataloader: Optional[DataLoader] = None, lr_scheduler: Optional[LRScheduler] = None, ) -> Tuple[Module, OptimizerWrapper, Callable, DataLoader, LRScheduler]: + param_info = get_param_info(optimizer) if not isinstance(model, ModelWrapper): use_ddp = self.dp_size > 1 and self.pp_size == 1 and self.zero_stage == 0 model = HybridParallelModule(model, self.precision, self.shard_config, self.dp_group, use_ddp, @@ -366,25 +402,33 @@ def configure( optimizer = HybridParallelAMPOptimizer(optimizer, model, use_pipeline=self.enable_pipeline_parallelism, + param_info=param_info, precision=self.precision, max_norm=self.max_norm, **self.amp_config) + self.checkpoint_io.link_master_and_working_param(optimizer.working_to_master_map, + optimizer.master_to_working_map) else: optimizer = HybridParallelNaiveOptimizer(optimizer, model, - use_pipeline=self.enable_pipeline_parallelism) + use_pipeline=self.enable_pipeline_parallelism, + param_info=param_info) else: assert self.dp_size > 1, "Please use Zero when data parallel size is greater than 1." assert self.precision != 'fp32', "Please set precision to 'fp16' or 'bf16' when using ZeRO." optimizer = HybridParallelZeroOptimizer(optimizer, model, use_pipeline=self.enable_pipeline_parallelism, + param_info=param_info, dp_process_group=self.dp_group, tp_process_group=self.tp_group, verbose=True, clip_grad_norm=self.max_norm, **self.zero_config, **self.amp_config) + self.checkpoint_io.link_master_and_working_param(optimizer._param_store.working_to_master_param, + optimizer._param_store.master_to_working_param) + return model, optimizer, criterion, dataloader, lr_scheduler def execute_pipeline(self, @@ -461,7 +505,8 @@ def seed_worker(worker_id): **_kwargs) def get_checkpoint_io(self) -> CheckpointIO: - return HypridParallelCheckpointIO(self.dp_group, self.pp_group, self.tp_group) + self.checkpoint_io = HypridParallelCheckpointIO(self.dp_group, self.pp_group, self.tp_group, self.zero_stage) + return self.checkpoint_io def no_sync(self, model: Module) -> Iterator[None]: raise NotImplementedError diff --git a/colossalai/checkpoint_io/hybrid_parallel_checkpoint_io.py b/colossalai/checkpoint_io/hybrid_parallel_checkpoint_io.py index 56a89bff75ca..c128858b1efe 100644 --- a/colossalai/checkpoint_io/hybrid_parallel_checkpoint_io.py +++ b/colossalai/checkpoint_io/hybrid_parallel_checkpoint_io.py @@ -4,7 +4,7 @@ import os from pathlib import Path from shutil import rmtree -from typing import Any, Callable, Iterator, List, Optional, OrderedDict, Tuple, Union +from typing import Dict, Iterator, Optional, OrderedDict, Tuple, Union import torch import torch.distributed as dist @@ -13,29 +13,23 @@ from torch.optim import Optimizer from torch.optim.lr_scheduler import _LRScheduler as LRScheduler -from colossalai.cluster import ProcessGroupMesh -from colossalai.tensor.d_tensor import ( - is_customized_distributed_tensor, - is_distributed_tensor, - to_global, - to_global_for_customized_distributed_tensor, -) +from colossalai.interface import OptimizerWrapper from .general_checkpoint_io import GeneralCheckpointIO from .index_file import CheckpointIndexFile from .utils import ( StateDictSharder, - calculate_tensor_size, gather_distributed_param, get_model_base_filenames, get_optimizer_base_filenames, - get_shard_filename, is_safetensors_available, load_shard_state_dict, load_state_dict_into_model, + load_states_into_optimizer, save_param_groups, - save_state_dict, save_state_dict_shards, + search_tp_partition_dim, + sharded_optimizer_loading_epilogue, ) try: @@ -52,9 +46,16 @@ class HypridParallelCheckpointIO(GeneralCheckpointIO): dp_group (ProcessGroup): Process group along data parallel dimension. pp_group (ProcessGroup): Process group along pipeline parallel dimension. tp_group (ProcessGroup): Process group along tensor parallel dimension. + zero_stage (int): The zero stage of plugin. Should be in [0, 1, 2]. + verbose (bool, optional): Whether to print logging massage when saving/loading has been succesfully executed. Defaults to True. """ - def __init__(self, dp_group: ProcessGroup, pp_group: ProcessGroup, tp_group: ProcessGroup) -> None: + def __init__(self, + dp_group: ProcessGroup, + pp_group: ProcessGroup, + tp_group: ProcessGroup, + zero_stage: int, + verbose: bool = True) -> None: super().__init__() self.dp_group = dp_group self.pp_group = pp_group @@ -65,6 +66,10 @@ def __init__(self, dp_group: ProcessGroup, pp_group: ProcessGroup, tp_group: Pro self.dp_size = dist.get_world_size(dp_group) self.pp_size = dist.get_world_size(pp_group) self.tp_size = dist.get_world_size(tp_group) + self.use_zero = (zero_stage > 0) + self.verbose = verbose + self.working_to_master_map = None + self.master_to_working_map = None @staticmethod def _model_sharder(model: nn.Module, @@ -81,7 +86,7 @@ def _model_sharder(model: nn.Module, continue # Gather tensor pieces when using tensor parallel. param_ = gather_distributed_param(param, keep_vars=False) - block, block_size = state_dict_sharder.append(prefix + name, param_) + block, block_size = state_dict_sharder.append_param(prefix + name, param_) if block is not None: yield block, block_size @@ -89,7 +94,7 @@ def _model_sharder(model: nn.Module, for name, buf in model.named_buffers(): if buf is not None and name not in model._non_persistent_buffers_set: buffer = buf if keep_vars else buf.detach() - block, block_size = state_dict_sharder.append(prefix + name, buffer) + block, block_size = state_dict_sharder.append_param(prefix + name, buffer) if block is not None: yield block, block_size @@ -98,7 +103,7 @@ def _model_sharder(model: nn.Module, if getattr(model.__class__, "get_extra_state", torch.nn.Module.get_extra_state) is not torch.nn.Module.get_extra_state: extra_state = model.get_extra_state() - block, block_size = state_dict_sharder.append(extra_state_key, extra_state) + block, block_size = state_dict_sharder.append_param(extra_state_key, extra_state) if block is not None: yield block, block_size @@ -106,10 +111,44 @@ def _model_sharder(model: nn.Module, yield state_dict_sharder.current_block, state_dict_sharder.current_block_size @staticmethod - def _optimizer_sharder(optimizer: Optimizer, size_per_shard: int = 1024): + def _optimizer_sharder(optimizer: OptimizerWrapper, + use_zero: bool, + dp_group: ProcessGroup, + tp_group: ProcessGroup, + master_to_working_map: Optional[Dict[int, torch.Tensor]] = None, + size_per_shard: int = 1024): + # An internel method that breaks state_dict of optimizer into shards within limited size. - # TODO (Baizhou): Implement sharding feature of optimizer. - pass + + state_dict_sharder = StateDictSharder(size_per_shard) + param_info = optimizer.param_info + + for param, state in optimizer.optim.state.items(): + + if param is None: + continue + + if master_to_working_map is not None: + working_param = master_to_working_map[id(param)] + else: + working_param = param + + param_id = param_info['param2id'][id(working_param)] + original_shape = param_info['param2shape'][id(working_param)] + state_ = HypridParallelCheckpointIO.gather_from_sharded_optimizer_state(state, + working_param, + original_shape=original_shape, + dp_group=dp_group, + tp_group=tp_group, + use_zero=use_zero, + inplace=False) + + block, block_size = state_dict_sharder.append_optim_state(param_id, state_) + if block is not None: + yield block, block_size + + # Return the last block in sharder. + yield state_dict_sharder.current_block, state_dict_sharder.current_block_size def save_sharded_model(self, model: nn.Module, @@ -148,7 +187,7 @@ def save_sharded_model(self, return # Then collect the sharded parameters & buffers along tp_group. - # Only devices with tp_size == 0 are responsible for model saving. + # Only devices with tp_rank == 0 are responsible for model saving. state_dict_shard = HypridParallelCheckpointIO._model_sharder(model, size_per_shard=size_per_shard) weights_name, save_index_file = get_model_base_filenames(prefix, use_safetensors) index_file = CheckpointIndexFile(checkpoint) @@ -165,9 +204,10 @@ def save_sharded_model(self, if control_saving: index_file.append_meta_data("total_size", total_size) index_file.write_index_file(save_index_file) - logging.info(f"The model is split into checkpoint shards. " - f"You can find where each parameters has been saved in the " - f"index located at {save_index_file}.") + if self.verbose: + logging.info(f"The model is split into checkpoint shards. " + f"You can find where each parameters has been saved in the " + f"index located at {save_index_file}.") else: # When pipeline is used, each stage produces its own shard files and index files. @@ -212,9 +252,10 @@ def save_sharded_model(self, final_index_file.write_index_file(final_index_file_path) rmtree(tmp_index_file_folder) - logging.info(f"The model is split into checkpoint shards. " - f"You can find where each parameters has been saved in the " - f"index located at {final_index_file_path}.") + if self.verbose: + logging.info(f"The model is split into checkpoint shards. " + f"You can find where each parameters has been saved in the " + f"index located at {final_index_file_path}.") def load_sharded_model(self, model: nn.Module, checkpoint_index_file: Path, strict: bool = False): """ @@ -222,7 +263,7 @@ def load_sharded_model(self, model: nn.Module, checkpoint_index_file: Path, stri Args: model (nn.Module): The model to be loaded. - index_file_path (str): Path to the index file of checkpointing folder. + checkpoint_index_file (str): Path to the index file of checkpointing folder. strict (bool, optional): For name matching during loading state_dict. Defaults to False. This argument should be manually set to False since params on same device might be stored in different files. """ @@ -263,7 +304,6 @@ def _load(name: str): missing_keys=missing_keys, strict=strict, load_sub_module=True) - del state_dict loaded_file.add(filename) # Load parameters. @@ -271,8 +311,11 @@ def _load(name: str): _load(name) # Load buffers. + non_persistent_buffers = set() + for n, m in model.named_modules(): + non_persistent_buffers |= set('.'.join((n, b)) for b in m._non_persistent_buffers_set) for name, buf in model.named_buffers(): - if buf is not None and name not in model._non_persistent_buffers_set: + if buf is not None and name not in non_persistent_buffers: _load(name) # Load extra states. @@ -281,16 +324,236 @@ def _load(name: str): torch.nn.Module.get_extra_state) is not torch.nn.Module.get_extra_state: _load(extra_state_key) + # Update master params if mixed-precision training is enabled. + with torch.no_grad(): + if self.working_to_master_map is not None: + for param in model.parameters(): + if (param is None) or (id(param) not in self.working_to_master_map): + continue + master_param = self.working_to_master_map[id(param)] + if self.use_zero: + # master_param is sharded under Zero setting + padding_size = (self.dp_size - param.numel() % self.dp_size) % self.dp_size + if padding_size > 0: + padded_param = torch.nn.functional.pad(param.data.view(-1), [0, padding_size]) + else: + padded_param = param.data.view(-1) + sharded_param = padded_param.split(padded_param.numel() // self.dp_size)[self.dp_rank] + master_param.data.copy_(sharded_param.data) + else: + master_param.data.copy_(param.data) + + if self.verbose: + logging.info(f"The model has been successfully loaded from sharded checkpoint: {ckpt_root_path}.") + def save_sharded_optimizer(self, - optimizer: Optimizer, + optimizer: OptimizerWrapper, checkpoint: str, gather_dtensor: bool = True, prefix: Optional[str] = None, size_per_shard: int = 1024): - pass + """ + Save sharded optimizer checkpoint under the given checkpointing path. + The following files will be created under the path: + - An index file (pytorch_optim.bin.index.json) containing a map between optimizer states and file names + - A group file (pytorch_optim_group.bin) recording information of param_groups + - Multiple files that store state tensors of optimizers. + If pipeline parallelism is used, the filenames are in the form of "pytorch_optim.-stage-000XX-shard-000XX.bin". + If pipeline parallelism is not used, "pytorch_optim.-000XX.bin" + + Args: + optimizer (OptimizerWrapper): Optimizer to save sharded state_dict + checkpoint (str): Path to save optimizer state_dict + gather_dtensor (bool): Whether to gather_dtensor, not used + prefix (str): Perfix of file to save + size_per_shard (int): Max file size of each file shard that store state tensors + """ + if os.path.isfile(checkpoint): + logging.error(f"Provided path ({checkpoint}) should be a directory, not a file") + return + + Path(checkpoint).mkdir(parents=True, exist_ok=True) + + # Devices along the same dp_group share the same copies of states when zero is not used. + # In this case only let the device with dp_rank == 0 save the model. + if not self.use_zero and self.dp_rank != 0: + return + + # Then collect the sharded states along dp_group(if using zero)/tp_group. + # Only devices with (dp_rank == 0 and tp_rank == 0) are responsible for states saving. + state_dict_shard = HypridParallelCheckpointIO._optimizer_sharder( + optimizer, + use_zero=self.use_zero, + dp_group=self.dp_group, + tp_group=self.tp_group, + master_to_working_map=self.master_to_working_map, + size_per_shard=size_per_shard) + states_name, save_index_file, param_group_file = get_optimizer_base_filenames(prefix) + index_file = CheckpointIndexFile(checkpoint) + control_saving = (self.dp_rank == 0 and self.tp_rank == 0) + + if self.pp_size == 1: + # When pipeline is not used, save the optimizer shards as in general checkpointIO + total_size = save_state_dict_shards(sharded_state_dict=state_dict_shard, + checkpoint=checkpoint, + index_file=index_file, + base_filename=states_name, + is_master=control_saving) + + if control_saving: + # Store param groups. + index_file.append_meta_data("param_groups", param_group_file) + group_file_path = os.path.join(checkpoint, param_group_file) + save_param_groups(optimizer.param_info, group_file_path) + # Store index file. + index_file.append_meta_data("total_size", total_size) + index_file.write_index_file(save_index_file) + if self.verbose: + logging.info(f"The optimizer is going to be split to checkpoint shards. " + f"You can find where each parameters has been saved in the " + f"index located at {save_index_file}.") + + else: + # When pipeline is used, each stage produces its own shard files and index files. + # Index files belonging to each stage are saved under a temporary folder ./tmp_index_files/ + # After all the state_dicts have been saved, the master rank integrates all the index files into one final index file and deletes the tmp folder. + + final_index_file_path = copy.deepcopy(save_index_file) + tmp_index_file_folder = os.path.join(checkpoint, "tmp_index_files") + Path(tmp_index_file_folder).mkdir(parents=True, exist_ok=True) + + # Manage filenames of sharded weights and index file for each pipeline stage. + states_name = states_name.replace(".bin", f"-stage-{self.pp_rank:05d}-shard.bin") + save_index_file = save_index_file.replace(".json", f"-stage-{self.pp_rank:05d}.json") + save_index_file = os.path.join("tmp_index_files", save_index_file) + + total_size = save_state_dict_shards(sharded_state_dict=state_dict_shard, + checkpoint=checkpoint, + index_file=index_file, + base_filename=states_name, + is_master=control_saving) + + if control_saving: + assert self.dp_rank == 0 and self.tp_rank == 0, "The saving process should have both dp_rank and tp_rank as 0." + index_file.append_meta_data("total_size", total_size) + index_file.write_index_file(save_index_file) + else: + return + + dist.barrier(self.pp_group) + + # The global master rank integrates the index files and clean the folder. + if self.pp_rank == 0: + + final_index_file = CheckpointIndexFile(checkpoint) + final_index_file.append_meta_data("total_size", 0) + + for filename in os.listdir(tmp_index_file_folder): + stage_index_file = CheckpointIndexFile.from_file(os.path.join(tmp_index_file_folder, filename)) + final_index_file.metadata["total_size"] += stage_index_file.metadata["total_size"] + for param_id, state_filename in stage_index_file.weight_map.items(): + final_index_file.append_weight_map(param_id, state_filename) + + # Store param groups. + final_index_file.append_meta_data("param_groups", param_group_file) + group_file_path = os.path.join(checkpoint, param_group_file) + save_param_groups(optimizer.param_info, group_file_path) + + final_index_file.write_index_file(final_index_file_path) + rmtree(tmp_index_file_folder) + + if self.verbose: + logging.info(f"The model is split into checkpoint shards. " + f"You can find where each parameters has been saved in the " + f"index located at {final_index_file_path}.") + + def load_sharded_optimizer(self, optimizer: OptimizerWrapper, checkpoint_index_file: str, prefix: str = ""): + """ + Load sharded optimizer with the given path to index file of checkpoint folder. + + Args: + optimizer (OptimizerWrapper): The optimizer to be loaded. + checkpoint_index_file (str): Path to the index file of checkpointing folder. + prefix (str): Not used. + """ + + def _get_param_id_from_optimizer_param(param: torch.Tensor, + master_to_working_map: Optional[Dict[int, torch.Tensor]] = None): + if master_to_working_map is not None: + working_param = master_to_working_map[id(param)] + else: + working_param = param + return optimizer.param_info['param2id'][id(working_param)] + + # id_map is a mapping from param ids kept by current pipeline, to their corresponding parameter objects. + # When Zero is used, the mapped parameter objects should be fp32 master parameters. + # IDs should be obtained through saved param2id mapping earlier saved in optimizer.param_info. + id_map = {} + for pg in optimizer.optim.param_groups: + for param in pg['params']: + param_id = _get_param_id_from_optimizer_param(param, self.master_to_working_map) + id_map[param_id] = param - def load_sharded_optimizer(self, optimizer: Optimizer, index_file_path: str, prefix: str): - pass + # Read checkpoint index file. + ckpt_index_file = CheckpointIndexFile.from_file(checkpoint_index_file) + ckpt_root_path = ckpt_index_file.root_path + weight_map = ckpt_index_file.weight_map + weight_map = {int(k): v for k, v in weight_map.items()} # convert saved id from str to int + + # Load param_groups + param_group_path = ckpt_index_file.get_param_group_filename() + if param_group_path is None: + raise RuntimeError(f'Invalid index file path {checkpoint_index_file} for an optimizer. \ + Lacking param group file under current directory.') + saved_groups = torch.load(param_group_path) + + updated_groups = [] + for old_pg, saved_pg in zip(optimizer.optim.param_groups, saved_groups): + # obtain updated param group + new_pg = copy.deepcopy(saved_pg) + new_pg['params'] = old_pg['params'] # The parameters in the same group shouln't change. + updated_groups.append(new_pg) + optimizer.optim.__dict__.update({'param_groups': updated_groups}) + + # Load saved states to optimizer. + # Keep a record of loaded files so that file will not be repeatedly loaded. + loaded_file = set() + for pg in optimizer.optim.param_groups: + for param in pg['params']: + if param is None: + continue + param_id = _get_param_id_from_optimizer_param(param, self.master_to_working_map) + if param_id not in weight_map: + continue + filename = weight_map[param_id] + + # If this param's states has been loaded before, directly return. + if filename in loaded_file: + continue + + file_path = os.path.join(ckpt_root_path, filename) + state_dict = load_shard_state_dict(Path(file_path), use_safetensors=False) + load_states_into_optimizer(optimizer.optim, state_dict, id_map, strict=True) + loaded_file.add(filename) + + # Then shard the loaded optimizer states if using tp/zero. + for param, state in optimizer.optim.state.items(): + device = param.device + if self.master_to_working_map is not None: + working_param = self.master_to_working_map[id(param)] + else: + working_param = param + original_shape = optimizer.param_info['param2shape'][id(working_param)] + sharded_state = self.shard_from_complete_optimizer_state(state, + current_shape=working_param.shape, + original_shape=original_shape, + device=device, + inplace=True) + optimizer.optim.state[param] = sharded_state + + sharded_optimizer_loading_epilogue(optimizer.optim) + if self.verbose: + logging.info(f"The optimizer has been successfully loaded from sharded checkpoint: {ckpt_root_path}.") def load_unsharded_model(self, model: nn.Module, checkpoint: str, strict: bool = True): # TODO(Baizhou): support this feature after implementing complete state_dict collection @@ -314,3 +577,121 @@ def save_lr_scheduler(self, lr_scheduler: LRScheduler, checkpoint: str): """ if self.coordinator.is_master(): super().save_lr_scheduler(lr_scheduler, checkpoint) + + def link_master_and_working_param(self, working_to_master_map: Dict[Union[int, torch.Tensor], torch.Tensor], + master_to_working_map: Dict[Union[int, torch.Tensor], torch.Tensor]): + """ + Create mappings between working params (for forward/backward) and master params (for optimizer update) with passed in mappings. + This mapping can only be created when mixied precision is used. + The created mappings should be mappings from integer parameter addresses to parameter objects. + + Args: + working_to_master_map (Dict[Union[int, torch.Tensor], torch.Tensor]): A mapping from working parameters objects/addresses to master parameter objects. + master_to_working_map (Dict[Union[int, torch.Tensor], torch.Tensor]): A mapping from master parameters objects/addresses to working parameter objects. + """ + self.working_to_master_map = dict() + for k, v in working_to_master_map.items(): + if isinstance(k, torch.Tensor): + self.working_to_master_map[id(k)] = v + elif isinstance(k, int): + self.working_to_master_map[k] = v + else: + raise ValueError( + f"The passed in mapping should have keys of type 'int' or 'torch.Tensor', but got {type(k)}!") + + self.master_to_working_map = dict() + for k, v in master_to_working_map.items(): + if isinstance(k, torch.Tensor): + self.master_to_working_map[id(k)] = v + elif isinstance(k, int): + self.master_to_working_map[k] = v + else: + raise ValueError( + f"The passed in mapping should have keys of type 'int' or 'torch.Tensor', but got {type(k)}!") + + @staticmethod + def gather_from_sharded_optimizer_state(state: OrderedDict, param: torch.Tensor, original_shape: torch.Size, + dp_group: ProcessGroup, tp_group: ProcessGroup, use_zero: bool, + inplace: bool) -> OrderedDict: + """ + With given parameter and its optimizer states, gather the complete optimizer state for saving. + + Args: + state (OrderedDict): Optimizer states of given parameter, might be distributed among tp/dp group if using TP/Zero. + param (torch.Tensor): The given parameter. It should be working_param when using Zero. + original_shape (torch.Size): The size of parameter before sharding. + dp_group (ProcessGroup): The process group of data parallel. + tp_group (ProcessGroup): The process group of tensor parallel. + use_zero (bool): Whether Zero is used. + inplace (bool): If set to True, will update the values of argument 'state' in place. Else will make a copy of state. + + Returns: + OrderedDict: The complete optimizer state of given parameter. + """ + dp_size = dist.get_world_size(dp_group) + tp_size = dist.get_world_size(tp_group) + current_shape = param.shape + state_ = state if inplace else copy.deepcopy(state) + + for k, v in state_.items(): + if isinstance(v, torch.Tensor) and k != 'step': + + # First gather Zero shards. + if use_zero: + v = v.cuda() + gather_tensor = [torch.zeros_like(v) for _ in range(dp_size)] + dist.all_gather(gather_tensor, v, group=dp_group) + v = torch.stack(gather_tensor).view(-1)[:param.numel()].reshape_as(param) + + # Then gather TP shards. + partition_dim = search_tp_partition_dim(current_shape, original_shape, tp_size) + if partition_dim is not None: + gather_tensor = [torch.zeros_like(v) for _ in range(tp_size)] + dist.all_gather(gather_tensor, v, group=tp_group) + v = torch.cat(gather_tensor, dim=partition_dim) + + state_[k] = v.detach().clone().cpu() + + return state_ + + def shard_from_complete_optimizer_state(self, state: OrderedDict, current_shape: torch.Size, + original_shape: torch.Size, device: torch.device, + inplace: bool) -> OrderedDict: + """ + With complete optimizer states of a specific parameter loaded from checkpoint, + slice out the sharded optimizer states kept by current device. + + Args: + state (OrderedDict): Complete optimizer states of a given parameter, loaded from checkpoint. + current_shape (torch.Size): The size of parameter after sharding. + original_shape (torch.Size): The size of parameter before sharding. + device (torch.device): The destination device of loaded optimizer states. + inplace (bool): If set to True, will update the values of argument 'state' in place. Else will make a copy of state. + + Returns: + OrderedDict: The sharded optimizer state of the given parameter. + """ + state_ = state if inplace else copy.deepcopy(state) + + for k, v in state_.items(): + if isinstance(v, torch.Tensor) and k != 'step': + + # Shard state along tensor parallel group. + partition_dim = search_tp_partition_dim(current_shape, original_shape, self.tp_size) + if partition_dim is not None: + slice_size = current_shape[partition_dim] + v = v.split(slice_size, dim=partition_dim)[self.tp_rank] + + # Shard state along data parallel group when using Zero. + if self.use_zero: + padding_size = (self.dp_size - v.numel() % self.dp_size) % self.dp_size + with torch.no_grad(): + v = v.flatten() + if padding_size > 0: + v = torch.nn.functional.pad(v, [0, padding_size]) + slice_size = v.numel() // self.dp_size + v = v.split(slice_size, dim=0)[self.dp_rank] + + state_[k] = v.detach().clone().to(device) + + return state_ diff --git a/colossalai/checkpoint_io/utils.py b/colossalai/checkpoint_io/utils.py index d04159c54d5e..0025d07dfc8e 100644 --- a/colossalai/checkpoint_io/utils.py +++ b/colossalai/checkpoint_io/utils.py @@ -1,4 +1,5 @@ # coding=utf-8 +import copy import os import re from collections import abc as container_abcs @@ -8,7 +9,9 @@ from typing import Iterator, List, Mapping, Optional, OrderedDict, Tuple import torch +import torch.distributed as dist import torch.nn as nn +from torch.distributed import ProcessGroup from torch.optim import Optimizer from colossalai.interface import OptimizerWrapper @@ -93,24 +96,31 @@ def is_safetensor_checkpoint(checkpoint_file_path: str) -> bool: return False -def gather_distributed_param(param: torch.Tensor, keep_vars: bool = False): +def search_tp_partition_dim(current_shape: torch.Size, original_shape: torch.Size, tp_size: int) -> Optional[int]: """ - Gather the complete parameter for saving if passed in param is distributed. + Given the current shape of parameter and the shape of parameter before sharding, + return the dimension along which the parameter is sharded when using tensor parallel. + If tensor parallel is not used, return None. Args: - param (torch.Tensor): A model parameter, might be d_tensor. - keep_vars (bool, optional): Whether to return the parameter in calculation graph. Defaults to False. + current_shape (torch.Size): The current shape of parameter after sharding. + original_shape (torch.Size): The shape of parameter before sharding. + tp_size (int): The size of tp group. Returns: - torch.Tensor: the complete parameter + Optional[int]: The dimension along which parameter is partitioned. """ - param_ = param if keep_vars else param.detach() - if is_distributed_tensor(param_): - return to_global(param_) - elif is_customized_distributed_tensor(param_): - return to_global_for_customized_distributed_tensor(param_) - else: - return param_ + partition_dim = None + for dim, length in enumerate(original_shape): + if length > current_shape[dim]: + partition_dim = dim + break + if partition_dim is not None: + assert original_shape[partition_dim] == tp_size * current_shape[partition_dim], \ + f"The parameter isn't evenly distributed among tensor parallel group: \ + shape before sharding {original_shape}, shape after sharding {current_shape}" + + return partition_dim # ====================================== @@ -136,7 +146,8 @@ def __init__(self, size_per_shard: int) -> None: self.current_block = OrderedDict() self.current_block_size = 0 - def append(self, name: str, tensor: torch.Tensor) -> Tuple[Optional[OrderedDict], int]: + def append_param(self, name: str, tensor: torch.Tensor) -> Tuple[Optional[OrderedDict], int]: + tensor_size = calculate_tensor_size(tensor) ret_block = None ret_block_size = 0 @@ -153,6 +164,64 @@ def append(self, name: str, tensor: torch.Tensor) -> Tuple[Optional[OrderedDict] self.current_block_size += tensor_size return ret_block, ret_block_size + def append_optim_state(self, param_id: int, state: OrderedDict) -> Tuple[Optional[OrderedDict], int]: + + # A state might contain more than one tensors. + # e.g. each Adam state includes: 'step', 'exp_avg', 'exp_avg_sq' + state_size = 0 + isDTensor = False + for state_tensor in state.values(): + + # When state_tensor is not of Tensor class, + # e.g., a SGD optimizer with momentum set to 0 can have None as state + # The calculation of tensor size should be skipped to avoid error. + if not isinstance(state_tensor, torch.Tensor): + continue + + # If the states are stored as DTensors, mark isDTensor as true. + if is_distributed_tensor(state_tensor): + isDTensor = True + state_size += calculate_tensor_size(state_tensor) + + ret_block = None + ret_block_size = 0 + + # directly return if state is stored as distributed tensor + if isDTensor: + return ret_block, ret_block_size + + # before we return the current block and create a new block, + # we need to ensure that the current block is not empty + if self.current_block_size + state_size > self.max_shard_size and self.current_block_size > 0: + ret_block = self.current_block + ret_block_size = self.current_block_size + self.current_block = OrderedDict() + self.current_block_size = 0 + + self.current_block[param_id] = state + self.current_block_size += state_size + return ret_block, ret_block_size + + +def gather_distributed_param(param: torch.Tensor, keep_vars: bool = False) -> torch.Tensor: + """ + Gather the complete parameter for saving if passed in param is distributed under tp setting. + + Args: + param (torch.Tensor): A model parameter, might be d_tensor. + keep_vars (bool, optional): Whether to return the parameter in calculation graph. Defaults to False. + + Returns: + torch.Tensor: the complete parameter + """ + param_ = param if keep_vars else param.detach() + if is_distributed_tensor(param_): + return to_global(param_) + elif is_customized_distributed_tensor(param_): + return to_global_for_customized_distributed_tensor(param_) + else: + return param_ + def save_state_dict_shards(sharded_state_dict: Iterator[Tuple[OrderedDict, int]], checkpoint: str, @@ -198,28 +267,17 @@ def shard_model_checkpoint(state_dict: torch.Tensor, max_shard_size: int = 1024) Splits a model state dictionary in sub-checkpoints so that the final size of each sub-checkpoint does not exceed a given size. """ - current_block = {} - current_block_size = 0 + state_dict_sharder = StateDictSharder(max_shard_size) for key, weight in state_dict.items(): - ret_block = None - ret_block_size = 0 if not is_distributed_tensor(weight): - weight_size = calculate_tensor_size(weight) - - # If this weight is going to tip up over the maximal size, we split. - if current_block_size + weight_size > max_shard_size and current_block_size > 0: - ret_block = current_block - ret_block_size = current_block_size - current_block = {} - current_block_size = 0 - current_block[key] = weight - current_block_size += weight_size + block, block_size = state_dict_sharder.append_param(key, weight) - if ret_block != None: - yield ret_block, ret_block_size + if block != None: + yield block, block_size - yield current_block, current_block_size + # Return the last block in sharder. + yield state_dict_sharder.current_block, state_dict_sharder.current_block_size def shard_optimizer_checkpoint(state_dict: dict, max_shard_size: int = 1024) -> Iterator[Tuple[OrderedDict, int]]: @@ -230,47 +288,147 @@ def shard_optimizer_checkpoint(state_dict: dict, max_shard_size: int = 1024) -> # Only split state_dict['state']; state_dict['param_group'] is not considered in this function. states = state_dict['state'] - - current_block = {} - current_block_size = 0 + state_dict_sharder = StateDictSharder(max_shard_size) for param_id, state in states.items(): + block, block_size = state_dict_sharder.append_optim_state(param_id, state) + if block != None: + yield block, block_size - ret_block = None - ret_block_size = 0 + # Return the last block in sharder. + yield state_dict_sharder.current_block, state_dict_sharder.current_block_size - # A state might contain more than one tensors. - # e.g. each Adam state includes: 'step', 'exp_avg', 'exp_avg_sq' - state_size = 0 - isDTensor = False - for state_tensor in state.values(): - # When state_tensor is not of Tensor class, - # e.g., a SGD optimizer with momentum set to 0 can have None as state - # The calculation of tensor size should be skipped to avoid error. - if not isinstance(state_tensor, torch.Tensor): - continue +# ====================================== +# Helper functions for saving state dict +# ====================================== - # If the states are stored as DTensors, mark isDTensor as true. - if is_distributed_tensor(state_tensor): - isDTensor = True - state_size += calculate_tensor_size(state_tensor) - if not isDTensor: +def save_state_dict(state_dict: dict, checkpoint_file_path: str, use_safetensors: bool) -> None: + """ + Save state dict to checkpoint. + + Args: + state_dict (dict): state dict. + checkpoint_file_path (str): path to the checkpoint file. + use_safetensors (bool): whether to use safetensors to save the checkpoint. + """ + if use_safetensors: + assert is_safetensors_available(), "safetensors is not available." + assert checkpoint_file_path.endswith('.safetensors'), \ + "safetensors only supports .safetensors suffix for checkpoint file." + from safetensors.torch import save_file as safe_save_file + safe_save_file(state_dict, checkpoint_file_path, metadata={"format": "pt"}) + else: + torch.save(state_dict, checkpoint_file_path) + + +def save_param_groups(state_dict: dict, group_file_path: str) -> None: + """ + Save information of param_groups to given file path. + + Args: + state_dict (dict): state dict. + group_file_path (str): path to the group file. + """ + param_groups = state_dict["param_groups"] + torch.save(param_groups, group_file_path) + + +def save_dtensor(name: str, tensor: torch.Tensor, index_file: "CheckpointIndexFile", use_safetensors: bool) -> None: + """ + Save distributed tensor to checkpoint. This checkpoint will be a dictionary which contains + only one tensor. + + Args: + tensor (Tensor): tensor to be saved. + index_file (CheckpointIndexFile): path to the checkpoint file. + size_per_shard (int): size per shard in MB. + """ + root_path = index_file.root_path + output_root_path = root_path.joinpath('dtensor') + + # create directory + output_root_path.mkdir(exist_ok=True) + + # save tensor to this directory + # TODO(YuliangLiu): get index of the tensor shard + # e.g. index = + index = 0 + + # save tensor to file + ckpt_file_name = generate_dtensor_file_name(name, index, use_safetensors) + ckpt_file_path = output_root_path.joinpath(ckpt_file_name) + + # dtensor ckpt file always contains only one tensor + state_dict = {name: tensor} + save_state_dict(state_dict, str(ckpt_file_path), use_safetensors) + + # update the weight map + # * means all shards + ckpt_file_name_in_weight_map = 'dtensor/' + generate_dtensor_file_name(name, '*', use_safetensors) + index_file.append_weight_map(name, ckpt_file_name_in_weight_map) + + +def get_checkpoint_file_suffix(use_safetensors: bool) -> str: + """ + Get checkpoint file suffix. + + Args: + use_safetensors (bool): whether to use safetensors to save the checkpoint. + + Returns: + str: checkpoint file suffix. + """ + if use_safetensors: + return '.safetensors' + else: + return '.bin' + + +def generate_checkpoint_shard_file_name(index: int, + total_number: int, + use_safetensors: bool, + prefix: str = None) -> str: + """ + Generate checkpoint shard file name. + + Args: + index (int): index of the shard. + total_number (int): total number of shards. + use_safetensors (bool): whether to use safetensors to save the checkpoint. + prefix (str): prefix of the shard file name. Default: None. + + Returns: + str: checkpoint shard file name. + """ + suffix = get_checkpoint_file_suffix(use_safetensors) + + if prefix is None: + return f"{index:05d}-of-{total_number:05d}.{suffix}" + else: + return f"{prefix}-{index:05d}-of-{total_number:05d}.{suffix}" + - if current_block_size + state_size > max_shard_size and current_block_size > 0: - ret_block = current_block - ret_block_size = current_block_size - current_block = {} - current_block_size = 0 +def generate_dtensor_file_name(param_name: str, index: int, use_safetensors: bool) -> str: + """ + Generate dtensor file name. + + Args: + param_name (str): name of the distributed parameter. + index (int): index of the shard. + use_safetensors (bool): whether to use safetensors to save the checkpoint. - current_block[param_id] = state - current_block_size += state_size + Returns: + str: dtensor file name. + """ + suffix = get_checkpoint_file_suffix(use_safetensors) + return f'{param_name}.{index}.{suffix}' - if ret_block != None: - yield ret_block, ret_block_size - yield current_block, current_block_size +# ======================================== +# Helper functions for loading state dict +# ======================================== def load_shard_state_dict(checkpoint_file: Path, use_safetensors: bool = False): @@ -383,17 +541,21 @@ def update_group(group, new_group): return id_map -def load_states_into_optimizer(optimizer: Optimizer, state_dict: dict, id_map: dict): +def load_states_into_optimizer(optimizer: Optimizer, state_dict: dict, id_map: dict, strict: bool = False): r"""Copies states from `state_dict` into an Optimizer object. Args: optimizer(Optimizer): An initialized Optimizer object to be loaded - state_dict(dict): a mapping from tensor index (an integer) + state_dict(dict): A mapping from tensor index (an integer) to its states to be loaded (a mapping from state name to a tensor). - id_map(dict): a mapping from tensor index (an integer) + id_map(dict): A mapping from tensor index (an integer) to its corresponding parameter (a tensor) whose states will be updated. + strict(bool, optional): If set to True, only load the parameters with its id in id_map. Defaults to False. """ + # Ensure that the keys of state_dict are integers. + state_dict = {int(k): v for k, v in state_dict.items()} + def cast(param, value, key=None): r"""Make a deep copy of value, casting all tensors to device of param.""" if isinstance(value, torch.Tensor): @@ -420,7 +582,7 @@ def cast(param, value, key=None): if k in id_map: param = id_map[k] new_states[param] = cast(param, v) - else: + elif not strict: new_states[k] = v optimizer.state.update(new_states) @@ -438,165 +600,6 @@ def sharded_optimizer_loading_epilogue(optimizer: Optimizer): optimizer.defaults.setdefault('differentiable', False) -# ====================================== -# Helper functions for saving state dict -# ====================================== - - -def save_state_dict(state_dict: dict, checkpoint_file_path: str, use_safetensors: bool) -> None: - """ - Save state dict to checkpoint. - - Args: - state_dict (dict): state dict. - checkpoint_file_path (str): path to the checkpoint file. - use_safetensors (bool): whether to use safetensors to save the checkpoint. - """ - if use_safetensors: - assert is_safetensors_available(), "safetensors is not available." - assert checkpoint_file_path.endswith('.safetensors'), \ - "safetensors only supports .safetensors suffix for checkpoint file." - from safetensors.torch import save_file as safe_save_file - safe_save_file(state_dict, checkpoint_file_path, metadata={"format": "pt"}) - else: - torch.save(state_dict, checkpoint_file_path) - - -def save_param_groups(state_dict: dict, group_file_path: str) -> None: - """ - Save information of param_groups to given file path. - - Args: - state_dict (dict): state dict. - group_file_path (str): path to the group file. - """ - param_groups = state_dict["param_groups"] - torch.save(param_groups, group_file_path) - - -def save_dtensor(name: str, tensor: torch.Tensor, index_file: "CheckpointIndexFile", use_safetensors: bool) -> None: - """ - Save distributed tensor to checkpoint. This checkpoint will be a dictionary which contains - only one tensor. - - Args: - tensor (Tensor): tensor to be saved. - index_file (CheckpointIndexFile): path to the checkpoint file. - size_per_shard (int): size per shard in MB. - """ - root_path = index_file.root_path - output_root_path = root_path.joinpath('dtensor') - - # create directory - output_root_path.mkdir(exist_ok=True) - - # save tensor to this directory - # TODO(YuliangLiu): get index of the tensor shard - # e.g. index = - index = 0 - - # save tensor to file - ckpt_file_name = generate_dtensor_file_name(name, index, use_safetensors) - ckpt_file_path = output_root_path.joinpath(ckpt_file_name) - - # dtensor ckpt file always contains only one tensor - state_dict = {name: tensor} - save_state_dict(state_dict, str(ckpt_file_path), use_safetensors) - - # update the weight map - # * means all shards - ckpt_file_name_in_weight_map = 'dtensor/' + generate_dtensor_file_name(name, '*', use_safetensors) - index_file.append_weight_map(name, ckpt_file_name_in_weight_map) - - -def get_checkpoint_file_suffix(use_safetensors: bool) -> str: - """ - Get checkpoint file suffix. - - Args: - use_safetensors (bool): whether to use safetensors to save the checkpoint. - - Returns: - str: checkpoint file suffix. - """ - if use_safetensors: - return '.safetensors' - else: - return '.bin' - - -def generate_checkpoint_shard_file_name(index: int, - total_number: int, - use_safetensors: bool, - prefix: str = None) -> str: - """ - Generate checkpoint shard file name. - - Args: - index (int): index of the shard. - total_number (int): total number of shards. - use_safetensors (bool): whether to use safetensors to save the checkpoint. - prefix (str): prefix of the shard file name. Default: None. - - Returns: - str: checkpoint shard file name. - """ - suffix = get_checkpoint_file_suffix(use_safetensors) - - if prefix is None: - return f"{index:05d}-of-{total_number:05d}.{suffix}" - else: - return f"{prefix}-{index:05d}-of-{total_number:05d}.{suffix}" - - -def generate_dtensor_file_name(param_name: str, index: int, use_safetensors: bool) -> str: - """ - Generate dtensor file name. - - Args: - param_name (str): name of the distributed parameter. - index (int): index of the shard. - use_safetensors (bool): whether to use safetensors to save the checkpoint. - - Returns: - str: dtensor file name. - """ - suffix = get_checkpoint_file_suffix(use_safetensors) - return f'{param_name}.{index}.{suffix}' - - -def save_state_dict_as_shard( - state_dict: dict, - checkpoint_path: str, - index: int, - total_number: int, - use_safetensors: bool, - prefix: str = None, -) -> None: - """ - Save state dict as shard. - - Args: - state_dict (dict): state dict. - checkpoint_path (str): path to the checkpoint file. - index (int): index of the shard. - total_number (int): total number of shards. - prefix (str): prefix of the shard file name. - use_safetensors (bool): whether to use safetensors to save the checkpoint. - """ - # generate the shard name - shard_file_name = generate_checkpoint_shard_file_name(index, total_number, use_safetensors, prefix) - shard_file_path = Path(checkpoint_path).joinpath(shard_file_name).absolute() - - # save the shard - save_state_dict(state_dict, str(shard_file_path), use_safetensors) - - -# ======================================== -# Helper functions for loading state dict -# ======================================== - - def has_index_file(checkpoint_path: str) -> Tuple[bool, Optional[Path]]: """ Check whether the checkpoint has an index file. diff --git a/colossalai/zero/gemini/gemini_ddp.py b/colossalai/zero/gemini/gemini_ddp.py index 5aff91f03153..1c19071feb67 100644 --- a/colossalai/zero/gemini/gemini_ddp.py +++ b/colossalai/zero/gemini/gemini_ddp.py @@ -679,7 +679,7 @@ def state_dict_shard(self, gathered_param_buffer.update(self._get_chunk_to_save_data(chunk, only_rank_0, dtype)) gathered_param = gathered_param_buffer.pop(fp32_param) - block, block_size = sharder.append(prefix + name, gathered_param) + block, block_size = sharder.append_param(prefix + name, gathered_param) if block is not None: yield block, block_size @@ -690,7 +690,7 @@ def state_dict_shard(self, for name, buf in self.named_buffers(): if buf is not None and name not in self._non_persistent_buffers_set: buffer = buf if keep_vars else buf.detach() - block, block_size = sharder.append(prefix + name, buffer) + block, block_size = sharder.append_param(prefix + name, buffer) if block is not None: yield block, block_size # save extra states @@ -698,7 +698,7 @@ def state_dict_shard(self, if getattr(self.__class__, "get_extra_state", torch.nn.Module.get_extra_state) is not torch.nn.Module.get_extra_state: extra_state = self.get_extra_state() - block, block_size = sharder.append(extra_state_key, extra_state) + block, block_size = sharder.append_param(extra_state_key, extra_state) if block is not None: yield block, block_size diff --git a/colossalai/zero/gemini/gemini_optimizer.py b/colossalai/zero/gemini/gemini_optimizer.py index a2085323f83e..58b0f33ab189 100644 --- a/colossalai/zero/gemini/gemini_optimizer.py +++ b/colossalai/zero/gemini/gemini_optimizer.py @@ -10,7 +10,7 @@ from torch.optim import Optimizer from colossalai.amp.naive_amp.mixed_precision_mixin import BF16MixedPrecisionMixin, FP16MixedPrecisionMixin -from colossalai.checkpoint_io.utils import calculate_tensor_size +from colossalai.checkpoint_io.utils import StateDictSharder from colossalai.logging import get_dist_logger from colossalai.nn.optimizer import ColossalaiOptimizer, CPUAdam, FusedAdam, HybridAdam from colossalai.tensor.d_tensor import is_distributed_tensor @@ -691,49 +691,17 @@ def state_shard(self, Iterator[OrderedDict]: A generator of state dict shard of optimizer states. """ - current_block = {} - current_block_size = 0 - + sharder = StateDictSharder(max_shard_size) for param_id in self.id_to_real_params.keys(): dist.barrier() state = self.collect_states(param_id=param_id, only_rank_0=only_rank_0) - ret_block = None - ret_block_size = 0 - - # A state might contain more than one tensors. - # e.g. each Adam state includes: 'step', 'exp_avg', 'exp_avg_sq' - state_size = 0 - isDTensor = False - for state_tensor in state.values(): - - # When state_tensor is not of Tensor class, - # e.g., a SGD optimizer with momentum set to 0 can have None as state - # The calculation of tensor size should be skipped to avoid error. - if not isinstance(state_tensor, torch.Tensor): - continue - - # If the states are stored as DTensors, mark isDTensor as true. - if is_distributed_tensor(state_tensor): - isDTensor = True - state_size += calculate_tensor_size(state_tensor) - - if not isDTensor: - - if current_block_size + state_size > max_shard_size and current_block_size > 0: - ret_block = current_block - ret_block_size = current_block_size - current_block = {} - current_block_size = 0 - - current_block[param_id] = state - current_block_size += state_size - - if ret_block != None: - yield ret_block, ret_block_size + block, block_size = sharder.append_optim_state(param_id, state) + if block is not None: + yield block, block_size - yield current_block, current_block_size + yield sharder.current_block, sharder.current_block_size class GeminiAdamOptimizer(ZeroOptimizer): diff --git a/tests/test_checkpoint_io/test_hybrid_parallel_plugin_checkpoint_io.py b/tests/test_checkpoint_io/test_hybrid_parallel_plugin_checkpoint_io.py index 67d73c31f6e0..e43908e0c651 100644 --- a/tests/test_checkpoint_io/test_hybrid_parallel_plugin_checkpoint_io.py +++ b/tests/test_checkpoint_io/test_hybrid_parallel_plugin_checkpoint_io.py @@ -10,6 +10,7 @@ from colossalai.shardformer.layer.utils import Randomizer from colossalai.tensor.d_tensor.api import clear_layout_converter from colossalai.testing import ( + assert_close_loose, check_state_dict_equal, clear_cache_before_run, parameterize, @@ -19,34 +20,34 @@ from tests.kit.model_zoo import model_zoo +# TODO (Baizhou): Add test cases for shard=False @clear_cache_before_run() @parameterize('shard', [True]) @parameterize('model_name', ['transformers_gpt']) @parameterize('size_per_shard', [32]) @parameterize('test_config', [{ - 'tp_size': 2, - 'pp_size': 2, - 'num_microbatches': 4, - 'precision': 'fp32', -}, { - 'tp_size': 1, - 'pp_size': 2, - 'num_microbatches': 4, - 'precision': 'fp32', -}, { 'tp_size': 4, 'pp_size': 1, 'precision': 'fp32', }, { 'tp_size': 2, - 'pp_size': 1, - 'precision': 'fp32', + 'pp_size': 2, + 'num_microbatches': 4, + 'precision': 'fp16', + 'initial_scale': 1 }, { 'tp_size': 2, 'pp_size': 1, 'zero_stage': 2, 'precision': 'fp16', 'initial_scale': 1 +}, { + 'tp_size': 1, + 'pp_size': 2, + 'num_microbatches': 4, + 'zero_stage': 1, + 'precision': 'fp16', + 'initial_scale': 1 }]) def exam_state_dict(shard: bool, model_name: str, size_per_shard: int, test_config: dict): @@ -61,46 +62,91 @@ def _criterion(outputs, inputs): loss = criterion(outputs) return loss + def _preprocess_data(data): + if booster.plugin.stage_manager is not None: + for k, v in data.items(): + if torch.is_tensor(v) or 'Tensor' in v.__class__.__name__: + new_shape = [1] * v.dim() + new_shape[0] = 4 + data[k] = v.to('cuda').repeat(*new_shape) + return iter([data]) + else: + return {k: v.cuda() for k, v in data.items()} + model = model_fn().cuda() optimizer = Adam(model.parameters(), lr=1e-3) model, optimizer, criterion, _, _ = booster.boost(model, optimizer, criterion) - new_model = model_fn().cuda() - new_optimizer = Adam(new_model.parameters(), lr=1e-3) - new_model, new_optimizer, criterion, _, _ = booster.boost(new_model, new_optimizer, criterion) - data = data_gen_fn() model.train() if booster.plugin.stage_manager is not None: - for k, v in data.items(): - if torch.is_tensor(v) or 'Tensor' in v.__class__.__name__: - new_shape = [1] * v.dim() - new_shape[0] = 4 - data[k] = v.to('cuda').repeat(*new_shape) - data_iter = iter([data]) - output = booster.execute_pipeline(data_iter, - model, - _criterion, - optimizer, - return_loss=True, - return_outputs=False) + booster.execute_pipeline(_preprocess_data(data), + model, + _criterion, + optimizer, + return_loss=True, + return_outputs=False) else: - data = {k: v.cuda() for k, v in data.items()} - output = model(**data) + output = model(**_preprocess_data(data)) loss = criterion(output) optimizer.backward(loss) optimizer.step() with shared_tempdir() as tempdir: + model_ckpt_path = f"{tempdir}/model" - # optimizer_ckpt_path = f"{tempdir}/optimizer" + optimizer_ckpt_path = f"{tempdir}/optimizer" booster.save_model(model, model_ckpt_path, shard=shard, size_per_shard=size_per_shard) - # booster.save_optimizer(optimizer, optimizer_ckpt_path, shard=shard, size_per_shard=size_per_shard) + booster.save_optimizer(optimizer, optimizer_ckpt_path, shard=shard, size_per_shard=size_per_shard) dist.barrier() + + new_model = model_fn().cuda() + new_optimizer = Adam(new_model.parameters(), lr=1e-3) + new_model, new_optimizer, criterion, _, _ = booster.boost(new_model, new_optimizer, criterion) + booster.load_model(new_model, model_ckpt_path) check_state_dict_equal(model.unwrap().state_dict(), new_model.unwrap().state_dict(), False) + booster.load_optimizer(new_optimizer, optimizer_ckpt_path) + check_state_dict_equal(optimizer.unwrap().state_dict(), new_optimizer.unwrap().state_dict(), False) + dist.barrier() + + # Check whether the loaded model & optimizer works smoothly. + model.train() + new_model.train() + if booster.plugin.stage_manager is not None: + booster.execute_pipeline(_preprocess_data(data), + model, + _criterion, + optimizer, + return_loss=True, + return_outputs=False) + booster.execute_pipeline(_preprocess_data(data), + new_model, + _criterion, + new_optimizer, + return_loss=True, + return_outputs=False) + else: + old_model_loss = criterion(model(**_preprocess_data(data))) + optimizer.backward(old_model_loss) + new_model_loss = criterion(new_model(**_preprocess_data(data))) + new_optimizer.backward(new_model_loss) + + optimizer.step() + new_optimizer.step() + + # Check updated weights. + stage_manager = booster.plugin.stage_manager + + if stage_manager is None or stage_manager.is_first_stage(): + assert_close_loose(model.unwrap().wte.weight.data, new_model.unwrap().wte.weight.data, atol=5e-3, rtol=5e-3) + assert_close_loose(model.unwrap().h[0].mlp.c_fc.weight.data, + new_model.unwrap().h[0].mlp.c_fc.weight.data, + atol=5e-3, + rtol=5e-3) + dist.barrier() Randomizer.reset_index() clear_layout_converter()