From bd3e4d048e18592add718a9d64282dfdbec2ccb3 Mon Sep 17 00:00:00 2001 From: Baizhou Zhang Date: Tue, 22 Aug 2023 11:04:22 +0800 Subject: [PATCH 01/11] add APIs --- .../booster/plugin/hybrid_parallel_plugin.py | 73 ++++++++++++++++++- 1 file changed, 70 insertions(+), 3 deletions(-) diff --git a/colossalai/booster/plugin/hybrid_parallel_plugin.py b/colossalai/booster/plugin/hybrid_parallel_plugin.py index 016323ae7821..b28bd9f35ed6 100644 --- a/colossalai/booster/plugin/hybrid_parallel_plugin.py +++ b/colossalai/booster/plugin/hybrid_parallel_plugin.py @@ -1,11 +1,13 @@ import random from contextlib import nullcontext from functools import partial +from pathlib import Path from typing import Any, Callable, Iterator, List, Optional, Tuple, Union import numpy as np import torch import torch.distributed as dist +import torch.nn as nn from torch.distributed import ProcessGroup from torch.nn import Module, SyncBatchNorm from torch.nn.parallel import DistributedDataParallel as DDP @@ -16,8 +18,8 @@ from torch.utils.data.distributed import DistributedSampler from colossalai.amp.naive_amp.mixed_precision_optimizer import MixedPrecisionOptimizer -from colossalai.checkpoint_io import CheckpointIO -from colossalai.cluster import ProcessGroupMesh +from colossalai.checkpoint_io import CheckpointIO, GeneralCheckpointIO +from colossalai.cluster import DistCoordinator, ProcessGroupMesh from colossalai.interface import ModelWrapper, OptimizerWrapper from colossalai.pipeline.schedule import OneForwardOneBackwardSchedule from colossalai.pipeline.stage_manager import PipelineStageManager @@ -292,6 +294,7 @@ def __init__(self, self.schedule = OneForwardOneBackwardSchedule(num_microbatches, self.stage_manager) self.tp_group = self.pg_mesh.get_group_along_axis(TP_AXIS) self.dp_group = self.pg_mesh.get_group_along_axis(DP_AXIS) + self.pp_group = self.pg_mesh.get_group_along_axis(PP_AXIS) self.shard_config = ShardConfig(tensor_parallel_process_group=self.tp_group, pipeline_stage_manager=self.stage_manager, enable_tensor_parallelism=self.tp_size > 1, @@ -460,7 +463,71 @@ def seed_worker(worker_id): **_kwargs) def get_checkpoint_io(self) -> CheckpointIO: - return None + return HypridParallelCheckpointIO() def no_sync(self, model: Module) -> Iterator[None]: raise NotImplementedError + + +class HypridParallelCheckpointIO(GeneralCheckpointIO): + + def __init__(self) -> None: + super().__init__() + self.coordinator = DistCoordinator() + + def load_unsharded_model(self, model: nn.Module, checkpoint: str, strict: bool = True): + # TODO(Baizhou): support this feature after implementing state_dict and load_state_dict + raise NotImplementedError + + def save_unsharded_model(self, model: nn.Module, checkpoint: str, gather_dtensor: bool, use_safetensors: bool): + # TODO(Baizhou): support this feature after implementing state_dict and load_state_dict + raise NotImplementedError + + def save_unsharded_optimizer(self, optimizer: Optimizer, checkpoint: str, gather_dtensor: bool): + # TODO(Baizhou): support this feature after implementing state_dict and load_state_dict + raise NotImplementedError + + def load_unsharded_optimizer(self, optimizer: Optimizer, checkpoint: str, gather_dtensor: bool): + # TODO(Baizhou): support this feature after implementing state_dict and load_state_dict + raise NotImplementedError + + def save_lr_scheduler(self, lr_scheduler: LRScheduler, checkpoint: str): + """ + Save lr scheduler to checkpoint but only on master process. + """ + if self.coordinator.is_master(): + super().save_lr_scheduler(lr_scheduler, checkpoint) + + def save_sharded_model(self, + model: nn.Module, + checkpoint_path: str, + gather_dtensor: bool = True, + prefix: Optional[str] = None, + max_shard_size: int = 1024, + use_safetensors: bool = False): + """ + TODO (Baizhou): Add docstrings. + """ + pass + + def load_sharded_model(self, + model: nn.Module, + checkpoint_index_file: Path, + strict: bool = False, + use_safetensors: bool = False, + load_sub_module: bool = True): + """ + TODO (Baizhou): Add docstrings. + """ + pass + + def save_sharded_optimizer(self, + optimizer: Optimizer, + checkpoint: str, + gather_dtensor: bool = True, + prefix: Optional[str] = None, + size_per_shard: int = 1024): + pass + + def load_sharded_optimizer(self, optimizer: Optimizer, index_file_path: str, prefix: str): + pass From 66584b5450fd475c2bd98248a6d022b4de394c5b Mon Sep 17 00:00:00 2001 From: Baizhou Zhang Date: Wed, 23 Aug 2023 15:56:11 +0800 Subject: [PATCH 02/11] implement save_sharded_model --- .../booster/plugin/hybrid_parallel_plugin.py | 271 +++++++++++++++--- 1 file changed, 238 insertions(+), 33 deletions(-) diff --git a/colossalai/booster/plugin/hybrid_parallel_plugin.py b/colossalai/booster/plugin/hybrid_parallel_plugin.py index b28bd9f35ed6..2184f232075e 100644 --- a/colossalai/booster/plugin/hybrid_parallel_plugin.py +++ b/colossalai/booster/plugin/hybrid_parallel_plugin.py @@ -1,8 +1,12 @@ +import copy +import logging +import os import random from contextlib import nullcontext from functools import partial from pathlib import Path -from typing import Any, Callable, Iterator, List, Optional, Tuple, Union +from shutil import rmtree +from typing import Any, Callable, Iterator, List, Optional, OrderedDict, Tuple, Union import numpy as np import torch @@ -18,16 +22,37 @@ from torch.utils.data.distributed import DistributedSampler from colossalai.amp.naive_amp.mixed_precision_optimizer import MixedPrecisionOptimizer -from colossalai.checkpoint_io import CheckpointIO, GeneralCheckpointIO +from colossalai.checkpoint_io import CheckpointIndexFile, CheckpointIO, GeneralCheckpointIO +from colossalai.checkpoint_io.utils import ( + calculate_tensor_size, + get_model_base_filenames, + get_optimizer_base_filenames, + get_shard_filename, + load_shard_state_dict, + save_param_groups, + save_state_dict, + save_state_dict_shards, +) from colossalai.cluster import DistCoordinator, ProcessGroupMesh from colossalai.interface import ModelWrapper, OptimizerWrapper from colossalai.pipeline.schedule import OneForwardOneBackwardSchedule from colossalai.pipeline.stage_manager import PipelineStageManager from colossalai.shardformer import ShardConfig, ShardFormer +from colossalai.tensor.d_tensor import ( + is_customized_distributed_tensor, + is_distributed_tensor, + to_global, + to_global_for_customized_distributed_tensor, +) from colossalai.zero.low_level import LowLevelZeroOptimizer from .pp_plugin_base import PipelinePluginBase +try: + from torch.nn.modules.module import _EXTRA_STATE_KEY_SUFFIX, _IncompatibleKeys +except ImportError: + _EXTRA_STATE_KEY_SUFFIX = '_extra_state' + DP_AXIS, PP_AXIS, TP_AXIS = 0, 1, 2 @@ -463,52 +488,184 @@ def seed_worker(worker_id): **_kwargs) def get_checkpoint_io(self) -> CheckpointIO: - return HypridParallelCheckpointIO() + return HypridParallelCheckpointIO(self.pg_mesh) def no_sync(self, model: Module) -> Iterator[None]: raise NotImplementedError class HypridParallelCheckpointIO(GeneralCheckpointIO): + """ + CheckpointIO for Hybrid Parallel Training. - def __init__(self) -> None: - super().__init__() - self.coordinator = DistCoordinator() - - def load_unsharded_model(self, model: nn.Module, checkpoint: str, strict: bool = True): - # TODO(Baizhou): support this feature after implementing state_dict and load_state_dict - raise NotImplementedError - - def save_unsharded_model(self, model: nn.Module, checkpoint: str, gather_dtensor: bool, use_safetensors: bool): - # TODO(Baizhou): support this feature after implementing state_dict and load_state_dict - raise NotImplementedError - - def save_unsharded_optimizer(self, optimizer: Optimizer, checkpoint: str, gather_dtensor: bool): - # TODO(Baizhou): support this feature after implementing state_dict and load_state_dict - raise NotImplementedError - - def load_unsharded_optimizer(self, optimizer: Optimizer, checkpoint: str, gather_dtensor: bool): - # TODO(Baizhou): support this feature after implementing state_dict and load_state_dict - raise NotImplementedError + Args: + pg_mesh (ProcessGroupMesh): Process group mesh containing information of process groups along different dimensions. + """ - def save_lr_scheduler(self, lr_scheduler: LRScheduler, checkpoint: str): - """ - Save lr scheduler to checkpoint but only on master process. - """ - if self.coordinator.is_master(): - super().save_lr_scheduler(lr_scheduler, checkpoint) + def __init__(self, pg_mesh: ProcessGroupMesh) -> None: + super().__init__() + self.dp_group = pg_mesh.get_group_along_axis(DP_AXIS) + self.pp_group = pg_mesh.get_group_along_axis(PP_AXIS) + self.tp_group = pg_mesh.get_group_along_axis(TP_AXIS) + self.dp_rank = dist.get_rank(self.dp_group) + self.tp_rank = dist.get_rank(self.tp_group) + self.pp_rank = dist.get_rank(self.pp_group) + self.dp_size, self.pp_size, self.tp_size = pg_mesh.size + + @staticmethod + def _model_sharder(model: nn.Module, + prefix: str = '', + keep_vars: bool = False, + size_per_shard: int = 1024) -> Iterator[Tuple[OrderedDict, int]]: + # An internel method that breaks state_dict of model into shards within limited size. + + state_dict_sharder = _StateDictSharder(size_per_shard) + + # Save parameters. + for name, param in model.named_parameters(): + if param is None: + continue + # Gather tensor pieces when using tensor parallel. + param_ = param if keep_vars else param.detach() + if is_distributed_tensor(param_): + param_ = to_global(param_) + elif is_customized_distributed_tensor(param_): + param_ = to_global_for_customized_distributed_tensor(param_) + + block, block_size = state_dict_sharder.append(prefix + name, param_) + if block is not None: + yield block, block_size + + # Save buffers. + for name, buf in model.named_buffers(): + if buf is not None and name not in model._non_persistent_buffers_set: + buffer = buf if keep_vars else buf.detach() + block, block_size = state_dict_sharder.append(prefix + name, buffer) + if block is not None: + yield block, block_size + + # Save extra states. + extra_state_key = prefix + _EXTRA_STATE_KEY_SUFFIX + if getattr(model.__class__, "get_extra_state", + torch.nn.Module.get_extra_state) is not torch.nn.Module.get_extra_state: + extra_state = model.get_extra_state() + block, block_size = state_dict_sharder.append(extra_state_key, extra_state) + if block is not None: + yield block, block_size + + # Return the last block in sharder. + yield state_dict_sharder.current_block, state_dict_sharder.current_block_size + + @staticmethod + def _optimizer_sharder(optimizer: Optimizer, size_per_shard: int = 1024): + # An internel method that breaks state_dict of optimizer into shards within limited size. + # TODO (Baizhou): Implement sharding feature of optimizer. + pass def save_sharded_model(self, model: nn.Module, - checkpoint_path: str, + checkpoint: str, gather_dtensor: bool = True, prefix: Optional[str] = None, - max_shard_size: int = 1024, - use_safetensors: bool = False): + size_per_shard: int = 1024, + use_safetensors: bool = False) -> None: """ - TODO (Baizhou): Add docstrings. + Save sharded model checkpoint under the given checkpointing path. + The following files will be created under the path: + - An index file (pytorch_optim.bin.index.json) containing a map between model params/buffers and file names. + - Multiple files that store state tensors of models. + If pipeline parallelism is used, the filenames are in the form of "pytorch_model.-stage-000XX-shard-000XX.bin". + If pipeline parallelism is not used, "pytorch_model.-000XX.bin" + + + Args: + model (nn.Module): Model on local device to be saved. + checkpoint_path (str): Checkpointing path which should be a directory path. + gather_dtensor (bool, optional): Whether to gather_dtensor, currently not used. Defaults to True. + prefix (str, optional): Perfix of file to save. Defaults to None. + size_per_shard (int, optional): Size per shard in MB. Defaults to 1024. + use_safetensors (bool, optional): Whether to use safe tensors. Defaults to False. """ - pass + + if os.path.isfile(checkpoint): + logging.error(f"Provided path ({checkpoint}) should be a directory, not a file") + return + + Path(checkpoint).mkdir(parents=True, exist_ok=True) + + # Devices along the same dp_group share the same copies of model. + # So only let the device with dp_rank == 0 save the model. + if self.dp_rank != 0: + return + + # Then collect the sharded parameters & buffers along tp_group. + # Only devices with tp_size == 0 are responsible for model saving. + state_dict_shard = HypridParallelCheckpointIO._model_sharder(model, size_per_shard=size_per_shard) + weights_name, save_index_file = get_model_base_filenames(prefix, use_safetensors) + index_file = CheckpointIndexFile(checkpoint) + control_saving = (self.tp_rank == 0) + + if self.pp_size == 1: + # When pipeline is not used, save the model shards as in general checkpointIO + total_size = save_state_dict_shards(sharded_state_dict=state_dict_shard, + checkpoint=checkpoint, + index_file=index_file, + base_filename=weights_name, + is_master=control_saving, + use_safetensors=use_safetensors) + if control_saving: + index_file.append_meta_data("total_size", total_size) + index_file.write_index_file(save_index_file) + logging.info(f"The model is split into checkpoint shards. " + f"You can find where each parameters has been saved in the " + f"index located at {save_index_file}.") + + else: + # When pipeline is used, each stage produces its own shard files and index files. + # Index files belonging to each stage are saved under a tmp folder ./tmp_index_files/ + # After all the state_dicts have been saved, the master rank integrates all the index files into one final index file and deletes the folder. + + final_index_file_path = copy.deepcopy(save_index_file) + tmp_index_file_folder = os.path.join(checkpoint, "tmp_index_files") + Path(tmp_index_file_folder).mkdir(parents=True, exist_ok=True) + + # Manage filenames of sharded weights and index file for each pipeline stage. + weights_name = weights_name.replace(".bin", f"-stage-{self.pp_rank:05d}-shard.bin") + weights_name = weights_name.replace(".safetensors", f"-stage-{self.pp_rank:05d}-shard.safetensors") + save_index_file = save_index_file.replace(".json", f"-stage-{self.pp_rank:05d}.json") + save_index_file = os.path.join(tmp_index_file_folder, save_index_file) + + total_size = save_state_dict_shards(sharded_state_dict=state_dict_shard, + checkpoint=checkpoint, + index_file=index_file, + base_filename=weights_name, + is_master=control_saving, + use_safetensors=use_safetensors) + if control_saving: + assert self.dp_rank == 0 and self.tp_rank == 0, "The saving process should have both dp_rank and tp_rank as 0." + index_file.append_meta_data("total_size", total_size) + index_file.write_index_file(save_index_file) + else: + return + + dist.barrier(self.pp_group) + + # The global master rank integrates the index files and clean the folder. + if self.pp_rank == 0: + final_index_file = CheckpointIndexFile(checkpoint) + final_index_file.append_meta_data("total_size", 0) + + for filename in os.listdir(tmp_index_file_folder): + stage_index_file = CheckpointIndexFile.from_file(os.path.join(tmp_index_file_folder, filename)) + final_index_file.metadata["total_size"] += stage_index_file.metadata["total_size"] + for weight, weight_filename in stage_index_file.weight_map.items(): + final_index_file.append_weight_map(weight, weight_filename) + + final_index_file.write_index_file(final_index_file_path) + rmtree(tmp_index_file_folder) + logging.info(f"The model is split into checkpoint shards. " + f"You can find where each parameters has been saved in the " + f"index located at {final_index_file_path}.") def load_sharded_model(self, model: nn.Module, @@ -531,3 +688,51 @@ def save_sharded_optimizer(self, def load_sharded_optimizer(self, optimizer: Optimizer, index_file_path: str, prefix: str): pass + + def load_unsharded_model(self, model: nn.Module, checkpoint: str, strict: bool = True): + # TODO(Baizhou): support this feature after implementing complete state_dict collection + raise NotImplementedError + + def save_unsharded_model(self, model: nn.Module, checkpoint: str, gather_dtensor: bool, use_safetensors: bool): + # TODO(Baizhou): support this feature after implementing complete state_dict collection + raise NotImplementedError + + def save_unsharded_optimizer(self, optimizer: Optimizer, checkpoint: str, gather_dtensor: bool): + # TODO(Baizhou): support this feature after implementing complete state_dict collection + raise NotImplementedError + + def load_unsharded_optimizer(self, optimizer: Optimizer, checkpoint: str, gather_dtensor: bool): + # TODO(Baizhou): support this feature after implementing complete state_dict collection + raise NotImplementedError + + def save_lr_scheduler(self, lr_scheduler: LRScheduler, checkpoint: str): + """ + Save lr scheduler to checkpoint but only on master process. + """ + if self.coordinator.is_master(): + super().save_lr_scheduler(lr_scheduler, checkpoint) + + +class _StateDictSharder: + + def __init__(self, size_per_shard: int) -> None: + self.max_shard_size = size_per_shard + self.current_block = OrderedDict() + self.current_block_size = 0 + + def append(self, name: str, tensor: torch.Tensor) -> Tuple[Optional[OrderedDict], int]: + tensor_size = calculate_tensor_size(tensor) + ret_block = None + ret_block_size = 0 + + # before we return the current block and create a new block, + # we need to ensure that the current block is not empty + if self.current_block_size + tensor_size > self.max_shard_size and self.current_block_size > 0: + ret_block = self.current_block + ret_block_size = self.current_block_size + self.current_block = OrderedDict() + self.current_block_size = 0 + + self.current_block[name] = tensor + self.current_block_size += tensor_size + return ret_block, ret_block_size From d5617465ea03634c94c8131dc5ba6cd10b53b0e1 Mon Sep 17 00:00:00 2001 From: Baizhou Zhang Date: Wed, 23 Aug 2023 17:38:02 +0800 Subject: [PATCH 03/11] add test for hybrid checkpointio --- .../booster/plugin/hybrid_parallel_plugin.py | 4 +- ...st_hybrid_parallel_plugin_checkpoint_io.py | 90 +++++++++++++++++++ 2 files changed, 92 insertions(+), 2 deletions(-) create mode 100644 tests/test_checkpoint_io/test_hybrid_parallel_plugin_checkpoint_io.py diff --git a/colossalai/booster/plugin/hybrid_parallel_plugin.py b/colossalai/booster/plugin/hybrid_parallel_plugin.py index 2184f232075e..699a54fa6153 100644 --- a/colossalai/booster/plugin/hybrid_parallel_plugin.py +++ b/colossalai/booster/plugin/hybrid_parallel_plugin.py @@ -510,7 +510,7 @@ def __init__(self, pg_mesh: ProcessGroupMesh) -> None: self.dp_rank = dist.get_rank(self.dp_group) self.tp_rank = dist.get_rank(self.tp_group) self.pp_rank = dist.get_rank(self.pp_group) - self.dp_size, self.pp_size, self.tp_size = pg_mesh.size + self.dp_size, self.pp_size, self.tp_size = pg_mesh.shape @staticmethod def _model_sharder(model: nn.Module, @@ -633,7 +633,7 @@ def save_sharded_model(self, weights_name = weights_name.replace(".bin", f"-stage-{self.pp_rank:05d}-shard.bin") weights_name = weights_name.replace(".safetensors", f"-stage-{self.pp_rank:05d}-shard.safetensors") save_index_file = save_index_file.replace(".json", f"-stage-{self.pp_rank:05d}.json") - save_index_file = os.path.join(tmp_index_file_folder, save_index_file) + save_index_file = os.path.join("tmp_index_files", save_index_file) total_size = save_state_dict_shards(sharded_state_dict=state_dict_shard, checkpoint=checkpoint, diff --git a/tests/test_checkpoint_io/test_hybrid_parallel_plugin_checkpoint_io.py b/tests/test_checkpoint_io/test_hybrid_parallel_plugin_checkpoint_io.py new file mode 100644 index 000000000000..eec8d3550790 --- /dev/null +++ b/tests/test_checkpoint_io/test_hybrid_parallel_plugin_checkpoint_io.py @@ -0,0 +1,90 @@ +import os + +import pytest +import torch +import torch.distributed as dist +from torch.optim import Adam +from utils import shared_tempdir + +import colossalai +from colossalai.booster import Booster +from colossalai.booster.plugin import HybridParallelPlugin +from colossalai.testing import ( + check_state_dict_equal, + clear_cache_before_run, + parameterize, + rerun_if_address_is_in_use, + spawn, +) +from tests.kit.model_zoo import model_zoo + +# TODO (Baizhou): Add more test configs to go through all kinds of parallel strategy. + + +@clear_cache_before_run() +@parameterize('shard', [True]) +@parameterize('model_name', ['transformers_gpt']) +@parameterize('size_per_shard', [32]) +@parameterize('test_config', [{ + 'tp_size': 2, + 'pp_size': 2, + 'num_microbatches': 4, + 'precision': 'fp32', +}]) +def exam_state_dict(shard: bool, model_name: str, size_per_shard: int, test_config: dict): + + (model_fn, data_gen_fn, output_transform_fn, loss_fn, + _) = next(iter(model_zoo.get_sub_registry(model_name).values())) + model = model_fn().cuda() + optimizer = Adam(model.parameters(), lr=1e-3) + criterion = loss_fn + + def _criterion(outputs, inputs): + outputs = output_transform_fn(outputs) + loss = criterion(outputs) + return loss + + plugin = HybridParallelPlugin(**test_config) + booster = Booster(plugin=plugin) + model, optimizer, criterion, _, _ = booster.boost(model, optimizer, criterion) + + data = data_gen_fn() + model.train() + if booster.plugin.stage_manager is not None: + for k, v in data.items(): + if torch.is_tensor(v) or 'Tensor' in v.__class__.__name__: + new_shape = [1] * v.dim() + new_shape[0] = 4 + data[k] = v.to('cuda').repeat(*new_shape) + data_iter = iter([data]) + output = booster.execute_pipeline(data_iter, + model, + _criterion, + optimizer, + return_loss=True, + return_outputs=False) + else: + data = {k: v.cuda() for k, v in data.items()} + output = model(**data) + loss = criterion(output) + optimizer.backward(loss) + + with shared_tempdir() as tempdir: + model_ckpt_path = f"{tempdir}/model" + optimizer_ckpt_path = f"{tempdir}/optimizer" + booster.save_model(model, model_ckpt_path, shard=shard, size_per_shard=size_per_shard) + # booster.save_optimizer(optimizer, optimizer_ckpt_path, shard=shard, size_per_shard=size_per_shard) + dist.barrier() + + +def run_dist(rank, world_size, port): + config = {} + colossalai.launch(config=config, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') + exam_state_dict() + + +@pytest.mark.dist +@pytest.mark.parametrize('world_size', [4]) +@rerun_if_address_is_in_use() +def test_hybrid_ckpIO(world_size): + spawn(run_dist, world_size) From 3a2cbae6fd7d17892094ac211654a1e7ba2eef61 Mon Sep 17 00:00:00 2001 From: Baizhou Zhang Date: Thu, 24 Aug 2023 12:20:42 +0800 Subject: [PATCH 04/11] implement naive loading for sharded model --- .../booster/plugin/hybrid_parallel_plugin.py | 45 ++++++++++++++----- ...st_hybrid_parallel_plugin_checkpoint_io.py | 27 ++++++++--- 2 files changed, 54 insertions(+), 18 deletions(-) diff --git a/colossalai/booster/plugin/hybrid_parallel_plugin.py b/colossalai/booster/plugin/hybrid_parallel_plugin.py index 699a54fa6153..42411c0142b3 100644 --- a/colossalai/booster/plugin/hybrid_parallel_plugin.py +++ b/colossalai/booster/plugin/hybrid_parallel_plugin.py @@ -3,7 +3,7 @@ import os import random from contextlib import nullcontext -from functools import partial +from functools import partial, reduce from pathlib import Path from shutil import rmtree from typing import Any, Callable, Iterator, List, Optional, OrderedDict, Tuple, Union @@ -28,7 +28,9 @@ get_model_base_filenames, get_optimizer_base_filenames, get_shard_filename, + is_safetensors_available, load_shard_state_dict, + load_state_dict_into_model, save_param_groups, save_state_dict, save_state_dict_shards, @@ -622,8 +624,8 @@ def save_sharded_model(self, else: # When pipeline is used, each stage produces its own shard files and index files. - # Index files belonging to each stage are saved under a tmp folder ./tmp_index_files/ - # After all the state_dicts have been saved, the master rank integrates all the index files into one final index file and deletes the folder. + # Index files belonging to each stage are saved under a temporary folder ./tmp_index_files/ + # After all the state_dicts have been saved, the master rank integrates all the index files into one final index file and deletes the tmp folder. final_index_file_path = copy.deepcopy(save_index_file) tmp_index_file_folder = os.path.join(checkpoint, "tmp_index_files") @@ -667,16 +669,37 @@ def save_sharded_model(self, f"You can find where each parameters has been saved in the " f"index located at {final_index_file_path}.") - def load_sharded_model(self, - model: nn.Module, - checkpoint_index_file: Path, - strict: bool = False, - use_safetensors: bool = False, - load_sub_module: bool = True): + def load_sharded_model(self, model: nn.Module, checkpoint_index_file: Path, strict: bool = False): """ - TODO (Baizhou): Add docstrings. + Load sharded model with the given path to index file of checkpoint folder. + + Args: + model (nn.Module): The model to be loaded. + index_file_path (str): Path to the index file of checkpointing folder. + strict (bool, optional): For name matching during loading state_dict. Defaults to False. + This argument should be manually set to False since params on same device might be stored in different files. """ - pass + + # Check whether the checkpoint uses safetensors. + use_safetensors = False + if "safetensors" in checkpoint_index_file.name: + use_safetensors = True + + if use_safetensors and not is_safetensors_available(): + raise ImportError("`safe_serialization` requires the `safetensors` library: `pip install safetensors`.") + + # Read checkpoint index file. + ckpt_index_file = CheckpointIndexFile.from_file(checkpoint_index_file) + checkpoint_files, _ = ckpt_index_file.get_checkpoint_filenames() + missing_keys = [] + + # Load params & buffers to model. + # Keep a record of loaded files so that file will not be repeatedly loaded. + strict = False + for shard_file in checkpoint_files: + state_dict = load_shard_state_dict(Path(shard_file), use_safetensors) + load_state_dict_into_model(model, state_dict, missing_keys, strict=strict, load_sub_module=True) + del state_dict def save_sharded_optimizer(self, optimizer: Optimizer, diff --git a/tests/test_checkpoint_io/test_hybrid_parallel_plugin_checkpoint_io.py b/tests/test_checkpoint_io/test_hybrid_parallel_plugin_checkpoint_io.py index eec8d3550790..cf2329f82783 100644 --- a/tests/test_checkpoint_io/test_hybrid_parallel_plugin_checkpoint_io.py +++ b/tests/test_checkpoint_io/test_hybrid_parallel_plugin_checkpoint_io.py @@ -23,6 +23,7 @@ @clear_cache_before_run() @parameterize('shard', [True]) +@parameterize('use_safetensors', [False, True]) @parameterize('model_name', ['transformers_gpt']) @parameterize('size_per_shard', [32]) @parameterize('test_config', [{ @@ -31,23 +32,27 @@ 'num_microbatches': 4, 'precision': 'fp32', }]) -def exam_state_dict(shard: bool, model_name: str, size_per_shard: int, test_config: dict): +def exam_state_dict(shard: bool, use_safetensors: bool, model_name: str, size_per_shard: int, test_config: dict): (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) = next(iter(model_zoo.get_sub_registry(model_name).values())) - model = model_fn().cuda() - optimizer = Adam(model.parameters(), lr=1e-3) criterion = loss_fn + plugin = HybridParallelPlugin(**test_config) + booster = Booster(plugin=plugin) def _criterion(outputs, inputs): outputs = output_transform_fn(outputs) loss = criterion(outputs) return loss - plugin = HybridParallelPlugin(**test_config) - booster = Booster(plugin=plugin) + model = model_fn().cuda() + optimizer = Adam(model.parameters(), lr=1e-3) model, optimizer, criterion, _, _ = booster.boost(model, optimizer, criterion) + new_model = model_fn().cuda() + new_optimizer = Adam(new_model.parameters(), lr=1e-3) + new_model, new_optimizer, criterion, _, _ = booster.boost(new_model, new_optimizer, criterion) + data = data_gen_fn() model.train() if booster.plugin.stage_manager is not None: @@ -69,12 +74,20 @@ def _criterion(outputs, inputs): loss = criterion(output) optimizer.backward(loss) + optimizer.step() + with shared_tempdir() as tempdir: model_ckpt_path = f"{tempdir}/model" - optimizer_ckpt_path = f"{tempdir}/optimizer" - booster.save_model(model, model_ckpt_path, shard=shard, size_per_shard=size_per_shard) + # optimizer_ckpt_path = f"{tempdir}/optimizer" + booster.save_model(model, + model_ckpt_path, + shard=shard, + size_per_shard=size_per_shard, + use_safetensors=use_safetensors) # booster.save_optimizer(optimizer, optimizer_ckpt_path, shard=shard, size_per_shard=size_per_shard) dist.barrier() + booster.load_model(new_model, model_ckpt_path) + check_state_dict_equal(model.unwrap().state_dict(), new_model.unwrap().state_dict(), False) def run_dist(rank, world_size, port): From 3627a95cec395f01aafaa05272f4f5ee1bae0309 Mon Sep 17 00:00:00 2001 From: Baizhou Zhang Date: Thu, 24 Aug 2023 14:08:53 +0800 Subject: [PATCH 05/11] implement efficient sharded model loading --- .../booster/plugin/hybrid_parallel_plugin.py | 46 ++++++++++++++++--- ...st_hybrid_parallel_plugin_checkpoint_io.py | 35 +++++++++----- 2 files changed, 63 insertions(+), 18 deletions(-) diff --git a/colossalai/booster/plugin/hybrid_parallel_plugin.py b/colossalai/booster/plugin/hybrid_parallel_plugin.py index 42411c0142b3..1c4684f4aa62 100644 --- a/colossalai/booster/plugin/hybrid_parallel_plugin.py +++ b/colossalai/booster/plugin/hybrid_parallel_plugin.py @@ -690,16 +690,48 @@ def load_sharded_model(self, model: nn.Module, checkpoint_index_file: Path, stri # Read checkpoint index file. ckpt_index_file = CheckpointIndexFile.from_file(checkpoint_index_file) - checkpoint_files, _ = ckpt_index_file.get_checkpoint_filenames() - missing_keys = [] + ckpt_root_path = ckpt_index_file.root_path + weight_map = ckpt_index_file.weight_map + strict = False # Load params & buffers to model. # Keep a record of loaded files so that file will not be repeatedly loaded. - strict = False - for shard_file in checkpoint_files: - state_dict = load_shard_state_dict(Path(shard_file), use_safetensors) - load_state_dict_into_model(model, state_dict, missing_keys, strict=strict, load_sub_module=True) - del state_dict + loaded_file = set() + + def _load(name: str): + if name not in weight_map: + raise ValueError(f"{name} is not stored in checkpoint, please check your checkpointing configuration!") + filename = weight_map[name] + + # If this param/buffer has been loaded before, directly return. + if filename in loaded_file: + return + + file_path = os.path.join(ckpt_root_path, filename) + state_dict = load_shard_state_dict(Path(file_path), use_safetensors) + missing_keys = [] + + load_state_dict_into_model(model, + state_dict, + missing_keys=missing_keys, + strict=strict, + load_sub_module=True) + loaded_file.add(filename) + + # Load parameters. + for name, _ in model.named_parameters(): + _load(name) + + # Load buffers. + for name, buf in model.named_buffers(): + if buf is not None and name not in model._non_persistent_buffers_set: + _load(name) + + # Load extra states. + extra_state_key = _EXTRA_STATE_KEY_SUFFIX + if getattr(model.__class__, "get_extra_state", + torch.nn.Module.get_extra_state) is not torch.nn.Module.get_extra_state: + _load(extra_state_key) def save_sharded_optimizer(self, optimizer: Optimizer, diff --git a/tests/test_checkpoint_io/test_hybrid_parallel_plugin_checkpoint_io.py b/tests/test_checkpoint_io/test_hybrid_parallel_plugin_checkpoint_io.py index cf2329f82783..ea0922ef5dec 100644 --- a/tests/test_checkpoint_io/test_hybrid_parallel_plugin_checkpoint_io.py +++ b/tests/test_checkpoint_io/test_hybrid_parallel_plugin_checkpoint_io.py @@ -1,5 +1,3 @@ -import os - import pytest import torch import torch.distributed as dist @@ -9,6 +7,7 @@ import colossalai from colossalai.booster import Booster from colossalai.booster.plugin import HybridParallelPlugin +from colossalai.tensor.d_tensor.api import clear_layout_converter from colossalai.testing import ( check_state_dict_equal, clear_cache_before_run, @@ -18,12 +17,9 @@ ) from tests.kit.model_zoo import model_zoo -# TODO (Baizhou): Add more test configs to go through all kinds of parallel strategy. - @clear_cache_before_run() @parameterize('shard', [True]) -@parameterize('use_safetensors', [False, True]) @parameterize('model_name', ['transformers_gpt']) @parameterize('size_per_shard', [32]) @parameterize('test_config', [{ @@ -31,8 +27,27 @@ 'pp_size': 2, 'num_microbatches': 4, 'precision': 'fp32', +}, { + 'tp_size': 1, + 'pp_size': 2, + 'num_microbatches': 4, + 'precision': 'fp32', +}, { + 'tp_size': 4, + 'pp_size': 1, + 'precision': 'fp32', +}, { + 'tp_size': 2, + 'pp_size': 1, + 'precision': 'fp32', +}, { + 'tp_size': 2, + 'pp_size': 1, + 'zero_stage': 2, + 'precision': 'fp16', + 'initial_scale': 1 }]) -def exam_state_dict(shard: bool, use_safetensors: bool, model_name: str, size_per_shard: int, test_config: dict): +def exam_state_dict(shard: bool, model_name: str, size_per_shard: int, test_config: dict): (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) = next(iter(model_zoo.get_sub_registry(model_name).values())) @@ -79,16 +94,14 @@ def _criterion(outputs, inputs): with shared_tempdir() as tempdir: model_ckpt_path = f"{tempdir}/model" # optimizer_ckpt_path = f"{tempdir}/optimizer" - booster.save_model(model, - model_ckpt_path, - shard=shard, - size_per_shard=size_per_shard, - use_safetensors=use_safetensors) + booster.save_model(model, model_ckpt_path, shard=shard, size_per_shard=size_per_shard) # booster.save_optimizer(optimizer, optimizer_ckpt_path, shard=shard, size_per_shard=size_per_shard) dist.barrier() booster.load_model(new_model, model_ckpt_path) check_state_dict_equal(model.unwrap().state_dict(), new_model.unwrap().state_dict(), False) + clear_layout_converter() + def run_dist(rank, world_size, port): config = {} From bc699e146f1a40b08a63dde3e08877490b303691 Mon Sep 17 00:00:00 2001 From: Baizhou Zhang Date: Thu, 24 Aug 2023 14:21:38 +0800 Subject: [PATCH 06/11] open a new file for hybrid checkpoint_io --- .../booster/plugin/hybrid_parallel_plugin.py | 334 +---------------- colossalai/checkpoint_io/__init__.py | 3 +- .../hybrid_parallel_checkpoint_io.py | 338 ++++++++++++++++++ 3 files changed, 344 insertions(+), 331 deletions(-) create mode 100644 colossalai/checkpoint_io/hybrid_parallel_checkpoint_io.py diff --git a/colossalai/booster/plugin/hybrid_parallel_plugin.py b/colossalai/booster/plugin/hybrid_parallel_plugin.py index 1c4684f4aa62..2e591b176ab3 100644 --- a/colossalai/booster/plugin/hybrid_parallel_plugin.py +++ b/colossalai/booster/plugin/hybrid_parallel_plugin.py @@ -1,17 +1,11 @@ -import copy -import logging -import os import random from contextlib import nullcontext -from functools import partial, reduce -from pathlib import Path -from shutil import rmtree -from typing import Any, Callable, Iterator, List, Optional, OrderedDict, Tuple, Union +from functools import partial +from typing import Any, Callable, Iterator, List, Optional, Tuple, Union import numpy as np import torch import torch.distributed as dist -import torch.nn as nn from torch.distributed import ProcessGroup from torch.nn import Module, SyncBatchNorm from torch.nn.parallel import DistributedDataParallel as DDP @@ -22,39 +16,16 @@ from torch.utils.data.distributed import DistributedSampler from colossalai.amp.naive_amp.mixed_precision_optimizer import MixedPrecisionOptimizer -from colossalai.checkpoint_io import CheckpointIndexFile, CheckpointIO, GeneralCheckpointIO -from colossalai.checkpoint_io.utils import ( - calculate_tensor_size, - get_model_base_filenames, - get_optimizer_base_filenames, - get_shard_filename, - is_safetensors_available, - load_shard_state_dict, - load_state_dict_into_model, - save_param_groups, - save_state_dict, - save_state_dict_shards, -) -from colossalai.cluster import DistCoordinator, ProcessGroupMesh +from colossalai.checkpoint_io import CheckpointIO, HypridParallelCheckpointIO +from colossalai.cluster import ProcessGroupMesh from colossalai.interface import ModelWrapper, OptimizerWrapper from colossalai.pipeline.schedule import OneForwardOneBackwardSchedule from colossalai.pipeline.stage_manager import PipelineStageManager from colossalai.shardformer import ShardConfig, ShardFormer -from colossalai.tensor.d_tensor import ( - is_customized_distributed_tensor, - is_distributed_tensor, - to_global, - to_global_for_customized_distributed_tensor, -) from colossalai.zero.low_level import LowLevelZeroOptimizer from .pp_plugin_base import PipelinePluginBase -try: - from torch.nn.modules.module import _EXTRA_STATE_KEY_SUFFIX, _IncompatibleKeys -except ImportError: - _EXTRA_STATE_KEY_SUFFIX = '_extra_state' - DP_AXIS, PP_AXIS, TP_AXIS = 0, 1, 2 @@ -494,300 +465,3 @@ def get_checkpoint_io(self) -> CheckpointIO: def no_sync(self, model: Module) -> Iterator[None]: raise NotImplementedError - - -class HypridParallelCheckpointIO(GeneralCheckpointIO): - """ - CheckpointIO for Hybrid Parallel Training. - - Args: - pg_mesh (ProcessGroupMesh): Process group mesh containing information of process groups along different dimensions. - """ - - def __init__(self, pg_mesh: ProcessGroupMesh) -> None: - super().__init__() - self.dp_group = pg_mesh.get_group_along_axis(DP_AXIS) - self.pp_group = pg_mesh.get_group_along_axis(PP_AXIS) - self.tp_group = pg_mesh.get_group_along_axis(TP_AXIS) - self.dp_rank = dist.get_rank(self.dp_group) - self.tp_rank = dist.get_rank(self.tp_group) - self.pp_rank = dist.get_rank(self.pp_group) - self.dp_size, self.pp_size, self.tp_size = pg_mesh.shape - - @staticmethod - def _model_sharder(model: nn.Module, - prefix: str = '', - keep_vars: bool = False, - size_per_shard: int = 1024) -> Iterator[Tuple[OrderedDict, int]]: - # An internel method that breaks state_dict of model into shards within limited size. - - state_dict_sharder = _StateDictSharder(size_per_shard) - - # Save parameters. - for name, param in model.named_parameters(): - if param is None: - continue - # Gather tensor pieces when using tensor parallel. - param_ = param if keep_vars else param.detach() - if is_distributed_tensor(param_): - param_ = to_global(param_) - elif is_customized_distributed_tensor(param_): - param_ = to_global_for_customized_distributed_tensor(param_) - - block, block_size = state_dict_sharder.append(prefix + name, param_) - if block is not None: - yield block, block_size - - # Save buffers. - for name, buf in model.named_buffers(): - if buf is not None and name not in model._non_persistent_buffers_set: - buffer = buf if keep_vars else buf.detach() - block, block_size = state_dict_sharder.append(prefix + name, buffer) - if block is not None: - yield block, block_size - - # Save extra states. - extra_state_key = prefix + _EXTRA_STATE_KEY_SUFFIX - if getattr(model.__class__, "get_extra_state", - torch.nn.Module.get_extra_state) is not torch.nn.Module.get_extra_state: - extra_state = model.get_extra_state() - block, block_size = state_dict_sharder.append(extra_state_key, extra_state) - if block is not None: - yield block, block_size - - # Return the last block in sharder. - yield state_dict_sharder.current_block, state_dict_sharder.current_block_size - - @staticmethod - def _optimizer_sharder(optimizer: Optimizer, size_per_shard: int = 1024): - # An internel method that breaks state_dict of optimizer into shards within limited size. - # TODO (Baizhou): Implement sharding feature of optimizer. - pass - - def save_sharded_model(self, - model: nn.Module, - checkpoint: str, - gather_dtensor: bool = True, - prefix: Optional[str] = None, - size_per_shard: int = 1024, - use_safetensors: bool = False) -> None: - """ - Save sharded model checkpoint under the given checkpointing path. - The following files will be created under the path: - - An index file (pytorch_optim.bin.index.json) containing a map between model params/buffers and file names. - - Multiple files that store state tensors of models. - If pipeline parallelism is used, the filenames are in the form of "pytorch_model.-stage-000XX-shard-000XX.bin". - If pipeline parallelism is not used, "pytorch_model.-000XX.bin" - - - Args: - model (nn.Module): Model on local device to be saved. - checkpoint_path (str): Checkpointing path which should be a directory path. - gather_dtensor (bool, optional): Whether to gather_dtensor, currently not used. Defaults to True. - prefix (str, optional): Perfix of file to save. Defaults to None. - size_per_shard (int, optional): Size per shard in MB. Defaults to 1024. - use_safetensors (bool, optional): Whether to use safe tensors. Defaults to False. - """ - - if os.path.isfile(checkpoint): - logging.error(f"Provided path ({checkpoint}) should be a directory, not a file") - return - - Path(checkpoint).mkdir(parents=True, exist_ok=True) - - # Devices along the same dp_group share the same copies of model. - # So only let the device with dp_rank == 0 save the model. - if self.dp_rank != 0: - return - - # Then collect the sharded parameters & buffers along tp_group. - # Only devices with tp_size == 0 are responsible for model saving. - state_dict_shard = HypridParallelCheckpointIO._model_sharder(model, size_per_shard=size_per_shard) - weights_name, save_index_file = get_model_base_filenames(prefix, use_safetensors) - index_file = CheckpointIndexFile(checkpoint) - control_saving = (self.tp_rank == 0) - - if self.pp_size == 1: - # When pipeline is not used, save the model shards as in general checkpointIO - total_size = save_state_dict_shards(sharded_state_dict=state_dict_shard, - checkpoint=checkpoint, - index_file=index_file, - base_filename=weights_name, - is_master=control_saving, - use_safetensors=use_safetensors) - if control_saving: - index_file.append_meta_data("total_size", total_size) - index_file.write_index_file(save_index_file) - logging.info(f"The model is split into checkpoint shards. " - f"You can find where each parameters has been saved in the " - f"index located at {save_index_file}.") - - else: - # When pipeline is used, each stage produces its own shard files and index files. - # Index files belonging to each stage are saved under a temporary folder ./tmp_index_files/ - # After all the state_dicts have been saved, the master rank integrates all the index files into one final index file and deletes the tmp folder. - - final_index_file_path = copy.deepcopy(save_index_file) - tmp_index_file_folder = os.path.join(checkpoint, "tmp_index_files") - Path(tmp_index_file_folder).mkdir(parents=True, exist_ok=True) - - # Manage filenames of sharded weights and index file for each pipeline stage. - weights_name = weights_name.replace(".bin", f"-stage-{self.pp_rank:05d}-shard.bin") - weights_name = weights_name.replace(".safetensors", f"-stage-{self.pp_rank:05d}-shard.safetensors") - save_index_file = save_index_file.replace(".json", f"-stage-{self.pp_rank:05d}.json") - save_index_file = os.path.join("tmp_index_files", save_index_file) - - total_size = save_state_dict_shards(sharded_state_dict=state_dict_shard, - checkpoint=checkpoint, - index_file=index_file, - base_filename=weights_name, - is_master=control_saving, - use_safetensors=use_safetensors) - if control_saving: - assert self.dp_rank == 0 and self.tp_rank == 0, "The saving process should have both dp_rank and tp_rank as 0." - index_file.append_meta_data("total_size", total_size) - index_file.write_index_file(save_index_file) - else: - return - - dist.barrier(self.pp_group) - - # The global master rank integrates the index files and clean the folder. - if self.pp_rank == 0: - final_index_file = CheckpointIndexFile(checkpoint) - final_index_file.append_meta_data("total_size", 0) - - for filename in os.listdir(tmp_index_file_folder): - stage_index_file = CheckpointIndexFile.from_file(os.path.join(tmp_index_file_folder, filename)) - final_index_file.metadata["total_size"] += stage_index_file.metadata["total_size"] - for weight, weight_filename in stage_index_file.weight_map.items(): - final_index_file.append_weight_map(weight, weight_filename) - - final_index_file.write_index_file(final_index_file_path) - rmtree(tmp_index_file_folder) - logging.info(f"The model is split into checkpoint shards. " - f"You can find where each parameters has been saved in the " - f"index located at {final_index_file_path}.") - - def load_sharded_model(self, model: nn.Module, checkpoint_index_file: Path, strict: bool = False): - """ - Load sharded model with the given path to index file of checkpoint folder. - - Args: - model (nn.Module): The model to be loaded. - index_file_path (str): Path to the index file of checkpointing folder. - strict (bool, optional): For name matching during loading state_dict. Defaults to False. - This argument should be manually set to False since params on same device might be stored in different files. - """ - - # Check whether the checkpoint uses safetensors. - use_safetensors = False - if "safetensors" in checkpoint_index_file.name: - use_safetensors = True - - if use_safetensors and not is_safetensors_available(): - raise ImportError("`safe_serialization` requires the `safetensors` library: `pip install safetensors`.") - - # Read checkpoint index file. - ckpt_index_file = CheckpointIndexFile.from_file(checkpoint_index_file) - ckpt_root_path = ckpt_index_file.root_path - weight_map = ckpt_index_file.weight_map - strict = False - - # Load params & buffers to model. - # Keep a record of loaded files so that file will not be repeatedly loaded. - loaded_file = set() - - def _load(name: str): - if name not in weight_map: - raise ValueError(f"{name} is not stored in checkpoint, please check your checkpointing configuration!") - filename = weight_map[name] - - # If this param/buffer has been loaded before, directly return. - if filename in loaded_file: - return - - file_path = os.path.join(ckpt_root_path, filename) - state_dict = load_shard_state_dict(Path(file_path), use_safetensors) - missing_keys = [] - - load_state_dict_into_model(model, - state_dict, - missing_keys=missing_keys, - strict=strict, - load_sub_module=True) - loaded_file.add(filename) - - # Load parameters. - for name, _ in model.named_parameters(): - _load(name) - - # Load buffers. - for name, buf in model.named_buffers(): - if buf is not None and name not in model._non_persistent_buffers_set: - _load(name) - - # Load extra states. - extra_state_key = _EXTRA_STATE_KEY_SUFFIX - if getattr(model.__class__, "get_extra_state", - torch.nn.Module.get_extra_state) is not torch.nn.Module.get_extra_state: - _load(extra_state_key) - - def save_sharded_optimizer(self, - optimizer: Optimizer, - checkpoint: str, - gather_dtensor: bool = True, - prefix: Optional[str] = None, - size_per_shard: int = 1024): - pass - - def load_sharded_optimizer(self, optimizer: Optimizer, index_file_path: str, prefix: str): - pass - - def load_unsharded_model(self, model: nn.Module, checkpoint: str, strict: bool = True): - # TODO(Baizhou): support this feature after implementing complete state_dict collection - raise NotImplementedError - - def save_unsharded_model(self, model: nn.Module, checkpoint: str, gather_dtensor: bool, use_safetensors: bool): - # TODO(Baizhou): support this feature after implementing complete state_dict collection - raise NotImplementedError - - def save_unsharded_optimizer(self, optimizer: Optimizer, checkpoint: str, gather_dtensor: bool): - # TODO(Baizhou): support this feature after implementing complete state_dict collection - raise NotImplementedError - - def load_unsharded_optimizer(self, optimizer: Optimizer, checkpoint: str, gather_dtensor: bool): - # TODO(Baizhou): support this feature after implementing complete state_dict collection - raise NotImplementedError - - def save_lr_scheduler(self, lr_scheduler: LRScheduler, checkpoint: str): - """ - Save lr scheduler to checkpoint but only on master process. - """ - if self.coordinator.is_master(): - super().save_lr_scheduler(lr_scheduler, checkpoint) - - -class _StateDictSharder: - - def __init__(self, size_per_shard: int) -> None: - self.max_shard_size = size_per_shard - self.current_block = OrderedDict() - self.current_block_size = 0 - - def append(self, name: str, tensor: torch.Tensor) -> Tuple[Optional[OrderedDict], int]: - tensor_size = calculate_tensor_size(tensor) - ret_block = None - ret_block_size = 0 - - # before we return the current block and create a new block, - # we need to ensure that the current block is not empty - if self.current_block_size + tensor_size > self.max_shard_size and self.current_block_size > 0: - ret_block = self.current_block - ret_block_size = self.current_block_size - self.current_block = OrderedDict() - self.current_block_size = 0 - - self.current_block[name] = tensor - self.current_block_size += tensor_size - return ret_block, ret_block_size diff --git a/colossalai/checkpoint_io/__init__.py b/colossalai/checkpoint_io/__init__.py index c25048e25754..07b1f81dace6 100644 --- a/colossalai/checkpoint_io/__init__.py +++ b/colossalai/checkpoint_io/__init__.py @@ -1,5 +1,6 @@ from .checkpoint_io_base import CheckpointIO from .general_checkpoint_io import GeneralCheckpointIO +from .hybrid_parallel_checkpoint_io import HypridParallelCheckpointIO from .index_file import CheckpointIndexFile -__all__ = ['CheckpointIO', 'CheckpointIndexFile', 'GeneralCheckpointIO'] +__all__ = ['CheckpointIO', 'CheckpointIndexFile', 'GeneralCheckpointIO', 'HybridParallelCheckpointIO'] diff --git a/colossalai/checkpoint_io/hybrid_parallel_checkpoint_io.py b/colossalai/checkpoint_io/hybrid_parallel_checkpoint_io.py new file mode 100644 index 000000000000..28d4a89b82a6 --- /dev/null +++ b/colossalai/checkpoint_io/hybrid_parallel_checkpoint_io.py @@ -0,0 +1,338 @@ +import copy +import logging +import os +from pathlib import Path +from shutil import rmtree +from typing import Any, Callable, Iterator, List, Optional, OrderedDict, Tuple, Union + +import torch +import torch.distributed as dist +import torch.nn as nn +from torch.optim import Optimizer +from torch.optim.lr_scheduler import _LRScheduler as LRScheduler + +from colossalai.cluster import ProcessGroupMesh +from colossalai.tensor.d_tensor import ( + is_customized_distributed_tensor, + is_distributed_tensor, + to_global, + to_global_for_customized_distributed_tensor, +) + +from . import CheckpointIndexFile, GeneralCheckpointIO +from .utils import ( + calculate_tensor_size, + get_model_base_filenames, + get_optimizer_base_filenames, + get_shard_filename, + is_safetensors_available, + load_shard_state_dict, + load_state_dict_into_model, + save_param_groups, + save_state_dict, + save_state_dict_shards, +) + +try: + from torch.nn.modules.module import _EXTRA_STATE_KEY_SUFFIX, _IncompatibleKeys +except ImportError: + _EXTRA_STATE_KEY_SUFFIX = '_extra_state' + +DP_AXIS, PP_AXIS, TP_AXIS = 0, 1, 2 + + +class HypridParallelCheckpointIO(GeneralCheckpointIO): + """ + CheckpointIO for Hybrid Parallel Training. + + Args: + pg_mesh (ProcessGroupMesh): Process group mesh containing information of process groups along different dimensions. + """ + + def __init__(self, pg_mesh: ProcessGroupMesh) -> None: + super().__init__() + self.dp_group = pg_mesh.get_group_along_axis(DP_AXIS) + self.pp_group = pg_mesh.get_group_along_axis(PP_AXIS) + self.tp_group = pg_mesh.get_group_along_axis(TP_AXIS) + self.dp_rank = dist.get_rank(self.dp_group) + self.tp_rank = dist.get_rank(self.tp_group) + self.pp_rank = dist.get_rank(self.pp_group) + self.dp_size, self.pp_size, self.tp_size = pg_mesh.shape + + @staticmethod + def _model_sharder(model: nn.Module, + prefix: str = '', + keep_vars: bool = False, + size_per_shard: int = 1024) -> Iterator[Tuple[OrderedDict, int]]: + # An internel method that breaks state_dict of model into shards within limited size. + + state_dict_sharder = _StateDictSharder(size_per_shard) + + # Save parameters. + for name, param in model.named_parameters(): + if param is None: + continue + # Gather tensor pieces when using tensor parallel. + param_ = param if keep_vars else param.detach() + if is_distributed_tensor(param_): + param_ = to_global(param_) + elif is_customized_distributed_tensor(param_): + param_ = to_global_for_customized_distributed_tensor(param_) + + block, block_size = state_dict_sharder.append(prefix + name, param_) + if block is not None: + yield block, block_size + + # Save buffers. + for name, buf in model.named_buffers(): + if buf is not None and name not in model._non_persistent_buffers_set: + buffer = buf if keep_vars else buf.detach() + block, block_size = state_dict_sharder.append(prefix + name, buffer) + if block is not None: + yield block, block_size + + # Save extra states. + extra_state_key = prefix + _EXTRA_STATE_KEY_SUFFIX + if getattr(model.__class__, "get_extra_state", + torch.nn.Module.get_extra_state) is not torch.nn.Module.get_extra_state: + extra_state = model.get_extra_state() + block, block_size = state_dict_sharder.append(extra_state_key, extra_state) + if block is not None: + yield block, block_size + + # Return the last block in sharder. + yield state_dict_sharder.current_block, state_dict_sharder.current_block_size + + @staticmethod + def _optimizer_sharder(optimizer: Optimizer, size_per_shard: int = 1024): + # An internel method that breaks state_dict of optimizer into shards within limited size. + # TODO (Baizhou): Implement sharding feature of optimizer. + pass + + def save_sharded_model(self, + model: nn.Module, + checkpoint: str, + gather_dtensor: bool = True, + prefix: Optional[str] = None, + size_per_shard: int = 1024, + use_safetensors: bool = False) -> None: + """ + Save sharded model checkpoint under the given checkpointing path. + The following files will be created under the path: + - An index file (pytorch_optim.bin.index.json) containing a map between model params/buffers and file names. + - Multiple files that store state tensors of models. + If pipeline parallelism is used, the filenames are in the form of "pytorch_model.-stage-000XX-shard-000XX.bin". + If pipeline parallelism is not used, "pytorch_model.-000XX.bin" + + + Args: + model (nn.Module): Model on local device to be saved. + checkpoint_path (str): Checkpointing path which should be a directory path. + gather_dtensor (bool, optional): Whether to gather_dtensor, currently not used. Defaults to True. + prefix (str, optional): Perfix of file to save. Defaults to None. + size_per_shard (int, optional): Size per shard in MB. Defaults to 1024. + use_safetensors (bool, optional): Whether to use safe tensors. Defaults to False. + """ + + if os.path.isfile(checkpoint): + logging.error(f"Provided path ({checkpoint}) should be a directory, not a file") + return + + Path(checkpoint).mkdir(parents=True, exist_ok=True) + + # Devices along the same dp_group share the same copies of model. + # So only let the device with dp_rank == 0 save the model. + if self.dp_rank != 0: + return + + # Then collect the sharded parameters & buffers along tp_group. + # Only devices with tp_size == 0 are responsible for model saving. + state_dict_shard = HypridParallelCheckpointIO._model_sharder(model, size_per_shard=size_per_shard) + weights_name, save_index_file = get_model_base_filenames(prefix, use_safetensors) + index_file = CheckpointIndexFile(checkpoint) + control_saving = (self.tp_rank == 0) + + if self.pp_size == 1: + # When pipeline is not used, save the model shards as in general checkpointIO + total_size = save_state_dict_shards(sharded_state_dict=state_dict_shard, + checkpoint=checkpoint, + index_file=index_file, + base_filename=weights_name, + is_master=control_saving, + use_safetensors=use_safetensors) + if control_saving: + index_file.append_meta_data("total_size", total_size) + index_file.write_index_file(save_index_file) + logging.info(f"The model is split into checkpoint shards. " + f"You can find where each parameters has been saved in the " + f"index located at {save_index_file}.") + + else: + # When pipeline is used, each stage produces its own shard files and index files. + # Index files belonging to each stage are saved under a temporary folder ./tmp_index_files/ + # After all the state_dicts have been saved, the master rank integrates all the index files into one final index file and deletes the tmp folder. + + final_index_file_path = copy.deepcopy(save_index_file) + tmp_index_file_folder = os.path.join(checkpoint, "tmp_index_files") + Path(tmp_index_file_folder).mkdir(parents=True, exist_ok=True) + + # Manage filenames of sharded weights and index file for each pipeline stage. + weights_name = weights_name.replace(".bin", f"-stage-{self.pp_rank:05d}-shard.bin") + weights_name = weights_name.replace(".safetensors", f"-stage-{self.pp_rank:05d}-shard.safetensors") + save_index_file = save_index_file.replace(".json", f"-stage-{self.pp_rank:05d}.json") + save_index_file = os.path.join("tmp_index_files", save_index_file) + + total_size = save_state_dict_shards(sharded_state_dict=state_dict_shard, + checkpoint=checkpoint, + index_file=index_file, + base_filename=weights_name, + is_master=control_saving, + use_safetensors=use_safetensors) + if control_saving: + assert self.dp_rank == 0 and self.tp_rank == 0, "The saving process should have both dp_rank and tp_rank as 0." + index_file.append_meta_data("total_size", total_size) + index_file.write_index_file(save_index_file) + else: + return + + dist.barrier(self.pp_group) + + # The global master rank integrates the index files and clean the folder. + if self.pp_rank == 0: + final_index_file = CheckpointIndexFile(checkpoint) + final_index_file.append_meta_data("total_size", 0) + + for filename in os.listdir(tmp_index_file_folder): + stage_index_file = CheckpointIndexFile.from_file(os.path.join(tmp_index_file_folder, filename)) + final_index_file.metadata["total_size"] += stage_index_file.metadata["total_size"] + for weight, weight_filename in stage_index_file.weight_map.items(): + final_index_file.append_weight_map(weight, weight_filename) + + final_index_file.write_index_file(final_index_file_path) + rmtree(tmp_index_file_folder) + logging.info(f"The model is split into checkpoint shards. " + f"You can find where each parameters has been saved in the " + f"index located at {final_index_file_path}.") + + def load_sharded_model(self, model: nn.Module, checkpoint_index_file: Path, strict: bool = False): + """ + Load sharded model with the given path to index file of checkpoint folder. + + Args: + model (nn.Module): The model to be loaded. + index_file_path (str): Path to the index file of checkpointing folder. + strict (bool, optional): For name matching during loading state_dict. Defaults to False. + This argument should be manually set to False since params on same device might be stored in different files. + """ + + # Check whether the checkpoint uses safetensors. + use_safetensors = False + if "safetensors" in checkpoint_index_file.name: + use_safetensors = True + + if use_safetensors and not is_safetensors_available(): + raise ImportError("`safe_serialization` requires the `safetensors` library: `pip install safetensors`.") + + # Read checkpoint index file. + ckpt_index_file = CheckpointIndexFile.from_file(checkpoint_index_file) + ckpt_root_path = ckpt_index_file.root_path + weight_map = ckpt_index_file.weight_map + strict = False + + # Load params & buffers to model. + # Keep a record of loaded files so that file will not be repeatedly loaded. + loaded_file = set() + + def _load(name: str): + if name not in weight_map: + raise ValueError(f"{name} is not stored in checkpoint, please check your checkpointing configuration!") + filename = weight_map[name] + + # If this param/buffer has been loaded before, directly return. + if filename in loaded_file: + return + + file_path = os.path.join(ckpt_root_path, filename) + state_dict = load_shard_state_dict(Path(file_path), use_safetensors) + missing_keys = [] + + load_state_dict_into_model(model, + state_dict, + missing_keys=missing_keys, + strict=strict, + load_sub_module=True) + loaded_file.add(filename) + + # Load parameters. + for name, _ in model.named_parameters(): + _load(name) + + # Load buffers. + for name, buf in model.named_buffers(): + if buf is not None and name not in model._non_persistent_buffers_set: + _load(name) + + # Load extra states. + extra_state_key = _EXTRA_STATE_KEY_SUFFIX + if getattr(model.__class__, "get_extra_state", + torch.nn.Module.get_extra_state) is not torch.nn.Module.get_extra_state: + _load(extra_state_key) + + def save_sharded_optimizer(self, + optimizer: Optimizer, + checkpoint: str, + gather_dtensor: bool = True, + prefix: Optional[str] = None, + size_per_shard: int = 1024): + pass + + def load_sharded_optimizer(self, optimizer: Optimizer, index_file_path: str, prefix: str): + pass + + def load_unsharded_model(self, model: nn.Module, checkpoint: str, strict: bool = True): + # TODO(Baizhou): support this feature after implementing complete state_dict collection + raise NotImplementedError + + def save_unsharded_model(self, model: nn.Module, checkpoint: str, gather_dtensor: bool, use_safetensors: bool): + # TODO(Baizhou): support this feature after implementing complete state_dict collection + raise NotImplementedError + + def save_unsharded_optimizer(self, optimizer: Optimizer, checkpoint: str, gather_dtensor: bool): + # TODO(Baizhou): support this feature after implementing complete state_dict collection + raise NotImplementedError + + def load_unsharded_optimizer(self, optimizer: Optimizer, checkpoint: str, gather_dtensor: bool): + # TODO(Baizhou): support this feature after implementing complete state_dict collection + raise NotImplementedError + + def save_lr_scheduler(self, lr_scheduler: LRScheduler, checkpoint: str): + """ + Save lr scheduler to checkpoint but only on master process. + """ + if self.coordinator.is_master(): + super().save_lr_scheduler(lr_scheduler, checkpoint) + + +class _StateDictSharder: + + def __init__(self, size_per_shard: int) -> None: + self.max_shard_size = size_per_shard + self.current_block = OrderedDict() + self.current_block_size = 0 + + def append(self, name: str, tensor: torch.Tensor) -> Tuple[Optional[OrderedDict], int]: + tensor_size = calculate_tensor_size(tensor) + ret_block = None + ret_block_size = 0 + + # before we return the current block and create a new block, + # we need to ensure that the current block is not empty + if self.current_block_size + tensor_size > self.max_shard_size and self.current_block_size > 0: + ret_block = self.current_block + ret_block_size = self.current_block_size + self.current_block = OrderedDict() + self.current_block_size = 0 + + self.current_block[name] = tensor + self.current_block_size += tensor_size + return ret_block, ret_block_size From 1e3d73e746acb5d0501e288acdf3a6ffaae1d5d5 Mon Sep 17 00:00:00 2001 From: Baizhou Zhang Date: Thu, 24 Aug 2023 14:47:38 +0800 Subject: [PATCH 07/11] small fix --- colossalai/booster/plugin/hybrid_parallel_plugin.py | 1 - 1 file changed, 1 deletion(-) diff --git a/colossalai/booster/plugin/hybrid_parallel_plugin.py b/colossalai/booster/plugin/hybrid_parallel_plugin.py index 2e591b176ab3..90e1d3efc067 100644 --- a/colossalai/booster/plugin/hybrid_parallel_plugin.py +++ b/colossalai/booster/plugin/hybrid_parallel_plugin.py @@ -292,7 +292,6 @@ def __init__(self, self.schedule = OneForwardOneBackwardSchedule(num_microbatches, self.stage_manager) self.tp_group = self.pg_mesh.get_group_along_axis(TP_AXIS) self.dp_group = self.pg_mesh.get_group_along_axis(DP_AXIS) - self.pp_group = self.pg_mesh.get_group_along_axis(PP_AXIS) self.shard_config = ShardConfig(tensor_parallel_process_group=self.tp_group, pipeline_stage_manager=self.stage_manager, enable_tensor_parallelism=self.tp_size > 1, From 054d9300b634d08d4f6621440eb00f0ec5f46011 Mon Sep 17 00:00:00 2001 From: Baizhou Zhang Date: Thu, 24 Aug 2023 15:19:22 +0800 Subject: [PATCH 08/11] fix circular importing --- colossalai/checkpoint_io/hybrid_parallel_checkpoint_io.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/colossalai/checkpoint_io/hybrid_parallel_checkpoint_io.py b/colossalai/checkpoint_io/hybrid_parallel_checkpoint_io.py index 28d4a89b82a6..116b5bff939a 100644 --- a/colossalai/checkpoint_io/hybrid_parallel_checkpoint_io.py +++ b/colossalai/checkpoint_io/hybrid_parallel_checkpoint_io.py @@ -19,7 +19,8 @@ to_global_for_customized_distributed_tensor, ) -from . import CheckpointIndexFile, GeneralCheckpointIO +from .general_checkpoint_io import GeneralCheckpointIO +from .index_file import CheckpointIndexFile from .utils import ( calculate_tensor_size, get_model_base_filenames, From bcb86898ed1c3f2db5f49c90e3a35e2e63dce8e4 Mon Sep 17 00:00:00 2001 From: Baizhou Zhang Date: Thu, 24 Aug 2023 15:51:43 +0800 Subject: [PATCH 09/11] fix docstring --- colossalai/checkpoint_io/hybrid_parallel_checkpoint_io.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/colossalai/checkpoint_io/hybrid_parallel_checkpoint_io.py b/colossalai/checkpoint_io/hybrid_parallel_checkpoint_io.py index 116b5bff939a..1028ae612f95 100644 --- a/colossalai/checkpoint_io/hybrid_parallel_checkpoint_io.py +++ b/colossalai/checkpoint_io/hybrid_parallel_checkpoint_io.py @@ -120,7 +120,7 @@ def save_sharded_model(self, """ Save sharded model checkpoint under the given checkpointing path. The following files will be created under the path: - - An index file (pytorch_optim.bin.index.json) containing a map between model params/buffers and file names. + - An index file (pytorch_model.bin.index.json) containing a map between model params/buffers and file names. - Multiple files that store state tensors of models. If pipeline parallelism is used, the filenames are in the form of "pytorch_model.-stage-000XX-shard-000XX.bin". If pipeline parallelism is not used, "pytorch_model.-000XX.bin" @@ -128,7 +128,7 @@ def save_sharded_model(self, Args: model (nn.Module): Model on local device to be saved. - checkpoint_path (str): Checkpointing path which should be a directory path. + checkpoint (str): Checkpointing path which should be a directory path. gather_dtensor (bool, optional): Whether to gather_dtensor, currently not used. Defaults to True. prefix (str, optional): Perfix of file to save. Defaults to None. size_per_shard (int, optional): Size per shard in MB. Defaults to 1024. From 8a774c25d9c298d534ec45989082edfcb56e0a18 Mon Sep 17 00:00:00 2001 From: Baizhou Zhang Date: Fri, 25 Aug 2023 14:36:25 +0800 Subject: [PATCH 10/11] arrange arguments and apis --- .../booster/plugin/hybrid_parallel_plugin.py | 3 +- .../hybrid_parallel_checkpoint_io.py | 55 ++++++------------- colossalai/checkpoint_io/utils.py | 55 ++++++++++++++++++- .../shardformer/layer/parallel_module.py | 9 +-- colossalai/zero/gemini/gemini_ddp.py | 28 +--------- 5 files changed, 76 insertions(+), 74 deletions(-) diff --git a/colossalai/booster/plugin/hybrid_parallel_plugin.py b/colossalai/booster/plugin/hybrid_parallel_plugin.py index 90e1d3efc067..c49b3e1823cd 100644 --- a/colossalai/booster/plugin/hybrid_parallel_plugin.py +++ b/colossalai/booster/plugin/hybrid_parallel_plugin.py @@ -292,6 +292,7 @@ def __init__(self, self.schedule = OneForwardOneBackwardSchedule(num_microbatches, self.stage_manager) self.tp_group = self.pg_mesh.get_group_along_axis(TP_AXIS) self.dp_group = self.pg_mesh.get_group_along_axis(DP_AXIS) + self.pp_group = self.pg_mesh.get_group_along_axis(PP_AXIS) self.shard_config = ShardConfig(tensor_parallel_process_group=self.tp_group, pipeline_stage_manager=self.stage_manager, enable_tensor_parallelism=self.tp_size > 1, @@ -460,7 +461,7 @@ def seed_worker(worker_id): **_kwargs) def get_checkpoint_io(self) -> CheckpointIO: - return HypridParallelCheckpointIO(self.pg_mesh) + return HypridParallelCheckpointIO(self.dp_group, self.pp_group, self.tp_group) def no_sync(self, model: Module) -> Iterator[None]: raise NotImplementedError diff --git a/colossalai/checkpoint_io/hybrid_parallel_checkpoint_io.py b/colossalai/checkpoint_io/hybrid_parallel_checkpoint_io.py index 1028ae612f95..1263b43bf652 100644 --- a/colossalai/checkpoint_io/hybrid_parallel_checkpoint_io.py +++ b/colossalai/checkpoint_io/hybrid_parallel_checkpoint_io.py @@ -1,4 +1,5 @@ import copy +import gc import logging import os from pathlib import Path @@ -8,6 +9,7 @@ import torch import torch.distributed as dist import torch.nn as nn +from torch.distributed import ProcessGroup from torch.optim import Optimizer from torch.optim.lr_scheduler import _LRScheduler as LRScheduler @@ -22,7 +24,9 @@ from .general_checkpoint_io import GeneralCheckpointIO from .index_file import CheckpointIndexFile from .utils import ( + StateDictSharder, calculate_tensor_size, + gather_distributed_param, get_model_base_filenames, get_optimizer_base_filenames, get_shard_filename, @@ -47,18 +51,22 @@ class HypridParallelCheckpointIO(GeneralCheckpointIO): CheckpointIO for Hybrid Parallel Training. Args: - pg_mesh (ProcessGroupMesh): Process group mesh containing information of process groups along different dimensions. + dp_group (ProcessGroup): Process group along data parallel dimension. + pp_group (ProcessGroup): Process group along pipeline parallel dimension. + tp_group (ProcessGroup): Process group along tensor parallel dimension. """ - def __init__(self, pg_mesh: ProcessGroupMesh) -> None: + def __init__(self, dp_group: ProcessGroup, pp_group: ProcessGroup, tp_group: ProcessGroup) -> None: super().__init__() - self.dp_group = pg_mesh.get_group_along_axis(DP_AXIS) - self.pp_group = pg_mesh.get_group_along_axis(PP_AXIS) - self.tp_group = pg_mesh.get_group_along_axis(TP_AXIS) + self.dp_group = dp_group + self.pp_group = pp_group + self.tp_group = tp_group self.dp_rank = dist.get_rank(self.dp_group) self.tp_rank = dist.get_rank(self.tp_group) self.pp_rank = dist.get_rank(self.pp_group) - self.dp_size, self.pp_size, self.tp_size = pg_mesh.shape + self.dp_size = dist.get_world_size(dp_group) + self.pp_size = dist.get_world_size(pp_group) + self.tp_size = dist.get_world_size(tp_group) @staticmethod def _model_sharder(model: nn.Module, @@ -67,19 +75,14 @@ def _model_sharder(model: nn.Module, size_per_shard: int = 1024) -> Iterator[Tuple[OrderedDict, int]]: # An internel method that breaks state_dict of model into shards within limited size. - state_dict_sharder = _StateDictSharder(size_per_shard) + state_dict_sharder = StateDictSharder(size_per_shard) # Save parameters. for name, param in model.named_parameters(): if param is None: continue # Gather tensor pieces when using tensor parallel. - param_ = param if keep_vars else param.detach() - if is_distributed_tensor(param_): - param_ = to_global(param_) - elif is_customized_distributed_tensor(param_): - param_ = to_global_for_customized_distributed_tensor(param_) - + param_ = gather_distributed_param(param, keep_vars=False) block, block_size = state_dict_sharder.append(prefix + name, param_) if block is not None: yield block, block_size @@ -262,6 +265,7 @@ def _load(name: str): missing_keys=missing_keys, strict=strict, load_sub_module=True) + del state_dict loaded_file.add(filename) # Load parameters. @@ -312,28 +316,3 @@ def save_lr_scheduler(self, lr_scheduler: LRScheduler, checkpoint: str): """ if self.coordinator.is_master(): super().save_lr_scheduler(lr_scheduler, checkpoint) - - -class _StateDictSharder: - - def __init__(self, size_per_shard: int) -> None: - self.max_shard_size = size_per_shard - self.current_block = OrderedDict() - self.current_block_size = 0 - - def append(self, name: str, tensor: torch.Tensor) -> Tuple[Optional[OrderedDict], int]: - tensor_size = calculate_tensor_size(tensor) - ret_block = None - ret_block_size = 0 - - # before we return the current block and create a new block, - # we need to ensure that the current block is not empty - if self.current_block_size + tensor_size > self.max_shard_size and self.current_block_size > 0: - ret_block = self.current_block - ret_block_size = self.current_block_size - self.current_block = OrderedDict() - self.current_block_size = 0 - - self.current_block[name] = tensor - self.current_block_size += tensor_size - return ret_block, ret_block_size diff --git a/colossalai/checkpoint_io/utils.py b/colossalai/checkpoint_io/utils.py index 8837776aee4d..ba84c699d009 100644 --- a/colossalai/checkpoint_io/utils.py +++ b/colossalai/checkpoint_io/utils.py @@ -13,7 +13,12 @@ from colossalai.interface import OptimizerWrapper from colossalai.nn.optimizer import ColossalaiOptimizer -from colossalai.tensor.d_tensor import is_distributed_tensor +from colossalai.tensor.d_tensor import ( + is_customized_distributed_tensor, + is_distributed_tensor, + to_global, + to_global_for_customized_distributed_tensor, +) SAFE_WEIGHTS_NAME = "model.safetensors" WEIGHTS_NAME = "pytorch_model.bin" @@ -88,8 +93,28 @@ def is_safetensor_checkpoint(checkpoint_file_path: str) -> bool: return False +def gather_distributed_param(param: torch.Tensor, keep_vars: bool = False): + """ + Gather the complete parameter for saving if passed in param is distributed. + + Args: + param (torch.Tensor): A model parameter, might be d_tensor. + keep_vars (bool, optional): Whether to return the parameter in calculation graph. Defaults to False. + + Returns: + torch.Tensor: the complete parameter + """ + param_ = param if keep_vars else param.detach() + if is_distributed_tensor(param_): + return to_global(param_) + elif is_customized_distributed_tensor(param_): + return to_global_for_customized_distributed_tensor(param_) + else: + return param_ + + # ====================================== -# Helper functions for saving shard file +# Helper classes and functions for saving shard file # ====================================== def unwrap_optimizer(optimizer: OptimizerWrapper): ''' @@ -104,6 +129,31 @@ def unwrap_optimizer(optimizer: OptimizerWrapper): return unwrapped_optim +class StateDictSharder: + + def __init__(self, size_per_shard: int) -> None: + self.max_shard_size = size_per_shard + self.current_block = OrderedDict() + self.current_block_size = 0 + + def append(self, name: str, tensor: torch.Tensor) -> Tuple[Optional[OrderedDict], int]: + tensor_size = calculate_tensor_size(tensor) + ret_block = None + ret_block_size = 0 + + # before we return the current block and create a new block, + # we need to ensure that the current block is not empty + if self.current_block_size + tensor_size > self.max_shard_size and self.current_block_size > 0: + ret_block = self.current_block + ret_block_size = self.current_block_size + self.current_block = OrderedDict() + self.current_block_size = 0 + + self.current_block[name] = tensor + self.current_block_size += tensor_size + return ret_block, ret_block_size + + def save_state_dict_shards(sharded_state_dict: Iterator[Tuple[OrderedDict, int]], checkpoint: str, index_file: "CheckpointIndexFile", @@ -137,6 +187,7 @@ def save_state_dict_shards(sharded_state_dict: Iterator[Tuple[OrderedDict, int]] # Only save on master rank. save_state_dict(shard, checkpoint_file_path, use_safetensors=use_safetensors) + del shard return total_size diff --git a/colossalai/shardformer/layer/parallel_module.py b/colossalai/shardformer/layer/parallel_module.py index bda147b121ab..4f391920e29b 100644 --- a/colossalai/shardformer/layer/parallel_module.py +++ b/colossalai/shardformer/layer/parallel_module.py @@ -10,6 +10,7 @@ from torch.distributed import ProcessGroup from torch.nn.modules.module import _EXTRA_STATE_KEY_SUFFIX, Module +from colossalai.checkpoint_io.utils import gather_distributed_param from colossalai.tensor.d_tensor import ( distribute_tensor, distribute_tensor_with_customization, @@ -56,13 +57,7 @@ def _save_to_state_dict(self, destination, prefix, keep_vars): """ for name, param in self._parameters.items(): if param is not None: - param_ = param if keep_vars else param.detach() - if is_distributed_tensor(param_): - destination[prefix + name] = to_global(param_) - elif is_customized_distributed_tensor(param_): - destination[prefix + name] = to_global_for_customized_distributed_tensor(param_) - else: - destination[prefix + name] = param_ + destination[prefix + name] = gather_distributed_param(param, keep_vars=keep_vars) for name, buf in self._buffers.items(): if buf is not None and name not in self._non_persistent_buffers_set: diff --git a/colossalai/zero/gemini/gemini_ddp.py b/colossalai/zero/gemini/gemini_ddp.py index 08384ee82d0b..5aff91f03153 100644 --- a/colossalai/zero/gemini/gemini_ddp.py +++ b/colossalai/zero/gemini/gemini_ddp.py @@ -8,7 +8,7 @@ import torch.distributed as dist import torch.nn as nn -from colossalai.checkpoint_io.utils import calculate_tensor_size +from colossalai.checkpoint_io.utils import StateDictSharder from colossalai.lazy import LazyTensor from colossalai.logging import get_dist_logger from colossalai.nn.parallel.data_parallel import ColoDDP, _cast_float, free_storage @@ -657,7 +657,7 @@ def state_dict_shard(self, Yields: Iterator[OrderedDict]: A generator of state dict shard """ - sharder = _StateDictSharder(max_shard_size) + sharder = StateDictSharder(max_shard_size) # get the mapping between copies and fp16 parameters fp16_to_fp32 = dict() @@ -705,30 +705,6 @@ def state_dict_shard(self, yield sharder.current_block, sharder.current_block_size -class _StateDictSharder: - - def __init__(self, max_shard_size: int) -> None: - self.max_shard_size = max_shard_size - self.current_block = OrderedDict() - self.current_block_size = 0 - - def append(self, name: str, tensor: torch.Tensor) -> Tuple[Optional[OrderedDict], int]: - tensor_size = calculate_tensor_size(tensor) - ret_block = None - ret_block_size = 0 - - # before we return the current block and create a new block, - # we need to ensure that the current block is not empty - if self.current_block_size + tensor_size > self.max_shard_size and self.current_block_size > 0: - ret_block = self.current_block - ret_block_size = self.current_block_size - self.current_block = OrderedDict() - self.current_block_size = 0 - self.current_block[name] = tensor - self.current_block_size += tensor_size - return ret_block, ret_block_size - - class GeminiDDP(ZeroDDP): def __init__(self, From 504c9f72f1787513ee8bba74f9d3d3c9ad97c03d Mon Sep 17 00:00:00 2001 From: Baizhou Zhang Date: Fri, 25 Aug 2023 20:42:23 +0800 Subject: [PATCH 11/11] small fix --- colossalai/checkpoint_io/hybrid_parallel_checkpoint_io.py | 2 -- colossalai/checkpoint_io/utils.py | 3 ++- 2 files changed, 2 insertions(+), 3 deletions(-) diff --git a/colossalai/checkpoint_io/hybrid_parallel_checkpoint_io.py b/colossalai/checkpoint_io/hybrid_parallel_checkpoint_io.py index 1263b43bf652..56a89bff75ca 100644 --- a/colossalai/checkpoint_io/hybrid_parallel_checkpoint_io.py +++ b/colossalai/checkpoint_io/hybrid_parallel_checkpoint_io.py @@ -43,8 +43,6 @@ except ImportError: _EXTRA_STATE_KEY_SUFFIX = '_extra_state' -DP_AXIS, PP_AXIS, TP_AXIS = 0, 1, 2 - class HypridParallelCheckpointIO(GeneralCheckpointIO): """ diff --git a/colossalai/checkpoint_io/utils.py b/colossalai/checkpoint_io/utils.py index ba84c699d009..d04159c54d5e 100644 --- a/colossalai/checkpoint_io/utils.py +++ b/colossalai/checkpoint_io/utils.py @@ -176,9 +176,10 @@ def save_state_dict_shards(sharded_state_dict: Iterator[Tuple[OrderedDict, int]] total_size = 0 for idx, shard_pair in enumerate(sharded_state_dict): + shard, current_size = shard_pair if not is_master: + del shard continue - shard, current_size = shard_pair shard_file = get_shard_filename(base_filename, idx) total_size = total_size + current_size for key in shard.keys():