From 606a8d472b6e2a1f0d364c36395c0099549790a2 Mon Sep 17 00:00:00 2001 From: Edenzzzz Date: Fri, 16 Aug 2024 06:28:33 +0000 Subject: [PATCH 1/5] use dist logger in plugins --- colossalai/booster/plugin/gemini_plugin.py | 17 +- .../booster/plugin/hybrid_parallel_plugin.py | 27 +- .../booster/plugin/low_level_zero_plugin.py | 33 +- .../plugin/moe_hybrid_parallel_plugin.py | 11 +- colossalai/booster/plugin/plugin/__init__.py | 23 + .../booster/plugin/plugin/dp_plugin_base.py | 76 + .../booster/plugin/plugin/gemini_plugin.py | 595 +++++++ .../plugin/plugin/hybrid_parallel_plugin.py | 1466 +++++++++++++++++ .../plugin/plugin/low_level_zero_plugin.py | 521 ++++++ .../plugin/moe_hybrid_parallel_plugin.py | 490 ++++++ .../booster/plugin/plugin/plugin_base.py | 90 + .../booster/plugin/plugin/pp_plugin_base.py | 22 + .../booster/plugin/plugin/torch_ddp_plugin.py | 257 +++ .../plugin/plugin/torch_fsdp_plugin.py | 372 +++++ .../booster/plugin/torch_fsdp_plugin.py | 15 +- 15 files changed, 3966 insertions(+), 49 deletions(-) create mode 100644 colossalai/booster/plugin/plugin/__init__.py create mode 100644 colossalai/booster/plugin/plugin/dp_plugin_base.py create mode 100644 colossalai/booster/plugin/plugin/gemini_plugin.py create mode 100644 colossalai/booster/plugin/plugin/hybrid_parallel_plugin.py create mode 100644 colossalai/booster/plugin/plugin/low_level_zero_plugin.py create mode 100644 colossalai/booster/plugin/plugin/moe_hybrid_parallel_plugin.py create mode 100644 colossalai/booster/plugin/plugin/plugin_base.py create mode 100644 colossalai/booster/plugin/plugin/pp_plugin_base.py create mode 100644 colossalai/booster/plugin/plugin/torch_ddp_plugin.py create mode 100644 colossalai/booster/plugin/plugin/torch_fsdp_plugin.py diff --git a/colossalai/booster/plugin/gemini_plugin.py b/colossalai/booster/plugin/gemini_plugin.py index ad131fbe739a..443c80831b14 100644 --- a/colossalai/booster/plugin/gemini_plugin.py +++ b/colossalai/booster/plugin/gemini_plugin.py @@ -1,5 +1,4 @@ import gc -import logging import os import random from pathlib import Path @@ -27,6 +26,7 @@ ) from colossalai.cluster import DistCoordinator, ProcessGroupMesh from colossalai.interface import ModelWrapper, OptimizerWrapper +from colossalai.logging import get_dist_logger from colossalai.shardformer import ShardConfig, ShardFormer from colossalai.zero import GeminiDDP, GeminiOptimizer from colossalai.zero.gemini.memory_tracer import MemStats @@ -63,6 +63,7 @@ class GeminiCheckpointIO(GeneralCheckpointIO): def __init__(self) -> None: super().__init__() self.coordinator = DistCoordinator() + self.logger = get_dist_logger() def save_unsharded_model(self, model: GeminiDDP, checkpoint: str, gather_dtensor: bool, use_safetensors: bool): """ @@ -118,7 +119,7 @@ def save_sharded_model( """ assert isinstance(model, GeminiDDP), "Please boost the model before saving!" if os.path.isfile(checkpoint_path): - logging.error(f"Provided path ({checkpoint_path}) should be a directory, not a file") + self.logger.error(f"Provided path ({checkpoint_path}) should be a directory, not a file") return Path(checkpoint_path).mkdir(parents=True, exist_ok=True) @@ -143,7 +144,7 @@ def save_sharded_model( index_file.append_meta_data("total_size", total_size) index_file.write_index_file(save_index_file) save_config_file(model.unwrap(), checkpoint_path) - logging.info( + self.logger.info( f"The model is split into checkpoint shards. " f"You can find where each parameters has been saved in the " f"index located at {save_index_file}." @@ -168,7 +169,7 @@ def save_sharded_optimizer( assert isinstance(optimizer, GeminiOptimizer), "Please boost the optimizer before saving!" if os.path.isfile(checkpoint): - logging.error(f"Provided path ({checkpoint}) should be a directory, not a file") + self.logger.error(f"Provided path ({checkpoint}) should be a directory, not a file") return Path(checkpoint).mkdir(parents=True, exist_ok=True) @@ -201,7 +202,7 @@ def save_sharded_optimizer( if self.coordinator.is_master(): index_file.append_meta_data("total_size", total_size) index_file.write_index_file(save_index_file) - logging.info( + self.logger.info( f"The optimizer is going to be split to checkpoint shards. " f"You can find where each parameters has been saved in the " f"index located at {save_index_file}." @@ -214,7 +215,7 @@ def load_sharded_optimizer(self, optimizer: GeminiOptimizer, checkpoint_index_fi """ assert isinstance(optimizer, GeminiOptimizer), "Please boost the optimizer before loading!" if not os.path.isfile(checkpoint_index_file): - logging.error(f"Provided path ({checkpoint_index_file}) should be a file") + self.logger.error(f"Provided path ({checkpoint_index_file}) should be a file") assert isinstance(optimizer, GeminiOptimizer) @@ -369,8 +370,10 @@ def __init__( assert precision in SUPPORTED_PRECISION, f"precision {precision} is not supported" if get_accelerator().name == "npu": assert placement_policy == "static", "NPU only supports static placement policy" + + self.logger = get_dist_logger() if enable_async_reduce and not pin_memory: - logging.warning( + self.logger.warning( f"enable_async_reduce sets pin_memory=True to achieve best performance, which is not implicitly set." ) pin_memory = True diff --git a/colossalai/booster/plugin/hybrid_parallel_plugin.py b/colossalai/booster/plugin/hybrid_parallel_plugin.py index 63427192f482..6c1515d38834 100644 --- a/colossalai/booster/plugin/hybrid_parallel_plugin.py +++ b/colossalai/booster/plugin/hybrid_parallel_plugin.py @@ -1,6 +1,5 @@ import ctypes import random -import warnings from collections import defaultdict from contextlib import contextmanager, nullcontext from copy import deepcopy @@ -27,6 +26,7 @@ from colossalai.cluster import ProcessGroupMesh from colossalai.interface import AMPModelMixin, ModelWrapper, OptimizerWrapper from colossalai.interface.optimizer import DistributedOptim +from colossalai.logging import get_dist_logger from colossalai.nn.optimizer import DistGaloreAwamW, cast_to_distributed from colossalai.pipeline.schedule import InterleavedSchedule, OneForwardOneBackwardSchedule from colossalai.pipeline.stage_manager import PipelineStageManager @@ -1023,6 +1023,7 @@ def __init__( inner_ring_size: int = None, ) -> None: super().__init__() + self.logger = get_dist_logger() assert ( dist.get_world_size() % (tp_size * pp_size) == 0 @@ -1040,7 +1041,7 @@ def __init__( tp_size > 1 ), f"Sequence parallelism mode {self.sequence_parallelism_mode} must be enabled when using tensor parallelism" if sp_size != 1: - warnings.warn( + self.logger.warning( f"The sp_size will be the same as tp_size in sequence parallelism mode {self.sequence_parallelism_mode}, will ignore the given sequence parallelism size." ) self.sp_size = 1 @@ -1126,7 +1127,11 @@ def __init__( else: raise NotImplementedError() if sequence_parallelism_mode == "ring_attn": - assert parallel_output, "Ring Attention doesn't support gathering output yet." + if not parallel_output: + self.logger.warning( + "parallel_output must be True for Zigzag Ring Attention, as we've not supported Zigzag all-gather yet." + ) + parallel_output = True self.tp_group = self.pg_mesh.get_group_along_axis(self.tp_axis) self.dp_group = self.pg_mesh.get_group_along_axis(self.dp_axis) @@ -1231,7 +1236,9 @@ def configure( optimizer = cast_to_distributed(optimizer) if isinstance(optimizer, DistGaloreAwamW) and zero_stage > 0 and self.dp_size > 0: - warnings.warn("Galore is only supported for Tensor Parallel and vanilla Data Parallel yet. Disabling ZeRO.") + self.logger.warning( + "Galore is only supported for Tensor Parallel and vanilla Data Parallel yet. Disabling ZeRO." + ) zero_config["partition_grad"] = False zero_stage = 0 @@ -1287,7 +1294,7 @@ def configure( else: is_zero = self.dp_size > 1 if self.dp_size == 1: - warnings.warn( + self.logger.warning( "Use Zero Optimizer when data parallel size is 1 may introduce unnecessary overhead. " "If you do not intend to use cpu_offload, please consider set zero_stage=0." ) @@ -1332,7 +1339,7 @@ def execute_pipeline( assert self.enable_pipeline_parallelism, "pipeline parallelism is not enabled" if return_outputs: - warnings.warn("return_outputs may lead to significant extra memory consumption.") + self.logger.warning("return_outputs may lead to significant extra memory consumption.") # Create a context for gradient synchronization based on the optimizer type. # If it's a HybridParallelZeroOptimizer, use optimizer.no_sync(); otherwise, use model.no_sync(). @@ -1346,10 +1353,8 @@ def execute_pipeline( ) # run with gradients accumulation - if ( - model.require_grad_sync == False - or (isinstance(optimizer, HybridParallelZeroOptimizer) and optimizer.require_grad_sync == False) - or not torch.is_grad_enabled() + if model.require_grad_sync == False or ( + isinstance(optimizer, HybridParallelZeroOptimizer) and optimizer.require_grad_sync == False ): return outputs @@ -1449,7 +1454,7 @@ def enable_lora( assert not isinstance(model, HybridParallelModule), "Lora should be enabled before boosting the model." assert self.pp_size == 1 and self.tp_size == 1 self.lora_enabled = True - warnings.warn("You have enabled LoRa training. Please check the hyperparameters such as lr") + self.logger.warning("You have enabled LoRa training. Please check the hyperparameters such as lr") if bnb_quantization_config is not None: model = quantize_model(model, bnb_quantization_config) diff --git a/colossalai/booster/plugin/low_level_zero_plugin.py b/colossalai/booster/plugin/low_level_zero_plugin.py index e4c386a2257d..6c36bad3c214 100644 --- a/colossalai/booster/plugin/low_level_zero_plugin.py +++ b/colossalai/booster/plugin/low_level_zero_plugin.py @@ -1,7 +1,5 @@ import enum -import logging import os -import warnings from contextlib import nullcontext from functools import partial from pathlib import Path @@ -33,6 +31,7 @@ ) from colossalai.interface import AMPModelMixin, ModelWrapper, OptimizerWrapper from colossalai.interface.optimizer import DistributedOptim +from colossalai.logging import get_dist_logger from colossalai.nn.optimizer import DistGaloreAwamW, cast_to_distributed from colossalai.quantization import BnbQuantizationConfig, quantize_model from colossalai.tensor.colo_parameter import ColoParameter @@ -62,9 +61,7 @@ class OptimizerParamCheckState(enum.Enum): class LowLevelZeroModel(ModelWrapper, AMPModelMixin): - def __init__( - self, module: nn.Module, precision: str, overlap_allgather: bool = False, cast_inputs: bool = True - ) -> None: + def __init__(self, module: nn.Module, precision: str, overlap_allgather: bool = False) -> None: super().__init__(module) self.dtype = None if precision == "fp16": @@ -76,7 +73,7 @@ def __init__( module = module.to(get_accelerator().get_current_device()) self.module = module self.convert_fn = None - if self.dtype is not None and cast_inputs: + if self.dtype is not None: self.convert_fn = partial(_convert_floating_point, dtype=self.dtype) self.overlap_allgather = overlap_allgather if overlap_allgather: @@ -115,6 +112,7 @@ def save_unsharded_optimizer(self, optimizer: OptimizerWrapper, checkpoint: str, state_dict = optimizer.state_dict() if self.coordinator.is_master(): save_state_dict(state_dict, checkpoint, use_safetensors=False) + self.logger = get_dist_logger() def save_sharded_optimizer( self, @@ -140,7 +138,7 @@ def save_sharded_optimizer( """ assert isinstance(optimizer, LowLevelZeroOptimizer), "Please boost the optimizer before saving!" if os.path.isfile(checkpoint): - logging.error(f"Provided path ({checkpoint}) should be a directory, not a file") + self.logger.error(f"Provided path ({checkpoint}) should be a directory, not a file") return Path(checkpoint).mkdir(parents=True, exist_ok=True) @@ -177,7 +175,7 @@ def save_sharded_optimizer( index_file.append_meta_data("total_size", total_size) if self.coordinator.is_master(): index_file.write_index_file(save_index_file) - logging.info( + self.logger.info( f"The optimizer is going to be split to checkpoint shards. " f"You can find where each parameters has been saved in the " f"index located at {save_index_file}." @@ -267,7 +265,7 @@ def save_sharded_model( def save_lora_as_pretrained(self, model, checkpoint, use_safetensors): if os.path.isfile(checkpoint): - logging.error(f"Provided path ({checkpoint}) should be a directory, not a file") + self.logger.error(f"Provided path ({checkpoint}) should be a directory, not a file") return from peft import PeftModel @@ -336,7 +334,6 @@ def __init__( cpu_offload: bool = False, master_weights: bool = True, verbose: bool = False, - cast_inputs: bool = True, ) -> None: super().__init__() assert stage in (1, 2), f"LowLevelZeroPlugin only supports stage 1/2 training" @@ -363,8 +360,7 @@ def __init__( ) self.lora_enabled = False self.verbose = verbose - self.cast_inputs = cast_inputs - + self.logger = get_dist_logger() # set class name with stage, for better error message setattr(self.__class__, "__name__", f"LowLevelZeroPlugin_ZeRO-{stage}") @@ -400,7 +396,7 @@ def enable_lora( assert not isinstance(model, LowLevelZeroModel), "Lora should be enabled before boosting the model." self.lora_enabled = True - warnings.warn("You have enabled LoRa training. Please check the hyperparameters such as lr") + self.logger.warning("You have enabled LoRa training. Please check the hyperparameters such as lr") if bnb_quantization_config is not None: model = quantize_model(model, bnb_quantization_config) @@ -449,7 +445,7 @@ def add_lora_params_to_optimizer(self, model, optimizer): origin_param = name2param[origin_key] group_id, check_state = self.get_param_group_id(optimizer, origin_param, param) if check_state == OptimizerParamCheckState.ORIGIN_PARAM_NOT_FIND: - warnings.warn( + self.logger.warning( f"Origin parameter {origin_key} related to {name} doesn't exist in optimizer param_groups." ) elif ( @@ -478,10 +474,7 @@ def configure( if not isinstance(model, ModelWrapper): model = LowLevelZeroModel( - model, - self.precision, - overlap_allgather=self.zero_optim_kwargs["overlap_allgather"], - cast_inputs=self.cast_inputs, + model, self.precision, overlap_allgather=self.zero_optim_kwargs["overlap_allgather"] ) # TODO: Support Galore + ZeRO @@ -493,7 +486,9 @@ def configure( optimizer = cast_to_distributed(optimizer) if isinstance(optimizer, DistGaloreAwamW) and zero_stage > 0 and dp_size > 0: - warnings.warn("Galore is only supported for Tensor Parallel and vanilla Data Parallel yet. Disabling ZeRO.") + self.logger.warning( + "Galore is only supported for Tensor Parallel and vanilla Data Parallel yet. Disabling ZeRO." + ) zero_optim_kwargs["partition_grad"] = False zero_stage = 0 diff --git a/colossalai/booster/plugin/moe_hybrid_parallel_plugin.py b/colossalai/booster/plugin/moe_hybrid_parallel_plugin.py index b3415af0eed6..874028f09b86 100644 --- a/colossalai/booster/plugin/moe_hybrid_parallel_plugin.py +++ b/colossalai/booster/plugin/moe_hybrid_parallel_plugin.py @@ -1,4 +1,3 @@ -import warnings from collections import defaultdict from types import MethodType from typing import Callable, List, Optional, OrderedDict, Tuple @@ -26,6 +25,7 @@ from colossalai.cluster.process_group_mesh import ProcessGroupMesh from colossalai.interface import ModelWrapper, OptimizerWrapper from colossalai.interface.optimizer import DistributedOptim +from colossalai.logging import get_dist_logger from colossalai.nn.optimizer import cast_to_distributed from colossalai.pipeline.schedule.interleaved_pp import InterleavedSchedule from colossalai.pipeline.schedule.one_f_one_b import OneForwardOneBackwardSchedule @@ -215,10 +215,11 @@ def __init__( overlap_p2p: bool = True, overlap_allgather: bool = False, ) -> None: + self.logger = get_dist_logger() if overlap_communication or zero_stage == 2: overlap_communication = False zero_stage = 1 - warnings.warn( + self.logger.warning( f"overlap_communication and zero_stage are set to False and 1 because " f"ZeRO-2 or comm overlap cause program hang when some experts are not routed. " ) @@ -238,7 +239,7 @@ def __init__( tp_size > 1 ), f"Sequence parallelism mode {self.sequence_parallelism_mode} must be enabled when using tensor parallelism" if sp_size != 1: - warnings.warn( + self.logger.warning( f"The sp_size will be the same as tp_size in sequence parallelism mode {self.sequence_parallelism_mode}, will ignore the given sequence parallelism size." ) self.sp_size = 1 @@ -400,7 +401,7 @@ def configure( and self.sequence_parallelism_mode == "all_to_all" ) if use_ddp: - warnings.warn( + self.logger.warning( f"Will have to check all params are used in pytorch DDP since not all experts are always activated" ) self.ddp_config["find_unused_parameters"] = True @@ -457,7 +458,7 @@ def configure( ) else: if self.dp_size <= 1: - warnings.warn( + self.logger.warning( "Use Zero Optimizer when data parallel size is 1 may introduce unnecessary overhead. " "If you do not intend to use cpu_offload, please consider set zero_stage=0." ) diff --git a/colossalai/booster/plugin/plugin/__init__.py b/colossalai/booster/plugin/plugin/__init__.py new file mode 100644 index 000000000000..7e0e6ffdd8e8 --- /dev/null +++ b/colossalai/booster/plugin/plugin/__init__.py @@ -0,0 +1,23 @@ +from .gemini_plugin import GeminiPlugin +from .hybrid_parallel_plugin import HybridParallelPlugin +from .low_level_zero_plugin import LowLevelZeroPlugin +from .moe_hybrid_parallel_plugin import MoeHybridParallelPlugin +from .plugin_base import Plugin +from .torch_ddp_plugin import TorchDDPPlugin + +__all__ = [ + "Plugin", + "TorchDDPPlugin", + "GeminiPlugin", + "LowLevelZeroPlugin", + "HybridParallelPlugin", + "MoeHybridParallelPlugin", +] + +import torch +from packaging import version + +if version.parse(torch.__version__) >= version.parse("1.12.0"): + from .torch_fsdp_plugin import TorchFSDPPlugin + + __all__.append("TorchFSDPPlugin") diff --git a/colossalai/booster/plugin/plugin/dp_plugin_base.py b/colossalai/booster/plugin/plugin/dp_plugin_base.py new file mode 100644 index 000000000000..27285f95ce52 --- /dev/null +++ b/colossalai/booster/plugin/plugin/dp_plugin_base.py @@ -0,0 +1,76 @@ +import random + +import numpy as np +import torch +import torch.distributed as dist +from torch.utils.data import DataLoader +from torch.utils.data.distributed import DistributedSampler + +from .plugin_base import Plugin + + +class DPPluginBase(Plugin): + """This is a base class for all DP plugins. It sets up world size and rank, and provides data loader creation.""" + + def __init__(self) -> None: + super().__init__() + assert ( + dist.is_initialized() + ), "torch.distributed is not initialized, please use colossalai.launch to create the distributed environment" + self.rank = dist.get_rank() + self.world_size = dist.get_world_size() + + def prepare_dataloader( + self, + dataset, + batch_size, + shuffle=False, + seed=1024, + drop_last=False, + pin_memory=False, + num_workers=0, + distributed_sampler_cls=None, + **kwargs, + ): + r""" + Prepare a dataloader for distributed training. The dataloader will be wrapped by + `torch.utils.data.DataLoader` and `torch.utils.data.DistributedSampler`. + + + Args: + dataset (`torch.utils.data.Dataset`): The dataset to be loaded. + shuffle (bool, optional): Whether to shuffle the dataset. Defaults to False. + seed (int, optional): Random worker seed for sampling, defaults to 1024. + add_sampler: Whether to add ``DistributedDataParallelSampler`` to the dataset. Defaults to True. + drop_last (bool, optional): Set to True to drop the last incomplete batch, if the dataset size + is not divisible by the batch size. If False and the size of dataset is not divisible by + the batch size, then the last batch will be smaller, defaults to False. + pin_memory (bool, optional): Whether to pin memory address in CPU memory. Defaults to False. + num_workers (int, optional): Number of worker threads for this dataloader. Defaults to 0. + kwargs (dict): optional parameters for ``torch.utils.data.DataLoader``, more details could be found in + `DataLoader `_. + + Returns: + :class:`torch.utils.data.DataLoader`: A DataLoader used for training or testing. + """ + _kwargs = kwargs.copy() + distributed_sampler_cls = distributed_sampler_cls or DistributedSampler + sampler = distributed_sampler_cls(dataset, num_replicas=self.world_size, rank=self.rank, shuffle=shuffle) + + # Deterministic dataloader + def seed_worker(worker_id): + worker_seed = seed + np.random.seed(worker_seed) + torch.manual_seed(worker_seed) + random.seed(worker_seed) + + return DataLoader( + dataset, + batch_size=batch_size, + sampler=sampler, + worker_init_fn=seed_worker, + drop_last=drop_last, + pin_memory=pin_memory, + num_workers=num_workers, + **_kwargs, + ) diff --git a/colossalai/booster/plugin/plugin/gemini_plugin.py b/colossalai/booster/plugin/plugin/gemini_plugin.py new file mode 100644 index 000000000000..443c80831b14 --- /dev/null +++ b/colossalai/booster/plugin/plugin/gemini_plugin.py @@ -0,0 +1,595 @@ +import gc +import os +import random +from pathlib import Path +from typing import Callable, Dict, Iterator, List, Optional, Tuple + +import numpy as np +import torch +import torch.distributed as dist +import torch.nn as nn +from torch.distributed.distributed_c10d import _get_default_group +from torch.optim import Optimizer +from torch.optim.lr_scheduler import _LRScheduler as LRScheduler +from torch.utils.data import DataLoader +from torch.utils.data.distributed import DistributedSampler + +from colossalai.accelerator import get_accelerator +from colossalai.checkpoint_io import CheckpointIndexFile, CheckpointIO, GeneralCheckpointIO +from colossalai.checkpoint_io.utils import ( + get_model_base_filenames, + get_optimizer_base_filenames, + load_shard_state_dict, + save_config_file, + save_state_dict, + save_state_dict_shards, +) +from colossalai.cluster import DistCoordinator, ProcessGroupMesh +from colossalai.interface import ModelWrapper, OptimizerWrapper +from colossalai.logging import get_dist_logger +from colossalai.shardformer import ShardConfig, ShardFormer +from colossalai.zero import GeminiDDP, GeminiOptimizer +from colossalai.zero.gemini.memory_tracer import MemStats + +from .dp_plugin_base import DPPluginBase + +__all__ = ["GeminiPlugin"] + +SUPPORTED_PRECISION = ["fp16", "bf16"] +PRECISION_STR_TO_DTYPE = {"fp16": torch.half, "bf16": torch.bfloat16} + +ZERO_AXIS, DP_AXIS, TP_AXIS = 0, 1, 2 + + +def get_param_info(optim: Optimizer): + # Get a backup of necessary information of parameters for future use, which includes: + # 1. A mapping from integer param_id to param32 shape. + if optim is None: + return {} + param_info = {"id2shape": {}} + + start_index = 0 + for group in optim.param_groups: + for param_id, param in enumerate(group["params"], start_index): + original_shape = param.shape if isinstance(param, torch.Tensor) else None + param_info["id2shape"][param_id] = original_shape + + start_index += len(group["params"]) + + return param_info + + +class GeminiCheckpointIO(GeneralCheckpointIO): + def __init__(self) -> None: + super().__init__() + self.coordinator = DistCoordinator() + self.logger = get_dist_logger() + + def save_unsharded_model(self, model: GeminiDDP, checkpoint: str, gather_dtensor: bool, use_safetensors: bool): + """ + Save sharded model to checkpoint but only on master process. + The model should be unwrapped in self.load_model via ModelWrapper.unwrap. + As there is communication when getting state dict, model.state_dict() must be called on all processes. + """ + assert isinstance(model, GeminiDDP), "Please boost the model before saving!" + state_dict = model.state_dict(only_rank_0=True) + if self.coordinator.is_master(): + save_state_dict(state_dict, checkpoint, use_safetensors) + + def load_unsharded_model(self, model: GeminiDDP, checkpoint: str, strict: bool = True): + """ + Load model from checkpoint with automatic unwrapping. + The model should be unwrapped in self.load_model via ModelWrapper.unwrap. + """ + assert isinstance(model, GeminiDDP), "Please boost the model before loading!" + super().load_unsharded_model(model, checkpoint, strict=strict) + + def save_unsharded_optimizer(self, optimizer: GeminiOptimizer, checkpoint: str, gather_dtensor: bool): + """ + Save unsharded optimizer state dict to checkpoint. + After calling optimizer.state_dict(), the complete optimizer states will be collected on master rank. + As there is communication when getting state dict, optimizer.state_dict() must be called on all processes. + The saving process will only be executed by master rank. + """ + assert isinstance(optimizer, GeminiOptimizer), "Please boost the optimizer before saving!" + state_dict = optimizer.state_dict() + if self.coordinator.is_master(): + save_state_dict(state_dict, checkpoint, use_safetensors=False) + + def load_unsharded_optimizer(self, optimizer: GeminiOptimizer, checkpoint: str): + """ + Loading unsharded optimizer from checkpoint file. + For each process, only loading optimizer states of parameters it controls. + """ + assert isinstance(optimizer, GeminiOptimizer), "Please boost the optimizer before loading!" + super().load_unsharded_optimizer(optimizer, checkpoint) + + def save_sharded_model( + self, + model: GeminiDDP, + checkpoint_path: str, + gather_dtensor: bool = False, + prefix: Optional[str] = None, + max_shard_size: int = 1024, + use_safetensors: bool = False, + ): + """ + Save sharded model. + As there is communication when getting state dict, model.state_dict() must be called on all processes. + """ + assert isinstance(model, GeminiDDP), "Please boost the model before saving!" + if os.path.isfile(checkpoint_path): + self.logger.error(f"Provided path ({checkpoint_path}) should be a directory, not a file") + return + + Path(checkpoint_path).mkdir(parents=True, exist_ok=True) + + state_dict_shard = model.state_dict_shard(max_shard_size=max_shard_size, only_rank_0=True) + weights_name, save_index_file = get_model_base_filenames(prefix, use_safetensors) + index_file = CheckpointIndexFile(checkpoint_path) + + # Save shards of optimizer states. + is_master = self.coordinator.is_master() + total_size = save_state_dict_shards( + sharded_state_dict=state_dict_shard, + checkpoint=checkpoint_path, + index_file=index_file, + base_filename=weights_name, + is_master=is_master, + use_safetensors=use_safetensors, + ) + + # only save the index file on the master rank + if self.coordinator.is_master(): + index_file.append_meta_data("total_size", total_size) + index_file.write_index_file(save_index_file) + save_config_file(model.unwrap(), checkpoint_path) + self.logger.info( + f"The model is split into checkpoint shards. " + f"You can find where each parameters has been saved in the " + f"index located at {save_index_file}." + ) + + def load_sharded_model( + self, model: GeminiDDP, checkpoint_index_file: Path, strict: bool = False, use_safetensors: bool = False + ): + """ + Load shard model, load model from multiple files. + """ + assert isinstance(model, GeminiDDP), "Please boost the model before loading!" + return super().load_sharded_model(model, checkpoint_index_file, strict, use_safetensors, load_sub_module=False) + + def save_sharded_optimizer( + self, optimizer: GeminiOptimizer, checkpoint: Path, gather_dtensor: bool, prefix: str, size_per_shard: int + ): + """ + Save sharded optimizer state dict to checkpoint folder. + As there is communication when getting state dict, this must be called on all processes. + """ + assert isinstance(optimizer, GeminiOptimizer), "Please boost the optimizer before saving!" + + if os.path.isfile(checkpoint): + self.logger.error(f"Provided path ({checkpoint}) should be a directory, not a file") + return + + Path(checkpoint).mkdir(parents=True, exist_ok=True) + + # Preparing file paths and index file. + states_name, save_index_file, param_group_file = get_optimizer_base_filenames(prefix) + index_file = CheckpointIndexFile(checkpoint) + index_file.append_meta_data("param_groups", param_group_file) + + # Store the information of param groups to param_group_file. + if self.coordinator.is_master(): + group_file_path = os.path.join(checkpoint, param_group_file) + param_groups = optimizer.get_param_groups_for_saving() + torch.save(param_groups, group_file_path) + + # States are broken into shards within max_shard_size. + state_dict_shard = optimizer.state_shard(prefix=prefix, max_shard_size=size_per_shard, only_rank_0=True) + + # Save shards of optimizer states. + total_size = save_state_dict_shards( + sharded_state_dict=state_dict_shard, + checkpoint=checkpoint, + index_file=index_file, + base_filename=states_name, + is_master=self.coordinator.is_master(), + use_safetensors=False, + ) + + # Wrap up index file. Only save it on master rank. + if self.coordinator.is_master(): + index_file.append_meta_data("total_size", total_size) + index_file.write_index_file(save_index_file) + self.logger.info( + f"The optimizer is going to be split to checkpoint shards. " + f"You can find where each parameters has been saved in the " + f"index located at {save_index_file}." + ) + + def load_sharded_optimizer(self, optimizer: GeminiOptimizer, checkpoint_index_file: Path, prefix: str): + """ + Loading sharded optimizer from checkpoint folder, with index file given. + For each process, only loading optimizer states of parameters it controls. + """ + assert isinstance(optimizer, GeminiOptimizer), "Please boost the optimizer before loading!" + if not os.path.isfile(checkpoint_index_file): + self.logger.error(f"Provided path ({checkpoint_index_file}) should be a file") + + assert isinstance(optimizer, GeminiOptimizer) + + # Read checkpoint index file. + ckpt_index_file = CheckpointIndexFile.from_file(checkpoint_index_file) + + # Load param_groups. + param_group_path = ckpt_index_file.get_param_group_filename() + if param_group_path is None: + raise RuntimeError( + f"Invalid index file path {checkpoint_index_file} for an optimizer. \ + Lacking param group file under current directory." + ) + saved_param_groups = torch.load(param_group_path) + optimizer.load_param_groups(saved_param_groups) + + checkpoint_files, _ = ckpt_index_file.get_checkpoint_filenames() + + # Load optimizer states from shard files under checkpoint path. + # For each file, only load the states managed by current process. + for shard_file in checkpoint_files: + state_dict_shard = load_shard_state_dict(Path(shard_file), use_safetensors=False) + optimizer.load_param_states(state_dict_shard) + del state_dict_shard + gc.collect() + + optimizer.optimizer_loading_epilogue() + + def save_lr_scheduler(self, lr_scheduler: LRScheduler, checkpoint: str): + """ + Save model to checkpoint but only on master process. + """ + if self.coordinator.is_master(): + super().save_lr_scheduler(lr_scheduler, checkpoint) + + +class GeminiPlugin(DPPluginBase): + """ + Plugin for Gemini. + + ```python + from colossalai.booster import Booster + from colossalai.booster.plugin import GeminiPlugin + + model, train_dataset, optimizer, criterion = ... + plugin = GeminiPlugin() + + train_dataloader = plugin.prepare_dataloader(train_dataset, batch_size=8) + booster = Booster(plugin=plugin) + model, optimizer, train_dataloader, criterion = booster.boost(model, optimizer, train_dataloader, criterion) + ``` + + Args: + chunk_config_dict (dict, optional): chunk configuration dictionary. + chunk_init_device (torch.device, optional): device to initialize the chunk. + placement_policy (str, optional): "static" and "auto". Defaults to "static". + enable_gradient_accumulation (bool, optional): Whether to enable gradient accumulation. When set to True, gradient will be stored after doing backward pass. Defaults to False. + shard_param_frac (float, optional): fraction of parameters to be sharded. Only for "static" placement. + If `shard_param_frac` is 1.0, it's equal to zero-3. If `shard_param_frac` is 0.0, it's equal to zero-2. Defaults to 1.0. + offload_optim_frac (float, optional): fraction of optimizer states to be offloaded. Only for "static" placement. + If `shard_param_frac` is 1.0 and `offload_optim_frac` is 0.0, it's equal to old "cuda" placement. Defaults to 0.0. + offload_param_frac (float, optional): fraction of parameters to be offloaded. Only for "static" placement. + For efficiency, this argument is useful only when `shard_param_frac` is 1.0 and `offload_optim_frac` is 1.0. + If `shard_param_frac` is 1.0, `offload_optim_frac` is 1.0 and `offload_param_frac` is 1.0, it's equal to old "cpu" placement. + When using static placement, we recommend users to tune `shard_param_frac` first and then `offload_optim_frac`. + Defaults to 0.0. + warmup_non_model_data_ratio (float, optional): ratio of expected non-model data memory during warmup. Only for "auto" placement. Defaults to 0.8. + steady_cuda_cap_ratio (float, optional): ratio of allowed cuda capacity for model data during steady state. Only for "auto" placement. Defaults to 0.9. + precision (str, optional): precision. Support 'fp16' and 'bf16'. Defaults to 'fp16'. + master_weights (bool, optional): Whether to keep fp32 master parameter weights in optimizer. Defaults to True. + pin_memory (bool, optional): use pin memory on CPU. Defaults to False. + force_outputs_fp32 (bool, optional): force outputs are fp32. Defaults to False. + strict_ddp_mode (bool, optional): use strict ddp mode (only use dp without other parallelism). Defaults to False. + search_range_m (int, optional): chunk size searching range divided by 2^20. Defaults to 32. + hidden_dim (int, optional): the hidden dimension of DNN. + Users can provide this argument to speed up searching. + If users do not know this argument before training, it is ok. We will use a default value 1024. + min_chunk_size_m (float, optional): the minimum chunk size divided by 2^20. + If the aggregate size of parameters is still smaller than the minimum chunk size, + all parameters will be compacted into one small chunk. + memstats (MemStats, optional) the memory statistics collector by a runtime memory tracer. + gpu_margin_mem_ratio (float, optional): The ratio of GPU remaining memory (after the first forward-backward) + which will be used when using hybrid CPU optimizer. + This argument is meaningless when `placement_policy` of `GeminiManager` is not "auto". + Defaults to 0.0. + initial_scale (float, optional): Initial scale used by DynamicGradScaler. Defaults to 2**16. + min_scale (float, optional): Min scale used by DynamicGradScaler. Defaults to 1. + growth_factor (float, optional): growth_factor used by DynamicGradScaler. Defaults to 2. + backoff_factor (float, optional): backoff_factor used by DynamicGradScaler. Defaults to 0.5. + growth_interval (float, optional): growth_interval used by DynamicGradScaler. Defaults to 1000. + hysteresis (float, optional): hysteresis used by DynamicGradScaler. Defaults to 2. + max_scale (int, optional): max_scale used by DynamicGradScaler. Defaults to 2**32. + max_norm (float, optional): max_norm used for `clip_grad_norm`. You should notice that you shall not do + clip_grad_norm by yourself when using ZeRO DDP. The ZeRO optimizer will take care of clip_grad_norm. + norm_type (float, optional): norm_type used for `clip_grad_norm`. + tp_size (int, optional): If 'tp_size' is set to be greater than 1, it means using tensor parallelism strategy, which is implemented in Shardformer, 'tp_size' determines the size of the tensor parallel process group. Default to 1. + extra_dp_size (int, optional): If 'extra_dp_size' is set to be greater than 1, it means creating another group to run with a ddp-like strategy. Default to 1. + enable_all_optimization (bool, optional): Whether to switch on all the optimizations supported by Shardformer. + Currently all the optimization methods include fused normalization, flash attention and JIT. + Defaults to False. + enable_fused_normalization (bool, optional): Whether to switch on fused normalization in Shardformer. Defaults to False. + enable_flash_attention (bool, optional): Whether to switch on flash attention in Shardformer. Defaults to False. + enable_jit_fused (bool, optional): Whether to switch on JIT in Shardformer. Default to False. + enable_sequence_parallelism (bool): Whether to turn on sequence parallelism in Shardformer. Defaults to False. + enable_sequence_overlap (bool): Whether to turn on sequence overlap in Shardformer. Defaults to False. + verbose (bool, optional): verbose mode. Debug info including chunk search result will be printed. Defaults to False. + """ + + def __init__( + self, + chunk_config_dict: Optional[dict] = None, + chunk_init_device: Optional[torch.device] = None, + placement_policy: str = "static", + enable_gradient_accumulation: bool = False, + max_prefetch: int = 0, + shard_param_frac: float = 1.0, # only for static placement + offload_optim_frac: float = 0.0, # only for static placement + offload_param_frac: float = 0.0, # only for static placement + warmup_non_model_data_ratio: float = 0.8, # only for auto placement + steady_cuda_cap_ratio: float = 0.9, # only for auto placement + precision: str = "fp16", + master_weights: bool = True, + pin_memory: bool = False, + force_outputs_fp32: bool = False, + strict_ddp_mode: bool = False, + search_range_m: int = 32, + hidden_dim: Optional[int] = None, + min_chunk_size_m: float = 32, + memstats: Optional[MemStats] = None, + gpu_margin_mem_ratio: float = 0.0, + initial_scale: float = 2**16, + min_scale: float = 1, + growth_factor: float = 2, + backoff_factor: float = 0.5, + growth_interval: int = 1000, + hysteresis: int = 2, + max_scale: float = 2**32, + max_norm: float = 0.0, + norm_type: float = 2.0, + tp_size: int = 1, + extra_dp_size: int = 1, + enable_all_optimization: bool = False, + enable_fused_normalization: bool = False, + enable_flash_attention: bool = False, + enable_sequence_parallelism: bool = False, + enable_jit_fused: bool = False, + enable_sequence_overlap: bool = False, + enable_async_reduce: bool = True, + verbose: bool = False, + ) -> None: + super().__init__() + assert precision in SUPPORTED_PRECISION, f"precision {precision} is not supported" + if get_accelerator().name == "npu": + assert placement_policy == "static", "NPU only supports static placement policy" + + self.logger = get_dist_logger() + if enable_async_reduce and not pin_memory: + self.logger.warning( + f"enable_async_reduce sets pin_memory=True to achieve best performance, which is not implicitly set." + ) + pin_memory = True + self.gemini_config = dict( + chunk_config_dict=chunk_config_dict, + chunk_init_device=(chunk_init_device or get_accelerator().get_current_device()), + placement_policy=placement_policy, + enable_gradient_accumulation=enable_gradient_accumulation, + shard_param_frac=shard_param_frac, + offload_optim_frac=offload_optim_frac, + offload_param_frac=offload_param_frac, + warmup_non_model_data_ratio=warmup_non_model_data_ratio, + steady_cuda_cap_ratio=steady_cuda_cap_ratio, + pin_memory=pin_memory, + force_outputs_fp32=force_outputs_fp32, + strict_ddp_mode=strict_ddp_mode, + search_range_m=search_range_m, + hidden_dim=hidden_dim, + min_chunk_size_m=min_chunk_size_m, + memstats=memstats, + mixed_precision=PRECISION_STR_TO_DTYPE[precision], + master_weights=master_weights, + max_prefetch=max_prefetch, + enable_async_reduce=enable_async_reduce, + ) + self.zero_optim_config = dict( + gpu_margin_mem_ratio=gpu_margin_mem_ratio, + ) + self.optim_kwargs = dict( + initial_scale=initial_scale, + growth_factor=growth_factor, + backoff_factor=backoff_factor, + growth_interval=growth_interval, + hysteresis=hysteresis, + min_scale=min_scale, + max_scale=max_scale, + max_norm=max_norm, + norm_type=norm_type, + ) + self.enable_tensor_parallelism = tp_size > 1 + self.enable_all_optimization = enable_all_optimization + self.enable_fused_normalization = enable_fused_normalization + self.enable_flash_attention = enable_flash_attention + self.enable_sequence_parallelism = enable_sequence_parallelism if self.enable_tensor_parallelism else False + self.enable_jit_fused = enable_jit_fused + self.enable_sequence_overlap = enable_sequence_overlap + self.verbose = verbose + + self.tp_size = tp_size + self.extra_dp_size = extra_dp_size + world_size = dist.get_world_size() + self.zero_size = world_size // (self.tp_size * self.extra_dp_size) + assert ( + world_size == (self.tp_size * self.extra_dp_size) * self.zero_size + ), f"The global group size can't be evenly divided by the subgroup size." + + self.pg_mesh = ProcessGroupMesh(self.zero_size, self.extra_dp_size, self.tp_size) + self.zero_group = ( + self.pg_mesh.get_group_along_axis(ZERO_AXIS) if self.zero_size < world_size else _get_default_group() + ) + self.extra_dp_group = self.pg_mesh.get_group_along_axis(DP_AXIS) if self.extra_dp_size > 1 else None + self.tp_group = self.pg_mesh.get_group_along_axis(TP_AXIS) if self.tp_size > 1 else None + self.dp_size = self.zero_size * self.extra_dp_size + + self.shard_config = ShardConfig( + tensor_parallel_process_group=self.tp_group, + enable_tensor_parallelism=self.enable_tensor_parallelism, + enable_all_optimization=self.enable_all_optimization, + enable_fused_normalization=self.enable_fused_normalization, + enable_flash_attention=self.enable_flash_attention, + enable_jit_fused=self.enable_jit_fused, + enable_sequence_parallelism=self.enable_sequence_parallelism, + enable_sequence_overlap=self.enable_sequence_overlap, + ) + + def __del__(self): + """Destroy the process groups in ProcessGroupMesh""" + self.pg_mesh.destroy_mesh_process_groups() + + def support_no_sync(self) -> bool: + return False + + def support_lora(self) -> bool: + return False + + def control_precision(self) -> bool: + return True + + def supported_precisions(self) -> List[str]: + return SUPPORTED_PRECISION + + def control_device(self) -> bool: + return True + + def supported_devices(self) -> List[str]: + return ["cuda", "npu"] + + def prepare_dataloader( + self, + dataset, + batch_size, + shuffle=False, + seed=1024, + drop_last=False, + pin_memory=False, + num_workers=0, + distributed_sampler_cls=None, + **kwargs, + ): + r""" + Prepare a dataloader for distributed training. The dataloader will be wrapped by + `torch.utils.data.DataLoader` and `torch.utils.data.DistributedSampler`. + + + Args: + dataset (`torch.utils.data.Dataset`): The dataset to be loaded. + shuffle (bool, optional): Whether to shuffle the dataset. Defaults to False. + seed (int, optional): Random worker seed for sampling, defaults to 1024. + add_sampler: Whether to add ``DistributedDataParallelSampler`` to the dataset. Defaults to True. + drop_last (bool, optional): Set to True to drop the last incomplete batch, if the dataset size + is not divisible by the batch size. If False and the size of dataset is not divisible by + the batch size, then the last batch will be smaller, defaults to False. + pin_memory (bool, optional): Whether to pin memory address in CPU memory. Defaults to False. + num_workers (int, optional): Number of worker threads for this dataloader. Defaults to 0. + kwargs (dict): optional parameters for ``torch.utils.data.DataLoader``, more details could be found in + `DataLoader `_. + + Returns: + :class:`torch.utils.data.DataLoader`: A DataLoader used for training or testing. + """ + _kwargs = kwargs.copy() + zero_world_size = self.pg_mesh.size(ZERO_AXIS) + extra_dp_world_size = self.pg_mesh.size(DP_AXIS) + zero_rank = self.pg_mesh.coordinate(ZERO_AXIS) + extra_dp_rank = self.pg_mesh.coordinate(DP_AXIS) + distributed_sampler_cls = distributed_sampler_cls or DistributedSampler + sampler = distributed_sampler_cls( + dataset, + num_replicas=zero_world_size * extra_dp_world_size, + rank=zero_rank * extra_dp_world_size + extra_dp_rank, + shuffle=shuffle, + ) + + # Deterministic dataloader + def seed_worker(worker_id): + worker_seed = seed + np.random.seed(worker_seed) + torch.manual_seed(worker_seed) + random.seed(worker_seed) + + return DataLoader( + dataset, + batch_size=batch_size, + sampler=sampler, + worker_init_fn=seed_worker, + drop_last=drop_last, + pin_memory=pin_memory, + num_workers=num_workers, + **_kwargs, + ) + + def configure( + self, + model: nn.Module, + optimizer: Optional[Optimizer] = None, + criterion: Optional[Callable] = None, + dataloader: Optional[DataLoader] = None, + lr_scheduler: Optional[LRScheduler] = None, + ) -> Tuple[nn.Module, OptimizerWrapper, Callable, DataLoader, LRScheduler]: + params_info = get_param_info(optimizer) + if not isinstance(model, ModelWrapper): + # convert model to sync bn + # FIXME(ver217): gemini does not support sync bn + # In torch/nn/modules/_functions.py, line 22, ``mean, invstd = torch.batch_norm_stats(input, eps)`` will get fp32 mean and invstd even though the input is fp16. + # This inconsistency of dtype will cause the error. + # We have two possible solutions: + # 1. keep batch norm always in fp32. This is hard for gemini, as it use chunks. + # 2. patch sync bn or write a new on. This is relatively easy, but we need to test it. + # model = nn.SyncBatchNorm.convert_sync_batchnorm(model, None) + + # wrap the model with Gemini + if self.enable_tensor_parallelism: + shardformer = ShardFormer(self.shard_config) + model, _ = shardformer.optimize(model) + + model = GeminiDDP( + model, + **self.gemini_config, + zero_group=self.zero_group, + extra_dp_group=self.extra_dp_group, + verbose=self.verbose, + ) + + if optimizer is not None and not isinstance(optimizer, OptimizerWrapper): + optimizer = GeminiOptimizer( + optimizer, + model, + **self.zero_optim_config, + **self.optim_kwargs, + tp_group=self.tp_group, + params_info=params_info, + verbose=self.verbose, + ) + + return model, optimizer, criterion, dataloader, lr_scheduler + + def control_checkpoint_io(self) -> bool: + return True + + def get_checkpoint_io(self) -> CheckpointIO: + return GeminiCheckpointIO() + + def no_sync(self, model: nn.Module, optimizer: OptimizerWrapper) -> Iterator[None]: + raise NotImplementedError + + def enable_lora( + self, model: nn.Module, pretrained_dir: Optional[str] = None, lora_config: Optional[Dict] = None + ) -> nn.Module: + raise NotImplementedError diff --git a/colossalai/booster/plugin/plugin/hybrid_parallel_plugin.py b/colossalai/booster/plugin/plugin/hybrid_parallel_plugin.py new file mode 100644 index 000000000000..6c1515d38834 --- /dev/null +++ b/colossalai/booster/plugin/plugin/hybrid_parallel_plugin.py @@ -0,0 +1,1466 @@ +import ctypes +import random +from collections import defaultdict +from contextlib import contextmanager, nullcontext +from copy import deepcopy +from functools import partial +from types import MethodType +from typing import Any, Callable, Dict, Iterator, List, Optional, OrderedDict, Tuple, Union + +import numpy as np +import torch +import torch.distributed as dist +from torch import Tensor, inf +from torch.distributed import ProcessGroup, get_world_size +from torch.nn import Module, SyncBatchNorm +from torch.nn.parallel import DistributedDataParallel as DDP +from torch.optim import Optimizer +from torch.optim.lr_scheduler import _LRScheduler as LRScheduler +from torch.utils._pytree import tree_map +from torch.utils.data import DataLoader +from torch.utils.data.distributed import DistributedSampler + +from colossalai.accelerator import get_accelerator +from colossalai.amp.naive_amp.mixed_precision_optimizer import MixedPrecisionOptimizer +from colossalai.checkpoint_io import CheckpointIO, HybridParallelCheckpointIO +from colossalai.cluster import ProcessGroupMesh +from colossalai.interface import AMPModelMixin, ModelWrapper, OptimizerWrapper +from colossalai.interface.optimizer import DistributedOptim +from colossalai.logging import get_dist_logger +from colossalai.nn.optimizer import DistGaloreAwamW, cast_to_distributed +from colossalai.pipeline.schedule import InterleavedSchedule, OneForwardOneBackwardSchedule +from colossalai.pipeline.stage_manager import PipelineStageManager +from colossalai.quantization import BnbQuantizationConfig, quantize_model +from colossalai.shardformer import GradientCheckpointConfig, ShardConfig, ShardFormer +from colossalai.shardformer.layer.utils import SeqParallelUtils, is_share_sp_tp +from colossalai.shardformer.policies.base_policy import Policy +from colossalai.tensor.colo_parameter import ColoParameter +from colossalai.tensor.d_tensor.api import is_distributed_tensor +from colossalai.tensor.param_op_hook import ColoParamOpHookManager +from colossalai.zero.low_level import LowLevelZeroOptimizer +from colossalai.zero.low_level.zero_hook import ZeroOpHook, wait_all_gather_handle + +from .pp_plugin_base import PipelinePluginBase + +SUPPORT_SP_MODE = ["split_gather", "ring", "all_to_all", "ring_attn"] + +PRECISION_TORCH_TYPE = {"fp16": torch.float16, "fp32": torch.float32, "bf16": torch.bfloat16} + + +def _convert_floating_point(x, dtype: torch.dtype = torch.float16): + if isinstance(x, torch.Tensor) and torch.is_floating_point(x): + return x.to(dtype) + return x + + +class HybridParallelModule(ModelWrapper, AMPModelMixin): + def __init__( + self, + module: Module, + precision: str, + shard_config: ShardConfig, + dp_group: ProcessGroup, + tp_group: ProcessGroup, + sp_group: ProcessGroup, + use_ddp: bool, + ddp_config: dict, + custom_policy: Policy, + overlap_allgather: bool = False, + ) -> None: + self.stage_manager = shard_config.pipeline_stage_manager + self.shard_config = shard_config + self.dp_group = dp_group + self.tp_group = tp_group + self.sp_group = sp_group + self.use_ddp = use_ddp + self.require_grad_sync = True + self.overlap_allgather = overlap_allgather + + shardformer = ShardFormer(shard_config) + if custom_policy is not None: + assert isinstance(custom_policy, object) + module, self.shared_params = shardformer.optimize(module, policy=custom_policy) + + # setting process groups for shared parameters + self.shared_param_process_groups = [] + for shared_param in self.shared_params: + if len(shared_param) > 0: + self.shared_param_process_groups.append( + self.stage_manager.init_process_group_by_stages(list(shared_param.keys())) + ) + + # setting mixed_precision + self.mixed_precision = None + if precision == "fp16": + self.mixed_precision = torch.float16 + elif precision == "bf16": + self.mixed_precision = torch.bfloat16 + if self.mixed_precision is not None: + module = module.to(self.mixed_precision) + module = module.to(get_accelerator().get_current_device()) + + # setting input type cast when using mixed precision + self.convert_fn = None + if self.mixed_precision is not None: + self.convert_fn = partial(_convert_floating_point, dtype=self.mixed_precision) + + # setting ddp configs + if use_ddp: + # convert model to sync bn + module = SyncBatchNorm.convert_sync_batchnorm(module, dp_group) + # wrap the model with PyTorch DDP + module = DDP(module, process_group=dp_group, **ddp_config) + + super().__init__(module) + if overlap_allgather: + self.op_hook = ZeroOpHook() + for p in module.parameters(): + if p.requires_grad and type(p) is not ColoParameter: + p.__class__ = ColoParameter + p.__init__(p, requires_grad=True) + + def sync_shared_params(self): + for shared_param, group in zip(self.shared_params, self.shared_param_process_groups): + if self.stage_manager.stage in shared_param: + param = shared_param[self.stage_manager.stage] + dist.all_reduce(param.grad, group=group) + dist.barrier() + + @contextmanager + def no_sync(self): + r""" + A context manager to disable automatic gradient synchronization (all-reduce) and allow manual synchronization + when 'no_sync' is active. Alternatively, synchronization will occur in the first forward-backward pass + when exiting the context. + """ + + # Store the current value of 'require_grad_sync' to restore it later. + old_require_grad_sync = self.require_grad_sync + # Disable automatic gradient synchronization. + self.require_grad_sync = False + try: + if self.use_ddp: + # If using data parallel processing (use_ddp), disable synchronization too. + with self.module.no_sync(): + yield + else: + yield + finally: + # Restore the original value of 'require_grad_sync'. + self.require_grad_sync = old_require_grad_sync + + def sync_dp_grads(self): + r""" + Synchronize gradients across data parallelism (DP) if the DP group size is greater than 1. + This function performs an all-reduce operation to combine gradients from different devices in the DP group. + + Args: + None + + Returns: + None + """ + + # Check if the DP group size is 1, meaning no synchronization is needed. + if self.dp_group.size() == 1: + return + + # Iterate through the model's parameters and perform gradient synchronization. + for p in self.module.parameters(): + if p.grad is not None: + # Perform all-reduce to combine gradients from different devices. + dist.all_reduce(p.grad, group=self.dp_group) + # Normalize the gradient by dividing it by the DP group size. + p.grad.div_(self.dp_group.size()) + + def sync_sp_grads(self, grads: Optional[List[torch.Tensor]] = None): + r""" + Synchronize gradients that are partially derived within sequence parallelism + if sequence parallelism is enabled. Gradients can be provided explicitly or extracted + from the module. + + Args: + grads (Optional[List[torch.Tensor]]): A list of gradient tensors to synchronize. If not + provided, gradients will be extracted from the model. + + Returns: + None + """ + + if self.shard_config.enable_sequence_parallelism: + if self.shard_config.sequence_parallelism_mode in ["all_to_all", "ring_attn"]: + return + + if self.shard_config.sequence_parallelism_mode in ["split_gather", "ring"]: + # If sequence parallelism is enabled and mode is split_gather or ring, gradients are synchronized + # across the tensor parallelism group. + group = self.tp_group + else: + raise ValueError(f"Unknown sequence parallelism mode: {self.shard_config.sequence_parallelism_mode}") + + if grads is not None: + # Synchronize provided gradient tensors across the tensor parallelism group. + SeqParallelUtils.allreduce_partial_data_grad(process_group=group, grads=grads) + else: + # Synchronize gradients from the model across the tensor parallelism group. + SeqParallelUtils.allreduce_partial_data_grad(process_group=group, model=self.module) + + def forward(self, *args, **kwargs): + if self.convert_fn is not None: + args = tree_map(self.convert_fn, args) + kwargs = tree_map(self.convert_fn, kwargs) + with self._wait_all_gather(): + return super().forward(*args, **kwargs) + + def unwrap(self): + module = super().unwrap() + if isinstance(module, DDP): + module = module.module + return module + + def _force_wait_all_gather(self): + for p in self.module.parameters(): + wait_all_gather_handle(p) + + def _wait_all_gather(self): + return ColoParamOpHookManager.use_hooks(self.op_hook) if self.overlap_allgather else nullcontext() + + +def get_param_info(optim: Optimizer): + # Get a backup of necessary information of parameters for future use, which includes: + # 1. A complete param_group, with params in the form of param_id + # 2. A mapping from param address (obtained using id(param)) to integer param_id + # 3. A mapping from integer param_id to param address. + # 4. A mapping from param_address (obtained using id(param)) to the original shape of parameter before sharding. + # When Zero is used, the params here are fp16/bf16 model params rather than fp32 master params in optimizer. + + if optim is None: + return {} + param_info = {"param_groups": [], "param2id": {}, "id2param": {}, "param2shape": {}} + start_index = 0 + for group in optim.param_groups: + packed_group = {k: v for k, v in group.items() if k != "params"} + packed_group["params"] = [] + + for param_id, param in enumerate(group["params"], start_index): + original_shape = param.shape if isinstance(param, torch.Tensor) else None + packed_group["params"].append(param_id) + param_info["param2id"][id(param)] = param_id + param_info["id2param"][param_id] = id(param) + param_info["param2shape"][id(param)] = original_shape + + param_info["param_groups"].append(packed_group) + start_index += len(group["params"]) + + return param_info + + +def reinitialize_optimizer(optim: Optimizer, model: Module): + model_params = set(model.parameters()) + new_param_groups = [] + for group in optim.param_groups: + params = [p for p in group["params"] if p in model_params] + new_param_groups.append({**group, "params": params}) + optim.__setstate__({"param_groups": new_param_groups}) + + +class HybridParallelNaiveOptimizer(OptimizerWrapper): + def __init__( + self, + optim: Optimizer, + model: HybridParallelModule, + use_pipeline: bool, + param_info: OrderedDict, + max_norm: float = 0, + tp_process_group: Optional[ProcessGroup] = None, # if using tp + pp_process_group: Optional[ProcessGroup] = None, # if using pp + ): + self.param_info = param_info + if use_pipeline: + reinitialize_optimizer(optim, model) + self.model = model + self.stage_manager = model.stage_manager + self.shared_params = model.shared_params + self.max_norm = max_norm + self.tp_pg = tp_process_group + self.pp_pg = pp_process_group + self.tp_size = get_world_size(self.tp_pg) if self.tp_pg is not None else 1 + self.pp_size = get_world_size(self.pp_pg) if self.pp_pg is not None else 1 + super().__init__(optim) + + def backward(self, loss: Tensor, *args, **kwargs): + r""" + Backpropagate gradients through the model and optionally synchronize sequence parallelism gradients. + + This method performs backward pass for gradient computation. If sequence parallelism is enabled + and gradient synchronization is required, it will synchronize gradients that are partially derived + within sequence parallelism across tp parallelism groups. + + Args: + loss (Tensor): The loss tensor to compute gradients with respect to. + *args: Additional positional arguments to be passed to the superclass backward method. + **kwargs: Additional keyword arguments to be passed to the superclass backward method. + + Returns: + None + """ + + # Call the superclass backward method to compute gradients. + super().backward(loss, *args, **kwargs) + + if self.model.require_grad_sync: + # If gradient synchronization is required, sync sequence parallelism gradients. + self.model.sync_sp_grads() + else: + # If gradient synchronization is is not required, return. + return + + def backward_by_grad(self, tensor: Tensor, grad: Tensor): + """ + Backpropagate gradients through the model using a precomputed gradient and optionally synchronize sequence parallelism gradients. + + This method performs a backward pass for gradient computation using a precomputed gradient tensor. + If sequence parallelism is enabled and gradient synchronization is required, it will synchronize + gradients that are partially derived within sequence parallelism across tp parallelism groups. + + Args: + tensor (Tensor): The input tensor for which gradients are computed. + grad (Tensor): The precomputed gradient tensor to compute gradients with respect to the input tensor. + + Returns: + None + """ + + # Call the superclass backward method to compute gradients. + super().backward_by_grad(tensor, grad) + + if self.model.require_grad_sync: + # If gradient synchronization is required, sync sequence parallelism gradients. + self.model.sync_sp_grads() + else: + # If gradient synchronization is is not required, return. + return + + def step(self, *args, **kwargs): + r""" + Perform an optimization step. + + Args: + *args: Variable-length positional arguments to be passed to the optimizer's step function. + **kwargs: Keyword arguments to be passed to the optimizer's step function. + """ + + if self.max_norm > 0: + # Compute the total gradient norm. + param_gradient_pairs = [ + (p, p.grad) for group in self.optim.param_groups for p in group["params"] if p.grad is not None + ] + total_norm = self._compute_grad_norm(param_gradient_pairs) + + # Clip the gradients to prevent exploding gradients. + self._clip_grad_norm(total_norm) + + # Perform the optimization step using the underlying optimizer. + self.optim.step(*args, **kwargs) + + def _compute_grad_norm(self, param_gradient_pairs: List[Tuple[Tensor]], norm_type: int = 2) -> int: + r""" + Compute and return the gradient norm for gradient clipping. + + Args: + param_gradient_pairs (List[Tuple[Tensor]]): List of (parameter, gradient) pairs; gradients are used for norm calculation. + norm_type (int, optional): Type of the norm used (e.g., 2 for L2 norm). Defaults to 2. + + Returns: + float: The total norm of the given gradients. + """ + + if len(param_gradient_pairs) == 0: + return 0.0 + + norm_type = float(norm_type) + + # gradients used for norm calculation. + gradients = [grad for param, grad in param_gradient_pairs] + + if norm_type == inf: + total_norm = max(grad.data.abs().max() for grad in gradients) + total_norm_cuda = torch.tensor( + [float(total_norm)], device=get_accelerator().get_current_device(), dtype=torch.float32 + ) + if self.tp_size > 1: + dist.all_reduce(tensor=total_norm_cuda, op=dist.ReduceOp.MAX, group=self.tp_pg) + if self.pp_size > 1: + dist.all_reduce(tensor=total_norm_cuda, op=dist.ReduceOp.MAX, group=self.pp_pg) + total_norm = total_norm_cuda.item() + else: + # gradients used for norm calculation. + gradients = [grad for param, grad in param_gradient_pairs] + # grad_to_param_mapping is used to check which gradients are not distributed across devices of the 'tp_group'. + grad_to_param_mapping = {id(grad): param for param, grad in param_gradient_pairs} + + total_norm_exponentiated = 0.0 + for grad in gradients: + grad_norm_exponentiated = grad.data.double().norm(norm_type) ** norm_type + + # If 'tp_size' is greater than 1 and the parameter for the gradient is not a distributed tensor, + # it indicates that the parameter is not distributed across devices of the 'tp_group'. + # Consequently, there is no need to perform an 'all_reduce' operation for 'grad_norm'. + # However, we still perform the 'all_reduce' operation for the sake of good coding practices. + # To ensure mathematical equivalence, we divide the 'grad_norm' by 'tp_size.' + if self.tp_size > 1: + param_for_grad = grad_to_param_mapping[id(grad)] + if not is_distributed_tensor(param_for_grad): + grad_norm_exponentiated /= self.tp_size + + # If 'pp_size' is greater than 1 and the gradient belongs to shared parameters, + # it means that this parameter is used in two different pipeline stages. + # To avoid redundant norm calculations, we divide the exponent of this norm by + # the number of shared stages. + if self.pp_size > 1: + for shared_param in self.shared_params: + if self.stage_manager.stage in shared_param: + stage_shared_param = shared_param[self.stage_manager.stage] + if grad is stage_shared_param.grad: + grad_norm_exponentiated /= len(shared_param) + + total_norm_exponentiated += grad_norm_exponentiated + + total_norm_exponentiated_cuda = torch.tensor( + [float(total_norm_exponentiated)], device=get_accelerator().get_current_device(), dtype=torch.float32 + ) + if self.tp_size > 1: + # compute norm in tp process group + dist.all_reduce(tensor=total_norm_exponentiated_cuda, op=dist.ReduceOp.SUM, group=self.tp_pg) + if self.pp_size > 1: + # compute norm in pp process group + dist.all_reduce(tensor=total_norm_exponentiated_cuda, op=dist.ReduceOp.SUM, group=self.pp_pg) + + # compute the total_norm + total_norm = total_norm_exponentiated_cuda.item() ** (1.0 / norm_type) + + return total_norm + + def _clip_grad_norm(self, total_norm: float) -> None: + r""" + Clips the gradients of the model's parameters to prevent exploding gradients. + + Args: + total_norm (float): The computed total gradient norm. + + Returns: + None + """ + clip_coef = torch.tensor(self.max_norm / (total_norm + 1e-6)) + clip_coef_clamped = torch.clamp(clip_coef, max=1.0) + + for group in self.optim.param_groups: + for p in group["params"]: + if p.grad is None: + continue + p.grad.data.mul_(clip_coef_clamped) + + def update_master_params(self, model: Module): + pass + + def get_working_to_master_map(self): + return None + + def get_master_to_working_map(self): + return None + + +class HybridParallelAMPOptimizer(MixedPrecisionOptimizer): + def __init__( + self, + optim: Optimizer, + model: HybridParallelModule, + use_pipeline: bool, + param_info: OrderedDict, + precision: str = "fp16", + initial_scale: float = 2**16, + min_scale: float = 1, + growth_factor: float = 2, + backoff_factor: float = 0.5, + growth_interval: int = 1000, + hysteresis: int = 2, + max_scale: float = 2**32, + max_norm: float = 0, + tp_process_group: Optional[ProcessGroup] = None, # if using tp + pp_process_group: Optional[ProcessGroup] = None, # if using pp + ): + self.model = model + self.param_info = param_info + self.stage_manager = model.stage_manager + self.shared_params = model.shared_params + self.tp_pg = tp_process_group + self.pp_pg = pp_process_group + self.tp_size = get_world_size(self.tp_pg) if self.tp_pg is not None else 1 + self.pp_size = get_world_size(self.pp_pg) if self.pp_pg is not None else 1 + if use_pipeline: + reinitialize_optimizer(optim, model) + super().__init__( + optim, + precision=precision, + initial_scale=initial_scale, + min_scale=min_scale, + growth_factor=growth_factor, + backoff_factor=backoff_factor, + growth_interval=growth_interval, + hysteresis=hysteresis, + max_scale=max_scale, + max_norm=max_norm, + ) + + def backward(self, loss: Tensor, *args, **kwargs): + r""" + Backpropagate gradients through the model and optionally synchronize sequence parallelism gradients. + + This method performs backward pass for gradient computation. If sequence parallelism is enabled + and gradient synchronization is required, it will synchronize gradients that are partially derived + within sequence parallelism across tp parallelism groups. + + Args: + loss (Tensor): The loss tensor to compute gradients with respect to. + *args: Additional positional arguments to be passed to the superclass backward method. + **kwargs: Additional keyword arguments to be passed to the superclass backward method. + + Returns: + None + """ + # Call the superclass backward method to compute gradients. + super().backward(loss, *args, **kwargs) + + if self.model.require_grad_sync: + # If gradient synchronization is required, sync sequence parallelism gradients. + self.model.sync_sp_grads() + else: + # If gradient synchronization is is not required, return. + return + + def backward_by_grad(self, tensor: Tensor, grad: Tensor): + """ + Backpropagate gradients through the model using a precomputed gradient and optionally synchronize sequence parallelism gradients. + + This method performs a backward pass for gradient computation using a precomputed gradient tensor. + If sequence parallelism is enabled and gradient synchronization is required, it will synchronize + gradients that are partially derived within sequence parallelism across tp parallelism groups. + + Args: + tensor (Tensor): The input tensor for which gradients are computed. + grad (Tensor): The precomputed gradient tensor to compute gradients with respect to the input tensor. + + Returns: + None + """ + # Call the superclass backward method to compute gradients. + super().backward_by_grad(tensor, grad) + + if self.model.require_grad_sync: + # If gradient synchronization is required, sync sequence parallelism gradients. + self.model.sync_sp_grads() + else: + # If gradient synchronization is is not required, return. + return + + def _compute_grad_norm(self, param_gradient_pairs: List[Tuple[Tensor]], norm_type: int = 2) -> int: + r""" + Compute and return the gradient norm for gradient clipping. + + Args: + param_gradient_pairs (List[Tuple[Tensor]]): List of (parameter, gradient) pairs; gradients are used for norm calculation. + norm_type (int, optional): Type of the norm used (e.g., 2 for L2 norm). Defaults to 2. + + Returns: + float: The total norm of the given gradients. + """ + if len(param_gradient_pairs) == 0: + return 0.0 + + norm_type = float(norm_type) + + if norm_type == inf: + # The parent class calculates the norm of 'dp' gradients, + # so we need to calculate the norm of 'tp' and 'pp' gradients. + total_norm = super()._compute_grad_norm(param_gradient_pairs, norm_type) + + total_norm_cuda = torch.tensor( + [float(total_norm)], device=get_accelerator().get_current_device(), dtype=torch.float32 + ) + + if self.tp_size > 1: + dist.all_reduce(tensor=total_norm_cuda, op=dist.ReduceOp.MAX, group=self.tp_pg) + if self.pp_size > 1: + dist.all_reduce(tensor=total_norm_cuda, op=dist.ReduceOp.MAX, group=self.pp_pg) + + total_norm = total_norm_cuda.item() + + else: + # gradients used for norm calculation. + gradients = [grad for param, grad in param_gradient_pairs] + # grad_to_param_mapping is used to check which gradients are not distributed in tensor parallelism. + grad_to_param_mapping = {id(grad): param for param, grad in param_gradient_pairs} + + total_norm_exponentiated = 0.0 + for grad in gradients: + grad_norm_exponentiated = grad.data.double().norm(norm_type) ** norm_type + + # If 'tp_size' is greater than 1 and the parameter for the gradient is not a distributed tensor, + # it indicates that the parameter is not distributed across devices of the 'tp_group'. + # Consequently, there is no need to perform an 'all_reduce' operation for 'grad_norm'. + # However, we still perform the 'all_reduce' operation for the sake of good coding practices. + # To ensure mathematical equivalence, we divide the 'grad_norm' by 'tp_size.' + if self.tp_size > 1: + param_for_grad = grad_to_param_mapping[id(grad)] + if not is_distributed_tensor(param_for_grad): + grad_norm_exponentiated /= self.tp_size + + # If 'pp_size' is greater than 1 and the gradient belongs to shared parameters, + # it means that this parameter is used in two different pipeline stages. + # To avoid redundant norm calculations, we divide the exponent of this norm by + # the number of shared stages. + if self.pp_size > 1: + for shared_param in self.shared_params: + if self.stage_manager.stage in shared_param: + stage_working_shared_param = shared_param[self.stage_manager.stage] + stage_master_shared_param = self.working_to_master_map[stage_working_shared_param] + if grad is stage_master_shared_param.grad: + grad_norm_exponentiated /= len(shared_param) + + total_norm_exponentiated += grad_norm_exponentiated + + total_norm_exponentiated_cuda = torch.tensor( + [float(total_norm_exponentiated)], device=get_accelerator().get_current_device(), dtype=torch.float32 + ) + if self.tp_size > 1: + # compute norm in tp process group + dist.all_reduce(tensor=total_norm_exponentiated_cuda, op=dist.ReduceOp.SUM, group=self.tp_pg) + if self.pp_size > 1: + # compute norm in pp process group + dist.all_reduce(tensor=total_norm_exponentiated_cuda, op=dist.ReduceOp.SUM, group=self.pp_pg) + + # compute the total_norm + total_norm = total_norm_exponentiated_cuda.item() ** (1.0 / norm_type) + + return total_norm + + +class HybridParallelZeroOptimizer(LowLevelZeroOptimizer): + def __init__( + self, + optimizer: Optimizer, + model: HybridParallelModule, + use_pipeline: bool, + param_info: OrderedDict, + pg_to_param_list: Dict[ProcessGroup, List[torch.nn.Parameter]] = None, + initial_scale: int = 2**16, # grad scaler config + min_scale: int = 1, + growth_factor: float = 2.0, + backoff_factor: float = 0.5, + growth_interval: int = 2000, + hysteresis: int = 2, + max_scale: int = 2**24, + clip_grad_norm: float = 0.0, # grad clipping + verbose: bool = False, + reduce_bucket_size: int = 1024 * 1024, # communication + communication_dtype: Optional[torch.dtype] = None, + overlap_communication: bool = True, + partition_grad: bool = False, # stage 2 flag + cpu_offload: bool = False, # cpu offload + dp_process_group: Optional[ProcessGroup] = None, # the dp pg for comm + tp_process_group: Optional[ProcessGroup] = None, # if using tp + pp_process_group: Optional[ProcessGroup] = None, # if using pp + forced_dtype: Optional[torch.dtype] = None, + overlap_allgather: bool = False, + ): + self.model = model + self.param_info = param_info + self.stage_manager = model.stage_manager + self.shared_params = model.shared_params + self.tp_pg = tp_process_group + self.pp_pg = pp_process_group + if use_pipeline: + reinitialize_optimizer(optimizer, model) + super().__init__( + optimizer=optimizer, + initial_scale=initial_scale, + min_scale=min_scale, + pg_to_param_list=pg_to_param_list, + growth_factor=growth_factor, + backoff_factor=backoff_factor, + growth_interval=growth_interval, + hysteresis=hysteresis, + max_scale=max_scale, + clip_grad_norm=clip_grad_norm, + verbose=verbose, + reduce_bucket_size=reduce_bucket_size, + communication_dtype=communication_dtype, + overlap_communication=overlap_communication, + partition_grad=partition_grad, + cpu_offload=cpu_offload, + dp_process_group=dp_process_group, + forced_dtype=forced_dtype, + overlap_allgather=overlap_allgather, + ) + + def sync_dp_grads(self): + r""" + Synchronize gradients in the data parallelism dimension. + + This method wraps the existing `_sync_grad` method in order to explicitly synchronize gradients + in the data parallelism dimension. It is necessary due to the introduction of new parallel dimensions, + namely tp (tensor parallelism) and pp (pipeline parallelism). This ensures better code organization + and readability. + + Args: + None + + Returns: + None + """ + # Call the superclass `_sync_grad` method to synchronize gradients. + super()._sync_grad() + + def _sync_sp_grads(self): + r""" + Synchronize gradients that are partially derived within sequence parallelism. + + This method is responsible for synchronizing partially derived gradients across tp parallelism groups. + It identifies gradients that ara partially derived or not and synchronizes them. + If synchronization is required and gradients are found to be synchronized, + it performs the synchronization. + + Args: + None + + Returns: + None + """ + + def _get_all_working_grads() -> List[Tensor]: + """Retrieve all working gradients from different parameter groups.""" + all_working_grads = [] + for group_id in range(self.num_param_groups): + working_grads = self.get_working_grads_by_group_id(group_id) + all_working_grads.extend(working_grads) + return all_working_grads + + def _get_grads_to_sync(all_working_grads) -> Union[List[Tensor], None]: + """Identify gradients to be synchronized in the sequence parallelism.""" + grads_to_sync = [] + for grad in all_working_grads: + param_id_for_grad = self.get_param_id_for_grad(grad) + param_for_grad = ctypes.cast(param_id_for_grad, ctypes.py_object).value + if SeqParallelUtils.is_sp_partial_derived_param(param_for_grad): + grads_to_sync.append(grad) + + if len(grads_to_sync) > 0: + return grads_to_sync + else: + return None + + # Get all working gradients and gradients to be synchronized. + all_working_grads = _get_all_working_grads() + grads_to_sync = _get_grads_to_sync(all_working_grads) + if self.require_grad_sync and grads_to_sync is not None: + # Synchronize sequence parallelism gradients if required. + SeqParallelUtils.allreduce_partial_data_grad(process_group=self.tp_pg, grads=grads_to_sync) + else: + return + + def backward(self, loss, retain_graph=False): + """ + Backpropagate gradients through the model and optionally synchronize sequence parallelism gradients. + + This method performs the backward pass for gradient computation based on a given loss tensor. + If sequence parallelism is enabled and gradient synchronization is required, it will synchronize + gradients that are partially derived within sequence parallelism across TP parallelism groups. + + Args: + loss: The loss tensor to compute gradients with respect to. + retain_graph (bool): Whether to retain the computation graph. + + Returns: + None + """ + # Call the superclass backward method to compute gradients. + super().backward(loss, retain_graph) + + if self.require_grad_sync and self.model.shard_config.enable_sequence_parallelism: + # If gradient synchronization is required, sync sequence parallelism gradients. + self._sync_sp_grads() + else: + # If gradient synchronization is is not required, return. + return + + def backward_by_grad(self, tensor, grad): + """ + Backpropagate gradients through the model using a precomputed gradient and optionally synchronize sequence parallelism gradients. + + This method performs a backward pass for gradient computation based on a precomputed gradient tensor. + If sequence parallelism is enabled and gradient synchronization is required, it will synchronize + gradients that are partially derived within sequence parallelism across TP parallelism groups. + + Args: + tensor: The input tensor for which gradients are computed. + grad: The precomputed gradient tensor to compute gradients with respect to the input tensor. + + Returns: + None + """ + # Call the superclass backward_by_grad method to compute gradients. + super().backward_by_grad(tensor, grad) + + if self.require_grad_sync and self.model.shard_config.enable_sequence_parallelism: + # If gradient synchronization is required, sync sequence parallelism gradients. + self._sync_sp_grads() + else: + # If gradient synchronization is is not required, return. + return + + def _compute_grad_norm(self, dp_pg, gradients: List[Tensor], norm_type: int = 2) -> float: + r""" + Compute and return the gradient norm for gradient clipping. + + Args: + gradients (List[Tensor]): A list of tensors containing gradients. + norm_type (int, optional): Type of the p-norm to be computed. Defaults to 2. + + Returns: + float: The computed gradient norm. + """ + + # Check if the list of gradients is empty + if len(gradients) == 0: + return 0.0 + + dp_size = get_world_size(dp_pg) if dp_pg is not None else 1 + tp_size = get_world_size(self.tp_pg) if self.tp_pg is not None else 1 + pp_size = get_world_size(self.pp_pg) if self.pp_pg is not None else 1 + norm_type = float(norm_type) + + if norm_type == inf: + # The parent class calculates the norm of 'dp' gradients, + # so we only need to calculate the norm 'tp' of 'pp' gradients. + total_norm = super()._compute_grad_norm(gradients, norm_type) + + total_norm_cuda = torch.tensor( + [float(total_norm)], device=get_accelerator().get_current_device(), dtype=torch.float32 + ) + + if tp_size > 1: + dist.all_reduce(tensor=total_norm_cuda, op=dist.ReduceOp.MAX, group=self.tp_pg) + if pp_size > 1: + dist.all_reduce(tensor=total_norm_cuda, op=dist.ReduceOp.MAX, group=self.pp_pg) + + total_norm = total_norm_cuda.item() + else: + total_norm_exponentiated = 0.0 + for grad in gradients: + grad_norm_exponentiated = grad.data.double().norm(norm_type) ** norm_type + + # If 'tp_size' is greater than 1 and the parameter for the gradient is not a distributed tensor, + # it indicates that the parameter is not distributed across devices of the 'tp_group'. + # Consequently, there is no need to perform an 'all_reduce' operation for 'grad_norm'. + # However, we still perform the 'all_reduce' operation for the sake of good coding practices. + # To ensure mathematical equivalence, we divide the 'grad_norm' by 'tp_size.' + if tp_size > 1: + param_id_for_grad = self.get_param_id_for_grad(grad) + param_for_grad = ctypes.cast(param_id_for_grad, ctypes.py_object).value + + if not is_distributed_tensor(param_for_grad): + grad_norm_exponentiated /= tp_size + + # If 'pp_size' is greater than 1 and the gradient belongs to shared parameters, + # it means that this parameter is used in two different pipeline stages. + # To avoid redundant norm calculations, we divide the exponent of this norm by + # the number of shared stages. + if pp_size > 1: + for shared_param in self.shared_params: + if self.stage_manager.stage in shared_param: + stage_shared_param = shared_param[self.stage_manager.stage] + working_grad = self.get_working_grad_by_param_id(id(stage_shared_param)) + if grad is working_grad: + grad_norm_exponentiated /= len(shared_param) + + total_norm_exponentiated += grad_norm_exponentiated + + total_norm_exponentiated_cuda = torch.tensor( + [float(total_norm_exponentiated)], device=get_accelerator().get_current_device(), dtype=torch.float32 + ) + if dp_size > 1: + # compute norm in dp process group + dist.all_reduce(tensor=total_norm_exponentiated_cuda, op=dist.ReduceOp.SUM, group=dp_pg) + if tp_size > 1: + # compute norm in tp process group + dist.all_reduce(tensor=total_norm_exponentiated_cuda, op=dist.ReduceOp.SUM, group=self.tp_pg) + if pp_size > 1: + # compute norm in pp process group + dist.all_reduce(tensor=total_norm_exponentiated_cuda, op=dist.ReduceOp.SUM, group=self.pp_pg) + + # Compute the 'total_norm' from 'total_norm_exponentiated' + total_norm = total_norm_exponentiated_cuda.item() ** (1.0 / norm_type) + + return total_norm + + +class HybridParallelPlugin(PipelinePluginBase): + """ + Plugin for Hybrid Parallel Training. + Tensor parallel, pipeline parallel and data parallel(DDP/ZeRO) can be picked and combined in this plugin. + The size of tp and pp should be passed in by user, then the size of dp is automatically calculated from dp_size = world_size / (tp_size * pp_size). + + ```python + from colossalai.booster import Booster + from colossalai.booster.plugin import HybridParallelPlugin + + model, train_dataset, optimizer, criterion = ... + plugin = HybridParallelPlugin(tp_size=2, pp_size=2) + + train_dataloader = plugin.prepare_dataloader(train_dataset, batch_size=8) + booster = Booster(plugin=plugin) + model, optimizer, criterion, train_dataloader, _ = booster.boost(model, optimizer, criterion, train_dataloader) + ``` + + Args: + tp_size (int): The size of tensor parallelism. Tensor parallelism will not be used when tp_size is set to 1. + pp_size (int): The number of pipeline stages in pipeline parallelism. Pipeline parallelism will not be used when pp_size is set to 1. + sp_size (int): The size of sequence parallelism. + precision (str, optional): Specifies the precision of parameters during training. + Auto-mixied precision will be used when this argument is set to 'fp16' or 'bf16', otherwise model is trained with 'fp32'. + Defaults to 'fp16'. + zero_stage (int, optional): The stage of ZeRO for data parallelism. Can only be choosed from [0, 1, 2]. + When set to 0, ZeRO will not be used. Defaults to 0. + enable_all_optimization (bool, optional): Whether to switch on all the optimizations supported by Shardformer. + Currently all the optimization methods include fused normalization, flash attention and JIT. + Defaults to False. + enable_fused_normalization (bool, optional): Whether to switch on fused normalization in Shardformer. Defaults to False. + enable_flash_attention (bool, optional): Whether to switch on flash attention in Shardformer. Defaults to False. + enable_jit_fused (bool, optional): Whether to switch on JIT in Shardformer. Default to False. + enable_sequence_parallelism (bool): Whether to turn on sequence parallelism in Shardformer. Defaults to False. + sequence_parallelism_mode (str): The Sequence parallelism mode. Can only be choosed from ["split_gather", "ring", "all_to_all"]. Defaults to "split_gather". + enable_sequence_overlap (bool): Whether to turn on sequence overlap in Shardformer. Defaults to False. + parallel_output (bool): Whether to keep the output parallel when enabling tensor parallelism. Default to True. + num_microbatches (int, optional): Number of microbatches when using pipeline parallelism. Defaults to None. + microbatch_size (int, optional): Microbatch size when using pipeline parallelism. + Either ``num_microbatches`` or ``microbatch_size`` should be provided if using pipeline. + If ``num_microbatches`` is provided, this will be ignored. Defaults to None. + initial_scale (float, optional): The initial loss scale of AMP. Defaults to 2**16. + min_scale (float, optional): The minimum loss scale of AMP. Defaults to 1. + growth_factor (float, optional): The multiplication factor for increasing loss scale when using AMP. Defaults to 2. + backoff_factor (float, optional): The multiplication factor for decreasing loss scale when using AMP. Defaults to 0.5. + growth_interval (int, optional): The number of steps to increase loss scale when no overflow occurs when using AMP. Defaults to 1000. + hysteresis (int, optional): The number of overflows before decreasing loss scale when using AMP. Defaults to 2. + max_scale (float, optional): The maximum loss scale of AMP. Defaults to 2**32. + max_norm (float, optional): Maximum norm for gradient clipping. Defaults to 0. + broadcast_buffers (bool, optional): Whether to broadcast buffers in the beginning of training when using DDP. Defaults to True. + ddp_bucket_cap_mb (int, optional): The bucket size in MB when using DDP. Defaults to 25. + find_unused_parameters (bool, optional): Whether to find unused parameters when using DDP. Defaults to False. + check_reduction (bool, optional): Whether to check reduction when using DDP. Defaults to False. + gradient_as_bucket_view (bool, optional): Whether to use gradient as bucket view when using DDP. Defaults to False. + static_graph (bool, optional): Whether to use static graph when using DDP. Defaults to False. + zero_bucket_size_in_m (int, optional): Gradient reduce bucket size in million elements when using ZeRO. Defaults to 12. + cpu_offload (bool, optional): Whether to open cpu_offload when using ZeRO. Defaults to False. + communication_dtype (torch.dtype, optional): Communication dtype when using ZeRO. If not specified, the dtype of param will be used. Defaults to None. + overlap_communication (bool, optional): Whether to overlap communication and computation when using ZeRO. Defaults to True. + custom_policy (Policy, optional): Custom policy for Shardformer. Defaults to None. + pp_style (str, optional): The style for pipeline parallelism. Defaults to '1f1b'. + num_model_chunks (int, optional): The number of model chunks for interleaved pipeline parallelism. Defaults to 1. + gradient_checkpoint_config (GradientCheckpointConfig, optional): Configuration for gradient checkpointing. Defaults to None. + enable_metadata_cache (bool, optional): Whether to enable metadata cache for pipeline parallelism. Defaults to True. + make_vocab_size_divisible_by (int, optional): it's used when padding the vocabulary size, to make it choose an faster kenel. Default to 64. + overlap_p2p (bool, optional): Whether to overlap the p2p communication in pipeline parallelism + inner_ring_size (int, optional): The inner ring size of 2D Ring Attention when sp mode is "ring_attn". + It's advisable to not tune this (especially in single-node settings) and let it be heuristically set based on topology by default. + + """ + + def __init__( + self, + tp_size: int, + pp_size: int, + sp_size: int = None, + precision: str = "fp16", + zero_stage: int = 0, + enable_all_optimization: bool = False, + enable_fused_normalization: bool = False, + enable_flash_attention: bool = False, + enable_jit_fused: bool = False, + enable_sequence_parallelism: bool = False, + sequence_parallelism_mode: str = None, + enable_sequence_overlap: bool = False, + parallel_output: bool = True, + num_microbatches: Optional[int] = None, + microbatch_size: Optional[int] = None, + initial_scale: float = 2**16, + min_scale: float = 1, + growth_factor: float = 2, + backoff_factor: float = 0.5, + growth_interval: int = 1000, + hysteresis: int = 2, + max_scale: float = 2**32, + max_norm: float = 0, + broadcast_buffers: bool = True, + ddp_bucket_cap_mb: int = 25, + find_unused_parameters: bool = False, + check_reduction: bool = False, + gradient_as_bucket_view: bool = False, + static_graph: bool = False, + zero_bucket_size_in_m: int = 12, + cpu_offload: bool = False, + communication_dtype: Optional[torch.dtype] = None, + overlap_communication: bool = True, + custom_policy: Policy = None, + pp_style: str = "1f1b", + num_model_chunks: int = 1, + num_layers_per_stage: Optional[List[int]] = None, + gradient_checkpoint_config: Optional[GradientCheckpointConfig] = None, + enable_metadata_cache: bool = True, + make_vocab_size_divisible_by: int = 64, + dp_outside: bool = True, + overlap_p2p: bool = True, + overlap_allgather: bool = False, + inner_ring_size: int = None, + ) -> None: + super().__init__() + self.logger = get_dist_logger() + + assert ( + dist.get_world_size() % (tp_size * pp_size) == 0 + ), f"World size {dist.get_world_size()} is not divisible by tp_size {tp_size} * pp_size {pp_size}" + + if enable_sequence_parallelism: + self.sequence_parallelism_mode = ( + sequence_parallelism_mode if sequence_parallelism_mode is not None else "all_to_all" + ) + assert ( + self.sequence_parallelism_mode in SUPPORT_SP_MODE + ), f"Sequence parallelism mode {self.sequence_parallelism_mode} is not in the supported list {SUPPORT_SP_MODE}" + if self.sequence_parallelism_mode in ["split_gather", "ring"]: + assert ( + tp_size > 1 + ), f"Sequence parallelism mode {self.sequence_parallelism_mode} must be enabled when using tensor parallelism" + if sp_size != 1: + self.logger.warning( + f"The sp_size will be the same as tp_size in sequence parallelism mode {self.sequence_parallelism_mode}, will ignore the given sequence parallelism size." + ) + self.sp_size = 1 + self.dp_size = dist.get_world_size() // (tp_size * pp_size) + elif self.sequence_parallelism_mode in ["all_to_all", "ring_attn"]: + self.sp_size = 1 if sp_size is None else sp_size + self.dp_size = dist.get_world_size() // (self.sp_size * pp_size * tp_size) + if self.sequence_parallelism_mode == "ring_attn": + enable_flash_attention = True + else: + self.dp_size = dist.get_world_size() // (tp_size * pp_size) + assert ( + sp_size == 1 or sp_size is None + ), f"You should not set sp_size when sequence parallelism is not enabled." + self.sp_size = 1 + + self.tp_size = tp_size + self.pp_size = pp_size + self.precision = precision + self.zero_stage = zero_stage + self.cpu_offload = cpu_offload + self.enable_all_optimization = enable_all_optimization + self.enable_fused_normalization = enable_fused_normalization + self.enable_flash_attention = enable_flash_attention + self.enable_jit_fused = enable_jit_fused + self.enable_sequence_parallelism = enable_sequence_parallelism + if dp_outside: + self.dp_axis, self.pp_axis, self.tp_axis, self.sp_axis = 0, 1, 2, 3 + if sequence_parallelism_mode == "ring_attn": + # Swap tp and sp since 2D Ring has better inter-node latency + self.pg_mesh = ProcessGroupMesh(self.dp_size, self.pp_size, self.sp_size, self.tp_size) + self.sp_axis = 2 + self.tp_axis = 3 + else: + self.pg_mesh = ProcessGroupMesh(self.dp_size, self.pp_size, self.tp_size, self.sp_size) + else: + self.pp_axis, self.dp_axis, self.tp_axis, self.sp_axis = 0, 1, 2, 3 + if sequence_parallelism_mode == "ring_attn": + self.pg_mesh = ProcessGroupMesh(self.pp_size, self.dp_size, self.sp_size, self.tp_size) + self.sp_axis = 2 + self.tp_axis = 3 + else: + self.pg_mesh = ProcessGroupMesh(self.pp_size, self.dp_size, self.tp_size, self.sp_size) + + self.stage_manager = None + self.schedule = None + self.custom_policy = custom_policy + assert zero_stage in (0, 1, 2) + if self.pp_size > 1: + assert pp_style in ["1f1b", "interleaved"], "Unsupported pipeline parallelism style" + assert pp_style == "interleaved" or num_model_chunks == 1, "num_model_chunks must be 1 when using 1f1b" + assert ( + num_microbatches is not None or microbatch_size is not None + ), "num_microbatches or microbatch_size must be specified when using pipeline parallelism" + assert ( + self.zero_stage <= 1 + ), "To avoid prohibitive gradient synchronization costs, zero stage must be 0 or 1 when using pipeline parallelism" + self.stage_manager = PipelineStageManager( + self.pg_mesh, + pipeline_axis=self.pp_axis, + enable_interleave=pp_style == "interleaved", + num_model_chunks=num_model_chunks, + num_layers_per_stage=num_layers_per_stage, + ) + + if pp_style == "interleaved": + assert num_model_chunks > 1, "number of model chunks must be > 1 when using interleaved" + self.schedule = InterleavedSchedule( + stage_manager=self.stage_manager, + num_model_chunks=num_model_chunks, + num_microbatch=num_microbatches, + microbatch_size=microbatch_size, + enable_metadata_cache=enable_metadata_cache, + overlap_p2p=overlap_p2p, + ) + elif pp_style == "1f1b": + self.schedule = OneForwardOneBackwardSchedule( + stage_manager=self.stage_manager, + num_microbatches=num_microbatches, + microbatch_size=microbatch_size, + enable_metadata_cache=enable_metadata_cache, + ) + else: + raise NotImplementedError() + if sequence_parallelism_mode == "ring_attn": + if not parallel_output: + self.logger.warning( + "parallel_output must be True for Zigzag Ring Attention, as we've not supported Zigzag all-gather yet." + ) + parallel_output = True + + self.tp_group = self.pg_mesh.get_group_along_axis(self.tp_axis) + self.dp_group = self.pg_mesh.get_group_along_axis(self.dp_axis) + self.pp_group = self.pg_mesh.get_group_along_axis(self.pp_axis) + if self.enable_sequence_parallelism and self.sequence_parallelism_mode in ["split_gather", "ring"]: + self.sp_group = self.pg_mesh.get_group_along_axis(self.tp_axis) + else: + self.sp_group = self.pg_mesh.get_group_along_axis(self.sp_axis) + + self.shard_config = ShardConfig( + tensor_parallel_process_group=self.tp_group, + sequence_parallel_process_group=self.sp_group, + pipeline_stage_manager=self.stage_manager, + enable_tensor_parallelism=self.tp_size > 1, + enable_all_optimization=self.enable_all_optimization, + enable_fused_normalization=self.enable_fused_normalization, + enable_flash_attention=self.enable_flash_attention, + enable_jit_fused=self.enable_jit_fused, + enable_sequence_parallelism=enable_sequence_parallelism, + sequence_parallelism_mode=sequence_parallelism_mode, + enable_sequence_overlap=enable_sequence_overlap, + parallel_output=parallel_output, + make_vocab_size_divisible_by=make_vocab_size_divisible_by, + gradient_checkpoint_config=gradient_checkpoint_config, + inner_ring_size=inner_ring_size, + ) + self.amp_config = dict( + initial_scale=initial_scale, + growth_factor=growth_factor, + backoff_factor=backoff_factor, + growth_interval=growth_interval, + hysteresis=hysteresis, + min_scale=min_scale, + max_scale=max_scale, + ) + + self.ddp_config = dict( + broadcast_buffers=broadcast_buffers, + bucket_cap_mb=ddp_bucket_cap_mb, + find_unused_parameters=find_unused_parameters, + check_reduction=check_reduction, + gradient_as_bucket_view=gradient_as_bucket_view, + static_graph=static_graph, + ) + + self.zero_config = dict( + reduce_bucket_size=zero_bucket_size_in_m * 1024 * 1024, + communication_dtype=communication_dtype, + overlap_communication=overlap_communication, + cpu_offload=cpu_offload, + partition_grad=(self.zero_stage == 2), + forced_dtype=PRECISION_TORCH_TYPE[precision], + overlap_allgather=overlap_allgather, + ) + + self.max_norm = max_norm + + def __del__(self): + """Destroy the process groups in ProcessGroupMesh""" + self.pg_mesh.destroy_mesh_process_groups() + + @property + def enable_pipeline_parallelism(self) -> bool: + return self.pp_size > 1 + + def supported_devices(self) -> List[str]: + return ["cuda", "npu"] + + def supported_precisions(self) -> List[str]: + return ["fp16", "bf16", "fp32"] + + def control_device(self) -> bool: + return True + + def control_precision(self) -> bool: + return True + + def support_no_sync(self) -> bool: + return True + + def support_lora(self) -> bool: + return True + + def control_checkpoint_io(self) -> bool: + return True + + def configure( + self, + model: Module, + optimizer: Optional[Optimizer] = None, + criterion: Optional[Callable] = None, + dataloader: Optional[DataLoader] = None, + lr_scheduler: Optional[LRScheduler] = None, + ) -> Tuple[Module, OptimizerWrapper, Callable, DataLoader, LRScheduler]: + param_info = get_param_info(optimizer) + + # TODO: Support Galore + ZeRO + zero_stage = self.zero_stage + zero_config = deepcopy(self.zero_config) + + # Replace with distributed implementation if exists + optimizer = cast_to_distributed(optimizer) + + if isinstance(optimizer, DistGaloreAwamW) and zero_stage > 0 and self.dp_size > 0: + self.logger.warning( + "Galore is only supported for Tensor Parallel and vanilla Data Parallel yet. Disabling ZeRO." + ) + zero_config["partition_grad"] = False + zero_stage = 0 + + if not isinstance(model, ModelWrapper): + # Shouldn't use pp (frequent grad accumulation) with torch ddp + use_ddp = (self.dp_size > 1 and self.pp_size == 1 and self.zero_stage == 0) or ( + self.dp_size == 1 and self.pp_size == 1 + ) + + # Apply Hybrid ZeRO across DP * SP ranks + if self.enable_sequence_parallelism and not is_share_sp_tp(self.sequence_parallelism_mode): + dp_group = self.pg_mesh.create_group_along_axis([self.dp_axis, self.sp_axis]) + self.dp_size = get_world_size(dp_group) + else: + dp_group = self.dp_group + model = HybridParallelModule( + model, + precision=self.precision, + shard_config=self.shard_config, + dp_group=dp_group, + tp_group=self.tp_group, + sp_group=self.sp_group, + use_ddp=use_ddp, + ddp_config=self.ddp_config, + custom_policy=self.custom_policy, + overlap_allgather=(self.zero_stage > 0 and self.zero_config["overlap_allgather"]), + ) + if optimizer is not None and not isinstance(optimizer, OptimizerWrapper): + if zero_stage == 0: + is_zero = False + if self.precision in ["fp16", "bf16"]: + optimizer = HybridParallelAMPOptimizer( + optimizer, + model, + use_pipeline=self.enable_pipeline_parallelism, + param_info=param_info, + precision=self.precision, + max_norm=self.max_norm, + pp_process_group=self.pp_group, + tp_process_group=self.tp_group, + **self.amp_config, + ) + else: + optimizer = HybridParallelNaiveOptimizer( + optimizer, + model, + use_pipeline=self.enable_pipeline_parallelism, + param_info=param_info, + max_norm=self.max_norm, + pp_process_group=self.pp_group, + tp_process_group=self.tp_group, + ) + else: + is_zero = self.dp_size > 1 + if self.dp_size == 1: + self.logger.warning( + "Use Zero Optimizer when data parallel size is 1 may introduce unnecessary overhead. " + "If you do not intend to use cpu_offload, please consider set zero_stage=0." + ) + + assert self.precision != "fp32", "Please set precision to 'fp16' or 'bf16' when using ZeRO." + optimizer = HybridParallelZeroOptimizer( + optimizer, + model, + use_pipeline=self.enable_pipeline_parallelism, + param_info=param_info, + dp_process_group=dp_group, + tp_process_group=self.tp_group, + pp_process_group=self.pp_group, + verbose=True, + clip_grad_norm=self.max_norm, + **zero_config, + **self.amp_config, + ) + # inject update_master_params + model.update_master_params = MethodType(optimizer.update_master_params, model) + + # Setup optimizers that require global states + optim = optimizer.optim + if isinstance(optim, DistributedOptim): + shard_to_param = optimizer.get_master_to_working_map() if is_zero else {} + padding_map = optimizer.get_param_padding_map() if is_zero else defaultdict(int) + optim.setup_distributed(self.tp_group, self.dp_group, shard_to_param, padding_map, is_zero) + + return model, optimizer, criterion, dataloader, lr_scheduler + + def execute_pipeline( + self, + data_iter: Iterator, + model: HybridParallelModule, + criterion: Callable[[Any, Any], torch.Tensor], + optimizer: Optional[ + Union[HybridParallelNaiveOptimizer, HybridParallelAMPOptimizer, HybridParallelZeroOptimizer] + ] = None, + return_loss: bool = True, + return_outputs: bool = False, + ) -> dict: + assert self.enable_pipeline_parallelism, "pipeline parallelism is not enabled" + + if return_outputs: + self.logger.warning("return_outputs may lead to significant extra memory consumption.") + + # Create a context for gradient synchronization based on the optimizer type. + # If it's a HybridParallelZeroOptimizer, use optimizer.no_sync(); otherwise, use model.no_sync(). + # This is to avoid redundant gradient reduction in pipeline parallelism (multiple microbatch values should be reduced once), + # so we disable it, performing manual reduction instead. + ctx = optimizer.no_sync() if isinstance(optimizer, HybridParallelZeroOptimizer) else model.no_sync() + + with ctx, model._wait_all_gather(): + outputs = self.schedule.forward_backward_step( + model, data_iter, criterion, optimizer, return_loss, return_outputs + ) + + # run with gradients accumulation + if model.require_grad_sync == False or ( + isinstance(optimizer, HybridParallelZeroOptimizer) and optimizer.require_grad_sync == False + ): + return outputs + + # Synchronize the grads of shared parameters of the model. + model.sync_shared_params() + # Synchronize sequence parallelism gradients of the model. + model.sync_sp_grads() + + # Check if the optimizer is a HybridParallelZeroOptimizer and synchronize data parallelism gradients if so. + # Otherwise, synchronize data parallelism gradients of the model. + # This is because these are two different forms of data parallelism. + if isinstance(optimizer, HybridParallelZeroOptimizer): + optimizer.sync_dp_grads() + else: + model.sync_dp_grads() + + return outputs + + def prepare_dataloader( + self, + dataset, + batch_size, + shuffle=False, + seed=1024, + drop_last=False, + pin_memory=False, + num_workers=0, + distributed_sampler_cls=None, + **kwargs, + ): + r""" + Prepare a dataloader for distributed training. The dataloader will be wrapped by + `torch.utils.data.DataLoader` and `torch.utils.data.DistributedSampler`. + + + Args: + dataset (`torch.utils.data.Dataset`): The dataset to be loaded. + shuffle (bool, optional): Whether to shuffle the dataset. Defaults to False. + seed (int, optional): Random worker seed for sampling, defaults to 1024. + add_sampler: Whether to add ``DistributedDataParallelSampler`` to the dataset. Defaults to True. + drop_last (bool, optional): Set to True to drop the last incomplete batch, if the dataset size + is not divisible by the batch size. If False and the size of dataset is not divisible by + the batch size, then the last batch will be smaller, defaults to False. + pin_memory (bool, optional): Whether to pin memory address in CPU memory. Defaults to False. + num_workers (int, optional): Number of worker threads for this dataloader. Defaults to 0. + kwargs (dict): optional parameters for ``torch.utils.data.DataLoader``, more details could be found in + `DataLoader `_. + + Returns:` + :class:`torch.utils.data.DataLoader`: A DataLoader used for training or testing. + """ + _kwargs = kwargs.copy() + distributed_sampler_cls = distributed_sampler_cls or DistributedSampler + sampler = distributed_sampler_cls( + dataset, + num_replicas=self.dp_group.size(), + rank=dist.get_group_rank(self.dp_group, global_rank=dist.get_rank()), + shuffle=shuffle, + ) + + # Deterministic dataloader + def seed_worker(worker_id): + worker_seed = seed + np.random.seed(worker_seed) + torch.manual_seed(worker_seed) + random.seed(worker_seed) + + return DataLoader( + dataset, + batch_size=batch_size, + sampler=sampler, + worker_init_fn=seed_worker, + drop_last=drop_last, + pin_memory=pin_memory, + num_workers=num_workers, + **_kwargs, + ) + + def get_checkpoint_io(self) -> CheckpointIO: + return HybridParallelCheckpointIO(self.dp_group, self.pp_group, self.tp_group, self.zero_stage) + + def no_sync(self, model: Module, optimizer: OptimizerWrapper) -> Iterator[None]: + assert ( + self.zero_stage != 2 + ), "ZERO2 is not compatible with no_sync function, please run gradient accumulation with gradient synchronization allowed." + return optimizer.no_sync() if isinstance(optimizer, HybridParallelZeroOptimizer) else model.no_sync() + + def enable_lora( + self, + model: Module, + pretrained_dir: Optional[str] = None, + lora_config: Optional[Dict] = None, + bnb_quantization_config: Optional[BnbQuantizationConfig] = None, + ) -> Module: + from peft import PeftModel, get_peft_model + + assert not isinstance(model, HybridParallelModule), "Lora should be enabled before boosting the model." + assert self.pp_size == 1 and self.tp_size == 1 + self.lora_enabled = True + self.logger.warning("You have enabled LoRa training. Please check the hyperparameters such as lr") + + if bnb_quantization_config is not None: + model = quantize_model(model, bnb_quantization_config) + + if pretrained_dir is None: + peft_model = get_peft_model(model, lora_config) + else: + peft_model = PeftModel.from_pretrained(model, pretrained_dir, is_trainable=True) + return peft_model diff --git a/colossalai/booster/plugin/plugin/low_level_zero_plugin.py b/colossalai/booster/plugin/plugin/low_level_zero_plugin.py new file mode 100644 index 000000000000..6c36bad3c214 --- /dev/null +++ b/colossalai/booster/plugin/plugin/low_level_zero_plugin.py @@ -0,0 +1,521 @@ +import enum +import os +from contextlib import nullcontext +from functools import partial +from pathlib import Path +from types import MethodType +from typing import Callable, Dict, Iterator, List, Optional, Tuple + +import torch +import torch.distributed +import torch.distributed as dist +import torch.nn as nn +from torch.distributed.distributed_c10d import _get_default_group +from torch.nn import Parameter +from torch.optim import Optimizer +from torch.optim.lr_scheduler import _LRScheduler as LRScheduler +from torch.utils._pytree import tree_map +from torch.utils.data import DataLoader + +from colossalai.accelerator import get_accelerator +from colossalai.checkpoint_io import CheckpointIndexFile, CheckpointIO +from colossalai.checkpoint_io.utils import ( + get_optimizer_base_filenames, + get_shard_filename, + load_param_groups_into_optimizer, + load_shard_state_dict, + load_states_into_optimizer, + save_param_groups, + save_state_dict, + sharded_optimizer_loading_epilogue, +) +from colossalai.interface import AMPModelMixin, ModelWrapper, OptimizerWrapper +from colossalai.interface.optimizer import DistributedOptim +from colossalai.logging import get_dist_logger +from colossalai.nn.optimizer import DistGaloreAwamW, cast_to_distributed +from colossalai.quantization import BnbQuantizationConfig, quantize_model +from colossalai.tensor.colo_parameter import ColoParameter +from colossalai.tensor.param_op_hook import ColoParamOpHookManager +from colossalai.zero import LowLevelZeroOptimizer +from colossalai.zero.low_level.zero_hook import ZeroOpHook, wait_all_gather_handle + +from .dp_plugin_base import DPPluginBase +from .torch_ddp_plugin import TorchDDPCheckpointIO + +__all__ = ["LowLevelZeroPlugin"] + + +def _convert_floating_point(x, dtype: torch.dtype = torch.float16): + if isinstance(x, torch.Tensor) and torch.is_floating_point(x): + return x.to(dtype) + return x + + +SUPPORTED_PRECISION = ["fp16", "bf16", "fp32"] + + +class OptimizerParamCheckState(enum.Enum): + ORIGIN_PARAM_FINDED = 0 + ORIGIN_PARAM_NOT_FIND = -1 + LORA_PARM_EXISTED = -2 + + +class LowLevelZeroModel(ModelWrapper, AMPModelMixin): + def __init__(self, module: nn.Module, precision: str, overlap_allgather: bool = False) -> None: + super().__init__(module) + self.dtype = None + if precision == "fp16": + self.dtype = torch.float16 + elif precision == "bf16": + self.dtype = torch.bfloat16 + if self.dtype is not None: + module = module.to(self.dtype) + module = module.to(get_accelerator().get_current_device()) + self.module = module + self.convert_fn = None + if self.dtype is not None: + self.convert_fn = partial(_convert_floating_point, dtype=self.dtype) + self.overlap_allgather = overlap_allgather + if overlap_allgather: + self.op_hook = ZeroOpHook() + for p in module.parameters(): + if p.requires_grad and type(p) is not ColoParameter: + p.__class__ = ColoParameter + p.__init__(p, requires_grad=True) + + def forward(self, *args, **kwargs): + if self.convert_fn is not None: + args = tree_map(self.convert_fn, args) + kwargs = tree_map(self.convert_fn, kwargs) + ctx = ColoParamOpHookManager.use_hooks(self.op_hook) if self.overlap_allgather else nullcontext() + with ctx: + return super().forward(*args, **kwargs) + + def _force_wait_all_gather(self): + for p in self.module.parameters(): + wait_all_gather_handle(p) + + +class LowLevelZeroCheckpointIO(TorchDDPCheckpointIO): + def save_unsharded_optimizer(self, optimizer: OptimizerWrapper, checkpoint: str, gather_dtensor: bool = False): + """Save optimizer to checkpoint but only on master process. + + Args: + optimizer (OptimizerWrapper): Optimizer to save state_dict + checkpoint (str): Path to save checkpoint + gather_dtensor (bool): Whether to gather_dtensor, not used + """ + assert isinstance(optimizer, LowLevelZeroOptimizer), "Please boost the optimizer before saving!" + # the `state_dict` in LowLevelZeroOptimizer has communication + # if only the master rank collect state_dict and save, + # the communication on each rank would not match + state_dict = optimizer.state_dict() + if self.coordinator.is_master(): + save_state_dict(state_dict, checkpoint, use_safetensors=False) + self.logger = get_dist_logger() + + def save_sharded_optimizer( + self, + optimizer: OptimizerWrapper, + checkpoint: str, + gather_dtensor: bool = False, + prefix: str = None, + size_per_shard: int = 1024, + ): + """ + Save sharded Zero-optimizer checkpoint under the given checkpointing path. + The following files will be created under the path: + - An index file (pytorch_optim.bin.index.json) containing a map between optimizer states and file names + - A group file (pytorch_optim_group.bin) recording information of param_groups + - Multiple files (pytorch_optim-000XX.bin) that store state tensors of optimizer in a sharding way + + Args: + optimizer (OptimizerWrapper): Optimizer to save sharded state_dict + checkpoint (str): Path to save optimizer state_dict + gather_dtensor (bool): Whether to gather_dtensor, not used + prefix (str): Perfix of file to save + size_per_shard (int): Max file size of each file that store state tensors + """ + assert isinstance(optimizer, LowLevelZeroOptimizer), "Please boost the optimizer before saving!" + if os.path.isfile(checkpoint): + self.logger.error(f"Provided path ({checkpoint}) should be a directory, not a file") + return + + Path(checkpoint).mkdir(parents=True, exist_ok=True) + + # state_dict only provide only 'param_groups' + state_dict = optimizer.optim.state_dict() + # state shard would be handled by the low-level zero optimizer + sharded_state = optimizer.state_dict_shard(max_shard_size=size_per_shard) + + # Preparing file paths and index file. + states_name, save_index_file, param_group_file = get_optimizer_base_filenames(prefix) + index_file = CheckpointIndexFile(checkpoint) + index_file.append_meta_data("param_groups", param_group_file) + + # Store the information of param groups to param_group_file. + if self.coordinator.is_master(): + group_file_path = os.path.join(checkpoint, param_group_file) + save_param_groups(state_dict, group_file_path) + + # Save shards of optimizer states. + total_size = 0 + for idx, shard_pair in enumerate(sharded_state): + shard, current_size = shard_pair + shard_file = get_shard_filename(states_name, idx) + total_size = total_size + current_size + for param_id in shard.keys(): + index_file.append_weight_map(str(param_id), shard_file) + + checkpoint_file_path = os.path.join(checkpoint, shard_file) + if self.coordinator.is_master(): + save_state_dict(shard, checkpoint_file_path, use_safetensors=False) + + # Wrap up index file. + index_file.append_meta_data("total_size", total_size) + if self.coordinator.is_master(): + index_file.write_index_file(save_index_file) + self.logger.info( + f"The optimizer is going to be split to checkpoint shards. " + f"You can find where each parameters has been saved in the " + f"index located at {save_index_file}." + ) + + def load_sharded_optimizer(self, optimizer: OptimizerWrapper, index_file_path: str, prefix: str): + """Load sharded optimizer with the given path to index file. + + Args: + optimizer (OptimizerWrapper): Optimizer to load state_dict + index_file_path (str): Path to the index file + prefix (str): Not used. + """ + assert isinstance(optimizer, LowLevelZeroOptimizer), "Please boost the optimizer before Loading!" + optimizer = optimizer.unwrap() + + # Read checkpoint index file. + ckpt_index_file = CheckpointIndexFile.from_file(index_file_path) + + # Load param_groups + param_group_path = ckpt_index_file.get_param_group_filename() + if param_group_path is None: + raise RuntimeError( + f"Invalid index file path {index_file_path} for an optimizer. \ + Lacking param group file under current directory." + ) + id_map = load_param_groups_into_optimizer(optimizer, param_group_path) + + checkpoint_files, _ = ckpt_index_file.get_checkpoint_filenames() + + for shard_file in checkpoint_files: + state_dict = load_shard_state_dict(Path(shard_file), use_safetensors=False) + # shard state dict + for param_idx, state in state_dict.items(): + for k, v in state.items(): + if isinstance(v, torch.Tensor) and k != "step": + padding_size = ( + self.coordinator.world_size - v.numel() % self.coordinator.world_size + ) % self.coordinator.world_size + with torch.no_grad(): + v = v.flatten() + if padding_size > 0: + v = torch.nn.functional.pad(v, [0, padding_size]) + v_list = v.split(v.numel() // self.coordinator.world_size) + state_dict[param_idx][k] = v_list[self.coordinator.rank].detach().clone() + load_states_into_optimizer(optimizer, state_dict, id_map) + sharded_optimizer_loading_epilogue(optimizer) + + def load_unsharded_model(self, model: ModelWrapper, checkpoint: str, strict: bool = True): + assert isinstance(model, LowLevelZeroModel), "Please boost the model before loading!" + model._force_wait_all_gather() + super().load_unsharded_model(model, checkpoint, strict) + model.update_master_params() + + def load_sharded_model( + self, + model: ModelWrapper, + checkpoint_index_file: Path, + strict: bool = False, + use_safetensors: bool = False, + load_sub_module: bool = True, + ): + assert isinstance(model, LowLevelZeroModel), "Please boost the model before loading!" + model._force_wait_all_gather() + super().load_sharded_model(model, checkpoint_index_file, strict, use_safetensors, load_sub_module) + model.update_master_params() + + def save_unsharded_model(self, model: ModelWrapper, checkpoint: str, gather_dtensor: bool, use_safetensors: bool): + assert isinstance(model, LowLevelZeroModel), "Please boost the model before loading!" + model._force_wait_all_gather() + return super().save_unsharded_model(model, checkpoint, gather_dtensor, use_safetensors) + + def save_sharded_model( + self, + model: ModelWrapper, + checkpoint_path: str, + gather_dtensor: bool = True, + prefix: Optional[str] = None, + max_shard_size: int = 1024, + use_safetensors: bool = False, + ): + assert isinstance(model, LowLevelZeroModel), "Please boost the model before loading!" + model._force_wait_all_gather() + return super().save_sharded_model( + model, checkpoint_path, gather_dtensor, prefix, max_shard_size, use_safetensors + ) + + def save_lora_as_pretrained(self, model, checkpoint, use_safetensors): + if os.path.isfile(checkpoint): + self.logger.error(f"Provided path ({checkpoint}) should be a directory, not a file") + return + from peft import PeftModel + + assert isinstance(model, ModelWrapper), "Please boost the model before saving!" + model._force_wait_all_gather() + peft_model = model.unwrap() + assert isinstance( + peft_model, PeftModel + ), "The model doesn't have lora adapters, please enable lora before saving." + return peft_model.save_pretrained(checkpoint, safe_serialization=use_safetensors) + + +class LowLevelZeroPlugin(DPPluginBase): + """ + Plugin for low level zero. + + ```python + from colossalai.booster import Booster + from colossalai.booster.plugin import LowLevelZeroPlugin + + model, train_dataset, optimizer, criterion = ... + plugin = LowLevelZeroPlugin() + + train_dataloader = plugin.prepare_dataloader(train_dataset, batch_size=8) + booster = Booster(plugin=plugin) + model, optimizer, train_dataloader, criterion = booster.boost(model, optimizer, train_dataloader, criterion) + ``` + + Args: + stage (int, optional): ZeRO stage. Defaults to 1. + precision (str, optional): precision. Support 'fp16', 'bf16' and 'fp32'. Defaults to 'fp16'. + initial_scale (float, optional): Initial scale used by DynamicGradScaler. Defaults to 2**32. + min_scale (float, optional): Min scale used by DynamicGradScaler. Defaults to 1. + growth_factor (float, optional): growth_factor used by DynamicGradScaler. Defaults to 2. + backoff_factor (float, optional): backoff_factor used by DynamicGradScaler. Defaults to 0.5. + growth_interval (float, optional): growth_interval used by DynamicGradScaler. Defaults to 1000. + hysteresis (float, optional): hysteresis used by DynamicGradScaler. Defaults to 2. + max_scale (int, optional): max_scale used by DynamicGradScaler. Defaults to 2**32. + max_norm (float, optional): max_norm used for `clip_grad_norm`. You should notice that you shall not do + clip_grad_norm by yourself when using ZeRO DDP. The ZeRO optimizer will take care of clip_grad_norm. + norm_type (float, optional): norm_type used for `clip_grad_norm`. + reduce_bucket_size_in_m (int, optional): grad reduce bucket size in M. Defaults to 12. + communication_dtype (torch.dtype, optional): communication dtype. If not specified, the dtype of param will be used. Defaults to None. + overlap_communication (bool, optional): whether to overlap communication and computation. Defaults to True. + cpu_offload (bool, optional): whether to offload grad, master weight and optimizer state to cpu. Defaults to False. + verbose (bool, optional): verbose mode. Debug info including grad overflow will be printed. Defaults to False. + """ + + def __init__( + self, + stage: int = 1, + precision: str = "fp16", + initial_scale: float = 2**32, + min_scale: float = 1, + growth_factor: float = 2, + backoff_factor: float = 0.5, + growth_interval: int = 1000, + hysteresis: int = 2, + max_scale: float = 2**32, + max_norm: float = 0.0, + norm_type: float = 2.0, + reduce_bucket_size_in_m: int = 12, + communication_dtype: Optional[torch.dtype] = None, + overlap_communication: bool = True, + overlap_allgather: bool = False, + cpu_offload: bool = False, + master_weights: bool = True, + verbose: bool = False, + ) -> None: + super().__init__() + assert stage in (1, 2), f"LowLevelZeroPlugin only supports stage 1/2 training" + assert precision in SUPPORTED_PRECISION, f"LowLevelZeroPlugin only supports {SUPPORTED_PRECISION} training" + assert norm_type == 2.0, f"LowLevelZeroPlugin only supports norm_type=2.0 now" + self.stage = stage + self.precision = precision + self.zero_optim_kwargs = dict( + initial_scale=initial_scale, + min_scale=min_scale, + growth_factor=growth_factor, + backoff_factor=backoff_factor, + growth_interval=growth_interval, + hysteresis=hysteresis, + max_scale=max_scale, + clip_grad_norm=max_norm, + reduce_bucket_size=reduce_bucket_size_in_m * 1024 * 1024, + communication_dtype=communication_dtype, + overlap_communication=overlap_communication, + partition_grad=(stage == 2), + cpu_offload=cpu_offload, + master_weights=master_weights, + overlap_allgather=overlap_allgather, + ) + self.lora_enabled = False + self.verbose = verbose + self.logger = get_dist_logger() + # set class name with stage, for better error message + setattr(self.__class__, "__name__", f"LowLevelZeroPlugin_ZeRO-{stage}") + + def support_no_sync(self) -> bool: + return self.stage == 1 + + def support_lora(self) -> bool: + return False + + def control_precision(self) -> bool: + return True + + def supported_precisions(self) -> List[str]: + return SUPPORTED_PRECISION + + def control_device(self) -> bool: + return True + + def supported_devices(self) -> List[str]: + return ["cuda", "npu"] + + def support_lora(self) -> bool: + return True + + def enable_lora( + self, + model: nn.Module, + pretrained_dir: Optional[str] = None, + lora_config: Optional[Dict] = None, + bnb_quantization_config: Optional[BnbQuantizationConfig] = None, + ) -> nn.Module: + from peft import PeftModel, get_peft_model + + assert not isinstance(model, LowLevelZeroModel), "Lora should be enabled before boosting the model." + self.lora_enabled = True + self.logger.warning("You have enabled LoRa training. Please check the hyperparameters such as lr") + + if bnb_quantization_config is not None: + model = quantize_model(model, bnb_quantization_config) + + if pretrained_dir is None: + peft_model = get_peft_model(model, lora_config) + else: + peft_model = PeftModel.from_pretrained(model, pretrained_dir, is_trainable=True) + return peft_model + + def get_param_group_id(self, optimizer: Optimizer, origin_param: Parameter): + origin_param_id = id(origin_param) + for group_id, param_group in enumerate(optimizer.param_groups): + for p in param_group["params"]: + if id(p) == origin_param_id: + return group_id + return -1 + + def get_param_group_id(self, optimizer: Optimizer, origin_param: Parameter, lora_param: Parameter): + origin_param_id = id(origin_param) + lora_param_id = id(lora_param) + target_group_id = None + for group_id, param_group in enumerate(optimizer.param_groups): + for p in param_group["params"]: + if id(p) == lora_param_id: + # check if the lora parameter exists. + return target_group_id, OptimizerParamCheckState.LORA_PARM_EXISTED + if id(p) == origin_param_id: + target_group_id = group_id + if target_group_id is not None: + return target_group_id, OptimizerParamCheckState.ORIGIN_PARAM_FINDED + else: + return target_group_id, OptimizerParamCheckState.ORIGIN_PARAM_NOT_FIND + + def add_lora_params_to_optimizer(self, model, optimizer): + """add lora parameters to optimizer""" + name2param = {} + for name, param in model.named_parameters(): + name2param[name] = param + + for name, param in name2param.items(): + if "lora_A" in name or "lora_B" in name: + origin_key = name.replace("lora_A.", "") + origin_key = origin_key.replace("lora_B.", "") + origin_key = origin_key.replace(f"{model.active_adapter}", "base_layer") + origin_param = name2param[origin_key] + group_id, check_state = self.get_param_group_id(optimizer, origin_param, param) + if check_state == OptimizerParamCheckState.ORIGIN_PARAM_NOT_FIND: + self.logger.warning( + f"Origin parameter {origin_key} related to {name} doesn't exist in optimizer param_groups." + ) + elif ( + check_state == OptimizerParamCheckState.ORIGIN_PARAM_FINDED + and group_id is not None + and group_id >= 0 + ): + optimizer.param_groups[group_id]["params"].append(param) + + def configure( + self, + model: nn.Module, + optimizer: Optional[Optimizer] = None, + criterion: Optional[Callable] = None, + dataloader: Optional[DataLoader] = None, + lr_scheduler: Optional[LRScheduler] = None, + ) -> Tuple[nn.Module, OptimizerWrapper, Callable, DataLoader, LRScheduler]: + if self.lora_enabled: + from peft import PeftModel + + assert isinstance( + model, PeftModel + ), "The model should have been wrapped as a PeftModel when self.lora_enabled is True" + if optimizer is not None: + self.add_lora_params_to_optimizer(model, optimizer) + + if not isinstance(model, ModelWrapper): + model = LowLevelZeroModel( + model, self.precision, overlap_allgather=self.zero_optim_kwargs["overlap_allgather"] + ) + + # TODO: Support Galore + ZeRO + zero_stage = self.stage + zero_optim_kwargs = {**self.zero_optim_kwargs} + dp_size = dist.get_world_size() + + # Replace with the distributed implementation if exists + optimizer = cast_to_distributed(optimizer) + + if isinstance(optimizer, DistGaloreAwamW) and zero_stage > 0 and dp_size > 0: + self.logger.warning( + "Galore is only supported for Tensor Parallel and vanilla Data Parallel yet. Disabling ZeRO." + ) + zero_optim_kwargs["partition_grad"] = False + zero_stage = 0 + + if optimizer is not None and not isinstance(optimizer, OptimizerWrapper): + optimizer: LowLevelZeroOptimizer = LowLevelZeroOptimizer( + optimizer, **zero_optim_kwargs, verbose=self.verbose + ) + # inject update_master_params + model.update_master_params = MethodType(optimizer.update_master_params, model) + + # Setup optimizers that require global states + optim = optimizer.optim + is_zero = dp_size > 1 and zero_stage > 0 + dp_group = _get_default_group() # Use the whole world + if isinstance(optim, DistributedOptim): + shard_to_param = optimizer.get_master_to_working_map() + padding_map = optimizer.get_param_padding_map() + optim.setup_distributed(None, dp_group, shard_to_param, padding_map, is_zero) + + return model, optimizer, criterion, dataloader, lr_scheduler + + def control_checkpoint_io(self) -> bool: + return True + + def get_checkpoint_io(self) -> CheckpointIO: + return LowLevelZeroCheckpointIO() + + def no_sync(self, model: nn.Module, optimizer: OptimizerWrapper) -> Iterator[None]: + assert isinstance(optimizer, LowLevelZeroOptimizer) + return optimizer.no_sync() diff --git a/colossalai/booster/plugin/plugin/moe_hybrid_parallel_plugin.py b/colossalai/booster/plugin/plugin/moe_hybrid_parallel_plugin.py new file mode 100644 index 000000000000..874028f09b86 --- /dev/null +++ b/colossalai/booster/plugin/plugin/moe_hybrid_parallel_plugin.py @@ -0,0 +1,490 @@ +from collections import defaultdict +from types import MethodType +from typing import Callable, List, Optional, OrderedDict, Tuple + +import torch +import torch.distributed as dist +from torch.distributed import ProcessGroup +from torch.nn import Module +from torch.optim import Optimizer +from torch.optim.lr_scheduler import _LRScheduler as LRScheduler +from torch.utils.data import DataLoader + +from colossalai.booster.plugin.hybrid_parallel_plugin import ( + PRECISION_TORCH_TYPE, + SUPPORT_SP_MODE, + HybridParallelAMPOptimizer, + HybridParallelModule, + HybridParallelNaiveOptimizer, + HybridParallelPlugin, + HybridParallelZeroOptimizer, + get_param_info, + reinitialize_optimizer, +) +from colossalai.checkpoint_io import MoECheckpointIO +from colossalai.cluster.process_group_mesh import ProcessGroupMesh +from colossalai.interface import ModelWrapper, OptimizerWrapper +from colossalai.interface.optimizer import DistributedOptim +from colossalai.logging import get_dist_logger +from colossalai.nn.optimizer import cast_to_distributed +from colossalai.pipeline.schedule.interleaved_pp import InterleavedSchedule +from colossalai.pipeline.schedule.one_f_one_b import OneForwardOneBackwardSchedule +from colossalai.pipeline.stage_manager import PipelineStageManager +from colossalai.shardformer.policies.base_policy import Policy +from colossalai.shardformer.shard.grad_ckpt_config import GradientCheckpointConfig +from colossalai.shardformer.shard.shard_config import ShardConfig +from colossalai.tensor.moe_tensor.api import is_moe_tensor + + +class MoeHybridParallelZeroOptimizer(HybridParallelZeroOptimizer): + def __init__( + self, + optimizer: Optimizer, + model: Module, + use_pipeline: bool, + dp_process_group: Optional[ProcessGroup], # the dp pg for comm + tp_process_group: Optional[ProcessGroup], # if using tp + pp_process_group: Optional[ProcessGroup], # if using pp + moe_dp_group: ProcessGroup, # moe dp pg for comm + param_info: OrderedDict, + initial_scale: int = 2**16, # grad scaler config + min_scale: int = 1, + growth_factor: float = 2.0, + backoff_factor: float = 0.5, + growth_interval: int = 2000, + hysteresis: int = 2, + max_scale: int = 2**24, + clip_grad_norm: float = 0.0, # grad clipping + verbose: bool = False, + reduce_bucket_size: int = 1024 * 1024, # communication + communication_dtype: Optional[torch.dtype] = None, + overlap_communication: bool = False, + partition_grad: bool = False, # stage 2 flag + cpu_offload: bool = False, # cpu offload + forced_dtype: Optional[torch.dtype] = None, + overlap_allgather: bool = False, + ): + pg_param_list = { + dp_process_group: list(filter(lambda p: not is_moe_tensor(p), model.parameters())), + moe_dp_group: list(filter(is_moe_tensor, model.parameters())), + } + + if len(pg_param_list[dp_process_group]) == 0 or len(pg_param_list[moe_dp_group]) == 0: + raise ValueError("No parameters found in dp_process_group or moe_dp_group") + + super().__init__( + model=model, + optimizer=optimizer, + use_pipeline=use_pipeline, + param_info=param_info, + initial_scale=initial_scale, + min_scale=min_scale, + growth_factor=growth_factor, + backoff_factor=backoff_factor, + growth_interval=growth_interval, + hysteresis=hysteresis, + max_scale=max_scale, + clip_grad_norm=clip_grad_norm, + verbose=verbose, + reduce_bucket_size=reduce_bucket_size, + communication_dtype=communication_dtype, + overlap_communication=overlap_communication, + partition_grad=partition_grad, + cpu_offload=cpu_offload, + tp_process_group=tp_process_group, + pp_process_group=pp_process_group, + forced_dtype=forced_dtype, + pg_to_param_list=pg_param_list, + overlap_allgather=overlap_allgather, + ) + + +class MoeHybridParallelPlugin(HybridParallelPlugin): + """ + Plugin for MoE Hybrid Parallel Training, which is similar to HybridParallelPlugin + Tensor parallel, pipeline parallel and data parallel(DDP/ZeRO) can be picked and combined in this plugin. + The size of tp and pp should be passed in by user, then the size of dp is automatically calculated from dp_size = world_size / (tp_size * pp_size). + + ```python + from colossalai.booster import Booster + from colossalai.booster.plugin import MoeHybridParallelPlugin + + model, train_dataset, optimizer, criterion = ... + plugin = MoeHybridParallelPlugin(tp_size=2, pp_size=2, ep_size=2) + + train_dataloader = plugin.prepare_dataloader(train_dataset, batch_size=8) + booster = Booster(plugin=plugin) + model, optimizer, criterion, train_dataloader, _ = booster.boost(model, optimizer, criterion, train_dataloader) + ``` + + Args: + tp_size (int): The size of tensor parallelism. Tensor parallelism will not be used when tp_size is set to 1. + pp_size (int): The number of pipeline stages in pipeline parallelism. Pipeline parallelism will not be used when pp_size is set to 1. + ep_size (int): The size of expert parallelism + sp_size (int): The size of sequence parallelism. + precision (str, optional): Specifies the precision of parameters during training. + Auto-mixied precision will be used when this argument is set to 'fp16' or 'bf16', otherwise model is trained with 'fp32'. + Defaults to 'fp16'. + zero_stage (int, optional): The stage of ZeRO for data parallelism. Can only be choosed from [0, 1, 2]. + When set to 0, ZeRO will not be used. Defaults to 0. + enable_all_optimization (bool, optional): Whether to switch on all the optimizations supported by Shardformer. + Currently all the optimization methods include fused normalization, flash attention and JIT. + Defaults to False. + enable_fused_normalization (bool, optional): Whether to switch on fused normalization in Shardformer. Defaults to False. + enable_flash_attention (bool, optional): Whether to switch on flash attention in Shardformer. Defaults to False. + enable_jit_fused (bool, optional): Whether to switch on JIT in Shardformer. Default to False. + enable_sequence_parallelism (bool): Whether to turn on sequence parallelism in Shardformer. Defaults to False. + sequence_parallelism_mode (str): The Sequence parallelism mode. Can only be choosed from ["split_gather", "ring", "all_to_all"]. Defaults to "split_gather". + enable_sequence_overlap (bool): Whether to turn on sequence overlap in Shardformer. Defaults to False. + parallel_output (bool): Whether to keep the output parallel when enabling tensor parallelism. Default to True. + num_microbatches (int, optional): Number of microbatches when using pipeline parallelism. Defaults to None. + microbatch_size (int, optional): Microbatch size when using pipeline parallelism. + Either ``num_microbatches`` or ``microbatch_size`` should be provided if using pipeline. + If ``num_microbatches`` is provided, this will be ignored. Defaults to None. + initial_scale (float, optional): The initial loss scale of AMP. Defaults to 2**16. + min_scale (float, optional): The minimum loss scale of AMP. Defaults to 1. + growth_factor (float, optional): The multiplication factor for increasing loss scale when using AMP. Defaults to 2. + backoff_factor (float, optional): The multiplication factor for decreasing loss scale when using AMP. Defaults to 0.5. + growth_interval (int, optional): The number of steps to increase loss scale when no overflow occurs when using AMP. Defaults to 1000. + hysteresis (int, optional): The number of overflows before decreasing loss scale when using AMP. Defaults to 2. + max_scale (float, optional): The maximum loss scale of AMP. Defaults to 2**32. + max_norm (float, optional): Maximum norm for gradient clipping. Defaults to 0. + broadcast_buffers (bool, optional): Whether to broadcast buffers in the beginning of training when using DDP. Defaults to True. + ddp_bucket_cap_mb (int, optional): The bucket size in MB when using DDP. Defaults to 25. + find_unused_parameters (bool, optional): Whether to find unused parameters when using DDP. Defaults to False. + check_reduction (bool, optional): Whether to check reduction when using DDP. Defaults to False. + gradient_as_bucket_view (bool, optional): Whether to use gradient as bucket view when using DDP. Defaults to False. + static_graph (bool, optional): Whether to use static graph when using DDP. Defaults to False. + zero_bucket_size_in_m (int, optional): Gradient reduce bucket size in million elements when using ZeRO. Defaults to 12. + cpu_offload (bool, optional): Whether to open cpu_offload when using ZeRO. Defaults to False. + communication_dtype (torch.dtype, optional): Communication dtype when using ZeRO. If not specified, the dtype of param will be used. Defaults to None. + overlap_communication (bool, optional): Whether to overlap communication and computation when using ZeRO. Defaults to True. + custom_policy (Policy, optional): Custom policy for Shardformer. Defaults to None. + pp_style (str, optional): The style for pipeline parallelism. Defaults to '1f1b'. + num_model_chunks (int, optional): The number of model chunks for interleaved pipeline parallelism. Defaults to 1. + gradient_checkpoint_config (GradientCheckpointConfig, optional): Configuration for gradient checkpointing. Defaults to None. + enable_metadata_cache (bool, optional): Whether to enable metadata cache for pipeline parallelism. Defaults to True. + make_vocab_size_divisible_by (int, optional): it's used when padding the vocabulary size, to make it choose an faster kenel. Default to 64. + overlap_p2p (bool, optional): Whether to overlap the p2p communication in pipeline parallelism + """ + + def __init__( + self, + tp_size: int, + pp_size: int, + ep_size: int, + sp_size: int = None, + precision: str = "fp16", + zero_stage: int = 0, + enable_all_optimization: bool = False, + enable_fused_normalization: bool = False, + enable_flash_attention: bool = False, + enable_jit_fused: bool = False, + enable_sequence_parallelism: bool = False, + sequence_parallelism_mode: str = None, + enable_sequence_overlap: bool = False, + parallel_output: bool = True, + num_microbatches: Optional[int] = None, + microbatch_size: Optional[int] = None, + initial_scale: float = 2**16, + min_scale: float = 1, + growth_factor: float = 2, + backoff_factor: float = 0.5, + growth_interval: int = 1000, + hysteresis: int = 2, + max_scale: float = 2**32, + max_norm: float = 0, + broadcast_buffers: bool = True, + ddp_bucket_cap_mb: int = 25, + find_unused_parameters: bool = False, + check_reduction: bool = False, + gradient_as_bucket_view: bool = False, + static_graph: bool = False, + zero_bucket_size_in_m: int = 12, + cpu_offload: bool = False, + communication_dtype: Optional[torch.dtype] = None, + overlap_communication: bool = False, + custom_policy: Policy = None, + pp_style: str = "1f1b", + num_model_chunks: int = 1, + num_layers_per_stage: Optional[List[int]] = None, + gradient_checkpoint_config: Optional[GradientCheckpointConfig] = None, + enable_metadata_cache: bool = True, + make_vocab_size_divisible_by: int = 64, + moe_dp_outside: bool = True, + overlap_p2p: bool = True, + overlap_allgather: bool = False, + ) -> None: + self.logger = get_dist_logger() + if overlap_communication or zero_stage == 2: + overlap_communication = False + zero_stage = 1 + self.logger.warning( + f"overlap_communication and zero_stage are set to False and 1 because " + f"ZeRO-2 or comm overlap cause program hang when some experts are not routed. " + ) + + assert ( + dist.get_world_size() % (tp_size * pp_size) == 0 + ), f"World size {dist.get_world_size()} is not divisible by tp_size {tp_size} * pp_size {pp_size}" + if enable_sequence_parallelism: + self.sequence_parallelism_mode = ( + sequence_parallelism_mode if sequence_parallelism_mode is not None else "all_to_all" + ) + assert ( + self.sequence_parallelism_mode in SUPPORT_SP_MODE + ), f"Sequence parallelism mode {self.sequence_parallelism_mode} is not in the supported list {SUPPORT_SP_MODE}" + if self.sequence_parallelism_mode in ["split_gather", "ring"]: + assert ( + tp_size > 1 + ), f"Sequence parallelism mode {self.sequence_parallelism_mode} must be enabled when using tensor parallelism" + if sp_size != 1: + self.logger.warning( + f"The sp_size will be the same as tp_size in sequence parallelism mode {self.sequence_parallelism_mode}, will ignore the given sequence parallelism size." + ) + self.sp_size = 1 + self.dp_size = dist.get_world_size() // (tp_size * pp_size) + elif self.sequence_parallelism_mode in ["all_to_all"]: + self.sp_size = 1 if sp_size is None else sp_size + self.dp_size = dist.get_world_size() // (self.sp_size * pp_size * tp_size) + else: + self.dp_size = dist.get_world_size() // (tp_size * pp_size) + assert ( + sp_size == 1 or sp_size is None + ), f"You should not set sp_size when sequence parallelism is not enabled." + self.sp_size = 1 + + assert self.dp_size % ep_size == 0, f"dp_size should be divisible by ep_size, {self.dp_size=} {ep_size=}" + self.moe_dp_size = self.dp_size // ep_size + self.ep_size = ep_size + self.tp_size = tp_size + self.pp_size = pp_size + self.precision = precision + self.zero_stage = zero_stage + self.cpu_offload = cpu_offload + self.enable_all_optimization = enable_all_optimization + self.enable_fused_normalization = enable_fused_normalization + self.enable_flash_attention = enable_flash_attention + self.enable_jit_fused = enable_jit_fused + self.enable_sequence_parallelism = enable_sequence_parallelism + if moe_dp_outside: + self.moe_dp_axis, self.pp_axis, self.ep_axis, self.tp_axis, self.sp_axis = 0, 1, 2, 3, 4 + self.pg_mesh = ProcessGroupMesh(self.moe_dp_size, self.pp_size, self.ep_size, self.tp_size, self.sp_size) + else: + self.pp_axis, self.moe_dp_axis, self.ep_axis, self.tp_axis, self.sp_axis = 0, 1, 2, 3, 4 + self.pg_mesh = ProcessGroupMesh(self.pp_size, self.moe_dp_size, self.ep_size, self.tp_size, self.sp_size) + + self.stage_manager = None + self.schedule = None + self.custom_policy = custom_policy + assert zero_stage in (0, 1, 2) + if self.pp_size > 1: + assert pp_style in ["1f1b", "interleaved"], "Unsupported pipeline parallelism style" + assert pp_style == "interleaved" or num_model_chunks == 1, "num_model_chunks must be 1 when using 1f1b" + assert ( + num_microbatches is not None or microbatch_size is not None + ), "num_microbatches or microbatch_size must be specified when using pipeline parallelism" + assert ( + self.zero_stage <= 1 + ), "To avoid prohibitive gradient synchronization costs, zero stage must be 0 or 1 when using pipeline parallelism" + self.stage_manager = PipelineStageManager( + self.pg_mesh, + pipeline_axis=self.pp_axis, + enable_interleave=pp_style == "interleaved", + num_model_chunks=num_model_chunks, + num_layers_per_stage=num_layers_per_stage, + ) + + if pp_style == "interleaved": + assert num_model_chunks > 1, "number of model chunks must be > 1 when using interleaved" + self.schedule = InterleavedSchedule( + stage_manager=self.stage_manager, + num_model_chunks=num_model_chunks, + num_microbatch=num_microbatches, + microbatch_size=microbatch_size, + enable_metadata_cache=enable_metadata_cache, + overlap_p2p=overlap_p2p, + ) + elif pp_style == "1f1b": + self.schedule = OneForwardOneBackwardSchedule( + stage_manager=self.stage_manager, + num_microbatches=num_microbatches, + microbatch_size=microbatch_size, + enable_metadata_cache=enable_metadata_cache, + ) + else: + raise NotImplementedError() + + self.tp_group = self.pg_mesh.get_group_along_axis(self.tp_axis) + self.dp_group = self.pg_mesh.get_group_along_axis([self.moe_dp_axis, self.ep_axis]) + self.pp_group = self.pg_mesh.get_group_along_axis(self.pp_axis) + self.moe_dp_group = self.pg_mesh.get_group_along_axis(self.moe_dp_axis) + self.ep_group = self.pg_mesh.get_group_along_axis(self.ep_axis) + if self.enable_sequence_parallelism and self.sequence_parallelism_mode in ["split_gather", "ring"]: + self.sp_group = self.pg_mesh.get_group_along_axis(self.tp_axis) + else: + self.sp_group = self.pg_mesh.get_group_along_axis(self.sp_axis) + + self.shard_config = ShardConfig( + tensor_parallel_process_group=self.tp_group, + sequence_parallel_process_group=self.sp_group, + ep_group=self.ep_group, + moe_dp_group=self.moe_dp_group, + pipeline_stage_manager=self.stage_manager, + enable_tensor_parallelism=self.tp_size > 1, + enable_all_optimization=self.enable_all_optimization, + enable_fused_normalization=self.enable_fused_normalization, + enable_flash_attention=self.enable_flash_attention, + enable_jit_fused=self.enable_jit_fused, + enable_sequence_parallelism=enable_sequence_parallelism, + sequence_parallelism_mode=sequence_parallelism_mode, + enable_sequence_overlap=enable_sequence_overlap, + parallel_output=parallel_output, + make_vocab_size_divisible_by=make_vocab_size_divisible_by, + gradient_checkpoint_config=gradient_checkpoint_config, + ) + self.amp_config = dict( + initial_scale=initial_scale, + growth_factor=growth_factor, + backoff_factor=backoff_factor, + growth_interval=growth_interval, + hysteresis=hysteresis, + min_scale=min_scale, + max_scale=max_scale, + ) + + self.ddp_config = dict( + broadcast_buffers=broadcast_buffers, + bucket_cap_mb=ddp_bucket_cap_mb, + find_unused_parameters=find_unused_parameters, + check_reduction=check_reduction, + gradient_as_bucket_view=gradient_as_bucket_view, + static_graph=static_graph, + ) + + self.zero_config = dict( + reduce_bucket_size=zero_bucket_size_in_m * 1024 * 1024, + communication_dtype=communication_dtype, + overlap_communication=overlap_communication, + cpu_offload=cpu_offload, + partition_grad=(self.zero_stage == 2), + forced_dtype=PRECISION_TORCH_TYPE[precision], + overlap_allgather=overlap_allgather, + ) + + self.max_norm = max_norm + + def get_checkpoint_io(self) -> MoECheckpointIO: + return MoECheckpointIO( + self.dp_group, self.pp_group, self.tp_group, self.ep_group, self.moe_dp_group, self.zero_stage + ) + + def configure( + self, + model: Module, + optimizer: Optional[Optimizer] = None, + criterion: Optional[Callable] = None, + dataloader: Optional[DataLoader] = None, + lr_scheduler: Optional[LRScheduler] = None, + ) -> Tuple[Module, OptimizerWrapper, Callable, DataLoader, LRScheduler]: + param_info = get_param_info(optimizer) + + # TODO: Support Galore + ZeRO + # Replace with distributed implementation if exists + optimizer = cast_to_distributed(optimizer) + + if not isinstance(model, ModelWrapper): + use_ddp = (self.dp_size > 1 and self.pp_size == 1 and self.zero_stage == 0) or ( + self.dp_size == 1 + and self.pp_size == 1 + and self.enable_sequence_parallelism + and self.sequence_parallelism_mode == "all_to_all" + ) + if use_ddp: + self.logger.warning( + f"Will have to check all params are used in pytorch DDP since not all experts are always activated" + ) + self.ddp_config["find_unused_parameters"] = True + + if dist.get_process_group_ranks(self.dp_group) != dist.get_process_group_ranks(self.moe_dp_group): + raise ValueError( + f"if pytorch ddp is used, dp_group and moe_dp_group are expected to be the same since DDP can only reduce grad across a single group, but found dp_group {dist.get_process_group_ranks(self.dp_group)} and moe_dp_group {dist.get_process_group_ranks(self.moe_dp_group)}, you might want to use HybridParallelPlugin (i.e. set ep_size = 1) or set zero_stage > 0" + ) + + # sync gradients across DP * SP ranks + if self.enable_sequence_parallelism and self.sequence_parallelism_mode == "all_to_all": + dp_group = self.pg_mesh.create_group_along_axis([self.moe_dp_axis, self.ep_axis, self.sp_axis]) + else: + dp_group = self.dp_group + + model = HybridParallelModule( + module=model, + precision=self.precision, + shard_config=self.shard_config, + dp_group=dp_group, + tp_group=self.tp_group, + sp_group=self.sp_group, + use_ddp=use_ddp, + ddp_config=self.ddp_config, + custom_policy=self.custom_policy, + ) + if optimizer is not None and not isinstance(optimizer, OptimizerWrapper): + if self.ep_size > 1: + # if ep is enabled, the num of (moe) paramaters changed since they are sharded among ep groups + # but the optimizer is not aware of ep, so we need to update the optimizer + reinitialize_optimizer(optimizer, model) + + if self.zero_stage == 0: + is_zero = False + if self.precision in ["fp16", "bf16"]: + optimizer = HybridParallelAMPOptimizer( + optimizer, + model, + use_pipeline=self.enable_pipeline_parallelism, + param_info=param_info, + precision=self.precision, + max_norm=self.max_norm, + **self.amp_config, + ) + else: + optimizer = HybridParallelNaiveOptimizer( + optimizer, + model, + use_pipeline=self.enable_pipeline_parallelism, + param_info=param_info, + max_norm=self.max_norm, + pp_process_group=self.pp_group, + tp_process_group=self.tp_group, + ) + else: + if self.dp_size <= 1: + self.logger.warning( + "Use Zero Optimizer when data parallel size is 1 may introduce unnecessary overhead. " + "If you do not intend to use cpu_offload, please consider set zero_stage=0." + ) + assert self.precision != "fp32", "Please set precision to 'fp16' or 'bf16' when using ZeRO." + optimizer = MoeHybridParallelZeroOptimizer( + optimizer, + model, + use_pipeline=self.enable_pipeline_parallelism, + param_info=param_info, + dp_process_group=dp_group, + tp_process_group=self.tp_group, + pp_process_group=self.pp_group, + moe_dp_group=self.moe_dp_group, + verbose=True, + clip_grad_norm=self.max_norm, + **self.zero_config, + **self.amp_config, + ) + # inject update_master_params + model.update_master_params = MethodType(optimizer.update_master_params, model) + + # Setup optimizers that require global states + optim = optimizer.optim + if isinstance(optim, DistributedOptim): + shard_to_param = optimizer.get_master_to_working_map() if is_zero else {} + padding_map = optimizer.get_param_padding_map() if is_zero else defaultdict(int) + optim.setup_distributed(self.tp_group, self.dp_group, shard_to_param, padding_map, is_zero) + + return model, optimizer, criterion, dataloader, lr_scheduler diff --git a/colossalai/booster/plugin/plugin/plugin_base.py b/colossalai/booster/plugin/plugin/plugin_base.py new file mode 100644 index 000000000000..6dc0c560d06d --- /dev/null +++ b/colossalai/booster/plugin/plugin/plugin_base.py @@ -0,0 +1,90 @@ +from abc import ABC, abstractmethod +from typing import Callable, Dict, Iterator, List, Optional, Tuple + +import torch.nn as nn +from torch.optim import Optimizer +from torch.optim.lr_scheduler import _LRScheduler as LRScheduler +from torch.utils.data import DataLoader, Dataset + +from colossalai.checkpoint_io import CheckpointIO +from colossalai.interface import OptimizerWrapper + +__all__ = ["Plugin"] + + +class Plugin(ABC): + @abstractmethod + def supported_devices(self) -> List[str]: + pass + + @abstractmethod + def supported_precisions(self) -> List[str]: + pass + + @abstractmethod + def control_precision(self) -> bool: + pass + + @abstractmethod + def control_device(self) -> bool: + pass + + @abstractmethod + def support_no_sync(self) -> bool: + pass + + @abstractmethod + def support_lora(self) -> bool: + pass + + @abstractmethod + def configure( + self, + model: nn.Module, + optimizer: Optional[Optimizer] = None, + criterion: Optional[Callable] = None, + dataloader: Optional[DataLoader] = None, + lr_scheduler: Optional[LRScheduler] = None, + ) -> Tuple[nn.Module, OptimizerWrapper, Callable, DataLoader, LRScheduler]: + # implement this method + pass + + @abstractmethod + def control_checkpoint_io(self) -> bool: + """ + Whether the plugin controls the checkpoint io + """ + + @abstractmethod + def get_checkpoint_io(self) -> CheckpointIO: + """ + Get checkpoint io object for this plugin, only invoked when control_checkpoint_io is True. + """ + + @abstractmethod + def no_sync(self, model: nn.Module, optimizer: OptimizerWrapper) -> Iterator[None]: + """ + Context manager to disable gradient synchronization. + """ + + @abstractmethod + def enable_lora(self, model: nn.Module, pretrained_dir: str, lora_config: Dict) -> nn.Module: + """ + Add LoRA modules to the model passed in. Should only be called in booster.enable_lora(). + """ + + @abstractmethod + def prepare_dataloader( + self, + dataset: Dataset, + batch_size: int, + shuffle: bool = False, + seed: int = 1024, + drop_last: bool = False, + pin_memory: bool = False, + num_workers: int = 0, + **kwargs, + ): + """Prepare a dataloader for distributed training. The dataloader will be wrapped by + `torch.utils.data.DataLoader` + """ diff --git a/colossalai/booster/plugin/plugin/pp_plugin_base.py b/colossalai/booster/plugin/plugin/pp_plugin_base.py new file mode 100644 index 000000000000..3d91eb95b409 --- /dev/null +++ b/colossalai/booster/plugin/plugin/pp_plugin_base.py @@ -0,0 +1,22 @@ +from abc import abstractmethod +from typing import Any, Callable, Iterator, Optional + +import torch + +from colossalai.interface import ModelWrapper, OptimizerWrapper + +from .plugin_base import Plugin + + +class PipelinePluginBase(Plugin): + @abstractmethod + def execute_pipeline( + self, + data_iter: Iterator, + model: ModelWrapper, + criterion: Callable[[Any, Any], torch.Tensor], + optimizer: Optional[OptimizerWrapper] = None, + return_loss: bool = True, + return_outputs: bool = False, + ) -> dict: + pass diff --git a/colossalai/booster/plugin/plugin/torch_ddp_plugin.py b/colossalai/booster/plugin/plugin/torch_ddp_plugin.py new file mode 100644 index 000000000000..5116446a4295 --- /dev/null +++ b/colossalai/booster/plugin/plugin/torch_ddp_plugin.py @@ -0,0 +1,257 @@ +from typing import Callable, Dict, Iterator, List, Optional, Tuple, Union + +import torch.nn as nn +from torch.nn.parallel import DistributedDataParallel as DDP +from torch.optim import Optimizer +from torch.optim.lr_scheduler import _LRScheduler as LRScheduler +from torch.utils.data import DataLoader + +from colossalai.checkpoint_io import CheckpointIO, GeneralCheckpointIO +from colossalai.cluster import DistCoordinator +from colossalai.interface import ModelWrapper, OptimizerWrapper +from colossalai.quantization import BnbQuantizationConfig, quantize_model +from colossalai.utils import get_current_device + +from .dp_plugin_base import DPPluginBase + +__all__ = ["TorchDDPPlugin"] + + +class TorchDDPCheckpointIO(GeneralCheckpointIO): + def __init__(self) -> None: + super().__init__() + self.coordinator = DistCoordinator() + + def load_unsharded_model(self, model: ModelWrapper, checkpoint: str, strict: bool = True): + """ + Load model from checkpoint. + """ + assert isinstance(model, ModelWrapper), "Please boost the model before loading!" + super().load_unsharded_model(model.unwrap(), checkpoint, strict=strict) + + def save_unsharded_model(self, model: ModelWrapper, checkpoint: str, gather_dtensor: bool, use_safetensors: bool): + """ + Save model to checkpoint but only on master process. + """ + assert isinstance(model, ModelWrapper), "Please boost the model before saving!" + if self.coordinator.is_master(): + super().save_unsharded_model(model.unwrap(), checkpoint, gather_dtensor, use_safetensors) + + def load_unsharded_optimizer(self, optimizer: OptimizerWrapper, checkpoint: str): + """ + Load optimizer from checkpoint. + """ + assert isinstance(optimizer, OptimizerWrapper), "Please boost the optimizer before loading!" + super().load_unsharded_optimizer(optimizer, checkpoint) + + def save_unsharded_optimizer(self, optimizer: OptimizerWrapper, checkpoint: str, gather_dtensor: bool): + """ + Save optimizer to checkpoint but only on master process. + """ + assert isinstance(optimizer, OptimizerWrapper), "Please boost the optimizer before saving!" + if self.coordinator.is_master(): + super().save_unsharded_optimizer(optimizer, checkpoint, gather_dtensor) + + def save_lr_scheduler(self, lr_scheduler: LRScheduler, checkpoint: str): + """ + Save model to checkpoint but only on master process. + """ + if self.coordinator.is_master(): + super().save_lr_scheduler(lr_scheduler, checkpoint) + + def save_sharded_model( + self, + model: ModelWrapper, + checkpoint_path: str, + gather_dtensor: bool = True, + prefix: Optional[str] = None, + max_shard_size: int = 1024, + use_safetensors: bool = False, + ): + """ + Save model to checkpoint but only on master process. + """ + assert isinstance(model, ModelWrapper), "Please boost the model before saving!" + if self.coordinator.is_master(): + super().save_sharded_model( + model.unwrap(), checkpoint_path, gather_dtensor, prefix, max_shard_size, use_safetensors + ) + + def load_sharded_model( + self, + model: ModelWrapper, + checkpoint_index_file: str, + strict: bool = False, + use_safetensors: bool = False, + load_sub_module: bool = True, + ): + """ + Load model from sharded checkpoint. + """ + assert isinstance(model, ModelWrapper), "Please boost the model before loading!" + super().load_sharded_model(model.unwrap(), checkpoint_index_file, strict, use_safetensors, load_sub_module) + + def save_sharded_optimizer( + self, + optimizer: OptimizerWrapper, + checkpoint: str, + gather_dtensor: bool = True, + prefix: Optional[str] = None, + size_per_shard: int = 1024, + ): + """ + Save optimizer to sharded checkpoint but only on master process. + """ + assert isinstance(optimizer, OptimizerWrapper), "Please boost the optimizer before saving!" + if self.coordinator.is_master(): + super().save_sharded_optimizer(optimizer.unwrap(), checkpoint, gather_dtensor, prefix, size_per_shard) + + def load_sharded_optimizer( + self, + optimizer: Optimizer, + index_file_path: str, + prefix: Optional[str] = None, + ): + """ + Load optimizer from sharded checkpoint. + """ + assert isinstance(optimizer, OptimizerWrapper), "Please boost the optimizer before loading!" + super().load_sharded_optimizer(optimizer.unwrap(), index_file_path, prefix) + + def save_lora_as_pretrained( + self, model: Union[nn.Module, ModelWrapper], checkpoint: str, use_safetensors: bool = False + ) -> None: + """ + Save the lora adapters and adapter configuration file to checkpoint directory. + """ + from peft import PeftModel + + assert isinstance(model, ModelWrapper), "Please boost the model before saving!" + if self.coordinator.is_master(): + peft_model = model.unwrap() + assert isinstance( + peft_model, PeftModel + ), "The model doesn't have lora adapters, please enable lora before saving." + peft_model.save_pretrained(save_directory=checkpoint, safe_serialization=use_safetensors) + + +class TorchDDPModel(ModelWrapper): + def __init__(self, module: nn.Module, *args, **kwargs) -> None: + super().__init__(module) + self.module = DDP(module, *args, **kwargs) + + def unwrap(self): + return self.module.module + + +class TorchDDPPlugin(DPPluginBase): + """ + Plugin for PyTorch DDP. + + ```python + from colossalai.booster import Booster + from colossalai.booster.plugin import TorchDDPPlugin + + model, train_dataset, optimizer, criterion = ... + plugin = TorchDDPPlugin() + + train_dataloader = plugin.prepare_dataloader(train_dataset, batch_size=8) + booster = Booster(plugin=plugin) + model, optimizer, train_dataloader, criterion = booster.boost(model, optimizer, train_dataloader, criterion) + ``` + + Args: + broadcast_buffers (bool, optional): Whether to broadcast buffers in the beginning of training. Defaults to True. + bucket_cap_mb (int, optional): The bucket size in MB. Defaults to 25. + find_unused_parameters (bool, optional): Whether to find unused parameters. Defaults to False. + check_reduction (bool, optional): Whether to check reduction. Defaults to False. + gradient_as_bucket_view (bool, optional): Whether to use gradient as bucket view. Defaults to False. + static_graph (bool, optional): Whether to use static graph. Defaults to False. + """ + + def __init__( + self, + broadcast_buffers: bool = True, + bucket_cap_mb: int = 25, + find_unused_parameters: bool = False, + check_reduction: bool = False, + gradient_as_bucket_view: bool = False, + static_graph: bool = False, + ) -> None: + super().__init__() + self.ddp_kwargs = dict( + broadcast_buffers=broadcast_buffers, + bucket_cap_mb=bucket_cap_mb, + find_unused_parameters=find_unused_parameters, + check_reduction=check_reduction, + gradient_as_bucket_view=gradient_as_bucket_view, + static_graph=static_graph, + ) + + def support_no_sync(self) -> bool: + return True + + def support_lora(self) -> bool: + return True + + def control_precision(self) -> bool: + return False + + def supported_precisions(self) -> List[str]: + return ["fp16", "fp16_apex", "bf16", "fp8"] + + def control_device(self) -> bool: + return True + + def supported_devices(self) -> List[str]: + return ["cuda", "npu"] + + def configure( + self, + model: nn.Module, + optimizer: Optional[Optimizer] = None, + criterion: Optional[Callable] = None, + dataloader: Optional[DataLoader] = None, + lr_scheduler: Optional[LRScheduler] = None, + ) -> Tuple[nn.Module, OptimizerWrapper, Callable, DataLoader, LRScheduler]: + # cast model to cuda + model = model.to(get_current_device()) + + # convert model to sync bn + model = nn.SyncBatchNorm.convert_sync_batchnorm(model, None) + + # wrap the model with PyTorch DDP + model = TorchDDPModel(model, **self.ddp_kwargs) + + if optimizer is not None and not isinstance(optimizer, OptimizerWrapper): + optimizer = OptimizerWrapper(optimizer) + + return model, optimizer, criterion, dataloader, lr_scheduler + + def control_checkpoint_io(self) -> bool: + return True + + def get_checkpoint_io(self) -> CheckpointIO: + return TorchDDPCheckpointIO() + + def no_sync(self, model: nn.Module, optimizer: OptimizerWrapper) -> Iterator[None]: + assert isinstance(model, TorchDDPModel), "Model is not boosted by TorchDDPPlugin." + return model.module.no_sync() + + def enable_lora( + self, + model: nn.Module, + pretrained_dir: Optional[str] = None, + lora_config: Optional[Dict] = None, + bnb_quantization_config: Optional[BnbQuantizationConfig] = None, + ) -> nn.Module: + from peft import PeftModel, get_peft_model + + if bnb_quantization_config is not None: + model = quantize_model(model, bnb_quantization_config) + + assert not isinstance(model, TorchDDPModel), "Lora should be enabled before boosting the model." + if pretrained_dir is None: + return get_peft_model(model, lora_config) + else: + return PeftModel.from_pretrained(model, pretrained_dir, is_trainable=True) diff --git a/colossalai/booster/plugin/plugin/torch_fsdp_plugin.py b/colossalai/booster/plugin/plugin/torch_fsdp_plugin.py new file mode 100644 index 000000000000..7b67da032d66 --- /dev/null +++ b/colossalai/booster/plugin/plugin/torch_fsdp_plugin.py @@ -0,0 +1,372 @@ +import os +from pathlib import Path +from typing import Callable, Dict, Iterable, Iterator, List, Optional, Tuple + +import torch +import torch.nn as nn +from packaging import version +from torch.distributed import ProcessGroup + +if version.parse(torch.__version__) >= version.parse("1.12.0"): + from torch.distributed.fsdp import FullStateDictConfig + from torch.distributed.fsdp import FullyShardedDataParallel as FSDP + from torch.distributed.fsdp import StateDictType + from torch.distributed.fsdp.fully_sharded_data_parallel import ( + BackwardPrefetch, + CPUOffload, + FullStateDictConfig, + MixedPrecision, + ShardingStrategy, + ) +else: + raise RuntimeError("FSDP is not supported while torch version under 1.12.0.") + +from torch.optim import Optimizer +from torch.optim.lr_scheduler import _LRScheduler as LRScheduler +from torch.utils.data import DataLoader + +from colossalai.checkpoint_io import CheckpointIndexFile, CheckpointIO, GeneralCheckpointIO, utils +from colossalai.cluster import DistCoordinator +from colossalai.interface import ModelWrapper, OptimizerWrapper +from colossalai.logging import get_dist_logger + +from .dp_plugin_base import DPPluginBase + +__all__ = ["TorchFSDPPlugin"] + + +class TorchFSDPCheckpointIO(GeneralCheckpointIO): + def __init__(self) -> None: + super().__init__() + self.coordinator = DistCoordinator() + self.logger = get_dist_logger() + + def load_unsharded_model(self, model: ModelWrapper, checkpoint: str, strict: bool): + assert isinstance(model, TorchFSDPModel), "Please boost the model before loading!" + model = model.unwrap() + checkpoint = utils.load_state_dict(checkpoint) + model.load_state_dict(checkpoint) + + def load_unsharded_optimizer(self, optimizer: OptimizerWrapper, checkpoint: Path): + assert isinstance(optimizer, FSDPOptimizerWrapper), "Please boost the optimizer before loading!" + checkpoint = utils.load_state_dict(checkpoint) + fsdp_model = optimizer.unwrap_model() + sharded_osd = FSDP.scatter_full_optim_state_dict(checkpoint, fsdp_model) + optimizer.load_state_dict(sharded_osd) + + def save_unsharded_model(self, model: ModelWrapper, checkpoint: str, gather_dtensor: bool, use_safetensors: bool): + """ + Save model to checkpoint but only on master process. + """ + assert isinstance(model, TorchFSDPModel), "Please boost the model before saving!" + model = model.unwrap() + cfg = FullStateDictConfig(offload_to_cpu=True, rank0_only=True) + with FSDP.state_dict_type(model, StateDictType.FULL_STATE_DICT, cfg): + full_model_state = model.state_dict() + utils.save_state_dict(full_model_state, checkpoint_file_path=checkpoint, use_safetensors=use_safetensors) + + def save_unsharded_optimizer(self, optimizer: OptimizerWrapper, checkpoint: str, gather_dtensor: bool): + """ + Save optimizer to checkpoint but only on master process. + """ + assert isinstance(optimizer, FSDPOptimizerWrapper), "Please boost the optimizer before saving!" + fsdp_model = optimizer.unwrap_model() + full_optimizer_state = FSDP.full_optim_state_dict(fsdp_model, optim=optimizer, rank0_only=True) + utils.save_state_dict(full_optimizer_state, checkpoint_file_path=checkpoint, use_safetensors=False) + + def save_sharded_model( + self, + model: ModelWrapper, + checkpoint_path: str, + gather_dtensor: bool = True, + prefix: Optional[str] = None, + size_per_shard: int = 1024, + use_safetensors: bool = False, + ): + """ + Save model to checkpoint but only on master process. + """ + assert isinstance(model, TorchFSDPModel), "Please boost the model before saving!" + if os.path.isfile(checkpoint_path): + self.logger.error(f"Provided path ({checkpoint_path}) should be a directory, not a file") + return + + Path(checkpoint_path).mkdir(parents=True, exist_ok=True) + with FSDP.state_dict_type( + model.unwrap(), StateDictType.FULL_STATE_DICT, FullStateDictConfig(offload_to_cpu=True, rank0_only=True) + ): + state_dict = model.unwrap().state_dict() + + state_dict_shard = utils.shard_model_checkpoint(state_dict, max_shard_size=size_per_shard) + + weights_name, save_index_file = utils.get_model_base_filenames(prefix, use_safetensors) + index_file = CheckpointIndexFile(checkpoint_path) + + # In general cases, is_master is set to True to get the right behavior. + total_size = utils.save_state_dict_shards( + sharded_state_dict=state_dict_shard, + checkpoint=checkpoint_path, + index_file=index_file, + base_filename=weights_name, + is_master=self.coordinator.is_master(), + use_safetensors=use_safetensors, + ) + + # only save the index file on the master rank + if self.coordinator.is_master(): + index_file.append_meta_data("total_size", total_size) + index_file.write_index_file(save_index_file) + utils.save_config_file(model.unwrap(), checkpoint_path) + self.logger.info( + f"The model is split into checkpoint shards. " + f"You can find where each parameters has been saved in the " + f"index located at {save_index_file}." + ) + + def load_sharded_model( + self, + model: nn.Module, + checkpoint_index_file: Path, + strict: bool = False, + use_safetensors: bool = False, + load_sub_module: bool = True, + ): + """ + Load model to checkpoint but only on master process. + """ + assert isinstance(model, TorchFSDPModel), "Please boost the model before loading!" + use_safetensors = False + if "safetensors" in checkpoint_index_file.name: + use_safetensors = True + + if use_safetensors and not utils.is_safetensors_available(): + raise ImportError("`safe_serialization` requires the `safetensors` library: `pip install safetensors`.") + + # read checkpoint index file + ckpt_index_file = CheckpointIndexFile.from_file(checkpoint_index_file) + checkpoint_files, _ = ckpt_index_file.get_checkpoint_filenames() + + fsdp_state_dict = {} + for shard_file in checkpoint_files: + fsdp_state_dict.update(utils.load_shard_state_dict(Path(shard_file), use_safetensors)) + + with FSDP.state_dict_type(model.unwrap(), StateDictType.FULL_STATE_DICT): + model.unwrap().load_state_dict(fsdp_state_dict, strict=False) + + def save_sharded_optimizer( + self, optimizer: Optimizer, checkpoint: str, gather_dtensor: bool, prefix: str, size_per_shard: int + ): + """ + Save optimizer to checkpoint but only on master process. + """ + assert isinstance(optimizer, FSDPOptimizerWrapper), "Please boost the optimizer before saving!" + + if os.path.isfile(checkpoint): + self.logger.error(f"Provided path ({checkpoint}) should be a directory, not a file") + return + + Path(checkpoint).mkdir(parents=True, exist_ok=True) + + with FSDP.state_dict_type( + optimizer.unwrap_model().unwrap(), + StateDictType.FULL_STATE_DICT, + FullStateDictConfig(offload_to_cpu=True, rank0_only=True), + ): + fsdp_optim_state = FSDP.full_optim_state_dict( + optimizer.unwrap_model().unwrap(), optim=optimizer, rank0_only=True + ) + + if self.coordinator.is_master(): + # Preparing file paths and index file. + states_name, save_index_file, param_group_file = utils.get_optimizer_base_filenames(prefix) + index_file = CheckpointIndexFile(checkpoint) + + index_file.append_meta_data("param_groups", param_group_file) + group_file_path = os.path.join(checkpoint, param_group_file) + utils.save_param_groups(fsdp_optim_state, group_file_path) + + sharded_state = utils.shard_optimizer_checkpoint(fsdp_optim_state, max_shard_size=size_per_shard) + + # Save shards of optimizer states. + # In general cases, is_master is set to True to get the right behavior. + total_size = utils.save_state_dict_shards( + sharded_state_dict=sharded_state, + checkpoint=checkpoint, + index_file=index_file, + base_filename=states_name, + is_master=self.coordinator.is_master(), + use_safetensors=False, + ) + + index_file.append_meta_data("total_size", total_size) + index_file.write_index_file(save_index_file) + self.logger.info( + f"The optimizer is going to be split to checkpoint shards. " + f"You can find where each parameters has been saved in the " + f"index located at {save_index_file}." + ) + + def load_sharded_optimizer(self, optimizer: Optimizer, index_file_path: str, size_per_shard: int): + """ + Load optimizer to checkpoint but only on master process. + """ + assert isinstance(optimizer, FSDPOptimizerWrapper), "Please boost the optimizer before saving!" + + ckpt_index_file = CheckpointIndexFile.from_file(index_file_path) + + # Load param_groups + param_group_path = ckpt_index_file.get_param_group_filename() + if param_group_path is None: + raise RuntimeError( + f"Invalid index file path {index_file_path} for an optimizer. " + "Looking param group file under current directory." + ) + + saved_param_groups = torch.load(param_group_path) + + # Load param + fsdp_optim_state = {} + checkpoint_files, _ = ckpt_index_file.get_checkpoint_filenames() + for shard_file in checkpoint_files: + state_dict_shard = utils.load_shard_state_dict(Path(shard_file), use_safetensors=False) + fsdp_optim_state.update(state_dict_shard) + + fsdp_optim_dict = dict(state=fsdp_optim_state, param_groups=saved_param_groups) + + with FSDP.state_dict_type(optimizer.unwrap_model().unwrap(), StateDictType.FULL_STATE_DICT): + fsdp_state = FSDP.optim_state_dict_to_load( + model=optimizer.unwrap_model().unwrap(), optim=optimizer, optim_state_dict=fsdp_optim_dict + ) + optimizer.load_state_dict(fsdp_state) + + def save_lr_scheduler(self, lr_scheduler: LRScheduler, checkpoint: str): + """ + Save model to checkpoint but only on master process. + """ + if self.coordinator.is_master(): + super().save_lr_scheduler(lr_scheduler, checkpoint) + + +class TorchFSDPModel(ModelWrapper): + def __init__(self, module: nn.Module, *args, **kwargs) -> None: + super().__init__(module) + self.module = FSDP(module, *args, **kwargs) + + def unwrap(self): + return self.module + + +class FSDPOptimizerWrapper(OptimizerWrapper): + def __init__(self, optimizer: Optimizer, model: nn.Module): + self.model = model + super().__init__(optimizer) + + def unwrap_model(self) -> nn.Module: + return self.model + + +class TorchFSDPPlugin(DPPluginBase): + """ + Plugin for PyTorch FSDP. + + ```python + from colossalai.booster import Booster + from colossalai.booster.plugin import TorchFSDPPlugin + + model, train_dataset, optimizer, criterion = ... + plugin = TorchFSDPPlugin() + + train_dataloader = plugin.prepare_train_dataloader(train_dataset, batch_size=8) + booster = Booster(plugin=plugin) + model, optimizer, train_dataloader, criterion = booster.boost(model, optimizer, train_dataloader, criterion) + ``` + + Args: + See https://pytorch.org/docs/stable/fsdp.html for details. + """ + + if version.parse(torch.__version__) >= version.parse("1.12.0"): + + def __init__( + self, + process_group: Optional[ProcessGroup] = None, + sharding_strategy: Optional[ShardingStrategy] = None, + cpu_offload: Optional[CPUOffload] = None, + auto_wrap_policy: Optional[Callable] = None, + backward_prefetch: Optional[BackwardPrefetch] = None, + mixed_precision: Optional[MixedPrecision] = None, + ignored_modules: Optional[Iterable[torch.nn.Module]] = None, + param_init_fn: Optional[Callable[[nn.Module], None]] = None, + sync_module_states: bool = False, + ): + super().__init__() + self.fsdp_kwargs = dict( + process_group=process_group, + sharding_strategy=sharding_strategy, + cpu_offload=cpu_offload, + auto_wrap_policy=auto_wrap_policy, + backward_prefetch=backward_prefetch, + mixed_precision=mixed_precision, + ignored_modules=ignored_modules, + param_init_fn=param_init_fn, + sync_module_states=sync_module_states, + ) + self.logger = get_dist_logger() + + else: + raise RuntimeError("FSDP is not supported while torch version under 1.12.0.") + + def support_no_sync(self) -> bool: + return False + + def support_lora(self) -> bool: + return False + + def no_sync(self, model: nn.Module, optimizer: OptimizerWrapper) -> Iterator[None]: + raise NotImplementedError("Torch fsdp no_sync func not supported yet.") + + def control_precision(self) -> bool: + return True + + def supported_precisions(self) -> List[str]: + return ["fp16", "bf16"] + + def control_device(self) -> bool: + return True + + def supported_devices(self) -> List[str]: + return ["cuda"] + + def configure( + self, + model: nn.Module, + optimizer: Optional[Optimizer] = None, + criterion: Optional[Callable] = None, + dataloader: Optional[DataLoader] = None, + lr_scheduler: Optional[LRScheduler] = None, + ) -> Tuple[nn.Module, OptimizerWrapper, Callable, DataLoader, LRScheduler]: + # wrap the model with PyTorch FSDP + fsdp_model = TorchFSDPModel(model, device_id=torch.cuda.current_device(), **self.fsdp_kwargs) + + if optimizer is not None: + if len(optimizer.param_groups) > 1: + self.logger.warning( + "TorchFSDPPlugin does not support optimizer that use multi param groups. The results may not be as expected if used." + ) + optimizer.__init__(fsdp_model.parameters(), **optimizer.defaults) + + if not isinstance(optimizer, FSDPOptimizerWrapper): + optimizer = FSDPOptimizerWrapper(optimizer, fsdp_model) + + return fsdp_model, optimizer, criterion, dataloader, lr_scheduler + + def control_checkpoint_io(self) -> bool: + return True + + def get_checkpoint_io(self) -> CheckpointIO: + return TorchFSDPCheckpointIO() + + def enable_lora( + self, model: nn.Module, pretrained_dir: Optional[str] = None, lora_config: Optional[Dict] = None + ) -> nn.Module: + raise NotImplementedError diff --git a/colossalai/booster/plugin/torch_fsdp_plugin.py b/colossalai/booster/plugin/torch_fsdp_plugin.py index cd2f9e84018a..7b67da032d66 100644 --- a/colossalai/booster/plugin/torch_fsdp_plugin.py +++ b/colossalai/booster/plugin/torch_fsdp_plugin.py @@ -1,6 +1,4 @@ -import logging import os -import warnings from pathlib import Path from typing import Callable, Dict, Iterable, Iterator, List, Optional, Tuple @@ -30,6 +28,7 @@ from colossalai.checkpoint_io import CheckpointIndexFile, CheckpointIO, GeneralCheckpointIO, utils from colossalai.cluster import DistCoordinator from colossalai.interface import ModelWrapper, OptimizerWrapper +from colossalai.logging import get_dist_logger from .dp_plugin_base import DPPluginBase @@ -40,6 +39,7 @@ class TorchFSDPCheckpointIO(GeneralCheckpointIO): def __init__(self) -> None: super().__init__() self.coordinator = DistCoordinator() + self.logger = get_dist_logger() def load_unsharded_model(self, model: ModelWrapper, checkpoint: str, strict: bool): assert isinstance(model, TorchFSDPModel), "Please boost the model before loading!" @@ -88,7 +88,7 @@ def save_sharded_model( """ assert isinstance(model, TorchFSDPModel), "Please boost the model before saving!" if os.path.isfile(checkpoint_path): - logging.error(f"Provided path ({checkpoint_path}) should be a directory, not a file") + self.logger.error(f"Provided path ({checkpoint_path}) should be a directory, not a file") return Path(checkpoint_path).mkdir(parents=True, exist_ok=True) @@ -117,7 +117,7 @@ def save_sharded_model( index_file.append_meta_data("total_size", total_size) index_file.write_index_file(save_index_file) utils.save_config_file(model.unwrap(), checkpoint_path) - logging.info( + self.logger.info( f"The model is split into checkpoint shards. " f"You can find where each parameters has been saved in the " f"index located at {save_index_file}." @@ -162,7 +162,7 @@ def save_sharded_optimizer( assert isinstance(optimizer, FSDPOptimizerWrapper), "Please boost the optimizer before saving!" if os.path.isfile(checkpoint): - logging.error(f"Provided path ({checkpoint}) should be a directory, not a file") + self.logger.error(f"Provided path ({checkpoint}) should be a directory, not a file") return Path(checkpoint).mkdir(parents=True, exist_ok=True) @@ -200,7 +200,7 @@ def save_sharded_optimizer( index_file.append_meta_data("total_size", total_size) index_file.write_index_file(save_index_file) - logging.info( + self.logger.info( f"The optimizer is going to be split to checkpoint shards. " f"You can find where each parameters has been saved in the " f"index located at {save_index_file}." @@ -311,6 +311,7 @@ def __init__( param_init_fn=param_init_fn, sync_module_states=sync_module_states, ) + self.logger = get_dist_logger() else: raise RuntimeError("FSDP is not supported while torch version under 1.12.0.") @@ -349,7 +350,7 @@ def configure( if optimizer is not None: if len(optimizer.param_groups) > 1: - warnings.warn( + self.logger.warning( "TorchFSDPPlugin does not support optimizer that use multi param groups. The results may not be as expected if used." ) optimizer.__init__(fsdp_model.parameters(), **optimizer.defaults) From 227d824b54faf5fadcaff8df7b56df74ce7b8faa Mon Sep 17 00:00:00 2001 From: Edenzzzz Date: Fri, 16 Aug 2024 09:30:32 +0000 Subject: [PATCH 2/5] remove trash --- colossalai/booster/plugin/plugin/__init__.py | 23 - .../booster/plugin/plugin/dp_plugin_base.py | 76 - .../booster/plugin/plugin/gemini_plugin.py | 595 ------- .../plugin/plugin/hybrid_parallel_plugin.py | 1466 ----------------- .../plugin/plugin/low_level_zero_plugin.py | 521 ------ .../plugin/moe_hybrid_parallel_plugin.py | 490 ------ .../booster/plugin/plugin/plugin_base.py | 90 - .../booster/plugin/plugin/pp_plugin_base.py | 22 - .../booster/plugin/plugin/torch_ddp_plugin.py | 257 --- .../plugin/plugin/torch_fsdp_plugin.py | 372 ----- 10 files changed, 3912 deletions(-) delete mode 100644 colossalai/booster/plugin/plugin/__init__.py delete mode 100644 colossalai/booster/plugin/plugin/dp_plugin_base.py delete mode 100644 colossalai/booster/plugin/plugin/gemini_plugin.py delete mode 100644 colossalai/booster/plugin/plugin/hybrid_parallel_plugin.py delete mode 100644 colossalai/booster/plugin/plugin/low_level_zero_plugin.py delete mode 100644 colossalai/booster/plugin/plugin/moe_hybrid_parallel_plugin.py delete mode 100644 colossalai/booster/plugin/plugin/plugin_base.py delete mode 100644 colossalai/booster/plugin/plugin/pp_plugin_base.py delete mode 100644 colossalai/booster/plugin/plugin/torch_ddp_plugin.py delete mode 100644 colossalai/booster/plugin/plugin/torch_fsdp_plugin.py diff --git a/colossalai/booster/plugin/plugin/__init__.py b/colossalai/booster/plugin/plugin/__init__.py deleted file mode 100644 index 7e0e6ffdd8e8..000000000000 --- a/colossalai/booster/plugin/plugin/__init__.py +++ /dev/null @@ -1,23 +0,0 @@ -from .gemini_plugin import GeminiPlugin -from .hybrid_parallel_plugin import HybridParallelPlugin -from .low_level_zero_plugin import LowLevelZeroPlugin -from .moe_hybrid_parallel_plugin import MoeHybridParallelPlugin -from .plugin_base import Plugin -from .torch_ddp_plugin import TorchDDPPlugin - -__all__ = [ - "Plugin", - "TorchDDPPlugin", - "GeminiPlugin", - "LowLevelZeroPlugin", - "HybridParallelPlugin", - "MoeHybridParallelPlugin", -] - -import torch -from packaging import version - -if version.parse(torch.__version__) >= version.parse("1.12.0"): - from .torch_fsdp_plugin import TorchFSDPPlugin - - __all__.append("TorchFSDPPlugin") diff --git a/colossalai/booster/plugin/plugin/dp_plugin_base.py b/colossalai/booster/plugin/plugin/dp_plugin_base.py deleted file mode 100644 index 27285f95ce52..000000000000 --- a/colossalai/booster/plugin/plugin/dp_plugin_base.py +++ /dev/null @@ -1,76 +0,0 @@ -import random - -import numpy as np -import torch -import torch.distributed as dist -from torch.utils.data import DataLoader -from torch.utils.data.distributed import DistributedSampler - -from .plugin_base import Plugin - - -class DPPluginBase(Plugin): - """This is a base class for all DP plugins. It sets up world size and rank, and provides data loader creation.""" - - def __init__(self) -> None: - super().__init__() - assert ( - dist.is_initialized() - ), "torch.distributed is not initialized, please use colossalai.launch to create the distributed environment" - self.rank = dist.get_rank() - self.world_size = dist.get_world_size() - - def prepare_dataloader( - self, - dataset, - batch_size, - shuffle=False, - seed=1024, - drop_last=False, - pin_memory=False, - num_workers=0, - distributed_sampler_cls=None, - **kwargs, - ): - r""" - Prepare a dataloader for distributed training. The dataloader will be wrapped by - `torch.utils.data.DataLoader` and `torch.utils.data.DistributedSampler`. - - - Args: - dataset (`torch.utils.data.Dataset`): The dataset to be loaded. - shuffle (bool, optional): Whether to shuffle the dataset. Defaults to False. - seed (int, optional): Random worker seed for sampling, defaults to 1024. - add_sampler: Whether to add ``DistributedDataParallelSampler`` to the dataset. Defaults to True. - drop_last (bool, optional): Set to True to drop the last incomplete batch, if the dataset size - is not divisible by the batch size. If False and the size of dataset is not divisible by - the batch size, then the last batch will be smaller, defaults to False. - pin_memory (bool, optional): Whether to pin memory address in CPU memory. Defaults to False. - num_workers (int, optional): Number of worker threads for this dataloader. Defaults to 0. - kwargs (dict): optional parameters for ``torch.utils.data.DataLoader``, more details could be found in - `DataLoader `_. - - Returns: - :class:`torch.utils.data.DataLoader`: A DataLoader used for training or testing. - """ - _kwargs = kwargs.copy() - distributed_sampler_cls = distributed_sampler_cls or DistributedSampler - sampler = distributed_sampler_cls(dataset, num_replicas=self.world_size, rank=self.rank, shuffle=shuffle) - - # Deterministic dataloader - def seed_worker(worker_id): - worker_seed = seed - np.random.seed(worker_seed) - torch.manual_seed(worker_seed) - random.seed(worker_seed) - - return DataLoader( - dataset, - batch_size=batch_size, - sampler=sampler, - worker_init_fn=seed_worker, - drop_last=drop_last, - pin_memory=pin_memory, - num_workers=num_workers, - **_kwargs, - ) diff --git a/colossalai/booster/plugin/plugin/gemini_plugin.py b/colossalai/booster/plugin/plugin/gemini_plugin.py deleted file mode 100644 index 443c80831b14..000000000000 --- a/colossalai/booster/plugin/plugin/gemini_plugin.py +++ /dev/null @@ -1,595 +0,0 @@ -import gc -import os -import random -from pathlib import Path -from typing import Callable, Dict, Iterator, List, Optional, Tuple - -import numpy as np -import torch -import torch.distributed as dist -import torch.nn as nn -from torch.distributed.distributed_c10d import _get_default_group -from torch.optim import Optimizer -from torch.optim.lr_scheduler import _LRScheduler as LRScheduler -from torch.utils.data import DataLoader -from torch.utils.data.distributed import DistributedSampler - -from colossalai.accelerator import get_accelerator -from colossalai.checkpoint_io import CheckpointIndexFile, CheckpointIO, GeneralCheckpointIO -from colossalai.checkpoint_io.utils import ( - get_model_base_filenames, - get_optimizer_base_filenames, - load_shard_state_dict, - save_config_file, - save_state_dict, - save_state_dict_shards, -) -from colossalai.cluster import DistCoordinator, ProcessGroupMesh -from colossalai.interface import ModelWrapper, OptimizerWrapper -from colossalai.logging import get_dist_logger -from colossalai.shardformer import ShardConfig, ShardFormer -from colossalai.zero import GeminiDDP, GeminiOptimizer -from colossalai.zero.gemini.memory_tracer import MemStats - -from .dp_plugin_base import DPPluginBase - -__all__ = ["GeminiPlugin"] - -SUPPORTED_PRECISION = ["fp16", "bf16"] -PRECISION_STR_TO_DTYPE = {"fp16": torch.half, "bf16": torch.bfloat16} - -ZERO_AXIS, DP_AXIS, TP_AXIS = 0, 1, 2 - - -def get_param_info(optim: Optimizer): - # Get a backup of necessary information of parameters for future use, which includes: - # 1. A mapping from integer param_id to param32 shape. - if optim is None: - return {} - param_info = {"id2shape": {}} - - start_index = 0 - for group in optim.param_groups: - for param_id, param in enumerate(group["params"], start_index): - original_shape = param.shape if isinstance(param, torch.Tensor) else None - param_info["id2shape"][param_id] = original_shape - - start_index += len(group["params"]) - - return param_info - - -class GeminiCheckpointIO(GeneralCheckpointIO): - def __init__(self) -> None: - super().__init__() - self.coordinator = DistCoordinator() - self.logger = get_dist_logger() - - def save_unsharded_model(self, model: GeminiDDP, checkpoint: str, gather_dtensor: bool, use_safetensors: bool): - """ - Save sharded model to checkpoint but only on master process. - The model should be unwrapped in self.load_model via ModelWrapper.unwrap. - As there is communication when getting state dict, model.state_dict() must be called on all processes. - """ - assert isinstance(model, GeminiDDP), "Please boost the model before saving!" - state_dict = model.state_dict(only_rank_0=True) - if self.coordinator.is_master(): - save_state_dict(state_dict, checkpoint, use_safetensors) - - def load_unsharded_model(self, model: GeminiDDP, checkpoint: str, strict: bool = True): - """ - Load model from checkpoint with automatic unwrapping. - The model should be unwrapped in self.load_model via ModelWrapper.unwrap. - """ - assert isinstance(model, GeminiDDP), "Please boost the model before loading!" - super().load_unsharded_model(model, checkpoint, strict=strict) - - def save_unsharded_optimizer(self, optimizer: GeminiOptimizer, checkpoint: str, gather_dtensor: bool): - """ - Save unsharded optimizer state dict to checkpoint. - After calling optimizer.state_dict(), the complete optimizer states will be collected on master rank. - As there is communication when getting state dict, optimizer.state_dict() must be called on all processes. - The saving process will only be executed by master rank. - """ - assert isinstance(optimizer, GeminiOptimizer), "Please boost the optimizer before saving!" - state_dict = optimizer.state_dict() - if self.coordinator.is_master(): - save_state_dict(state_dict, checkpoint, use_safetensors=False) - - def load_unsharded_optimizer(self, optimizer: GeminiOptimizer, checkpoint: str): - """ - Loading unsharded optimizer from checkpoint file. - For each process, only loading optimizer states of parameters it controls. - """ - assert isinstance(optimizer, GeminiOptimizer), "Please boost the optimizer before loading!" - super().load_unsharded_optimizer(optimizer, checkpoint) - - def save_sharded_model( - self, - model: GeminiDDP, - checkpoint_path: str, - gather_dtensor: bool = False, - prefix: Optional[str] = None, - max_shard_size: int = 1024, - use_safetensors: bool = False, - ): - """ - Save sharded model. - As there is communication when getting state dict, model.state_dict() must be called on all processes. - """ - assert isinstance(model, GeminiDDP), "Please boost the model before saving!" - if os.path.isfile(checkpoint_path): - self.logger.error(f"Provided path ({checkpoint_path}) should be a directory, not a file") - return - - Path(checkpoint_path).mkdir(parents=True, exist_ok=True) - - state_dict_shard = model.state_dict_shard(max_shard_size=max_shard_size, only_rank_0=True) - weights_name, save_index_file = get_model_base_filenames(prefix, use_safetensors) - index_file = CheckpointIndexFile(checkpoint_path) - - # Save shards of optimizer states. - is_master = self.coordinator.is_master() - total_size = save_state_dict_shards( - sharded_state_dict=state_dict_shard, - checkpoint=checkpoint_path, - index_file=index_file, - base_filename=weights_name, - is_master=is_master, - use_safetensors=use_safetensors, - ) - - # only save the index file on the master rank - if self.coordinator.is_master(): - index_file.append_meta_data("total_size", total_size) - index_file.write_index_file(save_index_file) - save_config_file(model.unwrap(), checkpoint_path) - self.logger.info( - f"The model is split into checkpoint shards. " - f"You can find where each parameters has been saved in the " - f"index located at {save_index_file}." - ) - - def load_sharded_model( - self, model: GeminiDDP, checkpoint_index_file: Path, strict: bool = False, use_safetensors: bool = False - ): - """ - Load shard model, load model from multiple files. - """ - assert isinstance(model, GeminiDDP), "Please boost the model before loading!" - return super().load_sharded_model(model, checkpoint_index_file, strict, use_safetensors, load_sub_module=False) - - def save_sharded_optimizer( - self, optimizer: GeminiOptimizer, checkpoint: Path, gather_dtensor: bool, prefix: str, size_per_shard: int - ): - """ - Save sharded optimizer state dict to checkpoint folder. - As there is communication when getting state dict, this must be called on all processes. - """ - assert isinstance(optimizer, GeminiOptimizer), "Please boost the optimizer before saving!" - - if os.path.isfile(checkpoint): - self.logger.error(f"Provided path ({checkpoint}) should be a directory, not a file") - return - - Path(checkpoint).mkdir(parents=True, exist_ok=True) - - # Preparing file paths and index file. - states_name, save_index_file, param_group_file = get_optimizer_base_filenames(prefix) - index_file = CheckpointIndexFile(checkpoint) - index_file.append_meta_data("param_groups", param_group_file) - - # Store the information of param groups to param_group_file. - if self.coordinator.is_master(): - group_file_path = os.path.join(checkpoint, param_group_file) - param_groups = optimizer.get_param_groups_for_saving() - torch.save(param_groups, group_file_path) - - # States are broken into shards within max_shard_size. - state_dict_shard = optimizer.state_shard(prefix=prefix, max_shard_size=size_per_shard, only_rank_0=True) - - # Save shards of optimizer states. - total_size = save_state_dict_shards( - sharded_state_dict=state_dict_shard, - checkpoint=checkpoint, - index_file=index_file, - base_filename=states_name, - is_master=self.coordinator.is_master(), - use_safetensors=False, - ) - - # Wrap up index file. Only save it on master rank. - if self.coordinator.is_master(): - index_file.append_meta_data("total_size", total_size) - index_file.write_index_file(save_index_file) - self.logger.info( - f"The optimizer is going to be split to checkpoint shards. " - f"You can find where each parameters has been saved in the " - f"index located at {save_index_file}." - ) - - def load_sharded_optimizer(self, optimizer: GeminiOptimizer, checkpoint_index_file: Path, prefix: str): - """ - Loading sharded optimizer from checkpoint folder, with index file given. - For each process, only loading optimizer states of parameters it controls. - """ - assert isinstance(optimizer, GeminiOptimizer), "Please boost the optimizer before loading!" - if not os.path.isfile(checkpoint_index_file): - self.logger.error(f"Provided path ({checkpoint_index_file}) should be a file") - - assert isinstance(optimizer, GeminiOptimizer) - - # Read checkpoint index file. - ckpt_index_file = CheckpointIndexFile.from_file(checkpoint_index_file) - - # Load param_groups. - param_group_path = ckpt_index_file.get_param_group_filename() - if param_group_path is None: - raise RuntimeError( - f"Invalid index file path {checkpoint_index_file} for an optimizer. \ - Lacking param group file under current directory." - ) - saved_param_groups = torch.load(param_group_path) - optimizer.load_param_groups(saved_param_groups) - - checkpoint_files, _ = ckpt_index_file.get_checkpoint_filenames() - - # Load optimizer states from shard files under checkpoint path. - # For each file, only load the states managed by current process. - for shard_file in checkpoint_files: - state_dict_shard = load_shard_state_dict(Path(shard_file), use_safetensors=False) - optimizer.load_param_states(state_dict_shard) - del state_dict_shard - gc.collect() - - optimizer.optimizer_loading_epilogue() - - def save_lr_scheduler(self, lr_scheduler: LRScheduler, checkpoint: str): - """ - Save model to checkpoint but only on master process. - """ - if self.coordinator.is_master(): - super().save_lr_scheduler(lr_scheduler, checkpoint) - - -class GeminiPlugin(DPPluginBase): - """ - Plugin for Gemini. - - ```python - from colossalai.booster import Booster - from colossalai.booster.plugin import GeminiPlugin - - model, train_dataset, optimizer, criterion = ... - plugin = GeminiPlugin() - - train_dataloader = plugin.prepare_dataloader(train_dataset, batch_size=8) - booster = Booster(plugin=plugin) - model, optimizer, train_dataloader, criterion = booster.boost(model, optimizer, train_dataloader, criterion) - ``` - - Args: - chunk_config_dict (dict, optional): chunk configuration dictionary. - chunk_init_device (torch.device, optional): device to initialize the chunk. - placement_policy (str, optional): "static" and "auto". Defaults to "static". - enable_gradient_accumulation (bool, optional): Whether to enable gradient accumulation. When set to True, gradient will be stored after doing backward pass. Defaults to False. - shard_param_frac (float, optional): fraction of parameters to be sharded. Only for "static" placement. - If `shard_param_frac` is 1.0, it's equal to zero-3. If `shard_param_frac` is 0.0, it's equal to zero-2. Defaults to 1.0. - offload_optim_frac (float, optional): fraction of optimizer states to be offloaded. Only for "static" placement. - If `shard_param_frac` is 1.0 and `offload_optim_frac` is 0.0, it's equal to old "cuda" placement. Defaults to 0.0. - offload_param_frac (float, optional): fraction of parameters to be offloaded. Only for "static" placement. - For efficiency, this argument is useful only when `shard_param_frac` is 1.0 and `offload_optim_frac` is 1.0. - If `shard_param_frac` is 1.0, `offload_optim_frac` is 1.0 and `offload_param_frac` is 1.0, it's equal to old "cpu" placement. - When using static placement, we recommend users to tune `shard_param_frac` first and then `offload_optim_frac`. - Defaults to 0.0. - warmup_non_model_data_ratio (float, optional): ratio of expected non-model data memory during warmup. Only for "auto" placement. Defaults to 0.8. - steady_cuda_cap_ratio (float, optional): ratio of allowed cuda capacity for model data during steady state. Only for "auto" placement. Defaults to 0.9. - precision (str, optional): precision. Support 'fp16' and 'bf16'. Defaults to 'fp16'. - master_weights (bool, optional): Whether to keep fp32 master parameter weights in optimizer. Defaults to True. - pin_memory (bool, optional): use pin memory on CPU. Defaults to False. - force_outputs_fp32 (bool, optional): force outputs are fp32. Defaults to False. - strict_ddp_mode (bool, optional): use strict ddp mode (only use dp without other parallelism). Defaults to False. - search_range_m (int, optional): chunk size searching range divided by 2^20. Defaults to 32. - hidden_dim (int, optional): the hidden dimension of DNN. - Users can provide this argument to speed up searching. - If users do not know this argument before training, it is ok. We will use a default value 1024. - min_chunk_size_m (float, optional): the minimum chunk size divided by 2^20. - If the aggregate size of parameters is still smaller than the minimum chunk size, - all parameters will be compacted into one small chunk. - memstats (MemStats, optional) the memory statistics collector by a runtime memory tracer. - gpu_margin_mem_ratio (float, optional): The ratio of GPU remaining memory (after the first forward-backward) - which will be used when using hybrid CPU optimizer. - This argument is meaningless when `placement_policy` of `GeminiManager` is not "auto". - Defaults to 0.0. - initial_scale (float, optional): Initial scale used by DynamicGradScaler. Defaults to 2**16. - min_scale (float, optional): Min scale used by DynamicGradScaler. Defaults to 1. - growth_factor (float, optional): growth_factor used by DynamicGradScaler. Defaults to 2. - backoff_factor (float, optional): backoff_factor used by DynamicGradScaler. Defaults to 0.5. - growth_interval (float, optional): growth_interval used by DynamicGradScaler. Defaults to 1000. - hysteresis (float, optional): hysteresis used by DynamicGradScaler. Defaults to 2. - max_scale (int, optional): max_scale used by DynamicGradScaler. Defaults to 2**32. - max_norm (float, optional): max_norm used for `clip_grad_norm`. You should notice that you shall not do - clip_grad_norm by yourself when using ZeRO DDP. The ZeRO optimizer will take care of clip_grad_norm. - norm_type (float, optional): norm_type used for `clip_grad_norm`. - tp_size (int, optional): If 'tp_size' is set to be greater than 1, it means using tensor parallelism strategy, which is implemented in Shardformer, 'tp_size' determines the size of the tensor parallel process group. Default to 1. - extra_dp_size (int, optional): If 'extra_dp_size' is set to be greater than 1, it means creating another group to run with a ddp-like strategy. Default to 1. - enable_all_optimization (bool, optional): Whether to switch on all the optimizations supported by Shardformer. - Currently all the optimization methods include fused normalization, flash attention and JIT. - Defaults to False. - enable_fused_normalization (bool, optional): Whether to switch on fused normalization in Shardformer. Defaults to False. - enable_flash_attention (bool, optional): Whether to switch on flash attention in Shardformer. Defaults to False. - enable_jit_fused (bool, optional): Whether to switch on JIT in Shardformer. Default to False. - enable_sequence_parallelism (bool): Whether to turn on sequence parallelism in Shardformer. Defaults to False. - enable_sequence_overlap (bool): Whether to turn on sequence overlap in Shardformer. Defaults to False. - verbose (bool, optional): verbose mode. Debug info including chunk search result will be printed. Defaults to False. - """ - - def __init__( - self, - chunk_config_dict: Optional[dict] = None, - chunk_init_device: Optional[torch.device] = None, - placement_policy: str = "static", - enable_gradient_accumulation: bool = False, - max_prefetch: int = 0, - shard_param_frac: float = 1.0, # only for static placement - offload_optim_frac: float = 0.0, # only for static placement - offload_param_frac: float = 0.0, # only for static placement - warmup_non_model_data_ratio: float = 0.8, # only for auto placement - steady_cuda_cap_ratio: float = 0.9, # only for auto placement - precision: str = "fp16", - master_weights: bool = True, - pin_memory: bool = False, - force_outputs_fp32: bool = False, - strict_ddp_mode: bool = False, - search_range_m: int = 32, - hidden_dim: Optional[int] = None, - min_chunk_size_m: float = 32, - memstats: Optional[MemStats] = None, - gpu_margin_mem_ratio: float = 0.0, - initial_scale: float = 2**16, - min_scale: float = 1, - growth_factor: float = 2, - backoff_factor: float = 0.5, - growth_interval: int = 1000, - hysteresis: int = 2, - max_scale: float = 2**32, - max_norm: float = 0.0, - norm_type: float = 2.0, - tp_size: int = 1, - extra_dp_size: int = 1, - enable_all_optimization: bool = False, - enable_fused_normalization: bool = False, - enable_flash_attention: bool = False, - enable_sequence_parallelism: bool = False, - enable_jit_fused: bool = False, - enable_sequence_overlap: bool = False, - enable_async_reduce: bool = True, - verbose: bool = False, - ) -> None: - super().__init__() - assert precision in SUPPORTED_PRECISION, f"precision {precision} is not supported" - if get_accelerator().name == "npu": - assert placement_policy == "static", "NPU only supports static placement policy" - - self.logger = get_dist_logger() - if enable_async_reduce and not pin_memory: - self.logger.warning( - f"enable_async_reduce sets pin_memory=True to achieve best performance, which is not implicitly set." - ) - pin_memory = True - self.gemini_config = dict( - chunk_config_dict=chunk_config_dict, - chunk_init_device=(chunk_init_device or get_accelerator().get_current_device()), - placement_policy=placement_policy, - enable_gradient_accumulation=enable_gradient_accumulation, - shard_param_frac=shard_param_frac, - offload_optim_frac=offload_optim_frac, - offload_param_frac=offload_param_frac, - warmup_non_model_data_ratio=warmup_non_model_data_ratio, - steady_cuda_cap_ratio=steady_cuda_cap_ratio, - pin_memory=pin_memory, - force_outputs_fp32=force_outputs_fp32, - strict_ddp_mode=strict_ddp_mode, - search_range_m=search_range_m, - hidden_dim=hidden_dim, - min_chunk_size_m=min_chunk_size_m, - memstats=memstats, - mixed_precision=PRECISION_STR_TO_DTYPE[precision], - master_weights=master_weights, - max_prefetch=max_prefetch, - enable_async_reduce=enable_async_reduce, - ) - self.zero_optim_config = dict( - gpu_margin_mem_ratio=gpu_margin_mem_ratio, - ) - self.optim_kwargs = dict( - initial_scale=initial_scale, - growth_factor=growth_factor, - backoff_factor=backoff_factor, - growth_interval=growth_interval, - hysteresis=hysteresis, - min_scale=min_scale, - max_scale=max_scale, - max_norm=max_norm, - norm_type=norm_type, - ) - self.enable_tensor_parallelism = tp_size > 1 - self.enable_all_optimization = enable_all_optimization - self.enable_fused_normalization = enable_fused_normalization - self.enable_flash_attention = enable_flash_attention - self.enable_sequence_parallelism = enable_sequence_parallelism if self.enable_tensor_parallelism else False - self.enable_jit_fused = enable_jit_fused - self.enable_sequence_overlap = enable_sequence_overlap - self.verbose = verbose - - self.tp_size = tp_size - self.extra_dp_size = extra_dp_size - world_size = dist.get_world_size() - self.zero_size = world_size // (self.tp_size * self.extra_dp_size) - assert ( - world_size == (self.tp_size * self.extra_dp_size) * self.zero_size - ), f"The global group size can't be evenly divided by the subgroup size." - - self.pg_mesh = ProcessGroupMesh(self.zero_size, self.extra_dp_size, self.tp_size) - self.zero_group = ( - self.pg_mesh.get_group_along_axis(ZERO_AXIS) if self.zero_size < world_size else _get_default_group() - ) - self.extra_dp_group = self.pg_mesh.get_group_along_axis(DP_AXIS) if self.extra_dp_size > 1 else None - self.tp_group = self.pg_mesh.get_group_along_axis(TP_AXIS) if self.tp_size > 1 else None - self.dp_size = self.zero_size * self.extra_dp_size - - self.shard_config = ShardConfig( - tensor_parallel_process_group=self.tp_group, - enable_tensor_parallelism=self.enable_tensor_parallelism, - enable_all_optimization=self.enable_all_optimization, - enable_fused_normalization=self.enable_fused_normalization, - enable_flash_attention=self.enable_flash_attention, - enable_jit_fused=self.enable_jit_fused, - enable_sequence_parallelism=self.enable_sequence_parallelism, - enable_sequence_overlap=self.enable_sequence_overlap, - ) - - def __del__(self): - """Destroy the process groups in ProcessGroupMesh""" - self.pg_mesh.destroy_mesh_process_groups() - - def support_no_sync(self) -> bool: - return False - - def support_lora(self) -> bool: - return False - - def control_precision(self) -> bool: - return True - - def supported_precisions(self) -> List[str]: - return SUPPORTED_PRECISION - - def control_device(self) -> bool: - return True - - def supported_devices(self) -> List[str]: - return ["cuda", "npu"] - - def prepare_dataloader( - self, - dataset, - batch_size, - shuffle=False, - seed=1024, - drop_last=False, - pin_memory=False, - num_workers=0, - distributed_sampler_cls=None, - **kwargs, - ): - r""" - Prepare a dataloader for distributed training. The dataloader will be wrapped by - `torch.utils.data.DataLoader` and `torch.utils.data.DistributedSampler`. - - - Args: - dataset (`torch.utils.data.Dataset`): The dataset to be loaded. - shuffle (bool, optional): Whether to shuffle the dataset. Defaults to False. - seed (int, optional): Random worker seed for sampling, defaults to 1024. - add_sampler: Whether to add ``DistributedDataParallelSampler`` to the dataset. Defaults to True. - drop_last (bool, optional): Set to True to drop the last incomplete batch, if the dataset size - is not divisible by the batch size. If False and the size of dataset is not divisible by - the batch size, then the last batch will be smaller, defaults to False. - pin_memory (bool, optional): Whether to pin memory address in CPU memory. Defaults to False. - num_workers (int, optional): Number of worker threads for this dataloader. Defaults to 0. - kwargs (dict): optional parameters for ``torch.utils.data.DataLoader``, more details could be found in - `DataLoader `_. - - Returns: - :class:`torch.utils.data.DataLoader`: A DataLoader used for training or testing. - """ - _kwargs = kwargs.copy() - zero_world_size = self.pg_mesh.size(ZERO_AXIS) - extra_dp_world_size = self.pg_mesh.size(DP_AXIS) - zero_rank = self.pg_mesh.coordinate(ZERO_AXIS) - extra_dp_rank = self.pg_mesh.coordinate(DP_AXIS) - distributed_sampler_cls = distributed_sampler_cls or DistributedSampler - sampler = distributed_sampler_cls( - dataset, - num_replicas=zero_world_size * extra_dp_world_size, - rank=zero_rank * extra_dp_world_size + extra_dp_rank, - shuffle=shuffle, - ) - - # Deterministic dataloader - def seed_worker(worker_id): - worker_seed = seed - np.random.seed(worker_seed) - torch.manual_seed(worker_seed) - random.seed(worker_seed) - - return DataLoader( - dataset, - batch_size=batch_size, - sampler=sampler, - worker_init_fn=seed_worker, - drop_last=drop_last, - pin_memory=pin_memory, - num_workers=num_workers, - **_kwargs, - ) - - def configure( - self, - model: nn.Module, - optimizer: Optional[Optimizer] = None, - criterion: Optional[Callable] = None, - dataloader: Optional[DataLoader] = None, - lr_scheduler: Optional[LRScheduler] = None, - ) -> Tuple[nn.Module, OptimizerWrapper, Callable, DataLoader, LRScheduler]: - params_info = get_param_info(optimizer) - if not isinstance(model, ModelWrapper): - # convert model to sync bn - # FIXME(ver217): gemini does not support sync bn - # In torch/nn/modules/_functions.py, line 22, ``mean, invstd = torch.batch_norm_stats(input, eps)`` will get fp32 mean and invstd even though the input is fp16. - # This inconsistency of dtype will cause the error. - # We have two possible solutions: - # 1. keep batch norm always in fp32. This is hard for gemini, as it use chunks. - # 2. patch sync bn or write a new on. This is relatively easy, but we need to test it. - # model = nn.SyncBatchNorm.convert_sync_batchnorm(model, None) - - # wrap the model with Gemini - if self.enable_tensor_parallelism: - shardformer = ShardFormer(self.shard_config) - model, _ = shardformer.optimize(model) - - model = GeminiDDP( - model, - **self.gemini_config, - zero_group=self.zero_group, - extra_dp_group=self.extra_dp_group, - verbose=self.verbose, - ) - - if optimizer is not None and not isinstance(optimizer, OptimizerWrapper): - optimizer = GeminiOptimizer( - optimizer, - model, - **self.zero_optim_config, - **self.optim_kwargs, - tp_group=self.tp_group, - params_info=params_info, - verbose=self.verbose, - ) - - return model, optimizer, criterion, dataloader, lr_scheduler - - def control_checkpoint_io(self) -> bool: - return True - - def get_checkpoint_io(self) -> CheckpointIO: - return GeminiCheckpointIO() - - def no_sync(self, model: nn.Module, optimizer: OptimizerWrapper) -> Iterator[None]: - raise NotImplementedError - - def enable_lora( - self, model: nn.Module, pretrained_dir: Optional[str] = None, lora_config: Optional[Dict] = None - ) -> nn.Module: - raise NotImplementedError diff --git a/colossalai/booster/plugin/plugin/hybrid_parallel_plugin.py b/colossalai/booster/plugin/plugin/hybrid_parallel_plugin.py deleted file mode 100644 index 6c1515d38834..000000000000 --- a/colossalai/booster/plugin/plugin/hybrid_parallel_plugin.py +++ /dev/null @@ -1,1466 +0,0 @@ -import ctypes -import random -from collections import defaultdict -from contextlib import contextmanager, nullcontext -from copy import deepcopy -from functools import partial -from types import MethodType -from typing import Any, Callable, Dict, Iterator, List, Optional, OrderedDict, Tuple, Union - -import numpy as np -import torch -import torch.distributed as dist -from torch import Tensor, inf -from torch.distributed import ProcessGroup, get_world_size -from torch.nn import Module, SyncBatchNorm -from torch.nn.parallel import DistributedDataParallel as DDP -from torch.optim import Optimizer -from torch.optim.lr_scheduler import _LRScheduler as LRScheduler -from torch.utils._pytree import tree_map -from torch.utils.data import DataLoader -from torch.utils.data.distributed import DistributedSampler - -from colossalai.accelerator import get_accelerator -from colossalai.amp.naive_amp.mixed_precision_optimizer import MixedPrecisionOptimizer -from colossalai.checkpoint_io import CheckpointIO, HybridParallelCheckpointIO -from colossalai.cluster import ProcessGroupMesh -from colossalai.interface import AMPModelMixin, ModelWrapper, OptimizerWrapper -from colossalai.interface.optimizer import DistributedOptim -from colossalai.logging import get_dist_logger -from colossalai.nn.optimizer import DistGaloreAwamW, cast_to_distributed -from colossalai.pipeline.schedule import InterleavedSchedule, OneForwardOneBackwardSchedule -from colossalai.pipeline.stage_manager import PipelineStageManager -from colossalai.quantization import BnbQuantizationConfig, quantize_model -from colossalai.shardformer import GradientCheckpointConfig, ShardConfig, ShardFormer -from colossalai.shardformer.layer.utils import SeqParallelUtils, is_share_sp_tp -from colossalai.shardformer.policies.base_policy import Policy -from colossalai.tensor.colo_parameter import ColoParameter -from colossalai.tensor.d_tensor.api import is_distributed_tensor -from colossalai.tensor.param_op_hook import ColoParamOpHookManager -from colossalai.zero.low_level import LowLevelZeroOptimizer -from colossalai.zero.low_level.zero_hook import ZeroOpHook, wait_all_gather_handle - -from .pp_plugin_base import PipelinePluginBase - -SUPPORT_SP_MODE = ["split_gather", "ring", "all_to_all", "ring_attn"] - -PRECISION_TORCH_TYPE = {"fp16": torch.float16, "fp32": torch.float32, "bf16": torch.bfloat16} - - -def _convert_floating_point(x, dtype: torch.dtype = torch.float16): - if isinstance(x, torch.Tensor) and torch.is_floating_point(x): - return x.to(dtype) - return x - - -class HybridParallelModule(ModelWrapper, AMPModelMixin): - def __init__( - self, - module: Module, - precision: str, - shard_config: ShardConfig, - dp_group: ProcessGroup, - tp_group: ProcessGroup, - sp_group: ProcessGroup, - use_ddp: bool, - ddp_config: dict, - custom_policy: Policy, - overlap_allgather: bool = False, - ) -> None: - self.stage_manager = shard_config.pipeline_stage_manager - self.shard_config = shard_config - self.dp_group = dp_group - self.tp_group = tp_group - self.sp_group = sp_group - self.use_ddp = use_ddp - self.require_grad_sync = True - self.overlap_allgather = overlap_allgather - - shardformer = ShardFormer(shard_config) - if custom_policy is not None: - assert isinstance(custom_policy, object) - module, self.shared_params = shardformer.optimize(module, policy=custom_policy) - - # setting process groups for shared parameters - self.shared_param_process_groups = [] - for shared_param in self.shared_params: - if len(shared_param) > 0: - self.shared_param_process_groups.append( - self.stage_manager.init_process_group_by_stages(list(shared_param.keys())) - ) - - # setting mixed_precision - self.mixed_precision = None - if precision == "fp16": - self.mixed_precision = torch.float16 - elif precision == "bf16": - self.mixed_precision = torch.bfloat16 - if self.mixed_precision is not None: - module = module.to(self.mixed_precision) - module = module.to(get_accelerator().get_current_device()) - - # setting input type cast when using mixed precision - self.convert_fn = None - if self.mixed_precision is not None: - self.convert_fn = partial(_convert_floating_point, dtype=self.mixed_precision) - - # setting ddp configs - if use_ddp: - # convert model to sync bn - module = SyncBatchNorm.convert_sync_batchnorm(module, dp_group) - # wrap the model with PyTorch DDP - module = DDP(module, process_group=dp_group, **ddp_config) - - super().__init__(module) - if overlap_allgather: - self.op_hook = ZeroOpHook() - for p in module.parameters(): - if p.requires_grad and type(p) is not ColoParameter: - p.__class__ = ColoParameter - p.__init__(p, requires_grad=True) - - def sync_shared_params(self): - for shared_param, group in zip(self.shared_params, self.shared_param_process_groups): - if self.stage_manager.stage in shared_param: - param = shared_param[self.stage_manager.stage] - dist.all_reduce(param.grad, group=group) - dist.barrier() - - @contextmanager - def no_sync(self): - r""" - A context manager to disable automatic gradient synchronization (all-reduce) and allow manual synchronization - when 'no_sync' is active. Alternatively, synchronization will occur in the first forward-backward pass - when exiting the context. - """ - - # Store the current value of 'require_grad_sync' to restore it later. - old_require_grad_sync = self.require_grad_sync - # Disable automatic gradient synchronization. - self.require_grad_sync = False - try: - if self.use_ddp: - # If using data parallel processing (use_ddp), disable synchronization too. - with self.module.no_sync(): - yield - else: - yield - finally: - # Restore the original value of 'require_grad_sync'. - self.require_grad_sync = old_require_grad_sync - - def sync_dp_grads(self): - r""" - Synchronize gradients across data parallelism (DP) if the DP group size is greater than 1. - This function performs an all-reduce operation to combine gradients from different devices in the DP group. - - Args: - None - - Returns: - None - """ - - # Check if the DP group size is 1, meaning no synchronization is needed. - if self.dp_group.size() == 1: - return - - # Iterate through the model's parameters and perform gradient synchronization. - for p in self.module.parameters(): - if p.grad is not None: - # Perform all-reduce to combine gradients from different devices. - dist.all_reduce(p.grad, group=self.dp_group) - # Normalize the gradient by dividing it by the DP group size. - p.grad.div_(self.dp_group.size()) - - def sync_sp_grads(self, grads: Optional[List[torch.Tensor]] = None): - r""" - Synchronize gradients that are partially derived within sequence parallelism - if sequence parallelism is enabled. Gradients can be provided explicitly or extracted - from the module. - - Args: - grads (Optional[List[torch.Tensor]]): A list of gradient tensors to synchronize. If not - provided, gradients will be extracted from the model. - - Returns: - None - """ - - if self.shard_config.enable_sequence_parallelism: - if self.shard_config.sequence_parallelism_mode in ["all_to_all", "ring_attn"]: - return - - if self.shard_config.sequence_parallelism_mode in ["split_gather", "ring"]: - # If sequence parallelism is enabled and mode is split_gather or ring, gradients are synchronized - # across the tensor parallelism group. - group = self.tp_group - else: - raise ValueError(f"Unknown sequence parallelism mode: {self.shard_config.sequence_parallelism_mode}") - - if grads is not None: - # Synchronize provided gradient tensors across the tensor parallelism group. - SeqParallelUtils.allreduce_partial_data_grad(process_group=group, grads=grads) - else: - # Synchronize gradients from the model across the tensor parallelism group. - SeqParallelUtils.allreduce_partial_data_grad(process_group=group, model=self.module) - - def forward(self, *args, **kwargs): - if self.convert_fn is not None: - args = tree_map(self.convert_fn, args) - kwargs = tree_map(self.convert_fn, kwargs) - with self._wait_all_gather(): - return super().forward(*args, **kwargs) - - def unwrap(self): - module = super().unwrap() - if isinstance(module, DDP): - module = module.module - return module - - def _force_wait_all_gather(self): - for p in self.module.parameters(): - wait_all_gather_handle(p) - - def _wait_all_gather(self): - return ColoParamOpHookManager.use_hooks(self.op_hook) if self.overlap_allgather else nullcontext() - - -def get_param_info(optim: Optimizer): - # Get a backup of necessary information of parameters for future use, which includes: - # 1. A complete param_group, with params in the form of param_id - # 2. A mapping from param address (obtained using id(param)) to integer param_id - # 3. A mapping from integer param_id to param address. - # 4. A mapping from param_address (obtained using id(param)) to the original shape of parameter before sharding. - # When Zero is used, the params here are fp16/bf16 model params rather than fp32 master params in optimizer. - - if optim is None: - return {} - param_info = {"param_groups": [], "param2id": {}, "id2param": {}, "param2shape": {}} - start_index = 0 - for group in optim.param_groups: - packed_group = {k: v for k, v in group.items() if k != "params"} - packed_group["params"] = [] - - for param_id, param in enumerate(group["params"], start_index): - original_shape = param.shape if isinstance(param, torch.Tensor) else None - packed_group["params"].append(param_id) - param_info["param2id"][id(param)] = param_id - param_info["id2param"][param_id] = id(param) - param_info["param2shape"][id(param)] = original_shape - - param_info["param_groups"].append(packed_group) - start_index += len(group["params"]) - - return param_info - - -def reinitialize_optimizer(optim: Optimizer, model: Module): - model_params = set(model.parameters()) - new_param_groups = [] - for group in optim.param_groups: - params = [p for p in group["params"] if p in model_params] - new_param_groups.append({**group, "params": params}) - optim.__setstate__({"param_groups": new_param_groups}) - - -class HybridParallelNaiveOptimizer(OptimizerWrapper): - def __init__( - self, - optim: Optimizer, - model: HybridParallelModule, - use_pipeline: bool, - param_info: OrderedDict, - max_norm: float = 0, - tp_process_group: Optional[ProcessGroup] = None, # if using tp - pp_process_group: Optional[ProcessGroup] = None, # if using pp - ): - self.param_info = param_info - if use_pipeline: - reinitialize_optimizer(optim, model) - self.model = model - self.stage_manager = model.stage_manager - self.shared_params = model.shared_params - self.max_norm = max_norm - self.tp_pg = tp_process_group - self.pp_pg = pp_process_group - self.tp_size = get_world_size(self.tp_pg) if self.tp_pg is not None else 1 - self.pp_size = get_world_size(self.pp_pg) if self.pp_pg is not None else 1 - super().__init__(optim) - - def backward(self, loss: Tensor, *args, **kwargs): - r""" - Backpropagate gradients through the model and optionally synchronize sequence parallelism gradients. - - This method performs backward pass for gradient computation. If sequence parallelism is enabled - and gradient synchronization is required, it will synchronize gradients that are partially derived - within sequence parallelism across tp parallelism groups. - - Args: - loss (Tensor): The loss tensor to compute gradients with respect to. - *args: Additional positional arguments to be passed to the superclass backward method. - **kwargs: Additional keyword arguments to be passed to the superclass backward method. - - Returns: - None - """ - - # Call the superclass backward method to compute gradients. - super().backward(loss, *args, **kwargs) - - if self.model.require_grad_sync: - # If gradient synchronization is required, sync sequence parallelism gradients. - self.model.sync_sp_grads() - else: - # If gradient synchronization is is not required, return. - return - - def backward_by_grad(self, tensor: Tensor, grad: Tensor): - """ - Backpropagate gradients through the model using a precomputed gradient and optionally synchronize sequence parallelism gradients. - - This method performs a backward pass for gradient computation using a precomputed gradient tensor. - If sequence parallelism is enabled and gradient synchronization is required, it will synchronize - gradients that are partially derived within sequence parallelism across tp parallelism groups. - - Args: - tensor (Tensor): The input tensor for which gradients are computed. - grad (Tensor): The precomputed gradient tensor to compute gradients with respect to the input tensor. - - Returns: - None - """ - - # Call the superclass backward method to compute gradients. - super().backward_by_grad(tensor, grad) - - if self.model.require_grad_sync: - # If gradient synchronization is required, sync sequence parallelism gradients. - self.model.sync_sp_grads() - else: - # If gradient synchronization is is not required, return. - return - - def step(self, *args, **kwargs): - r""" - Perform an optimization step. - - Args: - *args: Variable-length positional arguments to be passed to the optimizer's step function. - **kwargs: Keyword arguments to be passed to the optimizer's step function. - """ - - if self.max_norm > 0: - # Compute the total gradient norm. - param_gradient_pairs = [ - (p, p.grad) for group in self.optim.param_groups for p in group["params"] if p.grad is not None - ] - total_norm = self._compute_grad_norm(param_gradient_pairs) - - # Clip the gradients to prevent exploding gradients. - self._clip_grad_norm(total_norm) - - # Perform the optimization step using the underlying optimizer. - self.optim.step(*args, **kwargs) - - def _compute_grad_norm(self, param_gradient_pairs: List[Tuple[Tensor]], norm_type: int = 2) -> int: - r""" - Compute and return the gradient norm for gradient clipping. - - Args: - param_gradient_pairs (List[Tuple[Tensor]]): List of (parameter, gradient) pairs; gradients are used for norm calculation. - norm_type (int, optional): Type of the norm used (e.g., 2 for L2 norm). Defaults to 2. - - Returns: - float: The total norm of the given gradients. - """ - - if len(param_gradient_pairs) == 0: - return 0.0 - - norm_type = float(norm_type) - - # gradients used for norm calculation. - gradients = [grad for param, grad in param_gradient_pairs] - - if norm_type == inf: - total_norm = max(grad.data.abs().max() for grad in gradients) - total_norm_cuda = torch.tensor( - [float(total_norm)], device=get_accelerator().get_current_device(), dtype=torch.float32 - ) - if self.tp_size > 1: - dist.all_reduce(tensor=total_norm_cuda, op=dist.ReduceOp.MAX, group=self.tp_pg) - if self.pp_size > 1: - dist.all_reduce(tensor=total_norm_cuda, op=dist.ReduceOp.MAX, group=self.pp_pg) - total_norm = total_norm_cuda.item() - else: - # gradients used for norm calculation. - gradients = [grad for param, grad in param_gradient_pairs] - # grad_to_param_mapping is used to check which gradients are not distributed across devices of the 'tp_group'. - grad_to_param_mapping = {id(grad): param for param, grad in param_gradient_pairs} - - total_norm_exponentiated = 0.0 - for grad in gradients: - grad_norm_exponentiated = grad.data.double().norm(norm_type) ** norm_type - - # If 'tp_size' is greater than 1 and the parameter for the gradient is not a distributed tensor, - # it indicates that the parameter is not distributed across devices of the 'tp_group'. - # Consequently, there is no need to perform an 'all_reduce' operation for 'grad_norm'. - # However, we still perform the 'all_reduce' operation for the sake of good coding practices. - # To ensure mathematical equivalence, we divide the 'grad_norm' by 'tp_size.' - if self.tp_size > 1: - param_for_grad = grad_to_param_mapping[id(grad)] - if not is_distributed_tensor(param_for_grad): - grad_norm_exponentiated /= self.tp_size - - # If 'pp_size' is greater than 1 and the gradient belongs to shared parameters, - # it means that this parameter is used in two different pipeline stages. - # To avoid redundant norm calculations, we divide the exponent of this norm by - # the number of shared stages. - if self.pp_size > 1: - for shared_param in self.shared_params: - if self.stage_manager.stage in shared_param: - stage_shared_param = shared_param[self.stage_manager.stage] - if grad is stage_shared_param.grad: - grad_norm_exponentiated /= len(shared_param) - - total_norm_exponentiated += grad_norm_exponentiated - - total_norm_exponentiated_cuda = torch.tensor( - [float(total_norm_exponentiated)], device=get_accelerator().get_current_device(), dtype=torch.float32 - ) - if self.tp_size > 1: - # compute norm in tp process group - dist.all_reduce(tensor=total_norm_exponentiated_cuda, op=dist.ReduceOp.SUM, group=self.tp_pg) - if self.pp_size > 1: - # compute norm in pp process group - dist.all_reduce(tensor=total_norm_exponentiated_cuda, op=dist.ReduceOp.SUM, group=self.pp_pg) - - # compute the total_norm - total_norm = total_norm_exponentiated_cuda.item() ** (1.0 / norm_type) - - return total_norm - - def _clip_grad_norm(self, total_norm: float) -> None: - r""" - Clips the gradients of the model's parameters to prevent exploding gradients. - - Args: - total_norm (float): The computed total gradient norm. - - Returns: - None - """ - clip_coef = torch.tensor(self.max_norm / (total_norm + 1e-6)) - clip_coef_clamped = torch.clamp(clip_coef, max=1.0) - - for group in self.optim.param_groups: - for p in group["params"]: - if p.grad is None: - continue - p.grad.data.mul_(clip_coef_clamped) - - def update_master_params(self, model: Module): - pass - - def get_working_to_master_map(self): - return None - - def get_master_to_working_map(self): - return None - - -class HybridParallelAMPOptimizer(MixedPrecisionOptimizer): - def __init__( - self, - optim: Optimizer, - model: HybridParallelModule, - use_pipeline: bool, - param_info: OrderedDict, - precision: str = "fp16", - initial_scale: float = 2**16, - min_scale: float = 1, - growth_factor: float = 2, - backoff_factor: float = 0.5, - growth_interval: int = 1000, - hysteresis: int = 2, - max_scale: float = 2**32, - max_norm: float = 0, - tp_process_group: Optional[ProcessGroup] = None, # if using tp - pp_process_group: Optional[ProcessGroup] = None, # if using pp - ): - self.model = model - self.param_info = param_info - self.stage_manager = model.stage_manager - self.shared_params = model.shared_params - self.tp_pg = tp_process_group - self.pp_pg = pp_process_group - self.tp_size = get_world_size(self.tp_pg) if self.tp_pg is not None else 1 - self.pp_size = get_world_size(self.pp_pg) if self.pp_pg is not None else 1 - if use_pipeline: - reinitialize_optimizer(optim, model) - super().__init__( - optim, - precision=precision, - initial_scale=initial_scale, - min_scale=min_scale, - growth_factor=growth_factor, - backoff_factor=backoff_factor, - growth_interval=growth_interval, - hysteresis=hysteresis, - max_scale=max_scale, - max_norm=max_norm, - ) - - def backward(self, loss: Tensor, *args, **kwargs): - r""" - Backpropagate gradients through the model and optionally synchronize sequence parallelism gradients. - - This method performs backward pass for gradient computation. If sequence parallelism is enabled - and gradient synchronization is required, it will synchronize gradients that are partially derived - within sequence parallelism across tp parallelism groups. - - Args: - loss (Tensor): The loss tensor to compute gradients with respect to. - *args: Additional positional arguments to be passed to the superclass backward method. - **kwargs: Additional keyword arguments to be passed to the superclass backward method. - - Returns: - None - """ - # Call the superclass backward method to compute gradients. - super().backward(loss, *args, **kwargs) - - if self.model.require_grad_sync: - # If gradient synchronization is required, sync sequence parallelism gradients. - self.model.sync_sp_grads() - else: - # If gradient synchronization is is not required, return. - return - - def backward_by_grad(self, tensor: Tensor, grad: Tensor): - """ - Backpropagate gradients through the model using a precomputed gradient and optionally synchronize sequence parallelism gradients. - - This method performs a backward pass for gradient computation using a precomputed gradient tensor. - If sequence parallelism is enabled and gradient synchronization is required, it will synchronize - gradients that are partially derived within sequence parallelism across tp parallelism groups. - - Args: - tensor (Tensor): The input tensor for which gradients are computed. - grad (Tensor): The precomputed gradient tensor to compute gradients with respect to the input tensor. - - Returns: - None - """ - # Call the superclass backward method to compute gradients. - super().backward_by_grad(tensor, grad) - - if self.model.require_grad_sync: - # If gradient synchronization is required, sync sequence parallelism gradients. - self.model.sync_sp_grads() - else: - # If gradient synchronization is is not required, return. - return - - def _compute_grad_norm(self, param_gradient_pairs: List[Tuple[Tensor]], norm_type: int = 2) -> int: - r""" - Compute and return the gradient norm for gradient clipping. - - Args: - param_gradient_pairs (List[Tuple[Tensor]]): List of (parameter, gradient) pairs; gradients are used for norm calculation. - norm_type (int, optional): Type of the norm used (e.g., 2 for L2 norm). Defaults to 2. - - Returns: - float: The total norm of the given gradients. - """ - if len(param_gradient_pairs) == 0: - return 0.0 - - norm_type = float(norm_type) - - if norm_type == inf: - # The parent class calculates the norm of 'dp' gradients, - # so we need to calculate the norm of 'tp' and 'pp' gradients. - total_norm = super()._compute_grad_norm(param_gradient_pairs, norm_type) - - total_norm_cuda = torch.tensor( - [float(total_norm)], device=get_accelerator().get_current_device(), dtype=torch.float32 - ) - - if self.tp_size > 1: - dist.all_reduce(tensor=total_norm_cuda, op=dist.ReduceOp.MAX, group=self.tp_pg) - if self.pp_size > 1: - dist.all_reduce(tensor=total_norm_cuda, op=dist.ReduceOp.MAX, group=self.pp_pg) - - total_norm = total_norm_cuda.item() - - else: - # gradients used for norm calculation. - gradients = [grad for param, grad in param_gradient_pairs] - # grad_to_param_mapping is used to check which gradients are not distributed in tensor parallelism. - grad_to_param_mapping = {id(grad): param for param, grad in param_gradient_pairs} - - total_norm_exponentiated = 0.0 - for grad in gradients: - grad_norm_exponentiated = grad.data.double().norm(norm_type) ** norm_type - - # If 'tp_size' is greater than 1 and the parameter for the gradient is not a distributed tensor, - # it indicates that the parameter is not distributed across devices of the 'tp_group'. - # Consequently, there is no need to perform an 'all_reduce' operation for 'grad_norm'. - # However, we still perform the 'all_reduce' operation for the sake of good coding practices. - # To ensure mathematical equivalence, we divide the 'grad_norm' by 'tp_size.' - if self.tp_size > 1: - param_for_grad = grad_to_param_mapping[id(grad)] - if not is_distributed_tensor(param_for_grad): - grad_norm_exponentiated /= self.tp_size - - # If 'pp_size' is greater than 1 and the gradient belongs to shared parameters, - # it means that this parameter is used in two different pipeline stages. - # To avoid redundant norm calculations, we divide the exponent of this norm by - # the number of shared stages. - if self.pp_size > 1: - for shared_param in self.shared_params: - if self.stage_manager.stage in shared_param: - stage_working_shared_param = shared_param[self.stage_manager.stage] - stage_master_shared_param = self.working_to_master_map[stage_working_shared_param] - if grad is stage_master_shared_param.grad: - grad_norm_exponentiated /= len(shared_param) - - total_norm_exponentiated += grad_norm_exponentiated - - total_norm_exponentiated_cuda = torch.tensor( - [float(total_norm_exponentiated)], device=get_accelerator().get_current_device(), dtype=torch.float32 - ) - if self.tp_size > 1: - # compute norm in tp process group - dist.all_reduce(tensor=total_norm_exponentiated_cuda, op=dist.ReduceOp.SUM, group=self.tp_pg) - if self.pp_size > 1: - # compute norm in pp process group - dist.all_reduce(tensor=total_norm_exponentiated_cuda, op=dist.ReduceOp.SUM, group=self.pp_pg) - - # compute the total_norm - total_norm = total_norm_exponentiated_cuda.item() ** (1.0 / norm_type) - - return total_norm - - -class HybridParallelZeroOptimizer(LowLevelZeroOptimizer): - def __init__( - self, - optimizer: Optimizer, - model: HybridParallelModule, - use_pipeline: bool, - param_info: OrderedDict, - pg_to_param_list: Dict[ProcessGroup, List[torch.nn.Parameter]] = None, - initial_scale: int = 2**16, # grad scaler config - min_scale: int = 1, - growth_factor: float = 2.0, - backoff_factor: float = 0.5, - growth_interval: int = 2000, - hysteresis: int = 2, - max_scale: int = 2**24, - clip_grad_norm: float = 0.0, # grad clipping - verbose: bool = False, - reduce_bucket_size: int = 1024 * 1024, # communication - communication_dtype: Optional[torch.dtype] = None, - overlap_communication: bool = True, - partition_grad: bool = False, # stage 2 flag - cpu_offload: bool = False, # cpu offload - dp_process_group: Optional[ProcessGroup] = None, # the dp pg for comm - tp_process_group: Optional[ProcessGroup] = None, # if using tp - pp_process_group: Optional[ProcessGroup] = None, # if using pp - forced_dtype: Optional[torch.dtype] = None, - overlap_allgather: bool = False, - ): - self.model = model - self.param_info = param_info - self.stage_manager = model.stage_manager - self.shared_params = model.shared_params - self.tp_pg = tp_process_group - self.pp_pg = pp_process_group - if use_pipeline: - reinitialize_optimizer(optimizer, model) - super().__init__( - optimizer=optimizer, - initial_scale=initial_scale, - min_scale=min_scale, - pg_to_param_list=pg_to_param_list, - growth_factor=growth_factor, - backoff_factor=backoff_factor, - growth_interval=growth_interval, - hysteresis=hysteresis, - max_scale=max_scale, - clip_grad_norm=clip_grad_norm, - verbose=verbose, - reduce_bucket_size=reduce_bucket_size, - communication_dtype=communication_dtype, - overlap_communication=overlap_communication, - partition_grad=partition_grad, - cpu_offload=cpu_offload, - dp_process_group=dp_process_group, - forced_dtype=forced_dtype, - overlap_allgather=overlap_allgather, - ) - - def sync_dp_grads(self): - r""" - Synchronize gradients in the data parallelism dimension. - - This method wraps the existing `_sync_grad` method in order to explicitly synchronize gradients - in the data parallelism dimension. It is necessary due to the introduction of new parallel dimensions, - namely tp (tensor parallelism) and pp (pipeline parallelism). This ensures better code organization - and readability. - - Args: - None - - Returns: - None - """ - # Call the superclass `_sync_grad` method to synchronize gradients. - super()._sync_grad() - - def _sync_sp_grads(self): - r""" - Synchronize gradients that are partially derived within sequence parallelism. - - This method is responsible for synchronizing partially derived gradients across tp parallelism groups. - It identifies gradients that ara partially derived or not and synchronizes them. - If synchronization is required and gradients are found to be synchronized, - it performs the synchronization. - - Args: - None - - Returns: - None - """ - - def _get_all_working_grads() -> List[Tensor]: - """Retrieve all working gradients from different parameter groups.""" - all_working_grads = [] - for group_id in range(self.num_param_groups): - working_grads = self.get_working_grads_by_group_id(group_id) - all_working_grads.extend(working_grads) - return all_working_grads - - def _get_grads_to_sync(all_working_grads) -> Union[List[Tensor], None]: - """Identify gradients to be synchronized in the sequence parallelism.""" - grads_to_sync = [] - for grad in all_working_grads: - param_id_for_grad = self.get_param_id_for_grad(grad) - param_for_grad = ctypes.cast(param_id_for_grad, ctypes.py_object).value - if SeqParallelUtils.is_sp_partial_derived_param(param_for_grad): - grads_to_sync.append(grad) - - if len(grads_to_sync) > 0: - return grads_to_sync - else: - return None - - # Get all working gradients and gradients to be synchronized. - all_working_grads = _get_all_working_grads() - grads_to_sync = _get_grads_to_sync(all_working_grads) - if self.require_grad_sync and grads_to_sync is not None: - # Synchronize sequence parallelism gradients if required. - SeqParallelUtils.allreduce_partial_data_grad(process_group=self.tp_pg, grads=grads_to_sync) - else: - return - - def backward(self, loss, retain_graph=False): - """ - Backpropagate gradients through the model and optionally synchronize sequence parallelism gradients. - - This method performs the backward pass for gradient computation based on a given loss tensor. - If sequence parallelism is enabled and gradient synchronization is required, it will synchronize - gradients that are partially derived within sequence parallelism across TP parallelism groups. - - Args: - loss: The loss tensor to compute gradients with respect to. - retain_graph (bool): Whether to retain the computation graph. - - Returns: - None - """ - # Call the superclass backward method to compute gradients. - super().backward(loss, retain_graph) - - if self.require_grad_sync and self.model.shard_config.enable_sequence_parallelism: - # If gradient synchronization is required, sync sequence parallelism gradients. - self._sync_sp_grads() - else: - # If gradient synchronization is is not required, return. - return - - def backward_by_grad(self, tensor, grad): - """ - Backpropagate gradients through the model using a precomputed gradient and optionally synchronize sequence parallelism gradients. - - This method performs a backward pass for gradient computation based on a precomputed gradient tensor. - If sequence parallelism is enabled and gradient synchronization is required, it will synchronize - gradients that are partially derived within sequence parallelism across TP parallelism groups. - - Args: - tensor: The input tensor for which gradients are computed. - grad: The precomputed gradient tensor to compute gradients with respect to the input tensor. - - Returns: - None - """ - # Call the superclass backward_by_grad method to compute gradients. - super().backward_by_grad(tensor, grad) - - if self.require_grad_sync and self.model.shard_config.enable_sequence_parallelism: - # If gradient synchronization is required, sync sequence parallelism gradients. - self._sync_sp_grads() - else: - # If gradient synchronization is is not required, return. - return - - def _compute_grad_norm(self, dp_pg, gradients: List[Tensor], norm_type: int = 2) -> float: - r""" - Compute and return the gradient norm for gradient clipping. - - Args: - gradients (List[Tensor]): A list of tensors containing gradients. - norm_type (int, optional): Type of the p-norm to be computed. Defaults to 2. - - Returns: - float: The computed gradient norm. - """ - - # Check if the list of gradients is empty - if len(gradients) == 0: - return 0.0 - - dp_size = get_world_size(dp_pg) if dp_pg is not None else 1 - tp_size = get_world_size(self.tp_pg) if self.tp_pg is not None else 1 - pp_size = get_world_size(self.pp_pg) if self.pp_pg is not None else 1 - norm_type = float(norm_type) - - if norm_type == inf: - # The parent class calculates the norm of 'dp' gradients, - # so we only need to calculate the norm 'tp' of 'pp' gradients. - total_norm = super()._compute_grad_norm(gradients, norm_type) - - total_norm_cuda = torch.tensor( - [float(total_norm)], device=get_accelerator().get_current_device(), dtype=torch.float32 - ) - - if tp_size > 1: - dist.all_reduce(tensor=total_norm_cuda, op=dist.ReduceOp.MAX, group=self.tp_pg) - if pp_size > 1: - dist.all_reduce(tensor=total_norm_cuda, op=dist.ReduceOp.MAX, group=self.pp_pg) - - total_norm = total_norm_cuda.item() - else: - total_norm_exponentiated = 0.0 - for grad in gradients: - grad_norm_exponentiated = grad.data.double().norm(norm_type) ** norm_type - - # If 'tp_size' is greater than 1 and the parameter for the gradient is not a distributed tensor, - # it indicates that the parameter is not distributed across devices of the 'tp_group'. - # Consequently, there is no need to perform an 'all_reduce' operation for 'grad_norm'. - # However, we still perform the 'all_reduce' operation for the sake of good coding practices. - # To ensure mathematical equivalence, we divide the 'grad_norm' by 'tp_size.' - if tp_size > 1: - param_id_for_grad = self.get_param_id_for_grad(grad) - param_for_grad = ctypes.cast(param_id_for_grad, ctypes.py_object).value - - if not is_distributed_tensor(param_for_grad): - grad_norm_exponentiated /= tp_size - - # If 'pp_size' is greater than 1 and the gradient belongs to shared parameters, - # it means that this parameter is used in two different pipeline stages. - # To avoid redundant norm calculations, we divide the exponent of this norm by - # the number of shared stages. - if pp_size > 1: - for shared_param in self.shared_params: - if self.stage_manager.stage in shared_param: - stage_shared_param = shared_param[self.stage_manager.stage] - working_grad = self.get_working_grad_by_param_id(id(stage_shared_param)) - if grad is working_grad: - grad_norm_exponentiated /= len(shared_param) - - total_norm_exponentiated += grad_norm_exponentiated - - total_norm_exponentiated_cuda = torch.tensor( - [float(total_norm_exponentiated)], device=get_accelerator().get_current_device(), dtype=torch.float32 - ) - if dp_size > 1: - # compute norm in dp process group - dist.all_reduce(tensor=total_norm_exponentiated_cuda, op=dist.ReduceOp.SUM, group=dp_pg) - if tp_size > 1: - # compute norm in tp process group - dist.all_reduce(tensor=total_norm_exponentiated_cuda, op=dist.ReduceOp.SUM, group=self.tp_pg) - if pp_size > 1: - # compute norm in pp process group - dist.all_reduce(tensor=total_norm_exponentiated_cuda, op=dist.ReduceOp.SUM, group=self.pp_pg) - - # Compute the 'total_norm' from 'total_norm_exponentiated' - total_norm = total_norm_exponentiated_cuda.item() ** (1.0 / norm_type) - - return total_norm - - -class HybridParallelPlugin(PipelinePluginBase): - """ - Plugin for Hybrid Parallel Training. - Tensor parallel, pipeline parallel and data parallel(DDP/ZeRO) can be picked and combined in this plugin. - The size of tp and pp should be passed in by user, then the size of dp is automatically calculated from dp_size = world_size / (tp_size * pp_size). - - ```python - from colossalai.booster import Booster - from colossalai.booster.plugin import HybridParallelPlugin - - model, train_dataset, optimizer, criterion = ... - plugin = HybridParallelPlugin(tp_size=2, pp_size=2) - - train_dataloader = plugin.prepare_dataloader(train_dataset, batch_size=8) - booster = Booster(plugin=plugin) - model, optimizer, criterion, train_dataloader, _ = booster.boost(model, optimizer, criterion, train_dataloader) - ``` - - Args: - tp_size (int): The size of tensor parallelism. Tensor parallelism will not be used when tp_size is set to 1. - pp_size (int): The number of pipeline stages in pipeline parallelism. Pipeline parallelism will not be used when pp_size is set to 1. - sp_size (int): The size of sequence parallelism. - precision (str, optional): Specifies the precision of parameters during training. - Auto-mixied precision will be used when this argument is set to 'fp16' or 'bf16', otherwise model is trained with 'fp32'. - Defaults to 'fp16'. - zero_stage (int, optional): The stage of ZeRO for data parallelism. Can only be choosed from [0, 1, 2]. - When set to 0, ZeRO will not be used. Defaults to 0. - enable_all_optimization (bool, optional): Whether to switch on all the optimizations supported by Shardformer. - Currently all the optimization methods include fused normalization, flash attention and JIT. - Defaults to False. - enable_fused_normalization (bool, optional): Whether to switch on fused normalization in Shardformer. Defaults to False. - enable_flash_attention (bool, optional): Whether to switch on flash attention in Shardformer. Defaults to False. - enable_jit_fused (bool, optional): Whether to switch on JIT in Shardformer. Default to False. - enable_sequence_parallelism (bool): Whether to turn on sequence parallelism in Shardformer. Defaults to False. - sequence_parallelism_mode (str): The Sequence parallelism mode. Can only be choosed from ["split_gather", "ring", "all_to_all"]. Defaults to "split_gather". - enable_sequence_overlap (bool): Whether to turn on sequence overlap in Shardformer. Defaults to False. - parallel_output (bool): Whether to keep the output parallel when enabling tensor parallelism. Default to True. - num_microbatches (int, optional): Number of microbatches when using pipeline parallelism. Defaults to None. - microbatch_size (int, optional): Microbatch size when using pipeline parallelism. - Either ``num_microbatches`` or ``microbatch_size`` should be provided if using pipeline. - If ``num_microbatches`` is provided, this will be ignored. Defaults to None. - initial_scale (float, optional): The initial loss scale of AMP. Defaults to 2**16. - min_scale (float, optional): The minimum loss scale of AMP. Defaults to 1. - growth_factor (float, optional): The multiplication factor for increasing loss scale when using AMP. Defaults to 2. - backoff_factor (float, optional): The multiplication factor for decreasing loss scale when using AMP. Defaults to 0.5. - growth_interval (int, optional): The number of steps to increase loss scale when no overflow occurs when using AMP. Defaults to 1000. - hysteresis (int, optional): The number of overflows before decreasing loss scale when using AMP. Defaults to 2. - max_scale (float, optional): The maximum loss scale of AMP. Defaults to 2**32. - max_norm (float, optional): Maximum norm for gradient clipping. Defaults to 0. - broadcast_buffers (bool, optional): Whether to broadcast buffers in the beginning of training when using DDP. Defaults to True. - ddp_bucket_cap_mb (int, optional): The bucket size in MB when using DDP. Defaults to 25. - find_unused_parameters (bool, optional): Whether to find unused parameters when using DDP. Defaults to False. - check_reduction (bool, optional): Whether to check reduction when using DDP. Defaults to False. - gradient_as_bucket_view (bool, optional): Whether to use gradient as bucket view when using DDP. Defaults to False. - static_graph (bool, optional): Whether to use static graph when using DDP. Defaults to False. - zero_bucket_size_in_m (int, optional): Gradient reduce bucket size in million elements when using ZeRO. Defaults to 12. - cpu_offload (bool, optional): Whether to open cpu_offload when using ZeRO. Defaults to False. - communication_dtype (torch.dtype, optional): Communication dtype when using ZeRO. If not specified, the dtype of param will be used. Defaults to None. - overlap_communication (bool, optional): Whether to overlap communication and computation when using ZeRO. Defaults to True. - custom_policy (Policy, optional): Custom policy for Shardformer. Defaults to None. - pp_style (str, optional): The style for pipeline parallelism. Defaults to '1f1b'. - num_model_chunks (int, optional): The number of model chunks for interleaved pipeline parallelism. Defaults to 1. - gradient_checkpoint_config (GradientCheckpointConfig, optional): Configuration for gradient checkpointing. Defaults to None. - enable_metadata_cache (bool, optional): Whether to enable metadata cache for pipeline parallelism. Defaults to True. - make_vocab_size_divisible_by (int, optional): it's used when padding the vocabulary size, to make it choose an faster kenel. Default to 64. - overlap_p2p (bool, optional): Whether to overlap the p2p communication in pipeline parallelism - inner_ring_size (int, optional): The inner ring size of 2D Ring Attention when sp mode is "ring_attn". - It's advisable to not tune this (especially in single-node settings) and let it be heuristically set based on topology by default. - - """ - - def __init__( - self, - tp_size: int, - pp_size: int, - sp_size: int = None, - precision: str = "fp16", - zero_stage: int = 0, - enable_all_optimization: bool = False, - enable_fused_normalization: bool = False, - enable_flash_attention: bool = False, - enable_jit_fused: bool = False, - enable_sequence_parallelism: bool = False, - sequence_parallelism_mode: str = None, - enable_sequence_overlap: bool = False, - parallel_output: bool = True, - num_microbatches: Optional[int] = None, - microbatch_size: Optional[int] = None, - initial_scale: float = 2**16, - min_scale: float = 1, - growth_factor: float = 2, - backoff_factor: float = 0.5, - growth_interval: int = 1000, - hysteresis: int = 2, - max_scale: float = 2**32, - max_norm: float = 0, - broadcast_buffers: bool = True, - ddp_bucket_cap_mb: int = 25, - find_unused_parameters: bool = False, - check_reduction: bool = False, - gradient_as_bucket_view: bool = False, - static_graph: bool = False, - zero_bucket_size_in_m: int = 12, - cpu_offload: bool = False, - communication_dtype: Optional[torch.dtype] = None, - overlap_communication: bool = True, - custom_policy: Policy = None, - pp_style: str = "1f1b", - num_model_chunks: int = 1, - num_layers_per_stage: Optional[List[int]] = None, - gradient_checkpoint_config: Optional[GradientCheckpointConfig] = None, - enable_metadata_cache: bool = True, - make_vocab_size_divisible_by: int = 64, - dp_outside: bool = True, - overlap_p2p: bool = True, - overlap_allgather: bool = False, - inner_ring_size: int = None, - ) -> None: - super().__init__() - self.logger = get_dist_logger() - - assert ( - dist.get_world_size() % (tp_size * pp_size) == 0 - ), f"World size {dist.get_world_size()} is not divisible by tp_size {tp_size} * pp_size {pp_size}" - - if enable_sequence_parallelism: - self.sequence_parallelism_mode = ( - sequence_parallelism_mode if sequence_parallelism_mode is not None else "all_to_all" - ) - assert ( - self.sequence_parallelism_mode in SUPPORT_SP_MODE - ), f"Sequence parallelism mode {self.sequence_parallelism_mode} is not in the supported list {SUPPORT_SP_MODE}" - if self.sequence_parallelism_mode in ["split_gather", "ring"]: - assert ( - tp_size > 1 - ), f"Sequence parallelism mode {self.sequence_parallelism_mode} must be enabled when using tensor parallelism" - if sp_size != 1: - self.logger.warning( - f"The sp_size will be the same as tp_size in sequence parallelism mode {self.sequence_parallelism_mode}, will ignore the given sequence parallelism size." - ) - self.sp_size = 1 - self.dp_size = dist.get_world_size() // (tp_size * pp_size) - elif self.sequence_parallelism_mode in ["all_to_all", "ring_attn"]: - self.sp_size = 1 if sp_size is None else sp_size - self.dp_size = dist.get_world_size() // (self.sp_size * pp_size * tp_size) - if self.sequence_parallelism_mode == "ring_attn": - enable_flash_attention = True - else: - self.dp_size = dist.get_world_size() // (tp_size * pp_size) - assert ( - sp_size == 1 or sp_size is None - ), f"You should not set sp_size when sequence parallelism is not enabled." - self.sp_size = 1 - - self.tp_size = tp_size - self.pp_size = pp_size - self.precision = precision - self.zero_stage = zero_stage - self.cpu_offload = cpu_offload - self.enable_all_optimization = enable_all_optimization - self.enable_fused_normalization = enable_fused_normalization - self.enable_flash_attention = enable_flash_attention - self.enable_jit_fused = enable_jit_fused - self.enable_sequence_parallelism = enable_sequence_parallelism - if dp_outside: - self.dp_axis, self.pp_axis, self.tp_axis, self.sp_axis = 0, 1, 2, 3 - if sequence_parallelism_mode == "ring_attn": - # Swap tp and sp since 2D Ring has better inter-node latency - self.pg_mesh = ProcessGroupMesh(self.dp_size, self.pp_size, self.sp_size, self.tp_size) - self.sp_axis = 2 - self.tp_axis = 3 - else: - self.pg_mesh = ProcessGroupMesh(self.dp_size, self.pp_size, self.tp_size, self.sp_size) - else: - self.pp_axis, self.dp_axis, self.tp_axis, self.sp_axis = 0, 1, 2, 3 - if sequence_parallelism_mode == "ring_attn": - self.pg_mesh = ProcessGroupMesh(self.pp_size, self.dp_size, self.sp_size, self.tp_size) - self.sp_axis = 2 - self.tp_axis = 3 - else: - self.pg_mesh = ProcessGroupMesh(self.pp_size, self.dp_size, self.tp_size, self.sp_size) - - self.stage_manager = None - self.schedule = None - self.custom_policy = custom_policy - assert zero_stage in (0, 1, 2) - if self.pp_size > 1: - assert pp_style in ["1f1b", "interleaved"], "Unsupported pipeline parallelism style" - assert pp_style == "interleaved" or num_model_chunks == 1, "num_model_chunks must be 1 when using 1f1b" - assert ( - num_microbatches is not None or microbatch_size is not None - ), "num_microbatches or microbatch_size must be specified when using pipeline parallelism" - assert ( - self.zero_stage <= 1 - ), "To avoid prohibitive gradient synchronization costs, zero stage must be 0 or 1 when using pipeline parallelism" - self.stage_manager = PipelineStageManager( - self.pg_mesh, - pipeline_axis=self.pp_axis, - enable_interleave=pp_style == "interleaved", - num_model_chunks=num_model_chunks, - num_layers_per_stage=num_layers_per_stage, - ) - - if pp_style == "interleaved": - assert num_model_chunks > 1, "number of model chunks must be > 1 when using interleaved" - self.schedule = InterleavedSchedule( - stage_manager=self.stage_manager, - num_model_chunks=num_model_chunks, - num_microbatch=num_microbatches, - microbatch_size=microbatch_size, - enable_metadata_cache=enable_metadata_cache, - overlap_p2p=overlap_p2p, - ) - elif pp_style == "1f1b": - self.schedule = OneForwardOneBackwardSchedule( - stage_manager=self.stage_manager, - num_microbatches=num_microbatches, - microbatch_size=microbatch_size, - enable_metadata_cache=enable_metadata_cache, - ) - else: - raise NotImplementedError() - if sequence_parallelism_mode == "ring_attn": - if not parallel_output: - self.logger.warning( - "parallel_output must be True for Zigzag Ring Attention, as we've not supported Zigzag all-gather yet." - ) - parallel_output = True - - self.tp_group = self.pg_mesh.get_group_along_axis(self.tp_axis) - self.dp_group = self.pg_mesh.get_group_along_axis(self.dp_axis) - self.pp_group = self.pg_mesh.get_group_along_axis(self.pp_axis) - if self.enable_sequence_parallelism and self.sequence_parallelism_mode in ["split_gather", "ring"]: - self.sp_group = self.pg_mesh.get_group_along_axis(self.tp_axis) - else: - self.sp_group = self.pg_mesh.get_group_along_axis(self.sp_axis) - - self.shard_config = ShardConfig( - tensor_parallel_process_group=self.tp_group, - sequence_parallel_process_group=self.sp_group, - pipeline_stage_manager=self.stage_manager, - enable_tensor_parallelism=self.tp_size > 1, - enable_all_optimization=self.enable_all_optimization, - enable_fused_normalization=self.enable_fused_normalization, - enable_flash_attention=self.enable_flash_attention, - enable_jit_fused=self.enable_jit_fused, - enable_sequence_parallelism=enable_sequence_parallelism, - sequence_parallelism_mode=sequence_parallelism_mode, - enable_sequence_overlap=enable_sequence_overlap, - parallel_output=parallel_output, - make_vocab_size_divisible_by=make_vocab_size_divisible_by, - gradient_checkpoint_config=gradient_checkpoint_config, - inner_ring_size=inner_ring_size, - ) - self.amp_config = dict( - initial_scale=initial_scale, - growth_factor=growth_factor, - backoff_factor=backoff_factor, - growth_interval=growth_interval, - hysteresis=hysteresis, - min_scale=min_scale, - max_scale=max_scale, - ) - - self.ddp_config = dict( - broadcast_buffers=broadcast_buffers, - bucket_cap_mb=ddp_bucket_cap_mb, - find_unused_parameters=find_unused_parameters, - check_reduction=check_reduction, - gradient_as_bucket_view=gradient_as_bucket_view, - static_graph=static_graph, - ) - - self.zero_config = dict( - reduce_bucket_size=zero_bucket_size_in_m * 1024 * 1024, - communication_dtype=communication_dtype, - overlap_communication=overlap_communication, - cpu_offload=cpu_offload, - partition_grad=(self.zero_stage == 2), - forced_dtype=PRECISION_TORCH_TYPE[precision], - overlap_allgather=overlap_allgather, - ) - - self.max_norm = max_norm - - def __del__(self): - """Destroy the process groups in ProcessGroupMesh""" - self.pg_mesh.destroy_mesh_process_groups() - - @property - def enable_pipeline_parallelism(self) -> bool: - return self.pp_size > 1 - - def supported_devices(self) -> List[str]: - return ["cuda", "npu"] - - def supported_precisions(self) -> List[str]: - return ["fp16", "bf16", "fp32"] - - def control_device(self) -> bool: - return True - - def control_precision(self) -> bool: - return True - - def support_no_sync(self) -> bool: - return True - - def support_lora(self) -> bool: - return True - - def control_checkpoint_io(self) -> bool: - return True - - def configure( - self, - model: Module, - optimizer: Optional[Optimizer] = None, - criterion: Optional[Callable] = None, - dataloader: Optional[DataLoader] = None, - lr_scheduler: Optional[LRScheduler] = None, - ) -> Tuple[Module, OptimizerWrapper, Callable, DataLoader, LRScheduler]: - param_info = get_param_info(optimizer) - - # TODO: Support Galore + ZeRO - zero_stage = self.zero_stage - zero_config = deepcopy(self.zero_config) - - # Replace with distributed implementation if exists - optimizer = cast_to_distributed(optimizer) - - if isinstance(optimizer, DistGaloreAwamW) and zero_stage > 0 and self.dp_size > 0: - self.logger.warning( - "Galore is only supported for Tensor Parallel and vanilla Data Parallel yet. Disabling ZeRO." - ) - zero_config["partition_grad"] = False - zero_stage = 0 - - if not isinstance(model, ModelWrapper): - # Shouldn't use pp (frequent grad accumulation) with torch ddp - use_ddp = (self.dp_size > 1 and self.pp_size == 1 and self.zero_stage == 0) or ( - self.dp_size == 1 and self.pp_size == 1 - ) - - # Apply Hybrid ZeRO across DP * SP ranks - if self.enable_sequence_parallelism and not is_share_sp_tp(self.sequence_parallelism_mode): - dp_group = self.pg_mesh.create_group_along_axis([self.dp_axis, self.sp_axis]) - self.dp_size = get_world_size(dp_group) - else: - dp_group = self.dp_group - model = HybridParallelModule( - model, - precision=self.precision, - shard_config=self.shard_config, - dp_group=dp_group, - tp_group=self.tp_group, - sp_group=self.sp_group, - use_ddp=use_ddp, - ddp_config=self.ddp_config, - custom_policy=self.custom_policy, - overlap_allgather=(self.zero_stage > 0 and self.zero_config["overlap_allgather"]), - ) - if optimizer is not None and not isinstance(optimizer, OptimizerWrapper): - if zero_stage == 0: - is_zero = False - if self.precision in ["fp16", "bf16"]: - optimizer = HybridParallelAMPOptimizer( - optimizer, - model, - use_pipeline=self.enable_pipeline_parallelism, - param_info=param_info, - precision=self.precision, - max_norm=self.max_norm, - pp_process_group=self.pp_group, - tp_process_group=self.tp_group, - **self.amp_config, - ) - else: - optimizer = HybridParallelNaiveOptimizer( - optimizer, - model, - use_pipeline=self.enable_pipeline_parallelism, - param_info=param_info, - max_norm=self.max_norm, - pp_process_group=self.pp_group, - tp_process_group=self.tp_group, - ) - else: - is_zero = self.dp_size > 1 - if self.dp_size == 1: - self.logger.warning( - "Use Zero Optimizer when data parallel size is 1 may introduce unnecessary overhead. " - "If you do not intend to use cpu_offload, please consider set zero_stage=0." - ) - - assert self.precision != "fp32", "Please set precision to 'fp16' or 'bf16' when using ZeRO." - optimizer = HybridParallelZeroOptimizer( - optimizer, - model, - use_pipeline=self.enable_pipeline_parallelism, - param_info=param_info, - dp_process_group=dp_group, - tp_process_group=self.tp_group, - pp_process_group=self.pp_group, - verbose=True, - clip_grad_norm=self.max_norm, - **zero_config, - **self.amp_config, - ) - # inject update_master_params - model.update_master_params = MethodType(optimizer.update_master_params, model) - - # Setup optimizers that require global states - optim = optimizer.optim - if isinstance(optim, DistributedOptim): - shard_to_param = optimizer.get_master_to_working_map() if is_zero else {} - padding_map = optimizer.get_param_padding_map() if is_zero else defaultdict(int) - optim.setup_distributed(self.tp_group, self.dp_group, shard_to_param, padding_map, is_zero) - - return model, optimizer, criterion, dataloader, lr_scheduler - - def execute_pipeline( - self, - data_iter: Iterator, - model: HybridParallelModule, - criterion: Callable[[Any, Any], torch.Tensor], - optimizer: Optional[ - Union[HybridParallelNaiveOptimizer, HybridParallelAMPOptimizer, HybridParallelZeroOptimizer] - ] = None, - return_loss: bool = True, - return_outputs: bool = False, - ) -> dict: - assert self.enable_pipeline_parallelism, "pipeline parallelism is not enabled" - - if return_outputs: - self.logger.warning("return_outputs may lead to significant extra memory consumption.") - - # Create a context for gradient synchronization based on the optimizer type. - # If it's a HybridParallelZeroOptimizer, use optimizer.no_sync(); otherwise, use model.no_sync(). - # This is to avoid redundant gradient reduction in pipeline parallelism (multiple microbatch values should be reduced once), - # so we disable it, performing manual reduction instead. - ctx = optimizer.no_sync() if isinstance(optimizer, HybridParallelZeroOptimizer) else model.no_sync() - - with ctx, model._wait_all_gather(): - outputs = self.schedule.forward_backward_step( - model, data_iter, criterion, optimizer, return_loss, return_outputs - ) - - # run with gradients accumulation - if model.require_grad_sync == False or ( - isinstance(optimizer, HybridParallelZeroOptimizer) and optimizer.require_grad_sync == False - ): - return outputs - - # Synchronize the grads of shared parameters of the model. - model.sync_shared_params() - # Synchronize sequence parallelism gradients of the model. - model.sync_sp_grads() - - # Check if the optimizer is a HybridParallelZeroOptimizer and synchronize data parallelism gradients if so. - # Otherwise, synchronize data parallelism gradients of the model. - # This is because these are two different forms of data parallelism. - if isinstance(optimizer, HybridParallelZeroOptimizer): - optimizer.sync_dp_grads() - else: - model.sync_dp_grads() - - return outputs - - def prepare_dataloader( - self, - dataset, - batch_size, - shuffle=False, - seed=1024, - drop_last=False, - pin_memory=False, - num_workers=0, - distributed_sampler_cls=None, - **kwargs, - ): - r""" - Prepare a dataloader for distributed training. The dataloader will be wrapped by - `torch.utils.data.DataLoader` and `torch.utils.data.DistributedSampler`. - - - Args: - dataset (`torch.utils.data.Dataset`): The dataset to be loaded. - shuffle (bool, optional): Whether to shuffle the dataset. Defaults to False. - seed (int, optional): Random worker seed for sampling, defaults to 1024. - add_sampler: Whether to add ``DistributedDataParallelSampler`` to the dataset. Defaults to True. - drop_last (bool, optional): Set to True to drop the last incomplete batch, if the dataset size - is not divisible by the batch size. If False and the size of dataset is not divisible by - the batch size, then the last batch will be smaller, defaults to False. - pin_memory (bool, optional): Whether to pin memory address in CPU memory. Defaults to False. - num_workers (int, optional): Number of worker threads for this dataloader. Defaults to 0. - kwargs (dict): optional parameters for ``torch.utils.data.DataLoader``, more details could be found in - `DataLoader `_. - - Returns:` - :class:`torch.utils.data.DataLoader`: A DataLoader used for training or testing. - """ - _kwargs = kwargs.copy() - distributed_sampler_cls = distributed_sampler_cls or DistributedSampler - sampler = distributed_sampler_cls( - dataset, - num_replicas=self.dp_group.size(), - rank=dist.get_group_rank(self.dp_group, global_rank=dist.get_rank()), - shuffle=shuffle, - ) - - # Deterministic dataloader - def seed_worker(worker_id): - worker_seed = seed - np.random.seed(worker_seed) - torch.manual_seed(worker_seed) - random.seed(worker_seed) - - return DataLoader( - dataset, - batch_size=batch_size, - sampler=sampler, - worker_init_fn=seed_worker, - drop_last=drop_last, - pin_memory=pin_memory, - num_workers=num_workers, - **_kwargs, - ) - - def get_checkpoint_io(self) -> CheckpointIO: - return HybridParallelCheckpointIO(self.dp_group, self.pp_group, self.tp_group, self.zero_stage) - - def no_sync(self, model: Module, optimizer: OptimizerWrapper) -> Iterator[None]: - assert ( - self.zero_stage != 2 - ), "ZERO2 is not compatible with no_sync function, please run gradient accumulation with gradient synchronization allowed." - return optimizer.no_sync() if isinstance(optimizer, HybridParallelZeroOptimizer) else model.no_sync() - - def enable_lora( - self, - model: Module, - pretrained_dir: Optional[str] = None, - lora_config: Optional[Dict] = None, - bnb_quantization_config: Optional[BnbQuantizationConfig] = None, - ) -> Module: - from peft import PeftModel, get_peft_model - - assert not isinstance(model, HybridParallelModule), "Lora should be enabled before boosting the model." - assert self.pp_size == 1 and self.tp_size == 1 - self.lora_enabled = True - self.logger.warning("You have enabled LoRa training. Please check the hyperparameters such as lr") - - if bnb_quantization_config is not None: - model = quantize_model(model, bnb_quantization_config) - - if pretrained_dir is None: - peft_model = get_peft_model(model, lora_config) - else: - peft_model = PeftModel.from_pretrained(model, pretrained_dir, is_trainable=True) - return peft_model diff --git a/colossalai/booster/plugin/plugin/low_level_zero_plugin.py b/colossalai/booster/plugin/plugin/low_level_zero_plugin.py deleted file mode 100644 index 6c36bad3c214..000000000000 --- a/colossalai/booster/plugin/plugin/low_level_zero_plugin.py +++ /dev/null @@ -1,521 +0,0 @@ -import enum -import os -from contextlib import nullcontext -from functools import partial -from pathlib import Path -from types import MethodType -from typing import Callable, Dict, Iterator, List, Optional, Tuple - -import torch -import torch.distributed -import torch.distributed as dist -import torch.nn as nn -from torch.distributed.distributed_c10d import _get_default_group -from torch.nn import Parameter -from torch.optim import Optimizer -from torch.optim.lr_scheduler import _LRScheduler as LRScheduler -from torch.utils._pytree import tree_map -from torch.utils.data import DataLoader - -from colossalai.accelerator import get_accelerator -from colossalai.checkpoint_io import CheckpointIndexFile, CheckpointIO -from colossalai.checkpoint_io.utils import ( - get_optimizer_base_filenames, - get_shard_filename, - load_param_groups_into_optimizer, - load_shard_state_dict, - load_states_into_optimizer, - save_param_groups, - save_state_dict, - sharded_optimizer_loading_epilogue, -) -from colossalai.interface import AMPModelMixin, ModelWrapper, OptimizerWrapper -from colossalai.interface.optimizer import DistributedOptim -from colossalai.logging import get_dist_logger -from colossalai.nn.optimizer import DistGaloreAwamW, cast_to_distributed -from colossalai.quantization import BnbQuantizationConfig, quantize_model -from colossalai.tensor.colo_parameter import ColoParameter -from colossalai.tensor.param_op_hook import ColoParamOpHookManager -from colossalai.zero import LowLevelZeroOptimizer -from colossalai.zero.low_level.zero_hook import ZeroOpHook, wait_all_gather_handle - -from .dp_plugin_base import DPPluginBase -from .torch_ddp_plugin import TorchDDPCheckpointIO - -__all__ = ["LowLevelZeroPlugin"] - - -def _convert_floating_point(x, dtype: torch.dtype = torch.float16): - if isinstance(x, torch.Tensor) and torch.is_floating_point(x): - return x.to(dtype) - return x - - -SUPPORTED_PRECISION = ["fp16", "bf16", "fp32"] - - -class OptimizerParamCheckState(enum.Enum): - ORIGIN_PARAM_FINDED = 0 - ORIGIN_PARAM_NOT_FIND = -1 - LORA_PARM_EXISTED = -2 - - -class LowLevelZeroModel(ModelWrapper, AMPModelMixin): - def __init__(self, module: nn.Module, precision: str, overlap_allgather: bool = False) -> None: - super().__init__(module) - self.dtype = None - if precision == "fp16": - self.dtype = torch.float16 - elif precision == "bf16": - self.dtype = torch.bfloat16 - if self.dtype is not None: - module = module.to(self.dtype) - module = module.to(get_accelerator().get_current_device()) - self.module = module - self.convert_fn = None - if self.dtype is not None: - self.convert_fn = partial(_convert_floating_point, dtype=self.dtype) - self.overlap_allgather = overlap_allgather - if overlap_allgather: - self.op_hook = ZeroOpHook() - for p in module.parameters(): - if p.requires_grad and type(p) is not ColoParameter: - p.__class__ = ColoParameter - p.__init__(p, requires_grad=True) - - def forward(self, *args, **kwargs): - if self.convert_fn is not None: - args = tree_map(self.convert_fn, args) - kwargs = tree_map(self.convert_fn, kwargs) - ctx = ColoParamOpHookManager.use_hooks(self.op_hook) if self.overlap_allgather else nullcontext() - with ctx: - return super().forward(*args, **kwargs) - - def _force_wait_all_gather(self): - for p in self.module.parameters(): - wait_all_gather_handle(p) - - -class LowLevelZeroCheckpointIO(TorchDDPCheckpointIO): - def save_unsharded_optimizer(self, optimizer: OptimizerWrapper, checkpoint: str, gather_dtensor: bool = False): - """Save optimizer to checkpoint but only on master process. - - Args: - optimizer (OptimizerWrapper): Optimizer to save state_dict - checkpoint (str): Path to save checkpoint - gather_dtensor (bool): Whether to gather_dtensor, not used - """ - assert isinstance(optimizer, LowLevelZeroOptimizer), "Please boost the optimizer before saving!" - # the `state_dict` in LowLevelZeroOptimizer has communication - # if only the master rank collect state_dict and save, - # the communication on each rank would not match - state_dict = optimizer.state_dict() - if self.coordinator.is_master(): - save_state_dict(state_dict, checkpoint, use_safetensors=False) - self.logger = get_dist_logger() - - def save_sharded_optimizer( - self, - optimizer: OptimizerWrapper, - checkpoint: str, - gather_dtensor: bool = False, - prefix: str = None, - size_per_shard: int = 1024, - ): - """ - Save sharded Zero-optimizer checkpoint under the given checkpointing path. - The following files will be created under the path: - - An index file (pytorch_optim.bin.index.json) containing a map between optimizer states and file names - - A group file (pytorch_optim_group.bin) recording information of param_groups - - Multiple files (pytorch_optim-000XX.bin) that store state tensors of optimizer in a sharding way - - Args: - optimizer (OptimizerWrapper): Optimizer to save sharded state_dict - checkpoint (str): Path to save optimizer state_dict - gather_dtensor (bool): Whether to gather_dtensor, not used - prefix (str): Perfix of file to save - size_per_shard (int): Max file size of each file that store state tensors - """ - assert isinstance(optimizer, LowLevelZeroOptimizer), "Please boost the optimizer before saving!" - if os.path.isfile(checkpoint): - self.logger.error(f"Provided path ({checkpoint}) should be a directory, not a file") - return - - Path(checkpoint).mkdir(parents=True, exist_ok=True) - - # state_dict only provide only 'param_groups' - state_dict = optimizer.optim.state_dict() - # state shard would be handled by the low-level zero optimizer - sharded_state = optimizer.state_dict_shard(max_shard_size=size_per_shard) - - # Preparing file paths and index file. - states_name, save_index_file, param_group_file = get_optimizer_base_filenames(prefix) - index_file = CheckpointIndexFile(checkpoint) - index_file.append_meta_data("param_groups", param_group_file) - - # Store the information of param groups to param_group_file. - if self.coordinator.is_master(): - group_file_path = os.path.join(checkpoint, param_group_file) - save_param_groups(state_dict, group_file_path) - - # Save shards of optimizer states. - total_size = 0 - for idx, shard_pair in enumerate(sharded_state): - shard, current_size = shard_pair - shard_file = get_shard_filename(states_name, idx) - total_size = total_size + current_size - for param_id in shard.keys(): - index_file.append_weight_map(str(param_id), shard_file) - - checkpoint_file_path = os.path.join(checkpoint, shard_file) - if self.coordinator.is_master(): - save_state_dict(shard, checkpoint_file_path, use_safetensors=False) - - # Wrap up index file. - index_file.append_meta_data("total_size", total_size) - if self.coordinator.is_master(): - index_file.write_index_file(save_index_file) - self.logger.info( - f"The optimizer is going to be split to checkpoint shards. " - f"You can find where each parameters has been saved in the " - f"index located at {save_index_file}." - ) - - def load_sharded_optimizer(self, optimizer: OptimizerWrapper, index_file_path: str, prefix: str): - """Load sharded optimizer with the given path to index file. - - Args: - optimizer (OptimizerWrapper): Optimizer to load state_dict - index_file_path (str): Path to the index file - prefix (str): Not used. - """ - assert isinstance(optimizer, LowLevelZeroOptimizer), "Please boost the optimizer before Loading!" - optimizer = optimizer.unwrap() - - # Read checkpoint index file. - ckpt_index_file = CheckpointIndexFile.from_file(index_file_path) - - # Load param_groups - param_group_path = ckpt_index_file.get_param_group_filename() - if param_group_path is None: - raise RuntimeError( - f"Invalid index file path {index_file_path} for an optimizer. \ - Lacking param group file under current directory." - ) - id_map = load_param_groups_into_optimizer(optimizer, param_group_path) - - checkpoint_files, _ = ckpt_index_file.get_checkpoint_filenames() - - for shard_file in checkpoint_files: - state_dict = load_shard_state_dict(Path(shard_file), use_safetensors=False) - # shard state dict - for param_idx, state in state_dict.items(): - for k, v in state.items(): - if isinstance(v, torch.Tensor) and k != "step": - padding_size = ( - self.coordinator.world_size - v.numel() % self.coordinator.world_size - ) % self.coordinator.world_size - with torch.no_grad(): - v = v.flatten() - if padding_size > 0: - v = torch.nn.functional.pad(v, [0, padding_size]) - v_list = v.split(v.numel() // self.coordinator.world_size) - state_dict[param_idx][k] = v_list[self.coordinator.rank].detach().clone() - load_states_into_optimizer(optimizer, state_dict, id_map) - sharded_optimizer_loading_epilogue(optimizer) - - def load_unsharded_model(self, model: ModelWrapper, checkpoint: str, strict: bool = True): - assert isinstance(model, LowLevelZeroModel), "Please boost the model before loading!" - model._force_wait_all_gather() - super().load_unsharded_model(model, checkpoint, strict) - model.update_master_params() - - def load_sharded_model( - self, - model: ModelWrapper, - checkpoint_index_file: Path, - strict: bool = False, - use_safetensors: bool = False, - load_sub_module: bool = True, - ): - assert isinstance(model, LowLevelZeroModel), "Please boost the model before loading!" - model._force_wait_all_gather() - super().load_sharded_model(model, checkpoint_index_file, strict, use_safetensors, load_sub_module) - model.update_master_params() - - def save_unsharded_model(self, model: ModelWrapper, checkpoint: str, gather_dtensor: bool, use_safetensors: bool): - assert isinstance(model, LowLevelZeroModel), "Please boost the model before loading!" - model._force_wait_all_gather() - return super().save_unsharded_model(model, checkpoint, gather_dtensor, use_safetensors) - - def save_sharded_model( - self, - model: ModelWrapper, - checkpoint_path: str, - gather_dtensor: bool = True, - prefix: Optional[str] = None, - max_shard_size: int = 1024, - use_safetensors: bool = False, - ): - assert isinstance(model, LowLevelZeroModel), "Please boost the model before loading!" - model._force_wait_all_gather() - return super().save_sharded_model( - model, checkpoint_path, gather_dtensor, prefix, max_shard_size, use_safetensors - ) - - def save_lora_as_pretrained(self, model, checkpoint, use_safetensors): - if os.path.isfile(checkpoint): - self.logger.error(f"Provided path ({checkpoint}) should be a directory, not a file") - return - from peft import PeftModel - - assert isinstance(model, ModelWrapper), "Please boost the model before saving!" - model._force_wait_all_gather() - peft_model = model.unwrap() - assert isinstance( - peft_model, PeftModel - ), "The model doesn't have lora adapters, please enable lora before saving." - return peft_model.save_pretrained(checkpoint, safe_serialization=use_safetensors) - - -class LowLevelZeroPlugin(DPPluginBase): - """ - Plugin for low level zero. - - ```python - from colossalai.booster import Booster - from colossalai.booster.plugin import LowLevelZeroPlugin - - model, train_dataset, optimizer, criterion = ... - plugin = LowLevelZeroPlugin() - - train_dataloader = plugin.prepare_dataloader(train_dataset, batch_size=8) - booster = Booster(plugin=plugin) - model, optimizer, train_dataloader, criterion = booster.boost(model, optimizer, train_dataloader, criterion) - ``` - - Args: - stage (int, optional): ZeRO stage. Defaults to 1. - precision (str, optional): precision. Support 'fp16', 'bf16' and 'fp32'. Defaults to 'fp16'. - initial_scale (float, optional): Initial scale used by DynamicGradScaler. Defaults to 2**32. - min_scale (float, optional): Min scale used by DynamicGradScaler. Defaults to 1. - growth_factor (float, optional): growth_factor used by DynamicGradScaler. Defaults to 2. - backoff_factor (float, optional): backoff_factor used by DynamicGradScaler. Defaults to 0.5. - growth_interval (float, optional): growth_interval used by DynamicGradScaler. Defaults to 1000. - hysteresis (float, optional): hysteresis used by DynamicGradScaler. Defaults to 2. - max_scale (int, optional): max_scale used by DynamicGradScaler. Defaults to 2**32. - max_norm (float, optional): max_norm used for `clip_grad_norm`. You should notice that you shall not do - clip_grad_norm by yourself when using ZeRO DDP. The ZeRO optimizer will take care of clip_grad_norm. - norm_type (float, optional): norm_type used for `clip_grad_norm`. - reduce_bucket_size_in_m (int, optional): grad reduce bucket size in M. Defaults to 12. - communication_dtype (torch.dtype, optional): communication dtype. If not specified, the dtype of param will be used. Defaults to None. - overlap_communication (bool, optional): whether to overlap communication and computation. Defaults to True. - cpu_offload (bool, optional): whether to offload grad, master weight and optimizer state to cpu. Defaults to False. - verbose (bool, optional): verbose mode. Debug info including grad overflow will be printed. Defaults to False. - """ - - def __init__( - self, - stage: int = 1, - precision: str = "fp16", - initial_scale: float = 2**32, - min_scale: float = 1, - growth_factor: float = 2, - backoff_factor: float = 0.5, - growth_interval: int = 1000, - hysteresis: int = 2, - max_scale: float = 2**32, - max_norm: float = 0.0, - norm_type: float = 2.0, - reduce_bucket_size_in_m: int = 12, - communication_dtype: Optional[torch.dtype] = None, - overlap_communication: bool = True, - overlap_allgather: bool = False, - cpu_offload: bool = False, - master_weights: bool = True, - verbose: bool = False, - ) -> None: - super().__init__() - assert stage in (1, 2), f"LowLevelZeroPlugin only supports stage 1/2 training" - assert precision in SUPPORTED_PRECISION, f"LowLevelZeroPlugin only supports {SUPPORTED_PRECISION} training" - assert norm_type == 2.0, f"LowLevelZeroPlugin only supports norm_type=2.0 now" - self.stage = stage - self.precision = precision - self.zero_optim_kwargs = dict( - initial_scale=initial_scale, - min_scale=min_scale, - growth_factor=growth_factor, - backoff_factor=backoff_factor, - growth_interval=growth_interval, - hysteresis=hysteresis, - max_scale=max_scale, - clip_grad_norm=max_norm, - reduce_bucket_size=reduce_bucket_size_in_m * 1024 * 1024, - communication_dtype=communication_dtype, - overlap_communication=overlap_communication, - partition_grad=(stage == 2), - cpu_offload=cpu_offload, - master_weights=master_weights, - overlap_allgather=overlap_allgather, - ) - self.lora_enabled = False - self.verbose = verbose - self.logger = get_dist_logger() - # set class name with stage, for better error message - setattr(self.__class__, "__name__", f"LowLevelZeroPlugin_ZeRO-{stage}") - - def support_no_sync(self) -> bool: - return self.stage == 1 - - def support_lora(self) -> bool: - return False - - def control_precision(self) -> bool: - return True - - def supported_precisions(self) -> List[str]: - return SUPPORTED_PRECISION - - def control_device(self) -> bool: - return True - - def supported_devices(self) -> List[str]: - return ["cuda", "npu"] - - def support_lora(self) -> bool: - return True - - def enable_lora( - self, - model: nn.Module, - pretrained_dir: Optional[str] = None, - lora_config: Optional[Dict] = None, - bnb_quantization_config: Optional[BnbQuantizationConfig] = None, - ) -> nn.Module: - from peft import PeftModel, get_peft_model - - assert not isinstance(model, LowLevelZeroModel), "Lora should be enabled before boosting the model." - self.lora_enabled = True - self.logger.warning("You have enabled LoRa training. Please check the hyperparameters such as lr") - - if bnb_quantization_config is not None: - model = quantize_model(model, bnb_quantization_config) - - if pretrained_dir is None: - peft_model = get_peft_model(model, lora_config) - else: - peft_model = PeftModel.from_pretrained(model, pretrained_dir, is_trainable=True) - return peft_model - - def get_param_group_id(self, optimizer: Optimizer, origin_param: Parameter): - origin_param_id = id(origin_param) - for group_id, param_group in enumerate(optimizer.param_groups): - for p in param_group["params"]: - if id(p) == origin_param_id: - return group_id - return -1 - - def get_param_group_id(self, optimizer: Optimizer, origin_param: Parameter, lora_param: Parameter): - origin_param_id = id(origin_param) - lora_param_id = id(lora_param) - target_group_id = None - for group_id, param_group in enumerate(optimizer.param_groups): - for p in param_group["params"]: - if id(p) == lora_param_id: - # check if the lora parameter exists. - return target_group_id, OptimizerParamCheckState.LORA_PARM_EXISTED - if id(p) == origin_param_id: - target_group_id = group_id - if target_group_id is not None: - return target_group_id, OptimizerParamCheckState.ORIGIN_PARAM_FINDED - else: - return target_group_id, OptimizerParamCheckState.ORIGIN_PARAM_NOT_FIND - - def add_lora_params_to_optimizer(self, model, optimizer): - """add lora parameters to optimizer""" - name2param = {} - for name, param in model.named_parameters(): - name2param[name] = param - - for name, param in name2param.items(): - if "lora_A" in name or "lora_B" in name: - origin_key = name.replace("lora_A.", "") - origin_key = origin_key.replace("lora_B.", "") - origin_key = origin_key.replace(f"{model.active_adapter}", "base_layer") - origin_param = name2param[origin_key] - group_id, check_state = self.get_param_group_id(optimizer, origin_param, param) - if check_state == OptimizerParamCheckState.ORIGIN_PARAM_NOT_FIND: - self.logger.warning( - f"Origin parameter {origin_key} related to {name} doesn't exist in optimizer param_groups." - ) - elif ( - check_state == OptimizerParamCheckState.ORIGIN_PARAM_FINDED - and group_id is not None - and group_id >= 0 - ): - optimizer.param_groups[group_id]["params"].append(param) - - def configure( - self, - model: nn.Module, - optimizer: Optional[Optimizer] = None, - criterion: Optional[Callable] = None, - dataloader: Optional[DataLoader] = None, - lr_scheduler: Optional[LRScheduler] = None, - ) -> Tuple[nn.Module, OptimizerWrapper, Callable, DataLoader, LRScheduler]: - if self.lora_enabled: - from peft import PeftModel - - assert isinstance( - model, PeftModel - ), "The model should have been wrapped as a PeftModel when self.lora_enabled is True" - if optimizer is not None: - self.add_lora_params_to_optimizer(model, optimizer) - - if not isinstance(model, ModelWrapper): - model = LowLevelZeroModel( - model, self.precision, overlap_allgather=self.zero_optim_kwargs["overlap_allgather"] - ) - - # TODO: Support Galore + ZeRO - zero_stage = self.stage - zero_optim_kwargs = {**self.zero_optim_kwargs} - dp_size = dist.get_world_size() - - # Replace with the distributed implementation if exists - optimizer = cast_to_distributed(optimizer) - - if isinstance(optimizer, DistGaloreAwamW) and zero_stage > 0 and dp_size > 0: - self.logger.warning( - "Galore is only supported for Tensor Parallel and vanilla Data Parallel yet. Disabling ZeRO." - ) - zero_optim_kwargs["partition_grad"] = False - zero_stage = 0 - - if optimizer is not None and not isinstance(optimizer, OptimizerWrapper): - optimizer: LowLevelZeroOptimizer = LowLevelZeroOptimizer( - optimizer, **zero_optim_kwargs, verbose=self.verbose - ) - # inject update_master_params - model.update_master_params = MethodType(optimizer.update_master_params, model) - - # Setup optimizers that require global states - optim = optimizer.optim - is_zero = dp_size > 1 and zero_stage > 0 - dp_group = _get_default_group() # Use the whole world - if isinstance(optim, DistributedOptim): - shard_to_param = optimizer.get_master_to_working_map() - padding_map = optimizer.get_param_padding_map() - optim.setup_distributed(None, dp_group, shard_to_param, padding_map, is_zero) - - return model, optimizer, criterion, dataloader, lr_scheduler - - def control_checkpoint_io(self) -> bool: - return True - - def get_checkpoint_io(self) -> CheckpointIO: - return LowLevelZeroCheckpointIO() - - def no_sync(self, model: nn.Module, optimizer: OptimizerWrapper) -> Iterator[None]: - assert isinstance(optimizer, LowLevelZeroOptimizer) - return optimizer.no_sync() diff --git a/colossalai/booster/plugin/plugin/moe_hybrid_parallel_plugin.py b/colossalai/booster/plugin/plugin/moe_hybrid_parallel_plugin.py deleted file mode 100644 index 874028f09b86..000000000000 --- a/colossalai/booster/plugin/plugin/moe_hybrid_parallel_plugin.py +++ /dev/null @@ -1,490 +0,0 @@ -from collections import defaultdict -from types import MethodType -from typing import Callable, List, Optional, OrderedDict, Tuple - -import torch -import torch.distributed as dist -from torch.distributed import ProcessGroup -from torch.nn import Module -from torch.optim import Optimizer -from torch.optim.lr_scheduler import _LRScheduler as LRScheduler -from torch.utils.data import DataLoader - -from colossalai.booster.plugin.hybrid_parallel_plugin import ( - PRECISION_TORCH_TYPE, - SUPPORT_SP_MODE, - HybridParallelAMPOptimizer, - HybridParallelModule, - HybridParallelNaiveOptimizer, - HybridParallelPlugin, - HybridParallelZeroOptimizer, - get_param_info, - reinitialize_optimizer, -) -from colossalai.checkpoint_io import MoECheckpointIO -from colossalai.cluster.process_group_mesh import ProcessGroupMesh -from colossalai.interface import ModelWrapper, OptimizerWrapper -from colossalai.interface.optimizer import DistributedOptim -from colossalai.logging import get_dist_logger -from colossalai.nn.optimizer import cast_to_distributed -from colossalai.pipeline.schedule.interleaved_pp import InterleavedSchedule -from colossalai.pipeline.schedule.one_f_one_b import OneForwardOneBackwardSchedule -from colossalai.pipeline.stage_manager import PipelineStageManager -from colossalai.shardformer.policies.base_policy import Policy -from colossalai.shardformer.shard.grad_ckpt_config import GradientCheckpointConfig -from colossalai.shardformer.shard.shard_config import ShardConfig -from colossalai.tensor.moe_tensor.api import is_moe_tensor - - -class MoeHybridParallelZeroOptimizer(HybridParallelZeroOptimizer): - def __init__( - self, - optimizer: Optimizer, - model: Module, - use_pipeline: bool, - dp_process_group: Optional[ProcessGroup], # the dp pg for comm - tp_process_group: Optional[ProcessGroup], # if using tp - pp_process_group: Optional[ProcessGroup], # if using pp - moe_dp_group: ProcessGroup, # moe dp pg for comm - param_info: OrderedDict, - initial_scale: int = 2**16, # grad scaler config - min_scale: int = 1, - growth_factor: float = 2.0, - backoff_factor: float = 0.5, - growth_interval: int = 2000, - hysteresis: int = 2, - max_scale: int = 2**24, - clip_grad_norm: float = 0.0, # grad clipping - verbose: bool = False, - reduce_bucket_size: int = 1024 * 1024, # communication - communication_dtype: Optional[torch.dtype] = None, - overlap_communication: bool = False, - partition_grad: bool = False, # stage 2 flag - cpu_offload: bool = False, # cpu offload - forced_dtype: Optional[torch.dtype] = None, - overlap_allgather: bool = False, - ): - pg_param_list = { - dp_process_group: list(filter(lambda p: not is_moe_tensor(p), model.parameters())), - moe_dp_group: list(filter(is_moe_tensor, model.parameters())), - } - - if len(pg_param_list[dp_process_group]) == 0 or len(pg_param_list[moe_dp_group]) == 0: - raise ValueError("No parameters found in dp_process_group or moe_dp_group") - - super().__init__( - model=model, - optimizer=optimizer, - use_pipeline=use_pipeline, - param_info=param_info, - initial_scale=initial_scale, - min_scale=min_scale, - growth_factor=growth_factor, - backoff_factor=backoff_factor, - growth_interval=growth_interval, - hysteresis=hysteresis, - max_scale=max_scale, - clip_grad_norm=clip_grad_norm, - verbose=verbose, - reduce_bucket_size=reduce_bucket_size, - communication_dtype=communication_dtype, - overlap_communication=overlap_communication, - partition_grad=partition_grad, - cpu_offload=cpu_offload, - tp_process_group=tp_process_group, - pp_process_group=pp_process_group, - forced_dtype=forced_dtype, - pg_to_param_list=pg_param_list, - overlap_allgather=overlap_allgather, - ) - - -class MoeHybridParallelPlugin(HybridParallelPlugin): - """ - Plugin for MoE Hybrid Parallel Training, which is similar to HybridParallelPlugin - Tensor parallel, pipeline parallel and data parallel(DDP/ZeRO) can be picked and combined in this plugin. - The size of tp and pp should be passed in by user, then the size of dp is automatically calculated from dp_size = world_size / (tp_size * pp_size). - - ```python - from colossalai.booster import Booster - from colossalai.booster.plugin import MoeHybridParallelPlugin - - model, train_dataset, optimizer, criterion = ... - plugin = MoeHybridParallelPlugin(tp_size=2, pp_size=2, ep_size=2) - - train_dataloader = plugin.prepare_dataloader(train_dataset, batch_size=8) - booster = Booster(plugin=plugin) - model, optimizer, criterion, train_dataloader, _ = booster.boost(model, optimizer, criterion, train_dataloader) - ``` - - Args: - tp_size (int): The size of tensor parallelism. Tensor parallelism will not be used when tp_size is set to 1. - pp_size (int): The number of pipeline stages in pipeline parallelism. Pipeline parallelism will not be used when pp_size is set to 1. - ep_size (int): The size of expert parallelism - sp_size (int): The size of sequence parallelism. - precision (str, optional): Specifies the precision of parameters during training. - Auto-mixied precision will be used when this argument is set to 'fp16' or 'bf16', otherwise model is trained with 'fp32'. - Defaults to 'fp16'. - zero_stage (int, optional): The stage of ZeRO for data parallelism. Can only be choosed from [0, 1, 2]. - When set to 0, ZeRO will not be used. Defaults to 0. - enable_all_optimization (bool, optional): Whether to switch on all the optimizations supported by Shardformer. - Currently all the optimization methods include fused normalization, flash attention and JIT. - Defaults to False. - enable_fused_normalization (bool, optional): Whether to switch on fused normalization in Shardformer. Defaults to False. - enable_flash_attention (bool, optional): Whether to switch on flash attention in Shardformer. Defaults to False. - enable_jit_fused (bool, optional): Whether to switch on JIT in Shardformer. Default to False. - enable_sequence_parallelism (bool): Whether to turn on sequence parallelism in Shardformer. Defaults to False. - sequence_parallelism_mode (str): The Sequence parallelism mode. Can only be choosed from ["split_gather", "ring", "all_to_all"]. Defaults to "split_gather". - enable_sequence_overlap (bool): Whether to turn on sequence overlap in Shardformer. Defaults to False. - parallel_output (bool): Whether to keep the output parallel when enabling tensor parallelism. Default to True. - num_microbatches (int, optional): Number of microbatches when using pipeline parallelism. Defaults to None. - microbatch_size (int, optional): Microbatch size when using pipeline parallelism. - Either ``num_microbatches`` or ``microbatch_size`` should be provided if using pipeline. - If ``num_microbatches`` is provided, this will be ignored. Defaults to None. - initial_scale (float, optional): The initial loss scale of AMP. Defaults to 2**16. - min_scale (float, optional): The minimum loss scale of AMP. Defaults to 1. - growth_factor (float, optional): The multiplication factor for increasing loss scale when using AMP. Defaults to 2. - backoff_factor (float, optional): The multiplication factor for decreasing loss scale when using AMP. Defaults to 0.5. - growth_interval (int, optional): The number of steps to increase loss scale when no overflow occurs when using AMP. Defaults to 1000. - hysteresis (int, optional): The number of overflows before decreasing loss scale when using AMP. Defaults to 2. - max_scale (float, optional): The maximum loss scale of AMP. Defaults to 2**32. - max_norm (float, optional): Maximum norm for gradient clipping. Defaults to 0. - broadcast_buffers (bool, optional): Whether to broadcast buffers in the beginning of training when using DDP. Defaults to True. - ddp_bucket_cap_mb (int, optional): The bucket size in MB when using DDP. Defaults to 25. - find_unused_parameters (bool, optional): Whether to find unused parameters when using DDP. Defaults to False. - check_reduction (bool, optional): Whether to check reduction when using DDP. Defaults to False. - gradient_as_bucket_view (bool, optional): Whether to use gradient as bucket view when using DDP. Defaults to False. - static_graph (bool, optional): Whether to use static graph when using DDP. Defaults to False. - zero_bucket_size_in_m (int, optional): Gradient reduce bucket size in million elements when using ZeRO. Defaults to 12. - cpu_offload (bool, optional): Whether to open cpu_offload when using ZeRO. Defaults to False. - communication_dtype (torch.dtype, optional): Communication dtype when using ZeRO. If not specified, the dtype of param will be used. Defaults to None. - overlap_communication (bool, optional): Whether to overlap communication and computation when using ZeRO. Defaults to True. - custom_policy (Policy, optional): Custom policy for Shardformer. Defaults to None. - pp_style (str, optional): The style for pipeline parallelism. Defaults to '1f1b'. - num_model_chunks (int, optional): The number of model chunks for interleaved pipeline parallelism. Defaults to 1. - gradient_checkpoint_config (GradientCheckpointConfig, optional): Configuration for gradient checkpointing. Defaults to None. - enable_metadata_cache (bool, optional): Whether to enable metadata cache for pipeline parallelism. Defaults to True. - make_vocab_size_divisible_by (int, optional): it's used when padding the vocabulary size, to make it choose an faster kenel. Default to 64. - overlap_p2p (bool, optional): Whether to overlap the p2p communication in pipeline parallelism - """ - - def __init__( - self, - tp_size: int, - pp_size: int, - ep_size: int, - sp_size: int = None, - precision: str = "fp16", - zero_stage: int = 0, - enable_all_optimization: bool = False, - enable_fused_normalization: bool = False, - enable_flash_attention: bool = False, - enable_jit_fused: bool = False, - enable_sequence_parallelism: bool = False, - sequence_parallelism_mode: str = None, - enable_sequence_overlap: bool = False, - parallel_output: bool = True, - num_microbatches: Optional[int] = None, - microbatch_size: Optional[int] = None, - initial_scale: float = 2**16, - min_scale: float = 1, - growth_factor: float = 2, - backoff_factor: float = 0.5, - growth_interval: int = 1000, - hysteresis: int = 2, - max_scale: float = 2**32, - max_norm: float = 0, - broadcast_buffers: bool = True, - ddp_bucket_cap_mb: int = 25, - find_unused_parameters: bool = False, - check_reduction: bool = False, - gradient_as_bucket_view: bool = False, - static_graph: bool = False, - zero_bucket_size_in_m: int = 12, - cpu_offload: bool = False, - communication_dtype: Optional[torch.dtype] = None, - overlap_communication: bool = False, - custom_policy: Policy = None, - pp_style: str = "1f1b", - num_model_chunks: int = 1, - num_layers_per_stage: Optional[List[int]] = None, - gradient_checkpoint_config: Optional[GradientCheckpointConfig] = None, - enable_metadata_cache: bool = True, - make_vocab_size_divisible_by: int = 64, - moe_dp_outside: bool = True, - overlap_p2p: bool = True, - overlap_allgather: bool = False, - ) -> None: - self.logger = get_dist_logger() - if overlap_communication or zero_stage == 2: - overlap_communication = False - zero_stage = 1 - self.logger.warning( - f"overlap_communication and zero_stage are set to False and 1 because " - f"ZeRO-2 or comm overlap cause program hang when some experts are not routed. " - ) - - assert ( - dist.get_world_size() % (tp_size * pp_size) == 0 - ), f"World size {dist.get_world_size()} is not divisible by tp_size {tp_size} * pp_size {pp_size}" - if enable_sequence_parallelism: - self.sequence_parallelism_mode = ( - sequence_parallelism_mode if sequence_parallelism_mode is not None else "all_to_all" - ) - assert ( - self.sequence_parallelism_mode in SUPPORT_SP_MODE - ), f"Sequence parallelism mode {self.sequence_parallelism_mode} is not in the supported list {SUPPORT_SP_MODE}" - if self.sequence_parallelism_mode in ["split_gather", "ring"]: - assert ( - tp_size > 1 - ), f"Sequence parallelism mode {self.sequence_parallelism_mode} must be enabled when using tensor parallelism" - if sp_size != 1: - self.logger.warning( - f"The sp_size will be the same as tp_size in sequence parallelism mode {self.sequence_parallelism_mode}, will ignore the given sequence parallelism size." - ) - self.sp_size = 1 - self.dp_size = dist.get_world_size() // (tp_size * pp_size) - elif self.sequence_parallelism_mode in ["all_to_all"]: - self.sp_size = 1 if sp_size is None else sp_size - self.dp_size = dist.get_world_size() // (self.sp_size * pp_size * tp_size) - else: - self.dp_size = dist.get_world_size() // (tp_size * pp_size) - assert ( - sp_size == 1 or sp_size is None - ), f"You should not set sp_size when sequence parallelism is not enabled." - self.sp_size = 1 - - assert self.dp_size % ep_size == 0, f"dp_size should be divisible by ep_size, {self.dp_size=} {ep_size=}" - self.moe_dp_size = self.dp_size // ep_size - self.ep_size = ep_size - self.tp_size = tp_size - self.pp_size = pp_size - self.precision = precision - self.zero_stage = zero_stage - self.cpu_offload = cpu_offload - self.enable_all_optimization = enable_all_optimization - self.enable_fused_normalization = enable_fused_normalization - self.enable_flash_attention = enable_flash_attention - self.enable_jit_fused = enable_jit_fused - self.enable_sequence_parallelism = enable_sequence_parallelism - if moe_dp_outside: - self.moe_dp_axis, self.pp_axis, self.ep_axis, self.tp_axis, self.sp_axis = 0, 1, 2, 3, 4 - self.pg_mesh = ProcessGroupMesh(self.moe_dp_size, self.pp_size, self.ep_size, self.tp_size, self.sp_size) - else: - self.pp_axis, self.moe_dp_axis, self.ep_axis, self.tp_axis, self.sp_axis = 0, 1, 2, 3, 4 - self.pg_mesh = ProcessGroupMesh(self.pp_size, self.moe_dp_size, self.ep_size, self.tp_size, self.sp_size) - - self.stage_manager = None - self.schedule = None - self.custom_policy = custom_policy - assert zero_stage in (0, 1, 2) - if self.pp_size > 1: - assert pp_style in ["1f1b", "interleaved"], "Unsupported pipeline parallelism style" - assert pp_style == "interleaved" or num_model_chunks == 1, "num_model_chunks must be 1 when using 1f1b" - assert ( - num_microbatches is not None or microbatch_size is not None - ), "num_microbatches or microbatch_size must be specified when using pipeline parallelism" - assert ( - self.zero_stage <= 1 - ), "To avoid prohibitive gradient synchronization costs, zero stage must be 0 or 1 when using pipeline parallelism" - self.stage_manager = PipelineStageManager( - self.pg_mesh, - pipeline_axis=self.pp_axis, - enable_interleave=pp_style == "interleaved", - num_model_chunks=num_model_chunks, - num_layers_per_stage=num_layers_per_stage, - ) - - if pp_style == "interleaved": - assert num_model_chunks > 1, "number of model chunks must be > 1 when using interleaved" - self.schedule = InterleavedSchedule( - stage_manager=self.stage_manager, - num_model_chunks=num_model_chunks, - num_microbatch=num_microbatches, - microbatch_size=microbatch_size, - enable_metadata_cache=enable_metadata_cache, - overlap_p2p=overlap_p2p, - ) - elif pp_style == "1f1b": - self.schedule = OneForwardOneBackwardSchedule( - stage_manager=self.stage_manager, - num_microbatches=num_microbatches, - microbatch_size=microbatch_size, - enable_metadata_cache=enable_metadata_cache, - ) - else: - raise NotImplementedError() - - self.tp_group = self.pg_mesh.get_group_along_axis(self.tp_axis) - self.dp_group = self.pg_mesh.get_group_along_axis([self.moe_dp_axis, self.ep_axis]) - self.pp_group = self.pg_mesh.get_group_along_axis(self.pp_axis) - self.moe_dp_group = self.pg_mesh.get_group_along_axis(self.moe_dp_axis) - self.ep_group = self.pg_mesh.get_group_along_axis(self.ep_axis) - if self.enable_sequence_parallelism and self.sequence_parallelism_mode in ["split_gather", "ring"]: - self.sp_group = self.pg_mesh.get_group_along_axis(self.tp_axis) - else: - self.sp_group = self.pg_mesh.get_group_along_axis(self.sp_axis) - - self.shard_config = ShardConfig( - tensor_parallel_process_group=self.tp_group, - sequence_parallel_process_group=self.sp_group, - ep_group=self.ep_group, - moe_dp_group=self.moe_dp_group, - pipeline_stage_manager=self.stage_manager, - enable_tensor_parallelism=self.tp_size > 1, - enable_all_optimization=self.enable_all_optimization, - enable_fused_normalization=self.enable_fused_normalization, - enable_flash_attention=self.enable_flash_attention, - enable_jit_fused=self.enable_jit_fused, - enable_sequence_parallelism=enable_sequence_parallelism, - sequence_parallelism_mode=sequence_parallelism_mode, - enable_sequence_overlap=enable_sequence_overlap, - parallel_output=parallel_output, - make_vocab_size_divisible_by=make_vocab_size_divisible_by, - gradient_checkpoint_config=gradient_checkpoint_config, - ) - self.amp_config = dict( - initial_scale=initial_scale, - growth_factor=growth_factor, - backoff_factor=backoff_factor, - growth_interval=growth_interval, - hysteresis=hysteresis, - min_scale=min_scale, - max_scale=max_scale, - ) - - self.ddp_config = dict( - broadcast_buffers=broadcast_buffers, - bucket_cap_mb=ddp_bucket_cap_mb, - find_unused_parameters=find_unused_parameters, - check_reduction=check_reduction, - gradient_as_bucket_view=gradient_as_bucket_view, - static_graph=static_graph, - ) - - self.zero_config = dict( - reduce_bucket_size=zero_bucket_size_in_m * 1024 * 1024, - communication_dtype=communication_dtype, - overlap_communication=overlap_communication, - cpu_offload=cpu_offload, - partition_grad=(self.zero_stage == 2), - forced_dtype=PRECISION_TORCH_TYPE[precision], - overlap_allgather=overlap_allgather, - ) - - self.max_norm = max_norm - - def get_checkpoint_io(self) -> MoECheckpointIO: - return MoECheckpointIO( - self.dp_group, self.pp_group, self.tp_group, self.ep_group, self.moe_dp_group, self.zero_stage - ) - - def configure( - self, - model: Module, - optimizer: Optional[Optimizer] = None, - criterion: Optional[Callable] = None, - dataloader: Optional[DataLoader] = None, - lr_scheduler: Optional[LRScheduler] = None, - ) -> Tuple[Module, OptimizerWrapper, Callable, DataLoader, LRScheduler]: - param_info = get_param_info(optimizer) - - # TODO: Support Galore + ZeRO - # Replace with distributed implementation if exists - optimizer = cast_to_distributed(optimizer) - - if not isinstance(model, ModelWrapper): - use_ddp = (self.dp_size > 1 and self.pp_size == 1 and self.zero_stage == 0) or ( - self.dp_size == 1 - and self.pp_size == 1 - and self.enable_sequence_parallelism - and self.sequence_parallelism_mode == "all_to_all" - ) - if use_ddp: - self.logger.warning( - f"Will have to check all params are used in pytorch DDP since not all experts are always activated" - ) - self.ddp_config["find_unused_parameters"] = True - - if dist.get_process_group_ranks(self.dp_group) != dist.get_process_group_ranks(self.moe_dp_group): - raise ValueError( - f"if pytorch ddp is used, dp_group and moe_dp_group are expected to be the same since DDP can only reduce grad across a single group, but found dp_group {dist.get_process_group_ranks(self.dp_group)} and moe_dp_group {dist.get_process_group_ranks(self.moe_dp_group)}, you might want to use HybridParallelPlugin (i.e. set ep_size = 1) or set zero_stage > 0" - ) - - # sync gradients across DP * SP ranks - if self.enable_sequence_parallelism and self.sequence_parallelism_mode == "all_to_all": - dp_group = self.pg_mesh.create_group_along_axis([self.moe_dp_axis, self.ep_axis, self.sp_axis]) - else: - dp_group = self.dp_group - - model = HybridParallelModule( - module=model, - precision=self.precision, - shard_config=self.shard_config, - dp_group=dp_group, - tp_group=self.tp_group, - sp_group=self.sp_group, - use_ddp=use_ddp, - ddp_config=self.ddp_config, - custom_policy=self.custom_policy, - ) - if optimizer is not None and not isinstance(optimizer, OptimizerWrapper): - if self.ep_size > 1: - # if ep is enabled, the num of (moe) paramaters changed since they are sharded among ep groups - # but the optimizer is not aware of ep, so we need to update the optimizer - reinitialize_optimizer(optimizer, model) - - if self.zero_stage == 0: - is_zero = False - if self.precision in ["fp16", "bf16"]: - optimizer = HybridParallelAMPOptimizer( - optimizer, - model, - use_pipeline=self.enable_pipeline_parallelism, - param_info=param_info, - precision=self.precision, - max_norm=self.max_norm, - **self.amp_config, - ) - else: - optimizer = HybridParallelNaiveOptimizer( - optimizer, - model, - use_pipeline=self.enable_pipeline_parallelism, - param_info=param_info, - max_norm=self.max_norm, - pp_process_group=self.pp_group, - tp_process_group=self.tp_group, - ) - else: - if self.dp_size <= 1: - self.logger.warning( - "Use Zero Optimizer when data parallel size is 1 may introduce unnecessary overhead. " - "If you do not intend to use cpu_offload, please consider set zero_stage=0." - ) - assert self.precision != "fp32", "Please set precision to 'fp16' or 'bf16' when using ZeRO." - optimizer = MoeHybridParallelZeroOptimizer( - optimizer, - model, - use_pipeline=self.enable_pipeline_parallelism, - param_info=param_info, - dp_process_group=dp_group, - tp_process_group=self.tp_group, - pp_process_group=self.pp_group, - moe_dp_group=self.moe_dp_group, - verbose=True, - clip_grad_norm=self.max_norm, - **self.zero_config, - **self.amp_config, - ) - # inject update_master_params - model.update_master_params = MethodType(optimizer.update_master_params, model) - - # Setup optimizers that require global states - optim = optimizer.optim - if isinstance(optim, DistributedOptim): - shard_to_param = optimizer.get_master_to_working_map() if is_zero else {} - padding_map = optimizer.get_param_padding_map() if is_zero else defaultdict(int) - optim.setup_distributed(self.tp_group, self.dp_group, shard_to_param, padding_map, is_zero) - - return model, optimizer, criterion, dataloader, lr_scheduler diff --git a/colossalai/booster/plugin/plugin/plugin_base.py b/colossalai/booster/plugin/plugin/plugin_base.py deleted file mode 100644 index 6dc0c560d06d..000000000000 --- a/colossalai/booster/plugin/plugin/plugin_base.py +++ /dev/null @@ -1,90 +0,0 @@ -from abc import ABC, abstractmethod -from typing import Callable, Dict, Iterator, List, Optional, Tuple - -import torch.nn as nn -from torch.optim import Optimizer -from torch.optim.lr_scheduler import _LRScheduler as LRScheduler -from torch.utils.data import DataLoader, Dataset - -from colossalai.checkpoint_io import CheckpointIO -from colossalai.interface import OptimizerWrapper - -__all__ = ["Plugin"] - - -class Plugin(ABC): - @abstractmethod - def supported_devices(self) -> List[str]: - pass - - @abstractmethod - def supported_precisions(self) -> List[str]: - pass - - @abstractmethod - def control_precision(self) -> bool: - pass - - @abstractmethod - def control_device(self) -> bool: - pass - - @abstractmethod - def support_no_sync(self) -> bool: - pass - - @abstractmethod - def support_lora(self) -> bool: - pass - - @abstractmethod - def configure( - self, - model: nn.Module, - optimizer: Optional[Optimizer] = None, - criterion: Optional[Callable] = None, - dataloader: Optional[DataLoader] = None, - lr_scheduler: Optional[LRScheduler] = None, - ) -> Tuple[nn.Module, OptimizerWrapper, Callable, DataLoader, LRScheduler]: - # implement this method - pass - - @abstractmethod - def control_checkpoint_io(self) -> bool: - """ - Whether the plugin controls the checkpoint io - """ - - @abstractmethod - def get_checkpoint_io(self) -> CheckpointIO: - """ - Get checkpoint io object for this plugin, only invoked when control_checkpoint_io is True. - """ - - @abstractmethod - def no_sync(self, model: nn.Module, optimizer: OptimizerWrapper) -> Iterator[None]: - """ - Context manager to disable gradient synchronization. - """ - - @abstractmethod - def enable_lora(self, model: nn.Module, pretrained_dir: str, lora_config: Dict) -> nn.Module: - """ - Add LoRA modules to the model passed in. Should only be called in booster.enable_lora(). - """ - - @abstractmethod - def prepare_dataloader( - self, - dataset: Dataset, - batch_size: int, - shuffle: bool = False, - seed: int = 1024, - drop_last: bool = False, - pin_memory: bool = False, - num_workers: int = 0, - **kwargs, - ): - """Prepare a dataloader for distributed training. The dataloader will be wrapped by - `torch.utils.data.DataLoader` - """ diff --git a/colossalai/booster/plugin/plugin/pp_plugin_base.py b/colossalai/booster/plugin/plugin/pp_plugin_base.py deleted file mode 100644 index 3d91eb95b409..000000000000 --- a/colossalai/booster/plugin/plugin/pp_plugin_base.py +++ /dev/null @@ -1,22 +0,0 @@ -from abc import abstractmethod -from typing import Any, Callable, Iterator, Optional - -import torch - -from colossalai.interface import ModelWrapper, OptimizerWrapper - -from .plugin_base import Plugin - - -class PipelinePluginBase(Plugin): - @abstractmethod - def execute_pipeline( - self, - data_iter: Iterator, - model: ModelWrapper, - criterion: Callable[[Any, Any], torch.Tensor], - optimizer: Optional[OptimizerWrapper] = None, - return_loss: bool = True, - return_outputs: bool = False, - ) -> dict: - pass diff --git a/colossalai/booster/plugin/plugin/torch_ddp_plugin.py b/colossalai/booster/plugin/plugin/torch_ddp_plugin.py deleted file mode 100644 index 5116446a4295..000000000000 --- a/colossalai/booster/plugin/plugin/torch_ddp_plugin.py +++ /dev/null @@ -1,257 +0,0 @@ -from typing import Callable, Dict, Iterator, List, Optional, Tuple, Union - -import torch.nn as nn -from torch.nn.parallel import DistributedDataParallel as DDP -from torch.optim import Optimizer -from torch.optim.lr_scheduler import _LRScheduler as LRScheduler -from torch.utils.data import DataLoader - -from colossalai.checkpoint_io import CheckpointIO, GeneralCheckpointIO -from colossalai.cluster import DistCoordinator -from colossalai.interface import ModelWrapper, OptimizerWrapper -from colossalai.quantization import BnbQuantizationConfig, quantize_model -from colossalai.utils import get_current_device - -from .dp_plugin_base import DPPluginBase - -__all__ = ["TorchDDPPlugin"] - - -class TorchDDPCheckpointIO(GeneralCheckpointIO): - def __init__(self) -> None: - super().__init__() - self.coordinator = DistCoordinator() - - def load_unsharded_model(self, model: ModelWrapper, checkpoint: str, strict: bool = True): - """ - Load model from checkpoint. - """ - assert isinstance(model, ModelWrapper), "Please boost the model before loading!" - super().load_unsharded_model(model.unwrap(), checkpoint, strict=strict) - - def save_unsharded_model(self, model: ModelWrapper, checkpoint: str, gather_dtensor: bool, use_safetensors: bool): - """ - Save model to checkpoint but only on master process. - """ - assert isinstance(model, ModelWrapper), "Please boost the model before saving!" - if self.coordinator.is_master(): - super().save_unsharded_model(model.unwrap(), checkpoint, gather_dtensor, use_safetensors) - - def load_unsharded_optimizer(self, optimizer: OptimizerWrapper, checkpoint: str): - """ - Load optimizer from checkpoint. - """ - assert isinstance(optimizer, OptimizerWrapper), "Please boost the optimizer before loading!" - super().load_unsharded_optimizer(optimizer, checkpoint) - - def save_unsharded_optimizer(self, optimizer: OptimizerWrapper, checkpoint: str, gather_dtensor: bool): - """ - Save optimizer to checkpoint but only on master process. - """ - assert isinstance(optimizer, OptimizerWrapper), "Please boost the optimizer before saving!" - if self.coordinator.is_master(): - super().save_unsharded_optimizer(optimizer, checkpoint, gather_dtensor) - - def save_lr_scheduler(self, lr_scheduler: LRScheduler, checkpoint: str): - """ - Save model to checkpoint but only on master process. - """ - if self.coordinator.is_master(): - super().save_lr_scheduler(lr_scheduler, checkpoint) - - def save_sharded_model( - self, - model: ModelWrapper, - checkpoint_path: str, - gather_dtensor: bool = True, - prefix: Optional[str] = None, - max_shard_size: int = 1024, - use_safetensors: bool = False, - ): - """ - Save model to checkpoint but only on master process. - """ - assert isinstance(model, ModelWrapper), "Please boost the model before saving!" - if self.coordinator.is_master(): - super().save_sharded_model( - model.unwrap(), checkpoint_path, gather_dtensor, prefix, max_shard_size, use_safetensors - ) - - def load_sharded_model( - self, - model: ModelWrapper, - checkpoint_index_file: str, - strict: bool = False, - use_safetensors: bool = False, - load_sub_module: bool = True, - ): - """ - Load model from sharded checkpoint. - """ - assert isinstance(model, ModelWrapper), "Please boost the model before loading!" - super().load_sharded_model(model.unwrap(), checkpoint_index_file, strict, use_safetensors, load_sub_module) - - def save_sharded_optimizer( - self, - optimizer: OptimizerWrapper, - checkpoint: str, - gather_dtensor: bool = True, - prefix: Optional[str] = None, - size_per_shard: int = 1024, - ): - """ - Save optimizer to sharded checkpoint but only on master process. - """ - assert isinstance(optimizer, OptimizerWrapper), "Please boost the optimizer before saving!" - if self.coordinator.is_master(): - super().save_sharded_optimizer(optimizer.unwrap(), checkpoint, gather_dtensor, prefix, size_per_shard) - - def load_sharded_optimizer( - self, - optimizer: Optimizer, - index_file_path: str, - prefix: Optional[str] = None, - ): - """ - Load optimizer from sharded checkpoint. - """ - assert isinstance(optimizer, OptimizerWrapper), "Please boost the optimizer before loading!" - super().load_sharded_optimizer(optimizer.unwrap(), index_file_path, prefix) - - def save_lora_as_pretrained( - self, model: Union[nn.Module, ModelWrapper], checkpoint: str, use_safetensors: bool = False - ) -> None: - """ - Save the lora adapters and adapter configuration file to checkpoint directory. - """ - from peft import PeftModel - - assert isinstance(model, ModelWrapper), "Please boost the model before saving!" - if self.coordinator.is_master(): - peft_model = model.unwrap() - assert isinstance( - peft_model, PeftModel - ), "The model doesn't have lora adapters, please enable lora before saving." - peft_model.save_pretrained(save_directory=checkpoint, safe_serialization=use_safetensors) - - -class TorchDDPModel(ModelWrapper): - def __init__(self, module: nn.Module, *args, **kwargs) -> None: - super().__init__(module) - self.module = DDP(module, *args, **kwargs) - - def unwrap(self): - return self.module.module - - -class TorchDDPPlugin(DPPluginBase): - """ - Plugin for PyTorch DDP. - - ```python - from colossalai.booster import Booster - from colossalai.booster.plugin import TorchDDPPlugin - - model, train_dataset, optimizer, criterion = ... - plugin = TorchDDPPlugin() - - train_dataloader = plugin.prepare_dataloader(train_dataset, batch_size=8) - booster = Booster(plugin=plugin) - model, optimizer, train_dataloader, criterion = booster.boost(model, optimizer, train_dataloader, criterion) - ``` - - Args: - broadcast_buffers (bool, optional): Whether to broadcast buffers in the beginning of training. Defaults to True. - bucket_cap_mb (int, optional): The bucket size in MB. Defaults to 25. - find_unused_parameters (bool, optional): Whether to find unused parameters. Defaults to False. - check_reduction (bool, optional): Whether to check reduction. Defaults to False. - gradient_as_bucket_view (bool, optional): Whether to use gradient as bucket view. Defaults to False. - static_graph (bool, optional): Whether to use static graph. Defaults to False. - """ - - def __init__( - self, - broadcast_buffers: bool = True, - bucket_cap_mb: int = 25, - find_unused_parameters: bool = False, - check_reduction: bool = False, - gradient_as_bucket_view: bool = False, - static_graph: bool = False, - ) -> None: - super().__init__() - self.ddp_kwargs = dict( - broadcast_buffers=broadcast_buffers, - bucket_cap_mb=bucket_cap_mb, - find_unused_parameters=find_unused_parameters, - check_reduction=check_reduction, - gradient_as_bucket_view=gradient_as_bucket_view, - static_graph=static_graph, - ) - - def support_no_sync(self) -> bool: - return True - - def support_lora(self) -> bool: - return True - - def control_precision(self) -> bool: - return False - - def supported_precisions(self) -> List[str]: - return ["fp16", "fp16_apex", "bf16", "fp8"] - - def control_device(self) -> bool: - return True - - def supported_devices(self) -> List[str]: - return ["cuda", "npu"] - - def configure( - self, - model: nn.Module, - optimizer: Optional[Optimizer] = None, - criterion: Optional[Callable] = None, - dataloader: Optional[DataLoader] = None, - lr_scheduler: Optional[LRScheduler] = None, - ) -> Tuple[nn.Module, OptimizerWrapper, Callable, DataLoader, LRScheduler]: - # cast model to cuda - model = model.to(get_current_device()) - - # convert model to sync bn - model = nn.SyncBatchNorm.convert_sync_batchnorm(model, None) - - # wrap the model with PyTorch DDP - model = TorchDDPModel(model, **self.ddp_kwargs) - - if optimizer is not None and not isinstance(optimizer, OptimizerWrapper): - optimizer = OptimizerWrapper(optimizer) - - return model, optimizer, criterion, dataloader, lr_scheduler - - def control_checkpoint_io(self) -> bool: - return True - - def get_checkpoint_io(self) -> CheckpointIO: - return TorchDDPCheckpointIO() - - def no_sync(self, model: nn.Module, optimizer: OptimizerWrapper) -> Iterator[None]: - assert isinstance(model, TorchDDPModel), "Model is not boosted by TorchDDPPlugin." - return model.module.no_sync() - - def enable_lora( - self, - model: nn.Module, - pretrained_dir: Optional[str] = None, - lora_config: Optional[Dict] = None, - bnb_quantization_config: Optional[BnbQuantizationConfig] = None, - ) -> nn.Module: - from peft import PeftModel, get_peft_model - - if bnb_quantization_config is not None: - model = quantize_model(model, bnb_quantization_config) - - assert not isinstance(model, TorchDDPModel), "Lora should be enabled before boosting the model." - if pretrained_dir is None: - return get_peft_model(model, lora_config) - else: - return PeftModel.from_pretrained(model, pretrained_dir, is_trainable=True) diff --git a/colossalai/booster/plugin/plugin/torch_fsdp_plugin.py b/colossalai/booster/plugin/plugin/torch_fsdp_plugin.py deleted file mode 100644 index 7b67da032d66..000000000000 --- a/colossalai/booster/plugin/plugin/torch_fsdp_plugin.py +++ /dev/null @@ -1,372 +0,0 @@ -import os -from pathlib import Path -from typing import Callable, Dict, Iterable, Iterator, List, Optional, Tuple - -import torch -import torch.nn as nn -from packaging import version -from torch.distributed import ProcessGroup - -if version.parse(torch.__version__) >= version.parse("1.12.0"): - from torch.distributed.fsdp import FullStateDictConfig - from torch.distributed.fsdp import FullyShardedDataParallel as FSDP - from torch.distributed.fsdp import StateDictType - from torch.distributed.fsdp.fully_sharded_data_parallel import ( - BackwardPrefetch, - CPUOffload, - FullStateDictConfig, - MixedPrecision, - ShardingStrategy, - ) -else: - raise RuntimeError("FSDP is not supported while torch version under 1.12.0.") - -from torch.optim import Optimizer -from torch.optim.lr_scheduler import _LRScheduler as LRScheduler -from torch.utils.data import DataLoader - -from colossalai.checkpoint_io import CheckpointIndexFile, CheckpointIO, GeneralCheckpointIO, utils -from colossalai.cluster import DistCoordinator -from colossalai.interface import ModelWrapper, OptimizerWrapper -from colossalai.logging import get_dist_logger - -from .dp_plugin_base import DPPluginBase - -__all__ = ["TorchFSDPPlugin"] - - -class TorchFSDPCheckpointIO(GeneralCheckpointIO): - def __init__(self) -> None: - super().__init__() - self.coordinator = DistCoordinator() - self.logger = get_dist_logger() - - def load_unsharded_model(self, model: ModelWrapper, checkpoint: str, strict: bool): - assert isinstance(model, TorchFSDPModel), "Please boost the model before loading!" - model = model.unwrap() - checkpoint = utils.load_state_dict(checkpoint) - model.load_state_dict(checkpoint) - - def load_unsharded_optimizer(self, optimizer: OptimizerWrapper, checkpoint: Path): - assert isinstance(optimizer, FSDPOptimizerWrapper), "Please boost the optimizer before loading!" - checkpoint = utils.load_state_dict(checkpoint) - fsdp_model = optimizer.unwrap_model() - sharded_osd = FSDP.scatter_full_optim_state_dict(checkpoint, fsdp_model) - optimizer.load_state_dict(sharded_osd) - - def save_unsharded_model(self, model: ModelWrapper, checkpoint: str, gather_dtensor: bool, use_safetensors: bool): - """ - Save model to checkpoint but only on master process. - """ - assert isinstance(model, TorchFSDPModel), "Please boost the model before saving!" - model = model.unwrap() - cfg = FullStateDictConfig(offload_to_cpu=True, rank0_only=True) - with FSDP.state_dict_type(model, StateDictType.FULL_STATE_DICT, cfg): - full_model_state = model.state_dict() - utils.save_state_dict(full_model_state, checkpoint_file_path=checkpoint, use_safetensors=use_safetensors) - - def save_unsharded_optimizer(self, optimizer: OptimizerWrapper, checkpoint: str, gather_dtensor: bool): - """ - Save optimizer to checkpoint but only on master process. - """ - assert isinstance(optimizer, FSDPOptimizerWrapper), "Please boost the optimizer before saving!" - fsdp_model = optimizer.unwrap_model() - full_optimizer_state = FSDP.full_optim_state_dict(fsdp_model, optim=optimizer, rank0_only=True) - utils.save_state_dict(full_optimizer_state, checkpoint_file_path=checkpoint, use_safetensors=False) - - def save_sharded_model( - self, - model: ModelWrapper, - checkpoint_path: str, - gather_dtensor: bool = True, - prefix: Optional[str] = None, - size_per_shard: int = 1024, - use_safetensors: bool = False, - ): - """ - Save model to checkpoint but only on master process. - """ - assert isinstance(model, TorchFSDPModel), "Please boost the model before saving!" - if os.path.isfile(checkpoint_path): - self.logger.error(f"Provided path ({checkpoint_path}) should be a directory, not a file") - return - - Path(checkpoint_path).mkdir(parents=True, exist_ok=True) - with FSDP.state_dict_type( - model.unwrap(), StateDictType.FULL_STATE_DICT, FullStateDictConfig(offload_to_cpu=True, rank0_only=True) - ): - state_dict = model.unwrap().state_dict() - - state_dict_shard = utils.shard_model_checkpoint(state_dict, max_shard_size=size_per_shard) - - weights_name, save_index_file = utils.get_model_base_filenames(prefix, use_safetensors) - index_file = CheckpointIndexFile(checkpoint_path) - - # In general cases, is_master is set to True to get the right behavior. - total_size = utils.save_state_dict_shards( - sharded_state_dict=state_dict_shard, - checkpoint=checkpoint_path, - index_file=index_file, - base_filename=weights_name, - is_master=self.coordinator.is_master(), - use_safetensors=use_safetensors, - ) - - # only save the index file on the master rank - if self.coordinator.is_master(): - index_file.append_meta_data("total_size", total_size) - index_file.write_index_file(save_index_file) - utils.save_config_file(model.unwrap(), checkpoint_path) - self.logger.info( - f"The model is split into checkpoint shards. " - f"You can find where each parameters has been saved in the " - f"index located at {save_index_file}." - ) - - def load_sharded_model( - self, - model: nn.Module, - checkpoint_index_file: Path, - strict: bool = False, - use_safetensors: bool = False, - load_sub_module: bool = True, - ): - """ - Load model to checkpoint but only on master process. - """ - assert isinstance(model, TorchFSDPModel), "Please boost the model before loading!" - use_safetensors = False - if "safetensors" in checkpoint_index_file.name: - use_safetensors = True - - if use_safetensors and not utils.is_safetensors_available(): - raise ImportError("`safe_serialization` requires the `safetensors` library: `pip install safetensors`.") - - # read checkpoint index file - ckpt_index_file = CheckpointIndexFile.from_file(checkpoint_index_file) - checkpoint_files, _ = ckpt_index_file.get_checkpoint_filenames() - - fsdp_state_dict = {} - for shard_file in checkpoint_files: - fsdp_state_dict.update(utils.load_shard_state_dict(Path(shard_file), use_safetensors)) - - with FSDP.state_dict_type(model.unwrap(), StateDictType.FULL_STATE_DICT): - model.unwrap().load_state_dict(fsdp_state_dict, strict=False) - - def save_sharded_optimizer( - self, optimizer: Optimizer, checkpoint: str, gather_dtensor: bool, prefix: str, size_per_shard: int - ): - """ - Save optimizer to checkpoint but only on master process. - """ - assert isinstance(optimizer, FSDPOptimizerWrapper), "Please boost the optimizer before saving!" - - if os.path.isfile(checkpoint): - self.logger.error(f"Provided path ({checkpoint}) should be a directory, not a file") - return - - Path(checkpoint).mkdir(parents=True, exist_ok=True) - - with FSDP.state_dict_type( - optimizer.unwrap_model().unwrap(), - StateDictType.FULL_STATE_DICT, - FullStateDictConfig(offload_to_cpu=True, rank0_only=True), - ): - fsdp_optim_state = FSDP.full_optim_state_dict( - optimizer.unwrap_model().unwrap(), optim=optimizer, rank0_only=True - ) - - if self.coordinator.is_master(): - # Preparing file paths and index file. - states_name, save_index_file, param_group_file = utils.get_optimizer_base_filenames(prefix) - index_file = CheckpointIndexFile(checkpoint) - - index_file.append_meta_data("param_groups", param_group_file) - group_file_path = os.path.join(checkpoint, param_group_file) - utils.save_param_groups(fsdp_optim_state, group_file_path) - - sharded_state = utils.shard_optimizer_checkpoint(fsdp_optim_state, max_shard_size=size_per_shard) - - # Save shards of optimizer states. - # In general cases, is_master is set to True to get the right behavior. - total_size = utils.save_state_dict_shards( - sharded_state_dict=sharded_state, - checkpoint=checkpoint, - index_file=index_file, - base_filename=states_name, - is_master=self.coordinator.is_master(), - use_safetensors=False, - ) - - index_file.append_meta_data("total_size", total_size) - index_file.write_index_file(save_index_file) - self.logger.info( - f"The optimizer is going to be split to checkpoint shards. " - f"You can find where each parameters has been saved in the " - f"index located at {save_index_file}." - ) - - def load_sharded_optimizer(self, optimizer: Optimizer, index_file_path: str, size_per_shard: int): - """ - Load optimizer to checkpoint but only on master process. - """ - assert isinstance(optimizer, FSDPOptimizerWrapper), "Please boost the optimizer before saving!" - - ckpt_index_file = CheckpointIndexFile.from_file(index_file_path) - - # Load param_groups - param_group_path = ckpt_index_file.get_param_group_filename() - if param_group_path is None: - raise RuntimeError( - f"Invalid index file path {index_file_path} for an optimizer. " - "Looking param group file under current directory." - ) - - saved_param_groups = torch.load(param_group_path) - - # Load param - fsdp_optim_state = {} - checkpoint_files, _ = ckpt_index_file.get_checkpoint_filenames() - for shard_file in checkpoint_files: - state_dict_shard = utils.load_shard_state_dict(Path(shard_file), use_safetensors=False) - fsdp_optim_state.update(state_dict_shard) - - fsdp_optim_dict = dict(state=fsdp_optim_state, param_groups=saved_param_groups) - - with FSDP.state_dict_type(optimizer.unwrap_model().unwrap(), StateDictType.FULL_STATE_DICT): - fsdp_state = FSDP.optim_state_dict_to_load( - model=optimizer.unwrap_model().unwrap(), optim=optimizer, optim_state_dict=fsdp_optim_dict - ) - optimizer.load_state_dict(fsdp_state) - - def save_lr_scheduler(self, lr_scheduler: LRScheduler, checkpoint: str): - """ - Save model to checkpoint but only on master process. - """ - if self.coordinator.is_master(): - super().save_lr_scheduler(lr_scheduler, checkpoint) - - -class TorchFSDPModel(ModelWrapper): - def __init__(self, module: nn.Module, *args, **kwargs) -> None: - super().__init__(module) - self.module = FSDP(module, *args, **kwargs) - - def unwrap(self): - return self.module - - -class FSDPOptimizerWrapper(OptimizerWrapper): - def __init__(self, optimizer: Optimizer, model: nn.Module): - self.model = model - super().__init__(optimizer) - - def unwrap_model(self) -> nn.Module: - return self.model - - -class TorchFSDPPlugin(DPPluginBase): - """ - Plugin for PyTorch FSDP. - - ```python - from colossalai.booster import Booster - from colossalai.booster.plugin import TorchFSDPPlugin - - model, train_dataset, optimizer, criterion = ... - plugin = TorchFSDPPlugin() - - train_dataloader = plugin.prepare_train_dataloader(train_dataset, batch_size=8) - booster = Booster(plugin=plugin) - model, optimizer, train_dataloader, criterion = booster.boost(model, optimizer, train_dataloader, criterion) - ``` - - Args: - See https://pytorch.org/docs/stable/fsdp.html for details. - """ - - if version.parse(torch.__version__) >= version.parse("1.12.0"): - - def __init__( - self, - process_group: Optional[ProcessGroup] = None, - sharding_strategy: Optional[ShardingStrategy] = None, - cpu_offload: Optional[CPUOffload] = None, - auto_wrap_policy: Optional[Callable] = None, - backward_prefetch: Optional[BackwardPrefetch] = None, - mixed_precision: Optional[MixedPrecision] = None, - ignored_modules: Optional[Iterable[torch.nn.Module]] = None, - param_init_fn: Optional[Callable[[nn.Module], None]] = None, - sync_module_states: bool = False, - ): - super().__init__() - self.fsdp_kwargs = dict( - process_group=process_group, - sharding_strategy=sharding_strategy, - cpu_offload=cpu_offload, - auto_wrap_policy=auto_wrap_policy, - backward_prefetch=backward_prefetch, - mixed_precision=mixed_precision, - ignored_modules=ignored_modules, - param_init_fn=param_init_fn, - sync_module_states=sync_module_states, - ) - self.logger = get_dist_logger() - - else: - raise RuntimeError("FSDP is not supported while torch version under 1.12.0.") - - def support_no_sync(self) -> bool: - return False - - def support_lora(self) -> bool: - return False - - def no_sync(self, model: nn.Module, optimizer: OptimizerWrapper) -> Iterator[None]: - raise NotImplementedError("Torch fsdp no_sync func not supported yet.") - - def control_precision(self) -> bool: - return True - - def supported_precisions(self) -> List[str]: - return ["fp16", "bf16"] - - def control_device(self) -> bool: - return True - - def supported_devices(self) -> List[str]: - return ["cuda"] - - def configure( - self, - model: nn.Module, - optimizer: Optional[Optimizer] = None, - criterion: Optional[Callable] = None, - dataloader: Optional[DataLoader] = None, - lr_scheduler: Optional[LRScheduler] = None, - ) -> Tuple[nn.Module, OptimizerWrapper, Callable, DataLoader, LRScheduler]: - # wrap the model with PyTorch FSDP - fsdp_model = TorchFSDPModel(model, device_id=torch.cuda.current_device(), **self.fsdp_kwargs) - - if optimizer is not None: - if len(optimizer.param_groups) > 1: - self.logger.warning( - "TorchFSDPPlugin does not support optimizer that use multi param groups. The results may not be as expected if used." - ) - optimizer.__init__(fsdp_model.parameters(), **optimizer.defaults) - - if not isinstance(optimizer, FSDPOptimizerWrapper): - optimizer = FSDPOptimizerWrapper(optimizer, fsdp_model) - - return fsdp_model, optimizer, criterion, dataloader, lr_scheduler - - def control_checkpoint_io(self) -> bool: - return True - - def get_checkpoint_io(self) -> CheckpointIO: - return TorchFSDPCheckpointIO() - - def enable_lora( - self, model: nn.Module, pretrained_dir: Optional[str] = None, lora_config: Optional[Dict] = None - ) -> nn.Module: - raise NotImplementedError From 5444b75c15cd91ff3f2af31bbf3c2c74cb5d3454 Mon Sep 17 00:00:00 2001 From: Edenzzzz Date: Fri, 16 Aug 2024 10:01:48 +0000 Subject: [PATCH 3/5] print on rank 0 --- colossalai/booster/plugin/gemini_plugin.py | 15 +++++++++------ .../booster/plugin/hybrid_parallel_plugin.py | 16 ++++++++++------ .../booster/plugin/low_level_zero_plugin.py | 15 +++++++++------ .../booster/plugin/moe_hybrid_parallel_plugin.py | 13 +++++++++---- 4 files changed, 37 insertions(+), 22 deletions(-) diff --git a/colossalai/booster/plugin/gemini_plugin.py b/colossalai/booster/plugin/gemini_plugin.py index 443c80831b14..3754cfe600bb 100644 --- a/colossalai/booster/plugin/gemini_plugin.py +++ b/colossalai/booster/plugin/gemini_plugin.py @@ -119,7 +119,7 @@ def save_sharded_model( """ assert isinstance(model, GeminiDDP), "Please boost the model before saving!" if os.path.isfile(checkpoint_path): - self.logger.error(f"Provided path ({checkpoint_path}) should be a directory, not a file") + self.logger.error(f"Provided path ({checkpoint_path}) should be a directory, not a file", ranks=[0]) return Path(checkpoint_path).mkdir(parents=True, exist_ok=True) @@ -147,7 +147,8 @@ def save_sharded_model( self.logger.info( f"The model is split into checkpoint shards. " f"You can find where each parameters has been saved in the " - f"index located at {save_index_file}." + f"index located at {save_index_file}.", + ranks=[0], ) def load_sharded_model( @@ -169,7 +170,7 @@ def save_sharded_optimizer( assert isinstance(optimizer, GeminiOptimizer), "Please boost the optimizer before saving!" if os.path.isfile(checkpoint): - self.logger.error(f"Provided path ({checkpoint}) should be a directory, not a file") + self.logger.error(f"Provided path ({checkpoint}) should be a directory, not a file", ranks=[0]) return Path(checkpoint).mkdir(parents=True, exist_ok=True) @@ -205,7 +206,8 @@ def save_sharded_optimizer( self.logger.info( f"The optimizer is going to be split to checkpoint shards. " f"You can find where each parameters has been saved in the " - f"index located at {save_index_file}." + f"index located at {save_index_file}.", + ranks=[0], ) def load_sharded_optimizer(self, optimizer: GeminiOptimizer, checkpoint_index_file: Path, prefix: str): @@ -215,7 +217,7 @@ def load_sharded_optimizer(self, optimizer: GeminiOptimizer, checkpoint_index_fi """ assert isinstance(optimizer, GeminiOptimizer), "Please boost the optimizer before loading!" if not os.path.isfile(checkpoint_index_file): - self.logger.error(f"Provided path ({checkpoint_index_file}) should be a file") + self.logger.error(f"Provided path ({checkpoint_index_file}) should be a file", ranks=[0]) assert isinstance(optimizer, GeminiOptimizer) @@ -374,7 +376,8 @@ def __init__( self.logger = get_dist_logger() if enable_async_reduce and not pin_memory: self.logger.warning( - f"enable_async_reduce sets pin_memory=True to achieve best performance, which is not implicitly set." + f"enable_async_reduce sets pin_memory=True to achieve best performance, which is not implicitly set.", + ranks=[0], ) pin_memory = True self.gemini_config = dict( diff --git a/colossalai/booster/plugin/hybrid_parallel_plugin.py b/colossalai/booster/plugin/hybrid_parallel_plugin.py index 6c1515d38834..b4b40020fb2d 100644 --- a/colossalai/booster/plugin/hybrid_parallel_plugin.py +++ b/colossalai/booster/plugin/hybrid_parallel_plugin.py @@ -1042,7 +1042,8 @@ def __init__( ), f"Sequence parallelism mode {self.sequence_parallelism_mode} must be enabled when using tensor parallelism" if sp_size != 1: self.logger.warning( - f"The sp_size will be the same as tp_size in sequence parallelism mode {self.sequence_parallelism_mode}, will ignore the given sequence parallelism size." + f"The sp_size will be the same as tp_size in sequence parallelism mode {self.sequence_parallelism_mode}, will ignore the given sequence parallelism size.", + ranks=[0], ) self.sp_size = 1 self.dp_size = dist.get_world_size() // (tp_size * pp_size) @@ -1129,7 +1130,8 @@ def __init__( if sequence_parallelism_mode == "ring_attn": if not parallel_output: self.logger.warning( - "parallel_output must be True for Zigzag Ring Attention, as we've not supported Zigzag all-gather yet." + "parallel_output must be True for Zigzag Ring Attention, as we've not supported Zigzag all-gather yet.", + ranks=[0], ) parallel_output = True @@ -1237,7 +1239,8 @@ def configure( if isinstance(optimizer, DistGaloreAwamW) and zero_stage > 0 and self.dp_size > 0: self.logger.warning( - "Galore is only supported for Tensor Parallel and vanilla Data Parallel yet. Disabling ZeRO." + "Galore is only supported for Tensor Parallel and vanilla Data Parallel yet. Disabling ZeRO.", + ranks=[0], ) zero_config["partition_grad"] = False zero_stage = 0 @@ -1296,7 +1299,8 @@ def configure( if self.dp_size == 1: self.logger.warning( "Use Zero Optimizer when data parallel size is 1 may introduce unnecessary overhead. " - "If you do not intend to use cpu_offload, please consider set zero_stage=0." + "If you do not intend to use cpu_offload, please consider set zero_stage=0.", + ranks=[0], ) assert self.precision != "fp32", "Please set precision to 'fp16' or 'bf16' when using ZeRO." @@ -1339,7 +1343,7 @@ def execute_pipeline( assert self.enable_pipeline_parallelism, "pipeline parallelism is not enabled" if return_outputs: - self.logger.warning("return_outputs may lead to significant extra memory consumption.") + self.logger.warning("return_outputs may lead to significant extra memory consumption.", ranks=[0]) # Create a context for gradient synchronization based on the optimizer type. # If it's a HybridParallelZeroOptimizer, use optimizer.no_sync(); otherwise, use model.no_sync(). @@ -1454,7 +1458,7 @@ def enable_lora( assert not isinstance(model, HybridParallelModule), "Lora should be enabled before boosting the model." assert self.pp_size == 1 and self.tp_size == 1 self.lora_enabled = True - self.logger.warning("You have enabled LoRa training. Please check the hyperparameters such as lr") + self.logger.warning("You have enabled LoRa training. Please check the hyperparameters such as lr", ranks=[0]) if bnb_quantization_config is not None: model = quantize_model(model, bnb_quantization_config) diff --git a/colossalai/booster/plugin/low_level_zero_plugin.py b/colossalai/booster/plugin/low_level_zero_plugin.py index 6c36bad3c214..a350b67d0123 100644 --- a/colossalai/booster/plugin/low_level_zero_plugin.py +++ b/colossalai/booster/plugin/low_level_zero_plugin.py @@ -138,7 +138,7 @@ def save_sharded_optimizer( """ assert isinstance(optimizer, LowLevelZeroOptimizer), "Please boost the optimizer before saving!" if os.path.isfile(checkpoint): - self.logger.error(f"Provided path ({checkpoint}) should be a directory, not a file") + self.logger.error(f"Provided path ({checkpoint}) should be a directory, not a file", ranks=[0]) return Path(checkpoint).mkdir(parents=True, exist_ok=True) @@ -178,7 +178,8 @@ def save_sharded_optimizer( self.logger.info( f"The optimizer is going to be split to checkpoint shards. " f"You can find where each parameters has been saved in the " - f"index located at {save_index_file}." + f"index located at {save_index_file}.", + ranks=[0], ) def load_sharded_optimizer(self, optimizer: OptimizerWrapper, index_file_path: str, prefix: str): @@ -265,7 +266,7 @@ def save_sharded_model( def save_lora_as_pretrained(self, model, checkpoint, use_safetensors): if os.path.isfile(checkpoint): - self.logger.error(f"Provided path ({checkpoint}) should be a directory, not a file") + self.logger.error(f"Provided path ({checkpoint}) should be a directory, not a file", ranks=[0]) return from peft import PeftModel @@ -396,7 +397,7 @@ def enable_lora( assert not isinstance(model, LowLevelZeroModel), "Lora should be enabled before boosting the model." self.lora_enabled = True - self.logger.warning("You have enabled LoRa training. Please check the hyperparameters such as lr") + self.logger.warning("You have enabled LoRa training. Please check the hyperparameters such as lr", ranks=[0]) if bnb_quantization_config is not None: model = quantize_model(model, bnb_quantization_config) @@ -446,7 +447,8 @@ def add_lora_params_to_optimizer(self, model, optimizer): group_id, check_state = self.get_param_group_id(optimizer, origin_param, param) if check_state == OptimizerParamCheckState.ORIGIN_PARAM_NOT_FIND: self.logger.warning( - f"Origin parameter {origin_key} related to {name} doesn't exist in optimizer param_groups." + f"Origin parameter {origin_key} related to {name} doesn't exist in optimizer param_groups.", + ranks=[0], ) elif ( check_state == OptimizerParamCheckState.ORIGIN_PARAM_FINDED @@ -487,7 +489,8 @@ def configure( if isinstance(optimizer, DistGaloreAwamW) and zero_stage > 0 and dp_size > 0: self.logger.warning( - "Galore is only supported for Tensor Parallel and vanilla Data Parallel yet. Disabling ZeRO." + "Galore is only supported for Tensor Parallel and vanilla Data Parallel yet. Disabling ZeRO.", + ranks=[0], ) zero_optim_kwargs["partition_grad"] = False zero_stage = 0 diff --git a/colossalai/booster/plugin/moe_hybrid_parallel_plugin.py b/colossalai/booster/plugin/moe_hybrid_parallel_plugin.py index 874028f09b86..36973b240896 100644 --- a/colossalai/booster/plugin/moe_hybrid_parallel_plugin.py +++ b/colossalai/booster/plugin/moe_hybrid_parallel_plugin.py @@ -221,7 +221,8 @@ def __init__( zero_stage = 1 self.logger.warning( f"overlap_communication and zero_stage are set to False and 1 because " - f"ZeRO-2 or comm overlap cause program hang when some experts are not routed. " + f"ZeRO-2 or comm overlap cause program hang when some experts are not routed.", + ranks=[0], ) assert ( @@ -240,7 +241,9 @@ def __init__( ), f"Sequence parallelism mode {self.sequence_parallelism_mode} must be enabled when using tensor parallelism" if sp_size != 1: self.logger.warning( - f"The sp_size will be the same as tp_size in sequence parallelism mode {self.sequence_parallelism_mode}, will ignore the given sequence parallelism size." + f"The sp_size will be the same as tp_size in sequence parallelism mode {self.sequence_parallelism_mode}," + "will ignore the given sequence parallelism size.", + ranks=[0], ) self.sp_size = 1 self.dp_size = dist.get_world_size() // (tp_size * pp_size) @@ -402,7 +405,8 @@ def configure( ) if use_ddp: self.logger.warning( - f"Will have to check all params are used in pytorch DDP since not all experts are always activated" + f"Will have to check all params are used in pytorch DDP since not all experts are always activated", + ranks=[0], ) self.ddp_config["find_unused_parameters"] = True @@ -460,7 +464,8 @@ def configure( if self.dp_size <= 1: self.logger.warning( "Use Zero Optimizer when data parallel size is 1 may introduce unnecessary overhead. " - "If you do not intend to use cpu_offload, please consider set zero_stage=0." + "If you do not intend to use cpu_offload, please consider set zero_stage=0.", + ranks=[0], ) assert self.precision != "fp32", "Please set precision to 'fp16' or 'bf16' when using ZeRO." optimizer = MoeHybridParallelZeroOptimizer( From fdf9473de66f7b60ae4e707202710a16f997f361 Mon Sep 17 00:00:00 2001 From: Edenzzzz Date: Mon, 19 Aug 2024 08:03:29 +0000 Subject: [PATCH 4/5] fix typo --- colossalai/booster/plugin/low_level_zero_plugin.py | 1 - colossalai/booster/plugin/torch_ddp_plugin.py | 2 ++ 2 files changed, 2 insertions(+), 1 deletion(-) diff --git a/colossalai/booster/plugin/low_level_zero_plugin.py b/colossalai/booster/plugin/low_level_zero_plugin.py index a350b67d0123..cc3755e6f065 100644 --- a/colossalai/booster/plugin/low_level_zero_plugin.py +++ b/colossalai/booster/plugin/low_level_zero_plugin.py @@ -112,7 +112,6 @@ def save_unsharded_optimizer(self, optimizer: OptimizerWrapper, checkpoint: str, state_dict = optimizer.state_dict() if self.coordinator.is_master(): save_state_dict(state_dict, checkpoint, use_safetensors=False) - self.logger = get_dist_logger() def save_sharded_optimizer( self, diff --git a/colossalai/booster/plugin/torch_ddp_plugin.py b/colossalai/booster/plugin/torch_ddp_plugin.py index 5116446a4295..8a807970ced2 100644 --- a/colossalai/booster/plugin/torch_ddp_plugin.py +++ b/colossalai/booster/plugin/torch_ddp_plugin.py @@ -9,6 +9,7 @@ from colossalai.checkpoint_io import CheckpointIO, GeneralCheckpointIO from colossalai.cluster import DistCoordinator from colossalai.interface import ModelWrapper, OptimizerWrapper +from colossalai.logging import get_dist_logger from colossalai.quantization import BnbQuantizationConfig, quantize_model from colossalai.utils import get_current_device @@ -21,6 +22,7 @@ class TorchDDPCheckpointIO(GeneralCheckpointIO): def __init__(self) -> None: super().__init__() self.coordinator = DistCoordinator() + self.logger = get_dist_logger() def load_unsharded_model(self, model: ModelWrapper, checkpoint: str, strict: bool = True): """ From f756f681a3362e9d2d7060108d2d2e77de391abe Mon Sep 17 00:00:00 2001 From: Edenzzzz Date: Mon, 19 Aug 2024 08:08:18 +0000 Subject: [PATCH 5/5] fixes --- colossalai/booster/booster.py | 18 +++++++++++++----- colossalai/zero/gemini/gemini_optimizer.py | 12 +++++++----- 2 files changed, 20 insertions(+), 10 deletions(-) diff --git a/colossalai/booster/booster.py b/colossalai/booster/booster.py index 56d8a0935f10..8047d90f7a69 100644 --- a/colossalai/booster/booster.py +++ b/colossalai/booster/booster.py @@ -1,4 +1,3 @@ -import warnings from contextlib import contextmanager from typing import Any, Callable, Dict, Iterator, List, Optional, Union @@ -8,6 +7,8 @@ from torch.optim.lr_scheduler import _LRScheduler as LRScheduler from torch.utils.data import DataLoader +from colossalai.logging import get_dist_logger + SUPPORT_PEFT = False try: import peft @@ -81,12 +82,15 @@ def __init__( plugin, Plugin ), f"Expected the argument plugin to be an instance of Plugin, but got {type(plugin)}." self.plugin = plugin + self.logger = get_dist_logger() # set accelerator if self.plugin and self.plugin.control_device(): self.accelerator = None if device is not None: - warnings.warn("The plugin will control the accelerator, so the device argument will be ignored.") + self.logger.warning( + "The plugin will control the accelerator," "so the device argument will be ignored.", ranks=[0] + ) else: device = device or "cuda" self.accelerator = Accelerator(device) @@ -94,7 +98,10 @@ def __init__( # set precision if self.plugin and self.plugin.control_precision(): if mixed_precision is not None: - warnings.warn("The plugin will control the precision, so the mixed_precision argument will be ignored.") + self.logger.warning( + "The plugin will control the precision," "so the mixed_precision argument will be ignored.", + ranks=[0], + ) self.mixed_precision = None elif mixed_precision is None: self.mixed_precision = None @@ -267,8 +274,9 @@ def enable_lora( ), "Please provide pretrained directory path if not passing in lora configuration." if quantize is True: if bnb_quantization_config is not None: - warnings.warn( - "User defined BnbQuantizationConfig is not fully tested in ColossalAI. Use it at your own risk." + self.logger.warning( + "User defined BnbQuantizationConfig is not fully tested in ColossalAI. Use it at your own risk.", + ranks=[0], ) else: bnb_quantization_config = BnbQuantizationConfig( diff --git a/colossalai/zero/gemini/gemini_optimizer.py b/colossalai/zero/gemini/gemini_optimizer.py index 1d755c417b48..fdf2a497626f 100644 --- a/colossalai/zero/gemini/gemini_optimizer.py +++ b/colossalai/zero/gemini/gemini_optimizer.py @@ -1,7 +1,6 @@ # this code is inspired by the DeepSpeed library and implemented with our own design from scratch import copy import math -import warnings from typing import Any, Dict, Iterator, OrderedDict, Set, Tuple, Union import torch @@ -136,7 +135,7 @@ def __init__( self.tp_rank = dist.get_rank(tp_group) if tp_group is not None else 0 self.verbose = verbose self.param_groups_backup = list() - + self.logger = get_dist_logger() # Mapping from integer id to real/fake param tensor, used for checkpointing. self.id_to_real_params: Dict[int, Parameter] = dict() self.id_to_fake_params: Dict[int, Parameter] = dict() @@ -148,9 +147,10 @@ def __init__( for name, param in module.named_parameters(): if is_ddp_ignored(param): if param.requires_grad: - warnings.warn( + self.logger.warning( f"Parameter `{name}` is ignored by DDP but requires gradient! " - "You should handle its optimizer update by yourself!" + "You should handle its optimizer update by yourself!", + ranks=[0], ) else: ddp_param_list.append(param) @@ -842,7 +842,9 @@ def clip_grad_by_norm( *args, **kwargs, ) -> torch.Tensor: - warnings.warn(f"Gemini controls grad clipping by itself, so you should not use clip_grad_by_norm") + self.logger.warning( + f"Gemini controls grad clipping by itself, so you should not use clip_grad_by_norm", ranks=[0] + ) class GeminiAdamOptimizer(GeminiOptimizer):