diff --git a/colossalai/booster/plugin/hybrid_parallel_plugin.py b/colossalai/booster/plugin/hybrid_parallel_plugin.py index 3fbeebcc4110..d65bd437962e 100644 --- a/colossalai/booster/plugin/hybrid_parallel_plugin.py +++ b/colossalai/booster/plugin/hybrid_parallel_plugin.py @@ -22,6 +22,7 @@ from colossalai.pipeline.schedule import OneForwardOneBackwardSchedule from colossalai.pipeline.stage_manager import PipelineStageManager from colossalai.shardformer import ShardConfig, ShardFormer +from colossalai.shardformer.policies.base_policy import Policy from colossalai.zero.low_level import LowLevelZeroOptimizer from .pp_plugin_base import PipelinePluginBase @@ -38,13 +39,15 @@ def _convert_floating_point(x, dtype: torch.dtype = torch.float16): class HybridParallelModule(ModelWrapper): def __init__(self, module: Module, precision: str, shard_config: ShardConfig, dp_group: ProcessGroup, use_ddp: bool, - ddp_config: dict) -> None: + ddp_config: dict, custom_policy: Policy) -> None: self.stage_manager = shard_config.pipeline_stage_manager self.dp_group = dp_group shardformer = ShardFormer(shard_config) - module, self.shared_params = shardformer.optimize(module) + if custom_policy is not None: + assert isinstance(custom_policy, object) + module, self.shared_params = shardformer.optimize(module, policy=custom_policy) # setting process groups for shared parameters self.shared_param_process_groups = [] @@ -302,7 +305,8 @@ def __init__(self, zero_bucket_size_in_m: int = 12, cpu_offload: bool = False, communication_dtype: Optional[torch.dtype] = None, - overlap_communication: bool = True) -> None: + overlap_communication: bool = True, + custom_policy: Policy = None) -> None: super().__init__() assert dist.get_world_size() % ( @@ -326,6 +330,7 @@ def __init__(self, self.pg_mesh = ProcessGroupMesh(self.dp_size, self.pp_size, self.tp_size) self.stage_manager = None self.schedule = None + self.custom_policy = custom_policy assert zero_stage in (0, 1, 2) if self.pp_size > 1: assert num_microbatches is not None or microbatch_size is not None, 'num_microbatches or microbatch_size must be specified when using pipeline parallelism' @@ -405,7 +410,7 @@ def configure( if not isinstance(model, ModelWrapper): use_ddp = self.dp_size > 1 and self.pp_size == 1 and self.zero_stage == 0 model = HybridParallelModule(model, self.precision, self.shard_config, self.dp_group, use_ddp, - self.ddp_config) + self.ddp_config, self.custom_policy) if optimizer is not None and not isinstance(optimizer, OptimizerWrapper): if self.zero_stage == 0: if self.precision in ['fp16', 'bf16']: diff --git a/colossalai/booster/plugin/moe_hybrid_parallel_plugin.py b/colossalai/booster/plugin/moe_hybrid_parallel_plugin.py new file mode 100644 index 000000000000..fab6c2f0cb7b --- /dev/null +++ b/colossalai/booster/plugin/moe_hybrid_parallel_plugin.py @@ -0,0 +1,173 @@ +from typing import Optional + +import torch +import torch.distributed as dist + +from colossalai.booster.plugin.hybrid_parallel_plugin import HybridParallelPlugin +from colossalai.cluster import ProcessGroupMesh +from colossalai.pipeline.schedule import OneForwardOneBackwardSchedule +from colossalai.pipeline.stage_manager import PipelineStageManager +from colossalai.shardformer import ShardConfig +from colossalai.shardformer.policies.base_policy import Policy + +PP_AXIS, DP_AXIS, TP_AXIS = 0, 1, 2 + + +class MoeHybridParallelPlugin(HybridParallelPlugin): + """ + Plugin for Moe Hybrid Parallel Training. + Tensor parallel, pipeline parallel and data parallel(DDP/ZeRO) can be picked and combined in this plugin. + The size of tp and pp should be passed in by user, then the size of dp is automatically calculated from dp_size = world_size / (tp_size * pp_size). + + Example: + >>> from colossalai.booster import Booster + >>> from colossalai.booster.plugin import HybridParallelPlugin + + >>> model, train_dataset, optimizer, criterion = ... + >>> plugin = HybridParallelPlugin(tp_size=2, pp_size=2) + + >>> train_dataloader = plugin.prepare_dataloader(train_dataset, batch_size=8) + >>> booster = Booster(plugin=plugin) + >>> model, optimizer, criterion, train_dataloader, _ = booster.boost(model, optimizer, criterion, train_dataloader) + + Args: + tp_size (int): The size of tensor parallelism. Tensor parallelism will not be used when tp_size is set to 1. + pp_size (int): The number of pipeline stages in pipeline parallelism. Pipeline parallelism will not be used when pp_size is set to 1. + precision (str, optional): Specifies the precision of parameters during training. + Auto-mixied precision will be used when this argument is set to 'fp16' or 'bf16', otherwise model is trained with 'fp32'. + Defaults to 'fp16'. + zero_stage (int, optional): The stage of ZeRO for data parallelism. Can only be choosed from [0, 1, 2]. + When set to 0, ZeRO will not be used. Defaults to 0. + enable_all_optimization (bool, optional): Whether to switch on all the optimizations supported by Shardformer. + Currently all the optimization methods include fused normalization, flash attention and JIT. + Defaults to False. + enable_fused_normalization (bool, optional): Whether to switch on fused normalization in Shardformer. Defaults to False. + enable_flash_attention (bool, optional): Whether to switch on flash attention in Shardformer. Defaults to False. + enable_jit_fused (bool, optional): Whether to switch on JIT in Shardformer. Default to False. + enable_sequence_parallelism (bool): Whether to turn on sequence parallelism in Shardformer. Defaults to False. + enable_sequence_overlap (bool): Whether to turn on sequence overlap in Shardformer. Defaults to False. + num_microbatches (int, optional): Number of microbatches when using pipeline parallelism. Defaults to None. + microbatch_size (int, optional): Microbatch size when using pipeline parallelism. + Either ``num_microbatches`` or ``microbatch_size`` should be provided if using pipeline. + If ``num_microbatches`` is provided, this will be ignored. Defaults to None. + initial_scale (float, optional): The initial loss scale of AMP. Defaults to 2**16. + min_scale (float, optional): The minimum loss scale of AMP. Defaults to 1. + growth_factor (float, optional): The multiplication factor for increasing loss scale when using AMP. Defaults to 2. + backoff_factor (float, optional): The multiplication factor for decreasing loss scale when using AMP. Defaults to 0.5. + growth_interval (int, optional): The number of steps to increase loss scale when no overflow occurs when using AMP. Defaults to 1000. + hysteresis (int, optional): The number of overflows before decreasing loss scale when using AMP. Defaults to 2. + max_scale (float, optional): The maximum loss scale of AMP. Defaults to 2**32. + max_norm (float, optional): Maximum norm for gradient clipping. Defaults to 0. + broadcast_buffers (bool, optional): Whether to broadcast buffers in the beginning of training when using DDP. Defaults to True. + ddp_bucket_cap_mb (int, optional): The bucket size in MB when using DDP. Defaults to 25. + find_unused_parameters (bool, optional): Whether to find unused parameters when using DDP. Defaults to False. + check_reduction (bool, optional): Whether to check reduction when using DDP. Defaults to False. + gradient_as_bucket_view (bool, optional): Whether to use gradient as bucket view when using DDP. Defaults to False. + static_graph (bool, optional): Whether to use static graph when using DDP. Defaults to False. + zero_bucket_size_in_m (int, optional): Gradient reduce bucket size in million elements when using ZeRO. Defaults to 12. + cpu_offload (bool, optional): Whether to open cpu_offload when using ZeRO. Defaults to False. + communication_dtype (torch.dtype, optional): Communication dtype when using ZeRO. If not specified, the dtype of param will be used. Defaults to None. + overlap_communication (bool, optional): Whether to overlap communication and computation when using ZeRO. Defaults to True. + """ + + def __init__(self, + tp_size: int, + pp_size: int, + precision: str = 'fp16', + zero_stage: int = 0, + enable_all_optimization: bool = False, + enable_fused_normalization: bool = False, + enable_flash_attention: bool = False, + enable_jit_fused: bool = False, + enable_sequence_parallelism: bool = False, + enable_sequence_overlap: bool = False, + num_microbatches: Optional[int] = None, + microbatch_size: Optional[int] = None, + initial_scale: float = 2**16, + min_scale: float = 1, + growth_factor: float = 2, + backoff_factor: float = 0.5, + growth_interval: int = 1000, + hysteresis: int = 2, + max_scale: float = 2**32, + max_norm: float = 0, + broadcast_buffers: bool = True, + ddp_bucket_cap_mb: int = 25, + find_unused_parameters: bool = False, + check_reduction: bool = False, + gradient_as_bucket_view: bool = False, + static_graph: bool = False, + zero_bucket_size_in_m: int = 12, + cpu_offload: bool = False, + communication_dtype: Optional[torch.dtype] = None, + overlap_communication: bool = True, + custom_policy: Policy = None) -> None: + + super().__init__() + assert dist.get_world_size() % ( + tp_size * pp_size + ) == 0, f'world size {dist.get_world_size()} is not divisible by tp_size {tp_size} * pp_size {pp_size}' + + if enable_sequence_parallelism: + assert tp_size > 1, 'Sequence parallelism must be enabled when using tensor parallelism' + + self.tp_size = tp_size + self.pp_size = pp_size + self.dp_size = dist.get_world_size() // (tp_size * pp_size) + self.precision = precision + self.zero_stage = zero_stage + self.cpu_offload = cpu_offload + self.enable_all_optimization = enable_all_optimization + self.enable_fused_normalization = enable_fused_normalization + self.enable_flash_attention = enable_flash_attention + self.enable_jit_fused = enable_jit_fused + self.enable_sequence_parallelism = enable_sequence_parallelism + # we change pg mesh to (pp, dp, tp) for better moe performance + self.pg_mesh = ProcessGroupMesh(self.pp_size, self.dp_size, self.tp_size) + self.stage_manager = None + self.schedule = None + self.custom_policy = custom_policy + assert zero_stage in (0, 1, 2) + if self.pp_size > 1: + assert num_microbatches is not None or microbatch_size is not None, 'num_microbatches or microbatch_size must be specified when using pipeline parallelism' + assert self.zero_stage <= 1, 'zero stage must be 0 or 1 when using pipeline parallelism' + self.stage_manager = PipelineStageManager(self.pg_mesh, PP_AXIS) + self.schedule = OneForwardOneBackwardSchedule(self.stage_manager, + num_microbatches=num_microbatches, + microbatch_size=microbatch_size) + self.tp_group = self.pg_mesh.get_group_along_axis(TP_AXIS) + self.dp_group = self.pg_mesh.get_group_along_axis(DP_AXIS) + self.pp_group = self.pg_mesh.get_group_along_axis(PP_AXIS) + self.shard_config = ShardConfig(tensor_parallel_process_group=self.tp_group, + pipeline_stage_manager=self.stage_manager, + enable_tensor_parallelism=self.tp_size > 1, + enable_all_optimization=self.enable_all_optimization, + enable_fused_normalization=self.enable_fused_normalization, + enable_flash_attention=self.enable_flash_attention, + enable_jit_fused=self.enable_jit_fused, + enable_sequence_parallelism=enable_sequence_parallelism, + enable_sequence_overlap=enable_sequence_overlap) + self.amp_config = dict( + initial_scale=initial_scale, + growth_factor=growth_factor, + backoff_factor=backoff_factor, + growth_interval=growth_interval, + hysteresis=hysteresis, + min_scale=min_scale, + max_scale=max_scale, + ) + + self.ddp_config = dict(broadcast_buffers=broadcast_buffers, + bucket_cap_mb=ddp_bucket_cap_mb, + find_unused_parameters=find_unused_parameters, + check_reduction=check_reduction, + gradient_as_bucket_view=gradient_as_bucket_view, + static_graph=static_graph) + + self.zero_config = dict(reduce_bucket_size=zero_bucket_size_in_m * 1024 * 1024, + communication_dtype=communication_dtype, + overlap_communication=overlap_communication, + cpu_offload=cpu_offload, + partition_grad=(self.zero_stage == 2)) + + self.max_norm = max_norm diff --git a/colossalai/moe/manager.py b/colossalai/moe/manager.py index 3dc27c6cb0f0..e61fb0bf9582 100644 --- a/colossalai/moe/manager.py +++ b/colossalai/moe/manager.py @@ -24,7 +24,9 @@ def __init__(self): self.router_z_loss = [] self.parallel = None self.seed = None - self.use_kernel_optim = True + self.mode = None + self.use_kernel_optim = False + self.use_ep_inside = None self.has_setup = False self._parallel_info_dict = dict() @@ -37,15 +39,53 @@ def parallel_info_dict(self): def is_initialized(self): return self.has_setup - def setup(self, seed: int, use_kernel_optim: bool = True, max_ep_size: int = 8, parallel: bool = None): + def setup(self, + seed: int, + use_kernel_optim: bool = True, + parallel: bool = None, + mode: str = "dynamic", + max_ep_size: int = 8, + fixed_dp_size: int = 0, + fixed_ep_size: int = 0, + fixed_pp_size: int = 0, + use_ep_inside: bool = True) -> None: + """ + Setup MoE distributed context. + + Args: + seed (int): Random seed. Defaults to 42. + use_kernel_optim (bool, optional): Use cuda kernel. Defaults to True. + parallel (bool, optional): Parallel mode, should be EP, TP or None. Defaults to None. + mode (str, optional): Should be "fixed" or "dynamic". Defaults to "dynamic". + In fixed mode, the ep size and dp size is fixed. + In dynamic mode, the ep size and dp size will be changed according to num experts. + max_ep_size (int, optional): Max ep size in dynamic mode. Defaults to 8. + fixed_dp_size (int, optional): Fixed dp size in fixed mode. Defaults to 0. + fixed_ep_size (int, optional): Fixed ep size in fixed mode. Defaults to 0. + fixed_pp_size (int, optional): Fixed pp size in fixed mode. Defaults to 0. + use_ep_inside (bool, optional): Use ep inside dp if True, dp inside ep if Fasle. Defaults to True. + """ assert not self.is_initialized, "MoE distributed context shouldn't be set up again" assert torch.cuda.is_available(), "MoE requires to enable CUDA first" self.world_size = dist.get_world_size() self.seed = seed + dist.get_rank() - self.max_ep_size = min(max_ep_size, dist.get_world_size()) - self.min_dp_size = self.world_size // self.max_ep_size self.parallel = parallel + self.use_ep_inside = use_ep_inside + + # init by mode + self.mode = mode + assert self.mode in ["fixed", "dynamic"], "mode should be fixed or dynamic" + if self.mode == "dynamic": + self.max_ep_size = min(max_ep_size, dist.get_world_size()) + self.min_dp_size = self.world_size // self.max_ep_size + else: + assert fixed_dp_size > 0 and fixed_ep_size > 0 and fixed_pp_size > 0, "dp_size, ep_size and pp_size should be greater than 0" + assert isinstance(fixed_dp_size, int) and isinstance(fixed_ep_size, int) and isinstance( + fixed_pp_size, int), "dp_size, ep_size and pp_size should be int" + self.ep_size = fixed_ep_size + self.dp_size = fixed_dp_size + self.pp_size = fixed_pp_size # Enabling kernel optimization may raise error in some cases # Users can close kernel optimization manually @@ -67,30 +107,39 @@ def get_info(self, num_experts: int, use_tp: bool = False) -> Tuple[int, MoePara number of local experts, the MoeParallelInfo of the current ep_size """ - gt_flag = num_experts % self.max_ep_size == 0 # check whether num_experts is greater - lt_flag = self.max_ep_size % num_experts == 0 # check whether num_experts is less - - assert gt_flag or lt_flag, "Automatic experts placement dose not not support expert number" \ - " is not a multiple of ep size or vice versa." - - # If the number of experts is greater than maximum expert parallel size. a.k.a ep_size, - # there are multiple experts in each GPU and each GPU has different experts - # So it's data parallel size is 1 - # Otherwise, there is only one expert in each GPU - # The data parallel size should be calculated - dp_size = 1 if gt_flag else self.max_ep_size // num_experts - ep_size = self.max_ep_size // dp_size + if self.mode == "dynamic": + gt_flag = num_experts % self.max_ep_size == 0 # check whether num_experts is greater + lt_flag = self.max_ep_size % num_experts == 0 # check whether num_experts is less + + assert gt_flag or lt_flag, "Automatic experts placement dose not not support expert number" \ + " is not a multiple of ep size or vice versa." + + # If the number of experts is greater than maximum expert parallel size. a.k.a ep_size, + # there are multiple experts in each GPU and each GPU has different experts + # So it's data parallel size is 1 + # Otherwise, there is only one expert in each GPU + # The data parallel size should be calculated + dp_size = 1 if gt_flag else self.max_ep_size // num_experts + ep_size = self.max_ep_size // dp_size + # Don't forget to multiply minimum data parallel size + dp_size *= self.min_dp_size + pp_size = 1 + else: + dp_size = self.dp_size + ep_size = self.ep_size + pp_size = self.pp_size # Calculate the number of experts for each GPU if use_tp: num_local_experts = num_experts else: - num_local_experts = 1 if lt_flag else num_experts // self.max_ep_size + if self.mode == "dynamic": + num_local_experts = 1 if lt_flag else num_experts // self.max_ep_size + else: + num_local_experts = num_experts // ep_size - # Don't forget to multiply minimum data parallel size - dp_size *= self.min_dp_size if not (ep_size in self.parallel_info_dict): - self.parallel_info_dict[ep_size] = get_moe_info(ep_size, dp_size) + self.parallel_info_dict[ep_size] = get_moe_info(ep_size, dp_size, pp_size, ep_inside=self.use_ep_inside) return num_local_experts, self.parallel_info_dict[ep_size] diff --git a/colossalai/tensor/moe_tensor/api.py b/colossalai/tensor/moe_tensor/api.py index 442b3c0f4958..9120a40b8533 100644 --- a/colossalai/tensor/moe_tensor/api.py +++ b/colossalai/tensor/moe_tensor/api.py @@ -28,20 +28,23 @@ def set_moe_tensor_info(tensor: torch.Tensor, moe_info: MoeParallelInfo) -> None moe_info (dict): The moe info to be set. """ - tensor.__setattr__('moe_info', moe_info) + tensor.__setattr__("moe_info", moe_info) -def get_moe_info(ep_size: int, dp_size: int) -> MoeParallelInfo: +def get_moe_info(ep_size: int, dp_size: int, pp_size: int, ep_inside: bool) -> MoeParallelInfo: """ Get moe info for the given tensor. Args: - tensor (torch.Tensor): The tensor to be checked. + ep_size (int): The expert parallel size. + dp_size (int): The data parallel size. + pp_size (int): The pipeline parallel size. + ep_inside (bool, optional): Use ep inside dp if True, dp inside ep if Fasle. Returns: dict: The moe info of the given tensor. """ - return MoeParallelInfo(ep_size, dp_size) + return MoeParallelInfo(ep_inside, ep_size, dp_size, pp_size) def get_ep_group(tensor: torch.Tensor) -> ProcessGroup: diff --git a/colossalai/tensor/moe_tensor/moe_info.py b/colossalai/tensor/moe_tensor/moe_info.py index ca7f163b9c24..5097ac1044e7 100644 --- a/colossalai/tensor/moe_tensor/moe_info.py +++ b/colossalai/tensor/moe_tensor/moe_info.py @@ -2,15 +2,26 @@ class MoeParallelInfo: - """Moe parallelism information, storing parallel sizes and groups. - """ + """Moe parallelism information, storing parallel sizes and groups.""" + + def __init__(self, ep_inside: bool, ep_size: int, dp_size: int, pp_size: int = 1): + """ + init MoeParallelInfo with ep_size, dp_size and pp_size + + Args: + ep_size (int): expert parallel size + dp_size (int): data parallel (zero) size + pp_size (int, optional): pipeline parallel size. Defaults to 1. + ep_inside (bool, optional): Use ep inside dp if True, dp inside ep if Fasle. Defaults to True. + """ + self.pp_size, self.dp_size, self.ep_size = pp_size, dp_size, ep_size + if ep_inside: + self.pp_axis, self.dp_axis, self.ep_axis = 0, 1, 2 + self.pg = ProcessGroupMesh(self.pp_size, self.dp_size, self.ep_size) + else: + self.pp_axis, self.ep_axis, self.dp_axis = 0, 1, 2 + self.pg = ProcessGroupMesh(self.pp_size, self.ep_size, self.dp_size) - def __init__(self, ep_size: int, dp_size: int): - self.dp_axis = 0 - self.dp_size = dp_size - self.ep_axis = 1 - self.ep_size = ep_size - self.pg = ProcessGroupMesh(self.dp_size, self.ep_size) self.ep_group = self.pg.get_group_along_axis(self.ep_axis) self.ep_group_ranks = self.pg.get_ranks_in_group(self.ep_group) self.dp_group = self.pg.get_group_along_axis(self.dp_axis) diff --git a/examples/language/openmoe/model/__init__.py b/examples/language/openmoe/model/__init__.py new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/examples/language/openmoe/model/modeling_openmoe.py b/examples/language/openmoe/model/modeling_openmoe.py index 6ccbf64a60e4..90d3e0022ce4 100644 --- a/examples/language/openmoe/model/modeling_openmoe.py +++ b/examples/language/openmoe/model/modeling_openmoe.py @@ -145,87 +145,6 @@ def apply_rotary_embedding(q, k, cos, sin, decode=False, rotary_index=None): return out_q, out_k -class LlamaRotaryEmbedding(torch.nn.Module): - - def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None): - super().__init__() - - self.dim = dim - self.max_position_embeddings = max_position_embeddings - self.base = base - inv_freq = 1.0 / (self.base**(torch.arange(0, self.dim, 2).float().to(device) / self.dim)) - self.inv_freq = inv_freq - - # Build here to make `torch.jit.trace` work. - self._set_cos_sin_cache(seq_len=max_position_embeddings, - device=self.inv_freq.device, - dtype=torch.get_default_dtype()) - - def _set_cos_sin_cache(self, seq_len, device, dtype): - self.max_seq_len_cached = seq_len - t = torch.arange(self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype) - - freqs = torch.einsum("i,j->ij", t, self.inv_freq) - # Different from paper, but it uses a different permutation in order to obtain the same calculation - emb = torch.cat((freqs, freqs), dim=-1) - self.register_buffer("cos_cached", emb.cos()[None, None, :, :].to(dtype), persistent=False) - self.register_buffer("sin_cached", emb.sin()[None, None, :, :].to(dtype), persistent=False) - - def forward(self, x, seq_len=None): - # x: [bs, num_attention_heads, seq_len, head_size] - if seq_len > self.max_seq_len_cached: - self._set_cos_sin_cache(seq_len=seq_len, device=x.device, dtype=x.dtype) - - return ( - self.cos_cached[:, :, :seq_len, ...].to(dtype=x.dtype), - self.sin_cached[:, :, :seq_len, ...].to(dtype=x.dtype), - ) - - -class LlamaLinearScalingRotaryEmbedding(LlamaRotaryEmbedding): - """LlamaRotaryEmbedding extended with linear scaling. Credits to the Reddit user /u/kaiokendev""" - - def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None, scaling_factor=1.0): - self.scaling_factor = scaling_factor - super().__init__(dim, max_position_embeddings, base, device) - - def _set_cos_sin_cache(self, seq_len, device, dtype): - self.max_seq_len_cached = seq_len - t = torch.arange(self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype) - t = t / self.scaling_factor - - freqs = torch.einsum("i,j->ij", t, self.inv_freq) - # Different from paper, but it uses a different permutation in order to obtain the same calculation - emb = torch.cat((freqs, freqs), dim=-1) - self.register_buffer("cos_cached", emb.cos()[None, None, :, :].to(dtype), persistent=False) - self.register_buffer("sin_cached", emb.sin()[None, None, :, :].to(dtype), persistent=False) - - -class LlamaDynamicNTKScalingRotaryEmbedding(LlamaRotaryEmbedding): - """LlamaRotaryEmbedding extended with Dynamic NTK scaling. Credits to the Reddit users /u/bloc97 and /u/emozilla""" - - def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None, scaling_factor=1.0): - self.scaling_factor = scaling_factor - super().__init__(dim, max_position_embeddings, base, device) - - def _set_cos_sin_cache(self, seq_len, device, dtype): - self.max_seq_len_cached = seq_len - - if seq_len > self.max_position_embeddings: - base = self.base * ((self.scaling_factor * seq_len / self.max_position_embeddings) - - (self.scaling_factor - 1))**(self.dim / (self.dim - 2)) - inv_freq = 1.0 / (base**(torch.arange(0, self.dim, 2).float().to(device) / self.dim)) - self.register_buffer("inv_freq", inv_freq) - - t = torch.arange(self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype) - - freqs = torch.einsum("i,j->ij", t, self.inv_freq) - # Different from paper, but it uses a different permutation in order to obtain the same calculation - emb = torch.cat((freqs, freqs), dim=-1) - self.register_buffer("cos_cached", emb.cos()[None, None, :, :].to(dtype), persistent=False) - self.register_buffer("sin_cached", emb.sin()[None, None, :, :].to(dtype), persistent=False) - - def rotate_half(x): """Rotates half the hidden dims of the input.""" x1 = x[..., :x.shape[-1] // 2] @@ -233,17 +152,6 @@ def rotate_half(x): return torch.cat((-x2, x1), dim=-1) -def apply_rotary_pos_emb(q, k, cos, sin, position_ids): - # The first two dimensions of cos and sin are always 1, so we can `squeeze` them. - cos = cos.squeeze(1).squeeze(0) # [seq_len, dim] - sin = sin.squeeze(1).squeeze(0) # [seq_len, dim] - cos = cos[position_ids].unsqueeze(1) # [bs, 1, seq_len, dim] - sin = sin[position_ids].unsqueeze(1) # [bs, 1, seq_len, dim] - q_embed = (q * cos) + (rotate_half(q) * sin) - k_embed = (k * cos) + (rotate_half(k) * sin) - return q_embed, k_embed - - def SwiGLU(x): """Gated linear unit activation function. Args: @@ -256,7 +164,7 @@ def SwiGLU(x): return x1 * (x2 * torch.sigmoid(x2)) -class LlamaMLP(nn.Module): +class OpenMoeMLP(nn.Module): def __init__(self, config): super().__init__() @@ -267,6 +175,7 @@ def __init__(self, config): self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False) self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False) self.act_fn = SwiGLU + self.use_kernel = True if MOE_MANAGER.use_kernel_optim else False def forward(self, x): if self.pretraining_tp > 1: @@ -282,7 +191,7 @@ def forward(self, x): down_proj = [F.linear(intermediate_states[i], down_proj_slices[i]) for i in range(self.pretraining_tp)] down_proj = sum(down_proj) else: - if HAS_TRITON: + if HAS_TRITON and self.use_kernel: down_proj = self.down_proj(LlamaActCombine.apply(self.gate_proj(x), self.up_proj(x))) else: down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x)) @@ -302,7 +211,7 @@ def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim) -class LlamaAttention(nn.Module): +class OpenMoeAttention(nn.Module): """Multi-headed attention from 'Attention Is All You Need' paper""" def __init__(self, config: LlamaConfig): @@ -321,22 +230,6 @@ def __init__(self, config: LlamaConfig): self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=False) self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=False) self.sin, self.cos = generate_fixed_pos_embedding(self.head_dim, self.max_position_embeddings, 1e4) - self._init_rope() - - def _init_rope(self): - if self.config.rope_scaling is None: - self.rotary_emb = LlamaRotaryEmbedding(self.head_dim, max_position_embeddings=self.max_position_embeddings) - else: - scaling_type = self.config.rope_scaling["type"] - scaling_factor = self.config.rope_scaling["factor"] - if scaling_type == "linear": - self.rotary_emb = LlamaLinearScalingRotaryEmbedding( - self.head_dim, max_position_embeddings=self.max_position_embeddings, scaling_factor=scaling_factor) - elif scaling_type == "dynamic": - self.rotary_emb = LlamaDynamicNTKScalingRotaryEmbedding( - self.head_dim, max_position_embeddings=self.max_position_embeddings, scaling_factor=scaling_factor) - else: - raise ValueError(f"Unknown RoPE scaling type {scaling_type}") def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int): return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous() @@ -446,13 +339,13 @@ def forward( return attn_output, attn_weights, past_key_value -class LlamaDecoderLayer(nn.Module): +class OpenMoeDecoderLayer(nn.Module): def __init__(self, config: LlamaConfig, moe: bool): super().__init__() self.hidden_size = config.hidden_size self.moe = moe - self.self_attn = LlamaAttention(config=config) + self.self_attn = OpenMoeAttention(config=config) self.input_layernorm = T5LayerNorm(config.hidden_size, eps=config.rms_norm_eps) self.post_attention_layernorm = T5LayerNorm(config.hidden_size, eps=config.rms_norm_eps) if self.moe: @@ -470,9 +363,9 @@ def __init__(self, config: LlamaConfig, moe: bool): activation=config.hidden_act, gated=config.gated) self.pre_extra_mlp_layernorm = T5LayerNorm(config.hidden_size, eps=config.rms_norm_eps) - self.extra_mlp = LlamaMLP(config) + self.extra_mlp = OpenMoeMLP(config) else: - self.mlp = LlamaMLP(config) + self.mlp = OpenMoeMLP(config) def forward( self, @@ -556,7 +449,7 @@ def forward( "The bare LLaMA Model outputting raw hidden-states without any specific head on top.", LLAMA_START_DOCSTRING, ) -class LlamaPreTrainedModel(PreTrainedModel): +class OpenMoePreTrainedModel(PreTrainedModel): config_class = LlamaConfig base_model_prefix = "model" supports_gradient_checkpointing = True @@ -575,7 +468,7 @@ def _init_weights(self, module): module.weight.data[module.padding_idx].zero_() def _set_gradient_checkpointing(self, module, value=False): - if isinstance(module, LlamaModel): + if isinstance(module, OpenMoeModel): module.gradient_checkpointing = value @@ -647,7 +540,7 @@ def _set_gradient_checkpointing(self, module, value=False): "The bare LLaMA Model outputting raw hidden-states without any specific head on top.", LLAMA_START_DOCSTRING, ) -class LlamaModel(LlamaPreTrainedModel): +class OpenMoeModel(OpenMoePreTrainedModel): """ Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`LlamaDecoderLayer`] @@ -662,7 +555,7 @@ def __init__(self, config: LlamaConfig): self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx) self.layers = nn.ModuleList([ - LlamaDecoderLayer(config, moe=True if (i + 1) % config.moe_layer_interval == 0 else False) + OpenMoeDecoderLayer(config, moe=True if (i + 1) % config.moe_layer_interval == 0 else False) for i in range(config.num_hidden_layers) ]) self.norm = T5LayerNorm(config.hidden_size, eps=config.rms_norm_eps) @@ -827,12 +720,12 @@ def custom_forward(*inputs): ) -class OpenMoeForCausalLM(LlamaPreTrainedModel): +class OpenMoeForCausalLM(OpenMoePreTrainedModel): # _tied_weights_keys = ["lm_head.weight"] def __init__(self, config): super().__init__(config) - self.model = LlamaModel(config) + self.model = OpenMoeModel(config) self.pretraining_tp = config.pretraining_tp self.vocab_size = config.vocab_size self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) @@ -1022,17 +915,15 @@ def _reorder_cache(past_key_values, beam_idx): past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past),) return reordered_past - def _calculate_router_loss(self): - aux_loss, z_loss = MOE_MANAGER.get_loss() + def _calculate_router_loss(self, aux_loss: list = None, z_loss: list = None): + if aux_loss is None or z_loss is None: + aux_loss, z_loss = MOE_MANAGER.get_loss() assert len(aux_loss) == len(z_loss) == self.config.num_hidden_layers // self.config.moe_layer_interval aux_loss = self.config.router_aux_loss_factor * sum(aux_loss) / len(aux_loss) z_loss = self.config.router_z_loss_factor * sum(z_loss) / len(z_loss) return aux_loss, z_loss - def _calculate_loss(self, - logits: torch.Tensor, - targets: torch.Tensor - ) -> torch.Tensor: + def _calculate_loss(self, logits: torch.Tensor, targets: torch.Tensor) -> torch.Tensor: """Compute cross entropy and entropy for log probs and targets. Args: diff --git a/examples/language/openmoe/model/openmoe_policy.py b/examples/language/openmoe/model/openmoe_policy.py new file mode 100644 index 000000000000..cc82683cd319 --- /dev/null +++ b/examples/language/openmoe/model/openmoe_policy.py @@ -0,0 +1,545 @@ +import warnings +from functools import partial +from typing import Callable, Dict, List, Optional, Union + +import torch +import torch.nn as nn +import torch.nn.functional as F +from torch import Tensor +from torch.nn import Module +from transformers.modeling_outputs import CausalLMOutputWithPast +from transformers.utils import logging + +from colossalai.moe.manager import MOE_MANAGER +from colossalai.pipeline.stage_manager import PipelineStageManager +from colossalai.shardformer.layer import FusedRMSNorm, Linear1D_Col +from colossalai.shardformer.policies.base_policy import ModulePolicyDescription, Policy, SubModuleReplacementDescription + +from .modeling_openmoe import OpenMoeDecoderLayer, OpenMoeForCausalLM, OpenMoeModel + +__all__ = ["OpenMoePolicy", "OpenMoeForCausalLMPolicy"] + + +class OpenMoePolicy(Policy): + + def config_sanity_check(self): + pass + + def preprocess(self): + if self.shard_config.enable_tensor_parallelism: + # Resize embedding + vocab_size = self.model.config.vocab_size + world_size = self.shard_config.tensor_parallel_size + + if vocab_size % world_size != 0: + new_vocab_size = vocab_size + world_size - vocab_size % world_size + self.model.resize_token_embeddings(new_vocab_size) + + return self.model + + def module_policy(self) -> Dict[Union[str, nn.Module], ModulePolicyDescription]: + policy = {} + + if self.shard_config.enable_sequence_parallelism: + self.shard_config.enable_sequence_parallelism = False + raise NotImplementedError( + "openmoe dosen't support sequence parallelism now, will ignore the sequence parallelism flag.") + + if self.shard_config.enable_tensor_parallelism: + raise NotImplementedError("Tensor parallelism is not supported for openmoe model now.") + + # optimization configuration + if self.shard_config.enable_fused_normalization: + self.append_or_create_submodule_replacement( + description=[ + SubModuleReplacementDescription( + suffix="input_layernorm", + target_module=FusedRMSNorm, + ), + SubModuleReplacementDescription( + suffix="post_attention_layernorm", + target_module=FusedRMSNorm, + ), + SubModuleReplacementDescription( + suffix="pre_extra_mlp_layernorm", + target_module=FusedRMSNorm, + ignore_if_not_exist=True, + ), + ], + policy=policy, + target_key=OpenMoeDecoderLayer, + ) + + self.append_or_create_submodule_replacement( + description=SubModuleReplacementDescription( + suffix="norm", + target_module=FusedRMSNorm, + ), + policy=policy, + target_key=OpenMoeModel, + ) + + if self.shard_config.enable_flash_attention: + raise NotImplementedError("Flash attention has already been replaced in openmoe.") + + return policy + + def postprocess(self): + return self.model + + def set_pipeline_forward(self, model_cls: nn.Module, new_forward: Callable, policy: Dict) -> None: + """If under pipeline parallel setting, replacing the original forward method of huggingface + to customized forward method, and add this changing to policy.""" + if self.pipeline_stage_manager: + stage_manager = self.pipeline_stage_manager + if self.model.__class__.__name__ == "OpenMoeModel": + module = self.model + else: + module = self.model.model + + layers_per_stage = Policy.distribute_layers(len(module.layers), stage_manager.num_stages) + stage_index = Policy.get_stage_index(layers_per_stage, stage_manager.stage) + method_replacement = {"forward": partial(new_forward, stage_manager=stage_manager, stage_index=stage_index)} + self.append_or_create_method_replacement(description=method_replacement, + policy=policy, + target_key=model_cls) + + return + + def get_held_layers(self) -> List[Module]: + """Get pipeline layers for current stage.""" + assert self.pipeline_stage_manager is not None + + if self.model.__class__.__name__ == "LlamaModel": + module = self.model + else: + module = self.model.model + stage_manager = self.pipeline_stage_manager + + held_layers = [] + layers_per_stage = self.distribute_layers(len(module.layers), stage_manager.num_stages) + if stage_manager.is_first_stage(): + held_layers.append(module.embed_tokens) + start_idx, end_idx = self.get_stage_index(layers_per_stage, stage_manager.stage) + held_layers.extend(module.layers[start_idx:end_idx]) + if stage_manager.is_last_stage(): + held_layers.append(module.norm) + + return held_layers + + +class OpenMoeModelPolicy(OpenMoePolicy): + + def __init__(self) -> None: + super().__init__() + + def module_policy(self): + policy = super().module_policy() + if self.pipeline_stage_manager: + # set None as default + self.set_pipeline_forward( + model_cls=OpenMoeModel, + new_forward=OpenMoePipelineForwards.openmoe_model_forward, + policy=policy, + ) + return policy + + def get_held_layers(self) -> List[Module]: + """Get pipeline layers for current stage.""" + held_layers = super().get_held_layers() + return held_layers + + def get_shared_params(self) -> List[Dict[int, Tensor]]: + """No shared params in llama model""" + return [] + + +class OpenMoeForCausalLMPolicy(OpenMoePolicy): + + def module_policy(self): + policy = super().module_policy() + + if self.shard_config.enable_tensor_parallelism: + # add a new item for casual lm + new_item = { + OpenMoeForCausalLM: + ModulePolicyDescription(sub_module_replacement=[ + SubModuleReplacementDescription( + suffix="lm_head", + target_module=Linear1D_Col, + kwargs=dict(gather_output=True), + ) + ]) + } + policy.update(new_item) + + if self.pipeline_stage_manager: + # set None as default + self.set_pipeline_forward( + model_cls=OpenMoeForCausalLM, + new_forward=OpenMoePipelineForwards.llama_for_causal_lm_forward, + policy=policy, + ) + + return policy + + def get_held_layers(self) -> List[Module]: + """Get pipeline layers for current stage.""" + stage_manager = self.pipeline_stage_manager + held_layers = super().get_held_layers() + if stage_manager.is_last_stage(): + held_layers.append(self.model.lm_head) + return held_layers + + def get_shared_params(self) -> List[Dict[int, Tensor]]: + llama_model = self.model.model + if self.pipeline_stage_manager and self.pipeline_stage_manager.num_stages > 1: + if (id(llama_model.embed_tokens.weight) == id(self.model.lm_head.weight) + and self.pipeline_stage_manager.num_stages > 1): + # tie weights + return [{ + 0: llama_model.embed_tokens.weight, + self.pipeline_stage_manager.num_stages - 1: self.model.lm_head.weight, + }] + return [] + + +class OpenMoePipelineForwards: + """ + This class serves as a micro library for forward function substitution of Llama models + under pipeline setting. + """ + + @staticmethod + def openmoe_model_forward( + self: OpenMoeModel, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[List[torch.FloatTensor]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + stage_manager: Optional[PipelineStageManager] = None, + hidden_states: Optional[torch.FloatTensor] = None, + stage_index: Optional[List[int]] = None, + past_router_aux_loss: Optional[torch.FloatTensor] = None, + past_router_z_loss: Optional[torch.FloatTensor] = None, + ): + # reset moe loss for different data + MOE_MANAGER.reset_loss() + + logger = logging.get_logger(__name__) + + output_attentions = (output_attentions if output_attentions is not None else self.config.output_attentions) + output_hidden_states = (output_hidden_states + if output_hidden_states is not None else self.config.output_hidden_states) + use_cache = use_cache if use_cache is not None else self.config.use_cache + + return_dict = (return_dict if return_dict is not None else self.config.use_return_dict) + + # retrieve input_ids and inputs_embeds + if stage_manager.is_first_stage(): + if input_ids is not None and inputs_embeds is not None: + raise ValueError("You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time") + elif input_ids is not None: + batch_size, seq_length = input_ids.shape + elif inputs_embeds is not None: + batch_size, seq_length, _ = inputs_embeds.shape + else: + raise ValueError("You have to specify either decoder_input_ids or decoder_inputs_embeds") + device = input_ids.device if input_ids is not None else inputs_embeds.device + if inputs_embeds is None: + inputs_embeds = self.embed_tokens(input_ids) + hidden_states = inputs_embeds + else: + input_shape = hidden_states.shape[:-1] + batch_size, seq_length = input_shape + device = hidden_states.device + + seq_length_with_past = seq_length + past_key_values_length = 0 + + # TODO(jianghai): left the recording kv-value tensors as () or None type, this feature may be added in the future. + if output_attentions: + logger.warning_once("output_attentions=True is not supported for pipeline models at the moment.") + output_attentions = False + if output_hidden_states: + logger.warning_once("output_hidden_states=True is not supported for pipeline models at the moment.") + output_hidden_states = False + if use_cache: + logger.warning_once("use_cache=True is not supported for pipeline models at the moment.") + use_cache = False + + if past_key_values is not None: + past_key_values_length = past_key_values[0][0].shape[2] + seq_length_with_past = seq_length_with_past + past_key_values_length + + if position_ids is None: + position_ids = torch.arange( + past_key_values_length, + seq_length + past_key_values_length, + dtype=torch.long, + device=device, + ) + position_ids = position_ids.unsqueeze(0).view(-1, seq_length) + else: + position_ids = position_ids.view(-1, seq_length).long() + + # embed positions, for the first stage, hidden_states is the input embeddings, + # for the other stages, hidden_states is the output of the previous stage + if attention_mask is None: + attention_mask = torch.ones( + (batch_size, seq_length_with_past), + dtype=torch.bool, + device=hidden_states.device, + ) + attention_mask = self._prepare_decoder_attention_mask( + attention_mask, + (batch_size, seq_length), + hidden_states, + past_key_values_length, + ) + + if self.gradient_checkpointing and self.training: + if use_cache: + logger.warning_once( + "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`...") + use_cache = False + + # decoder layers + all_hidden_states = () if output_hidden_states else None + all_self_attns = () if output_attentions else None + next_decoder_cache = () if use_cache else None + + start_idx, end_idx = stage_index[0], stage_index[1] + for idx, decoder_layer in enumerate(self.layers[start_idx:end_idx], start=start_idx): + if output_hidden_states: + all_hidden_states += (hidden_states,) + + past_key_value = (past_key_values[idx] if past_key_values is not None else None) + + if self.gradient_checkpointing and self.training: + + def create_custom_forward(module): + + def custom_forward(*inputs): + # None for past_key_value + return module(*inputs, output_attentions, None) + + return custom_forward + + layer_outputs = torch.utils.checkpoint.checkpoint( + create_custom_forward(decoder_layer), + hidden_states, + attention_mask, + position_ids, + None, + ) + else: + layer_outputs = decoder_layer( + hidden_states, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_value=past_key_value, + output_attentions=output_attentions, + use_cache=use_cache, + ) + + hidden_states = layer_outputs[0] + + if use_cache: + next_decoder_cache += (layer_outputs[2 if output_attentions else 1],) + if output_attentions: + all_self_attns += (layer_outputs[1],) + + if stage_manager.is_last_stage(): + hidden_states = self.norm(hidden_states) + + # add hidden states from the last decoder layer + if output_hidden_states: + all_hidden_states += (hidden_states,) + next_cache = next_decoder_cache if use_cache else None + + # concat past losses with current ones + router_aux_loss, router_z_loss = MOE_MANAGER.get_loss() + if past_router_aux_loss is not None and past_router_z_loss is not None: + router_aux_loss = past_router_aux_loss + router_aux_loss + router_z_loss = past_router_z_loss + router_z_loss + + if stage_manager.is_last_stage(): + return tuple([ + hidden_states, + next_cache, + all_hidden_states, + all_self_attns, + router_aux_loss, + router_z_loss, + ]) + # always return dict for imediate stage + return { + "hidden_states": hidden_states, + "router_aux_loss": router_aux_loss, + "router_z_loss": router_z_loss, + } + + @staticmethod + def llama_for_causal_lm_forward( + self: OpenMoeForCausalLM, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[List[torch.FloatTensor]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + stage_manager: Optional[PipelineStageManager] = None, + hidden_states: Optional[torch.FloatTensor] = None, + stage_index: Optional[List[int]] = None, + chunk_head: Optional[bool] = None, + past_router_aux_loss: Optional[torch.FloatTensor] = None, + past_router_z_loss: Optional[torch.FloatTensor] = None, + ): + r""" + Args: + labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., + config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored + (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. + + Returns: + + Example: + + ```python + >>> from transformers import AutoTokenizer, LlamaForCausalLM + + >>> model = LlamaForCausalLM.from_pretrained(PATH_TO_CONVERTED_WEIGHTS) + >>> tokenizer = AutoTokenizer.from_pretrained(PATH_TO_CONVERTED_TOKENIZER) + + >>> prompt = "Hey, are you consciours? Can you talk to me?" + >>> inputs = tokenizer(prompt, return_tensors="pt") + + >>> # Generate + >>> generate_ids = model.generate(inputs.input_ids, max_length=30) + >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0] + "Hey, are you consciours? Can you talk to me?\nI'm not consciours, but I can talk to you." + ```""" + logger = logging.get_logger(__name__) + output_attentions = (output_attentions if output_attentions is not None else self.config.output_attentions) + output_hidden_states = (output_hidden_states + if output_hidden_states is not None else self.config.output_hidden_states) + return_dict = (return_dict if return_dict is not None else self.config.use_return_dict) + + # TODO(jianghai): left the recording kv-value tensors as () or None type, this feature may be added in the future. + if output_attentions: + logger.warning_once("output_attentions=True is not supported for pipeline models at the moment.") + output_attentions = False + if output_hidden_states: + logger.warning_once("output_hidden_states=True is not supported for pipeline models at the moment.") + output_hidden_states = False + + # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) + outputs = OpenMoePipelineForwards.openmoe_model_forward( + self.model, + input_ids=input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + stage_manager=stage_manager, + hidden_states=hidden_states, + stage_index=stage_index, + past_router_aux_loss=past_router_aux_loss, + past_router_z_loss=past_router_z_loss, + ) + + if stage_manager.is_last_stage(): + ( + hidden_states, + past_key_values, + all_hidden_states, + attentions, + router_aux_loss, + router_z_loss, + ) = outputs + + if self.pretraining_tp > 1: + lm_head_slices = self.lm_head.weight.split(self.vocab_size // self.pretraining_tp, dim=0) + logits = [F.linear(hidden_states, lm_head_slices[i]) for i in range(self.pretraining_tp)] + logits = torch.cat(logits, dim=-1) + + loss = None + # if no training, just do forward + if labels is None: + logits = self.lm_head(hidden_states) + logits = logits.float() + # the vocab size for openmoe is 30w+ + # which causes great activation memory in training, up to 20G for one sequence + # so we use chunk and checkpoint to reduce memory + else: + if chunk_head == True: + + def create_custom_forward(module): + + def custom_forward(*inputs): + logits = module(inputs[0]) + logits = logits.float() + # Shift so that tokens < n predict n + shift_logits = logits[..., :-1, :].contiguous().float() + shift_labels = inputs[1][..., 1:].contiguous() + # Flatten the tokens + loss = self._calculate_loss(shift_logits, shift_labels) + return loss + + return custom_forward + + aux_loss, z_loss = self._calculate_router_loss(router_aux_loss, router_z_loss) + loss = aux_loss + z_loss + for batch_idx in range(hidden_states.shape[0]): + loss = loss + torch.utils.checkpoint.checkpoint( + create_custom_forward(self.lm_head), + hidden_states[batch_idx:batch_idx + 1, :], + labels[batch_idx:batch_idx + 1, :], + ) + logits = None + else: + logits = self.lm_head(hidden_states) + logits = logits.float() + # Shift so that tokens < n predict n + shift_logits = logits[..., :-1, :].contiguous() + shift_labels = labels[..., 1:].contiguous() + # Flatten the tokens + aux_loss, z_loss = self._calculate_router_loss(router_aux_loss, router_z_loss) + loss = aux_loss + z_loss + loss = loss + self._calculate_loss(shift_logits, shift_labels) + + if not return_dict: + output = (logits,) + outputs[1:] + return (loss,) + output if loss is not None else output + + return CausalLMOutputWithPast( + loss=loss, + logits=logits, + past_key_values=past_key_values, + hidden_states=all_hidden_states, + attentions=attentions, + ) + else: + hidden_states = outputs["hidden_states"] + router_aux_loss = outputs["router_aux_loss"] + router_z_loss = outputs["router_z_loss"] + return { + "hidden_states": hidden_states, + "past_router_aux_loss": router_aux_loss, + "past_router_z_loss": router_z_loss, + } diff --git a/examples/language/openmoe/test_ci.sh b/examples/language/openmoe/test_ci.sh index 75eee902c747..e69de29bb2d1 100644 --- a/examples/language/openmoe/test_ci.sh +++ b/examples/language/openmoe/test_ci.sh @@ -1,5 +0,0 @@ -set -xe -pip install -r requirements.txt - -python infer.py --model "test" -torchrun --standalone --nproc_per_node 2 train.py --model_name "test" --batch_size 1 --num_epoch 20 diff --git a/examples/language/openmoe/train.py b/examples/language/openmoe/train.py index 132f17a9ba0f..2099bbde91f5 100644 --- a/examples/language/openmoe/train.py +++ b/examples/language/openmoe/train.py @@ -5,6 +5,7 @@ import transformers from huggingface_hub import snapshot_download from model.modeling_openmoe import OpenMoeForCausalLM +from model.openmoe_policy import OpenMoeForCausalLMPolicy from torch.utils.data import Dataset from tqdm import tqdm from transformers import Adafactor, T5Tokenizer @@ -14,6 +15,7 @@ from colossalai import get_default_parser from colossalai.booster import Booster from colossalai.booster.plugin import LowLevelZeroPlugin +from colossalai.booster.plugin.moe_hybrid_parallel_plugin import MoeHybridParallelPlugin from colossalai.cluster import DistCoordinator from colossalai.logging import disable_existing_loggers, get_dist_logger from colossalai.moe import MoeCheckpintIO @@ -52,31 +54,67 @@ def __len__(self): def __getitem__(self, idx): return { - 'input_ids': self.input_ids[idx], - 'attention_mask': self.attention_mask[idx], - 'labels': self.input_ids[idx] + "input_ids": self.input_ids[idx], + "attention_mask": self.attention_mask[idx], + "labels": self.input_ids[idx], } def parse_args(): + # basic settings parser = get_default_parser() - parser.add_argument("--model_name", - type=str, - default="base", - help="Path to pretrained model or model identifier from huggingface.co/models.") - parser.add_argument("--output_path", - type=str, - default="./output_model.bin", - help="The path of your saved model after finetuning.") + parser.add_argument( + "--model_name", + type=str, + default="base", + help="Path to pretrained model or model identifier from huggingface.co/models.", + ) + parser.add_argument( + "--output_path", + type=str, + default="./output_model.bin", + help="The path of your saved model after finetuning.", + ) parser.add_argument("--num_epoch", type=int, default=10, help="Number of epochs.") - parser.add_argument("--batch_size", - type=int, - default=4, - help="Batch size (per dp group) for the training dataloader.") + parser.add_argument( + "--batch_size", + type=int, + default=4, + help="Batch size (per dp group) for the training dataloader.", + ) parser.add_argument("--seed", type=int, default=42, help="A seed for reproducible training.") + parser.add_argument( + "--plugin", + type=str, + default="hybrid", + help="parallel plugin", + choices=["zero1", "zero2", "hybrid"], + ) + # hybrid plugin + parser.add_argument("--pp_size", type=int, default=2, help="pp size") + parser.add_argument("--dp_size", type=int, default=1, help="dp size") + parser.add_argument("--ep_size", type=int, default=2, help="ep size") + parser.add_argument("--zero_stage", type=int, default=1, help="zero stage in hybrid plugin") + parser.add_argument("--microbatch_size", type=int, default=1, help="microbatch size") + # kernel + parser.add_argument( + "--use_kernel", + action="store_true", + help="Use kernel optim. Need to install flash attention, apex, triton to enable all kernel optimizations.", + ) # loss - parser.add_argument("--router_aux_loss_factor", type=float, default=0.01, help="router_aux_loss_factor.") - parser.add_argument("--router_z_loss_factor", type=float, default=0.0001, help="router_z_loss_factor.") + parser.add_argument( + "--router_aux_loss_factor", + type=float, + default=0.01, + help="router_aux_loss_factor.", + ) + parser.add_argument( + "--router_z_loss_factor", + type=float, + default=0.0001, + help="router_z_loss_factor.", + ) parser.add_argument("--label_smoothing", type=float, default=0.0, help="label_smoothing.") parser.add_argument("--z_loss_factor", type=float, default=0.0001, help="z_loss_factor.") # optim @@ -95,7 +133,24 @@ def main(): coordinator = DistCoordinator() # Set up moe - MOE_MANAGER.setup(seed=42, parallel="EP") + if args.plugin in ["zero1", "zero2"]: + MOE_MANAGER.setup( + seed=42, + parallel="EP", + use_kernel_optim=False if args.model_name == "test" else args.use_kernel, + ) + elif args.plugin == "hybrid": + assert (args.dp_size * args.ep_size * + args.pp_size == coordinator.world_size), "dp_size * ep_size * pp_size must equal to world_size" + MOE_MANAGER.setup( + seed=42, + parallel="EP", + mode="fixed", + fixed_dp_size=args.dp_size, + fixed_ep_size=args.ep_size, + fixed_pp_size=args.pp_size, + use_kernel_optim=False if args.model_name == "test" else args.use_kernel, + ) # Manage loggers disable_existing_loggers() @@ -129,12 +184,27 @@ def main(): # Set plugin booster_kwargs = {} - plugin = LowLevelZeroPlugin(initial_scale=2**5, stage=2) + if args.plugin == "zero1": + plugin = LowLevelZeroPlugin(initial_scale=2**5, stage=1) + elif args.plugin == "zero2": + plugin = LowLevelZeroPlugin(initial_scale=2**5, stage=2) + elif args.plugin == "hybrid": + plugin = MoeHybridParallelPlugin( + tp_size=1, + pp_size=args.pp_size, + zero_stage=args.zero_stage, + microbatch_size=args.microbatch_size, + custom_policy=OpenMoeForCausalLMPolicy(), + enable_fused_normalization=args.use_kernel, + enable_jit_fused=args.use_kernel, + ) + else: + raise ValueError(f"Invalid plugin {args.plugin}") logger.info(f"Set plugin as {plugin}", ranks=[0]) # Prepare tokenizer and dataloader tokenizer = T5Tokenizer.from_pretrained("google/umt5-small") - dataset = RandomDataset(num_samples=1000 if args.model_name != "test" else 1) + dataset = RandomDataset(num_samples=1000 if args.model_name != "test" else 50) dataloader = plugin.prepare_dataloader(dataset, batch_size=args.batch_size, shuffle=True, drop_last=True) # Set optimizer @@ -143,27 +213,47 @@ def main(): # Set booster booster = Booster(plugin=plugin, **booster_kwargs) model, optimizer, _, dataloader, _ = booster.boost(model=model, optimizer=optimizer, dataloader=dataloader) + use_pipeline = (isinstance(booster.plugin, MoeHybridParallelPlugin) and booster.plugin.pp_size > 1) + is_pp_last_stage = use_pipeline and booster.plugin.stage_manager.is_last_stage() logger.info(f"Finish init booster", ranks=[0]) # Start finetuning logger.info(f"Start finetuning", ranks=[0]) for epoch in range(args.num_epoch): model.train() - with tqdm(dataloader, desc=f'Epoch [{epoch + 1}]', disable=not coordinator.is_master()) as pbar: - for batch in pbar: - # Forward - optimizer.zero_grad() - batch = move_to_cuda(batch, torch.cuda.current_device()) - - outputs = model(use_cache=False, chunk_head=True, **batch) - loss = outputs['loss'] + train_dataloader_iter = iter(dataloader) + total_len = len(train_dataloader_iter) + with tqdm( + range(total_len), + desc=f"Epoch [{epoch + 1}/{args.num_epoch}]", + disable=not coordinator.is_master(), + ) as pbar: + # Forward pass + for _ in pbar: + if use_pipeline: + outputs = booster.execute_pipeline( + train_dataloader_iter, + model, + lambda x, y: x.loss, + optimizer, + return_loss=True, + return_outputs=True, + ) + # Backward and optimize + if is_pp_last_stage: + loss = outputs["loss"] + pbar.set_postfix({"loss": loss.item()}) + else: + data = next(train_dataloader_iter) + data = move_to_cuda(data, torch.cuda.current_device()) + outputs = model(**data) + loss = outputs["loss"] + # Backward + booster.backward(loss, optimizer) + pbar.set_postfix({"loss": loss.item()}) - # Backward - booster.backward(loss, optimizer) optimizer.step() - - # Print batch loss - pbar.set_postfix({'loss': loss.item()}) + optimizer.zero_grad() # Finish training and evaluate logger.info(f"Finish finetuning", ranks=[0]) diff --git a/examples/language/openmoe/train.sh b/examples/language/openmoe/train.sh index 9a55779ca5ef..6712aa10a88b 100644 --- a/examples/language/openmoe/train.sh +++ b/examples/language/openmoe/train.sh @@ -1,3 +1,9 @@ -torchrun --standalone --nproc_per_node 2 train.py \ +torchrun --standalone --nproc_per_node 4 train.py \ --model_name "base" \ + --plugin "hybrid" \ + --pp_size 2 \ + --dp_size 1 \ + --ep_size 2 \ + --use_kernel \ + --zero_stage 1 \ --batch_size 4