diff --git a/colossalai/booster/plugin/moe_hybrid_parallel_plugin.py b/colossalai/booster/plugin/moe_hybrid_parallel_plugin.py index 784204528d65..5171780da347 100644 --- a/colossalai/booster/plugin/moe_hybrid_parallel_plugin.py +++ b/colossalai/booster/plugin/moe_hybrid_parallel_plugin.py @@ -1,23 +1,71 @@ import random -from typing import Optional +from typing import Callable, Optional, OrderedDict, Tuple import numpy as np import torch import torch.distributed as dist +from torch.distributed import ProcessGroup +from torch.nn import Module +from torch.optim import Optimizer +from torch.optim.lr_scheduler import _LRScheduler as LRScheduler from torch.utils.data import DataLoader from torch.utils.data.distributed import DistributedSampler -from colossalai.booster.plugin.hybrid_parallel_plugin import HybridParallelPlugin +from colossalai.booster.plugin.hybrid_parallel_plugin import ( + HybridParallelAMPOptimizer, + HybridParallelModule, + HybridParallelNaiveOptimizer, + HybridParallelPlugin, + get_param_info, + init_pipeline_optimizer, +) from colossalai.cluster import ProcessGroupMesh +from colossalai.interface import ModelWrapper, OptimizerWrapper from colossalai.moe import MoeCheckpintIO from colossalai.pipeline.schedule import OneForwardOneBackwardSchedule from colossalai.pipeline.stage_manager import PipelineStageManager from colossalai.shardformer import ShardConfig from colossalai.shardformer.policies.base_policy import Policy +from colossalai.zero.low_level import LowLevelZeroOptimizer PP_AXIS, DP_AXIS, TP_AXIS = 0, 1, 2 +class HybridParallelZeroOptimizer(LowLevelZeroOptimizer): + + def __init__( + self, + optimizer: Optimizer, + model: Module, + use_pipeline: bool, + param_info: OrderedDict, + initial_scale: int = 2**16, # grad scaler config + min_scale: int = 1, + growth_factor: float = 2., + backoff_factor: float = .5, + growth_interval: int = 2000, + hysteresis: int = 2, + max_scale: int = 2**24, + clip_grad_norm: float = 0.0, # grad clipping + verbose: bool = False, + reduce_bucket_size: int = 1024 * 1024, # communication + communication_dtype: Optional[torch.dtype] = None, + overlap_communication: bool = True, + partition_grad: bool = False, # stage 2 flag + cpu_offload: bool = False, # cpu offload + dp_process_group: Optional[ProcessGroup] = None, # the dp pg for comm + tp_process_group: Optional[ProcessGroup] = None, # if using tp + forced_dtype: Optional[torch.dtype] = None, + extra_dp_process_group: Optional[ProcessGroup] = None): + self.param_info = param_info + if use_pipeline: + init_pipeline_optimizer(optimizer, model) + super().__init__(optimizer, initial_scale, min_scale, growth_factor, backoff_factor, growth_interval, + hysteresis, max_scale, clip_grad_norm, verbose, reduce_bucket_size, communication_dtype, + overlap_communication, partition_grad, cpu_offload, dp_process_group, tp_process_group, + forced_dtype, extra_dp_process_group) + + class MoeHybridParallelPlugin(HybridParallelPlugin): """ Plugin for Moe Hybrid Parallel Training. @@ -78,6 +126,7 @@ class MoeHybridParallelPlugin(HybridParallelPlugin): def __init__(self, tp_size: int, pp_size: int, + extra_dp_size: int = 1, precision: str = 'fp16', zero_stage: int = 0, enable_all_optimization: bool = False, @@ -106,6 +155,7 @@ def __init__(self, cpu_offload: bool = False, communication_dtype: Optional[torch.dtype] = None, overlap_communication: bool = True, + use_ep_inside: bool = True, custom_policy: Policy = None) -> None: super().__init__(tp_size=tp_size, @@ -132,6 +182,23 @@ def __init__(self, self.enable_sequence_parallelism = enable_sequence_parallelism # we change pg mesh to (pp, dp, tp) for better moe performance self.pg_mesh = ProcessGroupMesh(self.pp_size, self.dp_size, self.tp_size) + + # sync moe in outer dp group, and sync other param in global dp group + if extra_dp_size > 1: + ep_size = self.dp_size // extra_dp_size + if use_ep_inside: + self.pg_mesh_moe = ProcessGroupMesh(self.pp_size, extra_dp_size, ep_size) + self.extra_dp_group = self.pg_mesh_moe.get_group_along_axis(1) + if dist.get_rank() == 0: + print(f"Zero Parallel: pp {self.pp_size}, outer_dp {extra_dp_size}, inner_dp {ep_size}") + else: + self.pg_mesh_moe = ProcessGroupMesh(self.pp_size, ep_size, extra_dp_size) + self.extra_dp_group = self.pg_mesh_moe.get_group_along_axis(2) + if dist.get_rank() == 0: + print(f"Zero Parallel: pp {self.pp_size}, outer_dp {ep_size}, inner_dp {extra_dp_size}") + else: + self.extra_dp_group = None + self.stage_manager = None self.schedule = None self.custom_policy = custom_policy @@ -235,3 +302,52 @@ def seed_worker(worker_id): def get_checkpoint_io(self) -> MoeCheckpintIO: self.checkpoint_io = MoeCheckpintIO(self.dp_group, self.pp_group, self.tp_group, self.zero_stage) return self.checkpoint_io + + def configure( + self, + model: Module, + optimizer: Optional[Optimizer] = None, + criterion: Optional[Callable] = None, + dataloader: Optional[DataLoader] = None, + lr_scheduler: Optional[LRScheduler] = None, + ) -> Tuple[Module, OptimizerWrapper, Callable, DataLoader, LRScheduler]: + param_info = get_param_info(optimizer) + if not isinstance(model, ModelWrapper): + use_ddp = self.dp_size > 1 and self.pp_size == 1 and self.zero_stage == 0 + model = HybridParallelModule(model, self.precision, self.shard_config, self.dp_group, use_ddp, + self.ddp_config, self.custom_policy) + if optimizer is not None and not isinstance(optimizer, OptimizerWrapper): + if self.zero_stage == 0: + if self.precision in ['fp16', 'bf16']: + optimizer = HybridParallelAMPOptimizer(optimizer, + model, + use_pipeline=self.enable_pipeline_parallelism, + param_info=param_info, + precision=self.precision, + max_norm=self.max_norm, + **self.amp_config) + self.checkpoint_io.link_master_and_working_param(optimizer.working_to_master_map, + optimizer.master_to_working_map) + else: + optimizer = HybridParallelNaiveOptimizer(optimizer, + model, + use_pipeline=self.enable_pipeline_parallelism, + param_info=param_info) + else: + assert self.dp_size > 1, "Please use Zero when data parallel size is greater than 1." + assert self.precision != 'fp32', "Please set precision to 'fp16' or 'bf16' when using ZeRO." + optimizer = HybridParallelZeroOptimizer(optimizer, + model, + use_pipeline=self.enable_pipeline_parallelism, + param_info=param_info, + dp_process_group=self.dp_group, + tp_process_group=self.tp_group, + extra_dp_process_group=self.extra_dp_group, + verbose=True, + clip_grad_norm=self.max_norm, + **self.zero_config, + **self.amp_config) + self.checkpoint_io.link_master_and_working_param(optimizer._param_store.working_to_master_param, + optimizer._param_store.master_to_working_param) + + return model, optimizer, criterion, dataloader, lr_scheduler diff --git a/colossalai/moe/experts.py b/colossalai/moe/experts.py index e05ea59b3d28..81a7b21544e4 100644 --- a/colossalai/moe/experts.py +++ b/colossalai/moe/experts.py @@ -39,27 +39,28 @@ def __init__( activation: Optional[Callable] = None, drop_rate: float = 0, gated: bool = False, + use_kernel: bool = False, ): super().__init__() assert expert_parallel in ["EP", "TP", None] self.expert_parallel = expert_parallel self.num_total_experts = num_experts self.gated = gated + self.use_kernel = use_kernel + self.hidden_size = hidden_size + self.intermediate_size = intermediate_size # get expert parallel info if expert_parallel is not None: self.num_local_experts, self.moe_info = MOE_MANAGER.get_info( num_experts, use_tp=True if expert_parallel == "TP" else False) # get settings for different parallel + self.ep_size = get_ep_size(self) if expert_parallel == "TP": - assert ( - intermediate_size % - MOE_MANAGER.max_ep_size == 0), "intermediate_size should be divide by maximum expert parallel size" - intermediate_size = intermediate_size // MOE_MANAGER.max_ep_size + intermediate_size = intermediate_size // self.ep_size num_experts = self.num_total_experts else: num_experts = self.num_local_experts - self.ep_size = get_ep_size(self) else: self.num_local_experts = self.num_total_experts self.ep_size = 1 @@ -71,19 +72,6 @@ def __init__( self.wi = nn.Parameter(torch.empty(num_experts, hidden_size, intermediate_size)) self.wo = nn.Parameter(torch.empty(num_experts, intermediate_size, hidden_size)) - # expert param should be different - if expert_parallel is not None: - seed_ctx = Randomizer(MOE_MANAGER.seed).fork_rng(enable_cpu=True) - else: - seed_ctx = nullcontext() - with seed_ctx: - if gated: - torch.nn.init.trunc_normal_(self.wi_gate, std=math.sqrt(0.1 / hidden_size)) - torch.nn.init.trunc_normal_(self.wi_up, std=math.sqrt(0.1 / hidden_size)) - else: - torch.nn.init.trunc_normal_(self.wi, std=math.sqrt(0.1 / hidden_size)) - torch.nn.init.trunc_normal_(self.wo, std=math.sqrt(0.1 / intermediate_size)) - self.act_name = activation self.act = get_activation(activation) self.drop = nn.Dropout(p=drop_rate) @@ -92,6 +80,24 @@ def __init__( for param in self.parameters(): set_moe_tensor_info(param, self.moe_info) + # init param + self.reset_parameters() + + @torch.no_grad() + def reset_parameters(self): + # expert param should be different + if self.expert_parallel is not None: + seed_ctx = Randomizer(MOE_MANAGER.seed).fork_rng(enable_cpu=True) + else: + seed_ctx = Randomizer(42).fork_rng(enable_cpu=True) + with seed_ctx: + if self.gated: + torch.nn.init.normal_(self.wi_gate, std=math.sqrt(0.1 / self.hidden_size)) + torch.nn.init.normal_(self.wi_up, std=math.sqrt(0.1 / self.hidden_size)) + else: + torch.nn.init.normal_(self.wi, std=math.sqrt(0.1 / self.hidden_size)) + torch.nn.init.normal_(self.wo, std=math.sqrt(0.1 / self.intermediate_size)) + def forward(self, x: torch.Tensor, param_slice: Tuple[slice] = (slice(None),)) -> torch.Tensor: """ Args: @@ -110,7 +116,7 @@ def forward(self, x: torch.Tensor, param_slice: Tuple[slice] = (slice(None),)) - x = x.reshape(e, -1, h) if self.gated: - if HAS_TRITON and self.act_name == "swiglu": + if self.use_kernel and HAS_TRITON and self.act_name == "swiglu": x = LlamaActCombine.apply( torch.bmm(x, self.wi_gate[param_slice]), torch.bmm(x, self.wi_up[param_slice]), @@ -142,7 +148,9 @@ def __init__( activation=None, drop_rate: float = 0, gated: bool = False, + use_kernel: bool = False, ): + # TODO: This class can be aborted super().__init__( num_experts, hidden_size, @@ -151,6 +159,7 @@ def __init__( activation, drop_rate, gated, + use_kernel, ) @@ -168,7 +177,9 @@ def __init__( activation: str = None, drop_rate: float = 0, gated: bool = False, + use_kernel: bool = False, ): + # TODO: This class can be aborted super().__init__( num_experts, hidden_size, @@ -177,6 +188,7 @@ def __init__( activation, drop_rate, gated, + use_kernel, ) diff --git a/colossalai/moe/layers.py b/colossalai/moe/layers.py index c2cf627aceae..036bd32ae7c0 100644 --- a/colossalai/moe/layers.py +++ b/colossalai/moe/layers.py @@ -51,7 +51,6 @@ def __init__( min_capacity: int = 4, noisy_policy: Optional[str] = None, drop_tks: bool = True, - expert_parallel: str = "EP", hidden_size: int = 2048, intermediate_size: int = 2048, activation: str = None, @@ -59,14 +58,16 @@ def __init__( ): super().__init__() self.hidden_size = hidden_size + self.intermediate_size = intermediate_size self.num_experts = num_experts self.use_kernel = MOE_MANAGER.use_kernel_optim - self.expert_parallel = expert_parallel - assert expert_parallel in [ + self.expert_parallel = MOE_MANAGER.get_parallel() + self.gated = gated + assert self.expert_parallel in [ "EP", "TP", None, - ], f"Unsupported expert parallel type {expert_parallel}" + ], f"Unsupported expert parallel type {self.expert_parallel}" # moe router noisy_func = get_noise_generator(noisy_policy, num_experts) @@ -80,23 +81,29 @@ def __init__( ) # moe experts - expert_cls = get_expert_class(expert_parallel) - self.experts: BaseMLPExperts = expert_cls( - num_experts=num_experts, - hidden_size=hidden_size, - intermediate_size=intermediate_size, - activation=activation, - gated=gated, - ) - if expert_parallel is not None: + expert_cls = get_expert_class(self.expert_parallel) + self.experts: BaseMLPExperts = expert_cls(num_experts=num_experts, + hidden_size=hidden_size, + intermediate_size=intermediate_size, + activation=activation, + gated=gated, + use_kernel=self.use_kernel) + if self.expert_parallel is not None: self.ep_group = get_ep_group(self.experts) self.ep_size = get_ep_size(self.experts) else: self.ep_group = None self.num_local_experts = self.experts.num_local_experts + # gate self.gate_weight = torch.nn.Parameter(torch.empty(num_experts, self.hidden_size)) - nn.init.trunc_normal_(self.gate_weight, std=math.sqrt(0.1 / self.hidden_size)) + + # init param + self.reset_parameters() + + @torch.no_grad() + def reset_parameters(self): + torch.nn.init.normal_(self.gate_weight, std=math.sqrt(0.1 / self.hidden_size)) def forward(self, inputs: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: """ @@ -171,7 +178,7 @@ def _ep_process(self, dispatch_data: torch.Tensor) -> torch.Tensor: expert_output = AllToAll.apply(expert_output, self.ep_group) return expert_output - def _tp_process(self, dispatch_data: torch.Tensor) -> torch.Tensor: + def _tp_process(self, dispatch_data: torch.Tensor, use_overlap: bool = False) -> torch.Tensor: """ TP with overlap. @@ -191,6 +198,13 @@ def _tp_process(self, dispatch_data: torch.Tensor) -> torch.Tensor: Returns: torch.Tensor: (num_experts, capacity, hidden_size) """ + if use_overlap == False: + expert_in, _ = AllGather.apply(dispatch_data, self.ep_group) + expert_out = self.experts(expert_in) + expert_out, _ = ReduceScatter.apply(expert_out, self.ep_group) + return expert_out + + # TODO: there is accuracy problem in overlap chunk_num = 4 chunk_size = dispatch_data.shape[0] // chunk_num out = torch.empty_like(dispatch_data) diff --git a/colossalai/moe/manager.py b/colossalai/moe/manager.py index e61fb0bf9582..1e949bb9a6dd 100644 --- a/colossalai/moe/manager.py +++ b/colossalai/moe/manager.py @@ -41,8 +41,8 @@ def is_initialized(self): def setup(self, seed: int, - use_kernel_optim: bool = True, - parallel: bool = None, + use_kernel_optim: bool = False, + parallel: str = None, mode: str = "dynamic", max_ep_size: int = 8, fixed_dp_size: int = 0, @@ -140,6 +140,11 @@ def get_info(self, num_experts: int, use_tp: bool = False) -> Tuple[int, MoePara if not (ep_size in self.parallel_info_dict): self.parallel_info_dict[ep_size] = get_moe_info(ep_size, dp_size, pp_size, ep_inside=self.use_ep_inside) + if dist.get_rank() == 0: + if self.use_ep_inside: + print(f"MoE Parallel: pp {pp_size}, dp {dp_size}, ep {ep_size}") + else: + print(f"MoE Parallel: pp {pp_size}, ep {ep_size}, dp {dp_size}") return num_local_experts, self.parallel_info_dict[ep_size] diff --git a/colossalai/zero/low_level/low_level_optim.py b/colossalai/zero/low_level/low_level_optim.py index 78d85bcb5432..f08ebea58589 100644 --- a/colossalai/zero/low_level/low_level_optim.py +++ b/colossalai/zero/low_level/low_level_optim.py @@ -7,6 +7,7 @@ import torch import torch.distributed as dist import torch.nn as nn +from torch._utils import _flatten_dense_tensors from torch.distributed import ProcessGroup from torch.optim import Optimizer @@ -28,6 +29,7 @@ has_inf_or_nan, release_param_grad, sync_tensor, + unflatten, ) from .bookkeeping import BucketStore, GradientStore, ParameterStore @@ -80,7 +82,8 @@ def __init__( cpu_offload: bool = False, # cpu offload dp_process_group: Optional[ProcessGroup] = None, # the dp pg for comm tp_process_group: Optional[ProcessGroup] = None, # if using tp - forced_dtype: Optional[torch.dtype] = None): + forced_dtype: Optional[torch.dtype] = None, + extra_dp_process_group: Optional[ProcessGroup] = None): super(LowLevelZeroOptimizer, self).__init__(optim=optimizer) self._dtype = self.optim.param_groups[0]['params'][0].dtype @@ -100,6 +103,16 @@ def __init__( self._local_rank = dist.get_rank(group=self.dp_pg) self._world_size = dist.get_world_size(group=self.dp_pg) + # extra dp + # This group is used to sync moe param, dp_world_size = moe_duplicates * extra_dp_size. + # Non moe param will be sync by global dp pg, moe param will be sync by extra dp pg. + # Moe param grad is be split as non moe param by global dp pg, and grad will be merged in step. + # And moe working and master param are split by extra dp pg. + self.extra_dp_pg = extra_dp_process_group + if self.extra_dp_pg is not None: + self.extra_dp_pg_size = dist.get_world_size(group=self.extra_dp_pg) + self.extra_dp_pg_rank = dist.get_rank(group=self.extra_dp_pg) + self.tp_pg = tp_process_group # working and master params for mixed precision training @@ -143,10 +156,11 @@ def __init__( group_params = list() for param in param_group['params']: if param.requires_grad: - # skip moe param - if is_moe_tensor(param): - moe_params.append(param) - continue + if self.extra_dp_pg is None: + # skip moe param + if is_moe_tensor(param): + moe_params.append(param) + continue group_params.append(param) # add the working params to working_param_groups for bookkeeping @@ -226,9 +240,12 @@ def _create_master_param_current_rank(self, param_list): padding_param = torch.nn.functional.pad(param.data.view(-1), [0, padding_size]) else: padding_param = param.data.view(-1) - splited_params = padding_param.split(padding_param.numel() // self._world_size) - - splited_param_current_rank = splited_params[self._local_rank].detach().float().to(device) + if self.extra_dp_pg is not None and is_moe_tensor(param): + splited_params = padding_param.split(padding_param.numel() // self.extra_dp_pg_size) + splited_param_current_rank = splited_params[self.extra_dp_pg_rank].detach().float().to(device) + else: + splited_params = padding_param.split(padding_param.numel() // self._world_size) + splited_param_current_rank = splited_params[self._local_rank].detach().float().to(device) params_current_rank.append(splited_param_current_rank) self._param_store.link_master_and_working_param(splited_param_current_rank, param) @@ -261,8 +278,9 @@ def _run_reduction(self): if self._bucket_store.num_elements_in_bucket() > 0: self._bucket_store.build_grad_in_bucket() - flat_grads = self._bucket_store.get_flatten_grad() - flat_grads /= self._world_size + if self.extra_dp_pg is None: + flat_grads = self._bucket_store.get_flatten_grad() + flat_grads /= self._world_size # ready to add other tensors to bucket self._bucket_store.reset_num_elements_in_bucket() @@ -270,7 +288,8 @@ def _run_reduction(self): if self._overlap_communication: stream = self._comm_stream # in case of the memory being reused in the default stream - flat_grads.record_stream(stream) + if self.extra_dp_pg is None: + flat_grads.record_stream(stream) # waiting for ops in the default stream finishing stream.wait_stream(torch.cuda.current_stream()) else: @@ -279,27 +298,73 @@ def _run_reduction(self): with torch.cuda.stream(stream): group_id = self._bucket_store.current_group_id - grad_dtype = flat_grads.dtype - if self._communication_dtype is not None: - flat_grads = flat_grads.to(self._communication_dtype) + if self.extra_dp_pg is None: + grad_dtype = flat_grads.dtype + if self._communication_dtype is not None: + flat_grads = flat_grads.to(self._communication_dtype) if not self._partition_grads: - dist.all_reduce(flat_grads, group=self.dp_pg) - if flat_grads.dtype != grad_dtype: - flat_grads = flat_grads.to(grad_dtype) - - flat_grads_per_rank = flat_grads.split(flat_grads.numel() // self._world_size) - grad_in_bucket = self._bucket_store.get_grad() - - for rank, grad_list in grad_in_bucket.items(): - sync_tensor(flat_grads_per_rank[rank], grad_list) - for grad in grad_list: - param_id = self._bucket_store.get_param_id_of_grad(grad) - if len(self._grad_store.get_partitioned_gradients_by_param_id(group_id, - param_id)) < self._world_size: - self._grad_store.append_gradients_by_param_id(grad, group_id, param_id) - else: - self._grad_store.add_gradients_by_param_id(grad, rank, group_id, param_id) + if self.extra_dp_pg is None: + dist.all_reduce(flat_grads, group=self.dp_pg) + if flat_grads.dtype != grad_dtype: + flat_grads = flat_grads.to(grad_dtype) + + flat_grads_per_rank = flat_grads.split(flat_grads.numel() // self._world_size) + grad_in_bucket = self._bucket_store.get_grad() + + for rank, grad_list in grad_in_bucket.items(): + sync_tensor(flat_grads_per_rank[rank], grad_list) + for grad in grad_list: + param_id = self._bucket_store.get_param_id_of_grad(grad) + if len(self._grad_store.get_partitioned_gradients_by_param_id( + group_id, param_id)) < self._world_size: + self._grad_store.append_gradients_by_param_id(grad, group_id, param_id) + else: + self._grad_store.add_gradients_by_param_id(grad, rank, group_id, param_id) + + # sync extra zero group + else: + # record moe and non moe param + moe_list = [] + for param in self._bucket_store._param_list: + moe_list.append(is_moe_tensor(param)) + + # divide them into different groups + moe_grad_list = [] + non_moe_grad_list = [] + for grad_list in self._bucket_store._grad_in_bucket.values(): + non_moe_cur_grad = [] + moe_cur_grad = [] + for i in range(len(grad_list)): + if moe_list[i] == True: + moe_cur_grad.append(grad_list[i]) + else: + non_moe_cur_grad.append(grad_list[i]) + if len(moe_cur_grad) > 0: + moe_grad_list.append(moe_cur_grad) + if len(non_moe_cur_grad) > 0: + non_moe_grad_list.append(non_moe_cur_grad) + + # sync non moe param in global dp group + if len(non_moe_grad_list) > 0: + flat_grads = [] + for grad_list in non_moe_grad_list: + flat_grads.append(_flatten_dense_tensors(grad_list)) + flat_grads = _flatten_dense_tensors(flat_grads) + flat_grads /= self._world_size + dist.all_reduce(flat_grads, group=self.dp_pg) + flat_grads_per_rank = flat_grads.split(flat_grads.numel() // self._world_size) + self._sync_unpartitioned_grad(non_moe_grad_list, flat_grads_per_rank, group_id) + + # sync moe param only in zero group + if len(moe_grad_list) > 0: + flat_grads = [] + for grad_list in moe_grad_list: + flat_grads.append(_flatten_dense_tensors(grad_list)) + flat_grads = _flatten_dense_tensors(flat_grads) + dist.all_reduce(flat_grads, group=self.extra_dp_pg) + flat_grads_per_rank = flat_grads.split(flat_grads.numel() // self._world_size) + self._sync_unpartitioned_grad(moe_grad_list, flat_grads_per_rank, group_id) else: flat_grads_list = list(flat_grads.split(len(flat_grads) // self._world_size)) @@ -320,6 +385,16 @@ def _run_reduction(self): self._bucket_store.reset() + def _sync_unpartitioned_grad(self, origin_grad_list, flat_grad_list, group_id): + for rank, grad_list in enumerate(origin_grad_list): + sync_tensor(flat_grad_list[rank], grad_list) + for grad in grad_list: + param_id = self._bucket_store.get_param_id_of_grad(grad) + if len(self._grad_store.get_partitioned_gradients_by_param_id(group_id, param_id)) < self._world_size: + self._grad_store.append_gradients_by_param_id(grad, group_id, param_id) + else: + self._grad_store.add_gradients_by_param_id(grad, rank, group_id, param_id) + def _add_to_bucket(self, param, group_id): param_size = param.numel() @@ -434,11 +509,21 @@ def step(self, closure=None): # else the splited grad should be attached to the splited param grads = self._grad_store.get_partitioned_gradients_by_param_id(group_id, id(working_param)) if len(grads) > 0: - real_working_params[group_id].append(working_param) - grad = grads[grad_index].to(splited_param.dtype).to(splited_param.device) - splited_param.grad = grad - grad_partition_groups.append(grad) - real_master_params[group_id].append(splited_param) + # moe hybrid zero + if self.extra_dp_pg is not None and is_moe_tensor(working_param): + real_working_params[group_id].append(working_param) + param_slice = self._world_size // self.extra_dp_pg_size + grad = grads[self.extra_dp_pg_rank * param_slice:(self.extra_dp_pg_rank + 1) * param_slice] + grad = flatten(grad).to(splited_param.dtype).to(splited_param.device) + splited_param.grad = grad + grad_partition_groups.append(grad) + real_master_params[group_id].append(splited_param) + else: + real_working_params[group_id].append(working_param) + grad = grads[grad_index].to(splited_param.dtype).to(splited_param.device) + splited_param.grad = grad + grad_partition_groups.append(grad) + real_master_params[group_id].append(splited_param) # compute norm working_grads = self._grad_store.get_working_grads_by_group_id(group_id) @@ -473,10 +558,17 @@ def step(self, closure=None): master_working_param = self.optim.param_groups[group_id]['params'] for idx, splited_param in enumerate(master_working_param): working_param = real_working_params[group_id][idx] - all_splited_param = [ - torch.zeros(splited_param.shape, device="cuda", dtype=dtype) for _ in range(self._world_size) - ] - dist.all_gather(all_splited_param, splited_param.cuda().to(dtype), group=self.dp_pg) + if self.extra_dp_pg is not None and is_moe_tensor(working_param): + all_splited_param = [ + torch.zeros(splited_param.shape, device="cuda", dtype=dtype) + for _ in range(self.extra_dp_pg_size) + ] + dist.all_gather(all_splited_param, splited_param.cuda().to(dtype), group=self.extra_dp_pg) + else: + all_splited_param = [ + torch.zeros(splited_param.shape, device="cuda", dtype=dtype) for _ in range(self._world_size) + ] + dist.all_gather(all_splited_param, splited_param.cuda().to(dtype), group=self.dp_pg) working_param.data.copy_(flatten(all_splited_param)[:working_param.numel()].reshape_as(working_param)) self.optim.param_groups[group_id]['params'] = self._master_param_groups_of_current_rank[group_id] @@ -568,10 +660,16 @@ def state_dict(self) -> Dict: for k, v in state.items(): if isinstance(v, torch.Tensor) and k != 'step': working_param = self._param_store.master_to_working_param[id(param)] - gather_tensor = [ - torch.zeros(v.shape, device='cuda', dtype=v.dtype) for _ in range(self._world_size) - ] - dist.all_gather(gather_tensor, v.cuda(), group=self.dp_pg) + if self.extra_dp_pg is not None and is_moe_tensor(v): + gather_tensor = [ + torch.zeros(v.shape, device='cuda', dtype=v.dtype) for _ in range(self.extra_dp_pg_size) + ] + dist.all_gather(gather_tensor, v.cuda(), group=self.extra_dp_pg) + else: + gather_tensor = [ + torch.zeros(v.shape, device='cuda', dtype=v.dtype) for _ in range(self._world_size) + ] + dist.all_gather(gather_tensor, v.cuda(), group=self.dp_pg) param_state = torch.stack(gather_tensor).view(-1)[:working_param.numel()].reshape_as( working_param).cpu() zero_state[param][k] = param_state @@ -595,8 +693,12 @@ def load_state_dict(self, state_dict: Dict): v = v.flatten() if padding_size > 0: v = torch.nn.functional.pad(v, [0, padding_size]) - v_list = v.split(v.numel() // self._world_size) - zero_state_dict['state'][param_idx][k] = v_list[self._local_rank].detach().clone() + if self.extra_dp_pg is not None and is_moe_tensor(v): + v_list = v.split(v.numel() // self.extra_dp_pg_size) + zero_state_dict['state'][param_idx][k] = v_list[self.extra_dp_pg_rank].detach().clone() + else: + v_list = v.split(v.numel() // self._world_size) + zero_state_dict['state'][param_idx][k] = v_list[self._local_rank].detach().clone() self.optim.load_state_dict(zero_state_dict) @@ -627,8 +729,16 @@ def state_dict_shard(self, max_shard_size: int = 1024) -> Iterator[Tuple[Dict, i for k, v in states.items(): if isinstance(v, torch.Tensor) and k != 'step': - state_tensor = [torch.zeros(v.shape, device='cuda', dtype=v.dtype) for _ in range(self._world_size)] - dist.all_gather(state_tensor, v.cuda(), group=self.dp_pg) + if self.extra_dp_pg is not None and is_moe_tensor(v): + state_tensor = [ + torch.zeros(v.shape, device='cuda', dtype=v.dtype) for _ in range(self.extra_dp_pg_size) + ] + dist.all_gather(state_tensor, v.cuda(), group=self.extra_dp_pg) + else: + state_tensor = [ + torch.zeros(v.shape, device='cuda', dtype=v.dtype) for _ in range(self._world_size) + ] + dist.all_gather(state_tensor, v.cuda(), group=self.dp_pg) state_tensor = torch.stack(state_tensor).view(-1)[:working_param.numel()].reshape_as( working_param).cpu() current_block_size += state_tensor.numel() @@ -658,4 +768,7 @@ def update_master_params(self, model: nn.Module) -> None: working_param = p.data.view(-1) if padding_size > 0: working_param = torch.nn.functional.pad(working_param, [0, padding_size]) - master_param.copy_(working_param.chunk(self._world_size)[self._local_rank]) + if self.extra_dp_pg is not None and is_moe_tensor(p): + master_param.copy_(working_param.chunk(self.extra_dp_pg_size)[self.extra_dp_pg_rank]) + else: + master_param.copy_(working_param.chunk(self._world_size)[self._local_rank]) diff --git a/examples/language/openmoe/benchmark/benchmark_cai.py b/examples/language/openmoe/benchmark/benchmark_cai.py index 5ff0843caaea..830ff9df0ec6 100644 --- a/examples/language/openmoe/benchmark/benchmark_cai.py +++ b/examples/language/openmoe/benchmark/benchmark_cai.py @@ -41,7 +41,7 @@ def load_ckpt(repo_name: str, model: OpenMoeForCausalLM, booster: Booster): class RandomDataset(Dataset): - def __init__(self, num_samples: int = 1000, max_length: int = 2048, vocab_size: int = 32000): + def __init__(self, num_samples: int = 1000, max_length: int = 2048, vocab_size: int = 256384): self.num_samples = num_samples self.max_length = max_length self.input_ids = torch.randint(0, vocab_size, (num_samples, max_length), device=get_current_device()) @@ -86,7 +86,6 @@ def parse_args(): type=str, default="hybrid", help="parallel plugin", - choices=["zero2", "zero2_ep", "hybrid", "zero2_tp"], ) # hybrid plugin parser.add_argument("--pp_size", type=int, default=2, help="pp size") @@ -94,6 +93,7 @@ def parse_args(): parser.add_argument("--ep_size", type=int, default=2, help="ep size") parser.add_argument("--zero_stage", type=int, default=1, help="zero stage in hybrid plugin") parser.add_argument("--microbatch_size", type=int, default=1, help="microbatch size") + parser.add_argument("--extra_dp_size", type=int, default=1) # kernel parser.add_argument( "--use_kernel", @@ -116,63 +116,73 @@ def main(): # Set plugin booster_kwargs = {} - if args.plugin == "zero2": + hybrid_dict = {"tp_size": 1, "custom_policy": OpenMoeForCausalLMPolicy(), "enable_fused_normalization": args.use_kernel, "enable_jit_fused": args.use_kernel} + mgr_dict = {"seed": 42, "use_kernel_optim": args.use_kernel} + if args.plugin == "zero": dp_size = dist.get_world_size() plugin = LowLevelZeroPlugin(initial_scale=2**5, stage=2) MOE_MANAGER.setup( - seed=42, parallel=None, - use_kernel_optim=args.use_kernel, + **mgr_dict, ) - elif args.plugin == "zero2_ep": + elif args.plugin == "ep": dp_size = dist.get_world_size() plugin = MoeHybridParallelPlugin( - tp_size=1, pp_size=1, zero_stage=2, - custom_policy=OpenMoeForCausalLMPolicy(), - enable_fused_normalization=args.use_kernel, - enable_jit_fused=args.use_kernel, + **hybrid_dict, ) MOE_MANAGER.setup( - seed=42, parallel="EP", - use_kernel_optim=args.use_kernel, + **mgr_dict, ) - elif args.plugin == "zero2_tp": + elif args.plugin == "ep_zero": dp_size = dist.get_world_size() + use_ep_inside = False plugin = MoeHybridParallelPlugin( - tp_size=1, pp_size=1, - zero_stage=2, - custom_policy=OpenMoeForCausalLMPolicy(), - enable_fused_normalization=args.use_kernel, - enable_jit_fused=args.use_kernel, + zero_stage=1, + extra_dp_size=args.extra_dp_size, + use_ep_inside=use_ep_inside, + **hybrid_dict, ) MOE_MANAGER.setup( - seed=42, - parallel="TP", - use_kernel_optim=args.use_kernel, + parallel="EP", + max_ep_size=dp_size // args.extra_dp_size, + use_ep_inside=use_ep_inside, + **mgr_dict, + ) + elif args.plugin == "zero_ep": + dp_size = dist.get_world_size() + use_ep_inside = True + plugin = MoeHybridParallelPlugin( + pp_size=1, + zero_stage=1, + extra_dp_size=args.extra_dp_size, + use_ep_inside=use_ep_inside, + **hybrid_dict, + ) + MOE_MANAGER.setup( + parallel="EP", + max_ep_size=dp_size // args.extra_dp_size, + use_ep_inside=use_ep_inside, + **mgr_dict, ) elif args.plugin == "hybrid": dp_size = dist.get_world_size() // args.pp_size plugin = MoeHybridParallelPlugin( - tp_size=1, pp_size=args.pp_size, zero_stage=args.zero_stage, microbatch_size=args.microbatch_size, - custom_policy=OpenMoeForCausalLMPolicy(), - enable_fused_normalization=args.use_kernel, - enable_jit_fused=args.use_kernel, + **hybrid_dict, ) MOE_MANAGER.setup( - seed=42, parallel="EP", mode="fixed", fixed_dp_size=args.dp_size, fixed_ep_size=args.ep_size, fixed_pp_size=args.pp_size, - use_kernel_optim=args.use_kernel, + **mgr_dict, ) else: raise ValueError(f"Invalid plugin {args.plugin}") @@ -219,7 +229,7 @@ def main(): coordinator.print_on_master(f"Finish init booster") # Start finetuning - coordinator.print_on_master(f"Start finetuning") + coordinator.print_on_master(f"Start training") model.train() train_dataloader_iter = iter(dataloader) total_len = len(train_dataloader_iter) - 1 diff --git a/examples/language/openmoe/benchmark/benchmark_cai.sh b/examples/language/openmoe/benchmark/benchmark_cai.sh index 5db65a216461..ec4490faa55d 100755 --- a/examples/language/openmoe/benchmark/benchmark_cai.sh +++ b/examples/language/openmoe/benchmark/benchmark_cai.sh @@ -2,7 +2,7 @@ set -xue -NUM_GPU=8 +NUM_GPU=4 MODEL="8b" SEQ_LENGTH=2048 WARMUP=8 @@ -16,7 +16,7 @@ else export PYTHONPATH=$example_dir:$PYTHONPATH fi -# zero2 +# zero torchrun --standalone --nproc_per_node $NUM_GPU \ $example_dir/benchmark/benchmark_cai.py \ --model_name $MODEL \ @@ -24,10 +24,10 @@ torchrun --standalone --nproc_per_node $NUM_GPU \ --seq_length $SEQ_LENGTH \ --warmup $WARMUP \ --active $ACTIVE \ - --plugin zero2 \ + --plugin zero \ --use_kernel -# zero2_tp +# ep torchrun --standalone --nproc_per_node $NUM_GPU \ $example_dir/benchmark/benchmark_cai.py \ --model_name $MODEL \ @@ -35,10 +35,10 @@ torchrun --standalone --nproc_per_node $NUM_GPU \ --seq_length $SEQ_LENGTH \ --warmup $WARMUP \ --active $ACTIVE \ - --plugin zero2_tp \ + --plugin ep \ --use_kernel -# zero2_ep +# ep_zero torchrun --standalone --nproc_per_node $NUM_GPU \ $example_dir/benchmark/benchmark_cai.py \ --model_name $MODEL \ @@ -46,14 +46,27 @@ torchrun --standalone --nproc_per_node $NUM_GPU \ --seq_length $SEQ_LENGTH \ --warmup $WARMUP \ --active $ACTIVE \ - --plugin zero2_ep \ - --use_kernel + --plugin ep_zero \ + --use_kernel \ + --extra_dp_size 2 + +# zero_ep +torchrun --standalone --nproc_per_node $NUM_GPU \ + $example_dir/benchmark/benchmark_cai.py \ + --model_name $MODEL \ + --batch_size 12 \ + --seq_length $SEQ_LENGTH \ + --warmup $WARMUP \ + --active $ACTIVE \ + --plugin zero_ep \ + --use_kernel \ + --extra_dp_size 2 # hybrid torchrun --standalone --nproc_per_node $NUM_GPU \ $example_dir/benchmark/benchmark_cai.py \ --model_name $MODEL \ - --batch_size 512 \ + --batch_size 128 \ --seq_length $SEQ_LENGTH \ --warmup $WARMUP \ --active $ACTIVE \ diff --git a/examples/language/openmoe/benchmark/benchmark_fsdp.py b/examples/language/openmoe/benchmark/benchmark_fsdp.py index 1b69c8d4abeb..0edf102d640c 100644 --- a/examples/language/openmoe/benchmark/benchmark_fsdp.py +++ b/examples/language/openmoe/benchmark/benchmark_fsdp.py @@ -20,9 +20,8 @@ class RandomDataset(Dataset): - def __init__( - self, num_samples: int = 1000, max_length: int = 2048, vocab_size: int = 32000 - ): + + def __init__(self, num_samples: int = 1000, max_length: int = 2048, vocab_size: int = 32000): self.num_samples = num_samples self.max_length = max_length self.input_ids = torch.randint(0, vocab_size, (num_samples, max_length)) @@ -52,9 +51,7 @@ def fsdp_main(rank, world_size, args): max_length=args.seq_length, num_samples=args.batch_size * (args.warmup + args.active) * dp_size, ) - sampler = DistributedSampler( - dataset, rank=rank, num_replicas=world_size, shuffle=False - ) + sampler = DistributedSampler(dataset, rank=rank, num_replicas=world_size, shuffle=False) train_kwargs = {"batch_size": args.batch_size, "sampler": sampler} train_loader = torch.utils.data.DataLoader(dataset, **train_kwargs) torch.cuda.set_device(rank) @@ -64,7 +61,9 @@ def fsdp_main(rank, world_size, args): setattr(config, "router_z_loss_factor", 0.1) setattr(config, "label_smoothing", 0.1) setattr(config, "z_loss_factor", 0.1) + torch.set_default_dtype(torch.float16) model = OpenMoeForCausalLM(config) + torch.set_default_dtype(torch.float32) auto_wrap_policy = functools.partial( transformer_auto_wrap_policy, transformer_layer_cls={ @@ -114,9 +113,7 @@ def fsdp_main(rank, world_size, args): performance_evaluator.on_fit_end() if dist.get_rank() == 0: - print( - f"Max CUDA memory usage: {torch.cuda.max_memory_allocated()/1024**2:.2f} MB" - ) + print(f"Max CUDA memory usage: {torch.cuda.max_memory_allocated()/1024**2:.2f} MB") if __name__ == "__main__": diff --git a/examples/language/openmoe/benchmark/benchmark_fsdp.sh b/examples/language/openmoe/benchmark/benchmark_fsdp.sh index 41ffcd882a3b..e1eb2a9c6053 100755 --- a/examples/language/openmoe/benchmark/benchmark_fsdp.sh +++ b/examples/language/openmoe/benchmark/benchmark_fsdp.sh @@ -6,8 +6,8 @@ NUM_GPU=8 MODEL="8b" BATCH_SIZE=1 SEQ_LENGTH=2048 -WARMUP=5 -ACTIVE=5 +WARMUP=6 +ACTIVE=3 # HACK: make model importable example_dir=$(dirname $(realpath $(dirname $0))) diff --git a/examples/language/openmoe/model/modeling_openmoe.py b/examples/language/openmoe/model/modeling_openmoe.py index f8c79320fa57..357c0f22a783 100644 --- a/examples/language/openmoe/model/modeling_openmoe.py +++ b/examples/language/openmoe/model/modeling_openmoe.py @@ -360,19 +360,17 @@ def __init__(self, config: LlamaConfig, moe: bool): self.input_layernorm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.post_attention_layernorm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps) if self.moe: - self.mlp = SparseMLP( - num_experts=config.num_experts, - top_k=config.topk, - capacity_factor_train=config.capacity_factor_train, - capacity_factor_eval=config.capacity_factor_eval, - min_capacity=config.min_capacity, - noisy_policy=config.noisy_policy, - drop_tks=config.drop_tks, - expert_parallel=MOE_MANAGER.get_parallel() if MOE_MANAGER.is_initialized else config.expert_parallel, - hidden_size=config.hidden_size, - intermediate_size=config.intermediate_size, - activation=config.hidden_act, - gated=config.gated) + self.mlp = SparseMLP(num_experts=config.num_experts, + top_k=config.topk, + capacity_factor_train=config.capacity_factor_train, + capacity_factor_eval=config.capacity_factor_eval, + min_capacity=config.min_capacity, + noisy_policy=config.noisy_policy, + drop_tks=config.drop_tks, + hidden_size=config.hidden_size, + intermediate_size=config.intermediate_size, + activation=config.hidden_act, + gated=config.gated) self.pre_extra_mlp_layernorm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.extra_mlp = OpenMoeMLP(config) else: diff --git a/pytest.ini b/pytest.ini index 38ad7d76de50..598e0a74e71c 100644 --- a/pytest.ini +++ b/pytest.ini @@ -2,4 +2,4 @@ markers = dist: tests which are run in a multi-GPU or multi-machine environment (at least 4 GPUs) largedist: tests which are run in a multi-GPU or multi-machine environment (at least 8 GPUs) -addopts = --ignore=tests/test_analyzer --ignore=tests/test_auto_parallel --ignore=tests/test_autochunk --ignore=tests/test_moe --ignore=tests/test_fx --ignore=tests/test_legacy +addopts = --ignore=tests/test_analyzer --ignore=tests/test_auto_parallel --ignore=tests/test_autochunk --ignore=tests/test_fx --ignore=tests/test_legacy diff --git a/tests/test_moe/moe_utils.py b/tests/test_moe/moe_utils.py index 53266beb1877..934061ae4417 100644 --- a/tests/test_moe/moe_utils.py +++ b/tests/test_moe/moe_utils.py @@ -14,16 +14,13 @@ class MoeModel(nn.Module): - def __init__(self, checkpoint: bool = False, expert_parallel: str = "EP"): + def __init__(self, checkpoint: bool = False): class TestSubModule(CheckpointModule): def __init__(self): super().__init__(checkpoint) - self.moe = SparseMLP(num_experts=8, - expert_parallel=expert_parallel, - hidden_size=16, - intermediate_size=32) + self.moe = SparseMLP(num_experts=8, hidden_size=16, intermediate_size=32) self.proj = nn.Linear(16, 4) def _forward(self, x): @@ -127,7 +124,7 @@ def sync_local_from_ep(local_model: SparseMLP, ep_model: SparseMLP, assert_grad_ """Sync the parameters of tp model from ep model Args: - tp_model (MoeModule) + local_model (MoeModule) ep_model (MoeModule) """ for (local_name, local_param), (ep_name, ep_param) in zip(local_model.named_parameters(), diff --git a/tests/test_moe/test_grad_handler.py b/tests/test_moe/test_grad_handler.py index b9f36b0bcc45..13e142aadd7a 100644 --- a/tests/test_moe/test_grad_handler.py +++ b/tests/test_moe/test_grad_handler.py @@ -18,7 +18,7 @@ def run_test(rank, world_size, port): colossalai.launch(config=dict(), rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') - MOE_MANAGER.setup(42) # MOE initialization + MOE_MANAGER.setup(42, parallel="EP") # MOE initialization num_experts_list = [1, 2, 4] layer_list = [] for num_experts in num_experts_list: @@ -26,7 +26,6 @@ def run_test(rank, world_size, port): intermediate_size=DIM * 4, num_experts=num_experts, top_k=1, - expert_parallel="EP", noisy_policy="Jitter") layer_list.append(moe_layer) diff --git a/tests/test_moe/test_kernel.py b/tests/test_moe/test_kernel.py index 0074a698fd96..db40110d8d9e 100644 --- a/tests/test_moe/test_kernel.py +++ b/tests/test_moe/test_kernel.py @@ -23,7 +23,7 @@ def run_routing(rank, world_size, port, rs=2, hidden_size=128, data_type=torch.f colossalai.launch(config=dict(), rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') local_rank = dist.get_rank() - MOE_MANAGER.setup(42) # MOE environment initialization + MOE_MANAGER.setup(42, parallel="EP") # MOE environment initialization MOE_MANAGER.reset_loss() torch.manual_seed(rs + local_rank) # set each process has different random seed @@ -34,7 +34,6 @@ def run_routing(rank, world_size, port, rs=2, hidden_size=128, data_type=torch.f intermediate_size=hidden_size * 2, num_experts=NUM_EXPERTS, top_k=topk, - expert_parallel="EP", capacity_factor_train=1.0) layer = layer.to(get_current_device()) if data_type == torch.float16: diff --git a/tests/test_moe/test_moe_checkpoint.py b/tests/test_moe/test_moe_checkpoint.py index 489f5ebdacfc..09af499185db 100644 --- a/tests/test_moe/test_moe_checkpoint.py +++ b/tests/test_moe/test_moe_checkpoint.py @@ -1,5 +1,7 @@ +import importlib import os import shutil +import sys import pytest import torch @@ -11,8 +13,12 @@ from colossalai.booster.plugin.moe_hybrid_parallel_plugin import MoeHybridParallelPlugin from colossalai.moe.manager import MOE_MANAGER from colossalai.testing import rerun_if_address_is_in_use, spawn -from examples.language.openmoe.model.modeling_openmoe import OpenMoeForCausalLM -from examples.language.openmoe.model.openmoe_policy import OpenMoeForCausalLMPolicy + +sys.path.append(os.path.join(os.path.dirname(os.path.dirname(os.path.dirname(__file__))), "examples/language/openmoe")) + +# TODO: better way to import them +OpenMoeForCausalLM = importlib.import_module("model.modeling_openmoe").OpenMoeForCausalLM +OpenMoeForCausalLMPolicy = importlib.import_module("model.openmoe_policy").OpenMoeForCausalLMPolicy def get_config(): diff --git a/tests/test_moe/test_moe_ep_tp.py b/tests/test_moe/test_moe_ep_tp.py index 72b639c8b43a..2bbf739ebbd4 100644 --- a/tests/test_moe/test_moe_ep_tp.py +++ b/tests/test_moe/test_moe_ep_tp.py @@ -16,10 +16,11 @@ def run_test(rank, world_size, port): colossalai.launch(config=dict(), rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') - MOE_MANAGER.setup(42) # MOE initialization - - ep_model = SparseMLP(num_experts=4, expert_parallel="EP", hidden_size=DIM, intermediate_size=DIM) - tp_model = SparseMLP(num_experts=4, expert_parallel="TP", hidden_size=DIM, intermediate_size=DIM) + MOE_MANAGER.setup(42, parallel="EP") # MOE initialization + ep_model = SparseMLP(num_experts=4, hidden_size=DIM, intermediate_size=DIM) + MOE_MANAGER.__init__() + MOE_MANAGER.setup(42, parallel="TP") + tp_model = SparseMLP(num_experts=4, hidden_size=DIM, intermediate_size=DIM) ep_model = ep_model.to(get_current_device()) tp_model = tp_model.to(get_current_device()) diff --git a/tests/test_moe/test_moe_group.py b/tests/test_moe/test_moe_group.py index f5d54ba290aa..e111ea6bb18d 100644 --- a/tests/test_moe/test_moe_group.py +++ b/tests/test_moe/test_moe_group.py @@ -59,7 +59,7 @@ def run_moe_init(expert_cls): def _run_test(rank, world_size, port, expert_cls): colossalai.launch(config=dict(), rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') - MOE_MANAGER.setup(seed=42) + MOE_MANAGER.setup(seed=42, parallel="EP") run_moe_init(expert_cls) diff --git a/tests/test_moe/test_moe_hybrid_zero.py b/tests/test_moe/test_moe_hybrid_zero.py new file mode 100644 index 000000000000..a2b8efb0e2dc --- /dev/null +++ b/tests/test_moe/test_moe_hybrid_zero.py @@ -0,0 +1,89 @@ +import pytest +import torch +import torch.distributed as dist + +import colossalai +from colossalai.booster import Booster +from colossalai.booster.plugin import LowLevelZeroPlugin +from colossalai.booster.plugin.low_level_zero_plugin import LowLevelZeroModel +from colossalai.moe.manager import MOE_MANAGER +from colossalai.tensor.moe_tensor.api import is_moe_tensor +from colossalai.testing import rerun_if_address_is_in_use, spawn +from tests.test_moe.moe_utils import MoeModel + + +def run_fwd_bwd(model, data, label, criterion, optimizer, enable_autocast=False): + model.train() + with torch.cuda.amp.autocast(enabled=enable_autocast): + if criterion: + y = model(data) + loss = criterion(y, label) + else: + loss = model(data, label) + loss = loss.float() + + if isinstance(model, LowLevelZeroModel): + optimizer.backward(loss / 2) + else: + loss.backward() + return y + + +def run_zero_optim_test(local_rank, world_size, stage=1): + criterion = torch.nn.CrossEntropyLoss() + data = torch.randn(16, 4).cuda() + label = torch.randint(0, 4, (16,)).cuda() + + MOE_MANAGER.setup(seed=42, parallel=None) + torch_model = MoeModel(checkpoint=True) + torch_optimizer = torch.optim.Adam(torch_model.parameters()) + torch_model = torch_model.cuda() + + MOE_MANAGER.__init__() + MOE_MANAGER.setup(seed=42, max_ep_size=2, use_ep_inside=False, parallel="EP") + zero_model = MoeModel(checkpoint=True) + extra_dp_group = MOE_MANAGER.parallel_info_dict[2].dp_group + ep_rank = dist.get_rank(MOE_MANAGER.parallel_info_dict[2].ep_group) + ep_size = MOE_MANAGER.parallel_info_dict[2].ep_size + for zero_param, torch_param in zip(zero_model.parameters(), torch_model.parameters()): + if is_moe_tensor(zero_param): + num_expert = torch_param.data.shape[0] + zero_param.data.copy_(torch_param.data[ep_rank * (num_expert // ep_size):(ep_rank + 1) * + (num_expert // ep_size)].detach().clone()) + else: + zero_param.data.copy_(torch_param.data.detach().clone()) + zero_optimizer = torch.optim.Adam(zero_model.parameters()) + plugin = LowLevelZeroPlugin(stage=stage, precision="fp32") + plugin.zero_optim_kwargs["extra_dp_process_group"] = extra_dp_group + booster = Booster(plugin=plugin) + zero_model, zero_optimizer, _, _, _ = booster.boost(zero_model, zero_optimizer) + + run_fwd_bwd(torch_model, data, label, criterion, None) + torch_optimizer.step() + run_fwd_bwd(zero_model, data, label, criterion, zero_optimizer) + zero_optimizer.step() + + for (torch_name, torch_param), (zero_name, zero_param) in zip(torch_model.named_parameters(), + zero_model.named_parameters()): + if is_moe_tensor(zero_param): + num_expert = torch_param.data.shape[0] + torch_param.data = torch_param.data[ep_rank * (num_expert // ep_size):(ep_rank + 1) * + (num_expert // ep_size)] + assert torch.allclose(torch_param.data, zero_param.data, + atol=1e-4), f"{torch_name}\ntorch_param {torch_param.data}\nzero_param {zero_param.data}" + + +def run_dist(rank, world_size, port): + colossalai.launch(config=dict(), rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') + run_zero_optim_test(rank, world_size, stage=1) + + +@pytest.mark.dist +@pytest.mark.parametrize("world_size", [4]) +@rerun_if_address_is_in_use() +def test_moe_zero_optim(world_size): + spawn(run_dist, world_size) + + +if __name__ == '__main__': + test_moe_zero_optim(world_size=4) diff --git a/tests/test_moe/test_moe_local.py b/tests/test_moe/test_moe_local.py index 09cc0cc6a4ef..1211a0d2d7f1 100644 --- a/tests/test_moe/test_moe_local.py +++ b/tests/test_moe/test_moe_local.py @@ -16,10 +16,11 @@ def run_test(rank, world_size, port): colossalai.launch(config=dict(), rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') - MOE_MANAGER.setup(42) # MOE initialization - - ep_model = SparseMLP(num_experts=4, expert_parallel="EP", hidden_size=DIM, intermediate_size=DIM) - local_model = SparseMLP(num_experts=4, expert_parallel=None, hidden_size=DIM, intermediate_size=DIM) + MOE_MANAGER.setup(42, parallel=None) + local_model = SparseMLP(num_experts=4, hidden_size=DIM, intermediate_size=DIM) + MOE_MANAGER.__init__() + MOE_MANAGER.setup(42, parallel="EP") # MOE initialization + ep_model = SparseMLP(num_experts=4, hidden_size=DIM, intermediate_size=DIM) ep_model = ep_model.to(get_current_device()) local_model = local_model.to(get_current_device()) diff --git a/tests/test_moe/test_moe_zero_fwd_bwd.py b/tests/test_moe/test_moe_zero_fwd_bwd.py index 2b2afa4623b5..499d65f0072a 100644 --- a/tests/test_moe/test_moe_zero_fwd_bwd.py +++ b/tests/test_moe/test_moe_zero_fwd_bwd.py @@ -86,7 +86,7 @@ def run_zero_test(local_rank, world_size, stage=1): def run_dist(rank, world_size, port): colossalai.launch(config=dict(), rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') - MOE_MANAGER.setup(seed=42) + MOE_MANAGER.setup(seed=42, parallel="EP") seed_all(42 + rank) run_zero_test(rank, world_size, stage=1) run_zero_test(rank, world_size, stage=2) diff --git a/tests/test_moe/test_moe_zero_optim.py b/tests/test_moe/test_moe_zero_optim.py index 38a5cfbfd66e..8f4d89f17330 100644 --- a/tests/test_moe/test_moe_zero_optim.py +++ b/tests/test_moe/test_moe_zero_optim.py @@ -75,7 +75,7 @@ def run_zero_optim_test(local_rank, world_size, stage=1): def run_dist(rank, world_size, port): colossalai.launch(config=dict(), rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') - MOE_MANAGER.setup(seed=42) + MOE_MANAGER.setup(seed=42, parallel="EP") run_zero_optim_test(rank, world_size, stage=1) run_zero_optim_test(rank, world_size, stage=2)