diff --git a/colossalai/context/__init__.py b/colossalai/context/__init__.py index 50178b5fa850..08ef4e35fe2d 100644 --- a/colossalai/context/__init__.py +++ b/colossalai/context/__init__.py @@ -1,6 +1,6 @@ from .config import Config, ConfigException +from .moe_context import MOE_CONTEXT from .parallel_context import ParallelContext from .parallel_mode import ParallelMode -from .moe_context import MOE_CONTEXT from .process_group_initializer import * from .random import * diff --git a/colossalai/context/moe_context.py b/colossalai/context/moe_context.py index 652b9c2382f4..b41f4072a405 100644 --- a/colossalai/context/moe_context.py +++ b/colossalai/context/moe_context.py @@ -3,9 +3,29 @@ import torch import torch.distributed as dist +from colossalai.context.parallel_mode import ParallelMode from colossalai.context.singleton_meta import SingletonMeta -from colossalai.tensor.moe_tensor.api import get_moe_info -from colossalai.tensor.moe_tensor.moe_info import MoeParallelInfo +from colossalai.tensor import ProcessGroup + + +def _check_sanity(): + from colossalai.core import global_context as gpc + if gpc.tensor_parallel_size > 1 or gpc.pipeline_parallel_size > 1: + raise NotImplementedError("Moe is not compatible with tensor or " + "pipeline parallel at present.") + + +class MoeParallelInfo: + """Moe parallelism information, storing parallel sizes and groups. + """ + + def __init__(self, ep_size: int, dp_size: int): + _check_sanity() + self.ep_size = ep_size + self.dp_size = dp_size + self.pg = ProcessGroup(tp_degree=ep_size, dp_degree=dp_size) + self.ep_group = self.pg.tp_process_group() + self.dp_group = self.pg.dp_process_group() class MoeContext(metaclass=SingletonMeta): @@ -14,15 +34,13 @@ class MoeContext(metaclass=SingletonMeta): """ def __init__(self): - self.world_size = None + self.world_size = 1 # Users may want to set maximum expert parallel size smaller than the world size # since very low bandwidth across nodes may constrain the performance of MoE # When we have a maximum expert parallel size, we have a minimum data parallel size naturally - self.max_ep_size = None - self.min_dp_size = None - self.router_aux_loss = [] - self.router_z_loss = [] - self.parallel = None + self.max_ep_size = 1 + self.min_dp_size = 1 + self.aux_loss = None self.use_kernel_optim = True self.has_setup = False @@ -36,14 +54,18 @@ def parallel_info_dict(self): def is_initialized(self): return self.has_setup - def setup(self, seed: int, use_kernel_optim: bool = True, max_ep_size: int = 8, parallel: bool = None): + def setup(self, seed: int, use_kernel_optim: bool = True): assert not self.is_initialized, "MoE distributed context shouldn't be set up again" + _check_sanity() assert torch.cuda.is_available(), "MoE requires to enable CUDA first" self.world_size = dist.get_world_size() - self.max_ep_size = min(max_ep_size, dist.get_world_size()) + + from colossalai.core import global_context as gpc + self.max_ep_size = gpc.config.get('max_ep_size', self.world_size) + assert self.world_size % self.max_ep_size == 0, \ + "Maximum expert parallel size must be a factor of the number of GPUs" self.min_dp_size = self.world_size // self.max_ep_size - self.parallel = parallel # Enabling kernel optimization may raise error in some cases # Users can close kernel optimization manually @@ -53,7 +75,7 @@ def setup(self, seed: int, use_kernel_optim: bool = True, max_ep_size: int = 8, moe_set_seed(seed) self.has_setup = True - def get_info(self, num_experts: int, use_tp: bool = False) -> Tuple[int, MoeParallelInfo]: + def get_info(self, num_experts: int) -> Tuple[int, MoeParallelInfo]: """Calculate the Data Parallel Group and Expert Parallel Group. Parameters @@ -82,15 +104,12 @@ def get_info(self, num_experts: int, use_tp: bool = False) -> Tuple[int, MoePara ep_size = self.max_ep_size // dp_size # Calculate the number of experts for each GPU - if use_tp: - num_local_experts = num_experts - else: - num_local_experts = 1 if lt_flag else num_experts // self.max_ep_size + num_local_experts = 1 if lt_flag else num_experts // self.max_ep_size # Don't forget to multiply minimum data parallel size dp_size *= self.min_dp_size if not (ep_size in self.parallel_info_dict): - self.parallel_info_dict[ep_size] = get_moe_info(ep_size, dp_size) + self.parallel_info_dict[ep_size] = MoeParallelInfo(ep_size, dp_size) return num_local_experts, self.parallel_info_dict[ep_size] @@ -98,18 +117,13 @@ def set_kernel_not_use(self): self.use_kernel_optim = False def reset_loss(self): - self.router_aux_loss, self.router_z_loss = [], [] + self.aux_loss = 0 - def add_loss(self, aux_loss: float = 0., z_loss: float = 0.): - self.router_aux_loss.append(aux_loss) - self.router_z_loss.append(z_loss) + def add_loss(self, loss): + self.aux_loss += loss def get_loss(self): - cur_loss = self.router_aux_loss, self.router_z_loss - return cur_loss - - def get_parallel(self): - return self.parallel + return self.aux_loss MOE_CONTEXT = MoeContext() diff --git a/colossalai/context/random/__init__.py b/colossalai/context/random/__init__.py index d64b993257c1..e2314f859d3f 100644 --- a/colossalai/context/random/__init__.py +++ b/colossalai/context/random/__init__.py @@ -3,7 +3,6 @@ get_current_mode, get_seeds, get_states, - moe_set_seed, reset_seeds, seed, set_mode, @@ -14,5 +13,5 @@ __all__ = [ 'seed', 'set_mode', 'with_seed', 'add_seed', 'get_seeds', 'get_states', 'get_current_mode', 'set_seed_states', - 'sync_states', 'moe_set_seed', 'reset_seeds' + 'sync_states', 'reset_seeds' ] diff --git a/colossalai/context/random/_helper.py b/colossalai/context/random/_helper.py index 973c4d9faa32..811d40f660ff 100644 --- a/colossalai/context/random/_helper.py +++ b/colossalai/context/random/_helper.py @@ -7,8 +7,8 @@ import torch.cuda from torch import Tensor -from .seed_manager import SeedManager from ..parallel_mode import ParallelMode +from .seed_manager import SeedManager _SEED_MANAGER = SeedManager() @@ -159,14 +159,5 @@ def wrapper(*args, **kwargs): return wrapper -def moe_set_seed(seed): - if torch.cuda.is_available(): - from colossalai.core import global_context as gpc - global_rank = gpc.get_global_rank() - diff_seed = seed + global_rank - add_seed(ParallelMode.TENSOR, diff_seed, True) - print(f"moe seed condition: {global_rank} with tensor seed {diff_seed}", flush=True) - - def reset_seeds(): _SEED_MANAGER.reset() diff --git a/colossalai/initialize.py b/colossalai/initialize.py index dc0df0517508..c0aedffc26fe 100644 --- a/colossalai/initialize.py +++ b/colossalai/initialize.py @@ -19,7 +19,6 @@ from colossalai.amp.naive_amp import NaiveAMPModel from colossalai.builder.builder import build_gradient_handler from colossalai.context import Config, ConfigException, ParallelMode -from colossalai.context.moe_context import MOE_CONTEXT from colossalai.core import global_context as gpc from colossalai.engine import Engine from colossalai.engine.gradient_accumulation import accumulate_gradient @@ -32,7 +31,6 @@ from colossalai.logging import get_dist_logger from colossalai.nn.optimizer.colossalai_optimizer import ColossalaiOptimizer from colossalai.utils import get_current_device, is_using_ddp, is_using_pp, is_using_sequence, sync_model_param -from colossalai.utils.moe import sync_moe_model_param from colossalai.zero.legacy import ShardedOptimizerV2, convert_to_zero_v2 from colossalai.zero.legacy.gemini.ophooks import BaseOpHook @@ -306,8 +304,6 @@ def initialize(model: nn.Module, if not use_zero: if is_using_sequence(): sync_model_param(model, ParallelMode.SEQUENCE_DP) - elif MOE_CONTEXT.is_initialized: - sync_moe_model_param(model) elif is_using_ddp(): sync_model_param(model, ParallelMode.DATA) else: @@ -359,13 +355,6 @@ def initialize(model: nn.Module, "Training with zero is detected, ZeROGradientHandler is automatically " "added even though not specified in the configuration", ranks=[0]) - elif is_using_ddp() and MOE_CONTEXT.is_initialized: - gradient_handler_cfg = [dict(type='MoeGradientHandler')] - if verbose: - logger.info( - "Data parallel training is detected with moe parallel, MoeGradientHandler is automatically " - "added even though not specified in the configuration", - ranks=[0]) elif is_using_sequence(): model = DDP(model, process_group=gpc.get_group(ParallelMode.SEQUENCE_DP), diff --git a/colossalai/moe/__init__.py b/colossalai/moe/__init__.py new file mode 100644 index 000000000000..492cdaf13d1d --- /dev/null +++ b/colossalai/moe/__init__.py @@ -0,0 +1,10 @@ +from .checkpoint import MoeCheckpintIO +from .experts import EPMLPExperts, TPMLPExperts, build_ffn_experts +from .layers import SparseMLP +from .routers import MoeRouter, Top1Router, Top2Router +from .utils import NormalNoiseGenerator, UniformNoiseGenerator + +__all__ = [ + 'EPMLPExperts', 'TPMLPExperts', 'Top1Router', 'Top2Router', 'NormalNoiseGenerator', 'UniformNoiseGenerator', + 'SparseMLP', 'MoeRouter', 'MoeCheckpintIO', 'build_ffn_experts' +] diff --git a/colossalai/nn/layer/moe/_operation.py b/colossalai/moe/_operation.py similarity index 100% rename from colossalai/nn/layer/moe/_operation.py rename to colossalai/moe/_operation.py diff --git a/colossalai/nn/layer/moe/checkpoint.py b/colossalai/moe/checkpoint.py similarity index 100% rename from colossalai/nn/layer/moe/checkpoint.py rename to colossalai/moe/checkpoint.py diff --git a/colossalai/nn/layer/moe/experts.py b/colossalai/moe/experts.py similarity index 78% rename from colossalai/nn/layer/moe/experts.py rename to colossalai/moe/experts.py index fd93bed97992..da4fe58977e8 100644 --- a/colossalai/nn/layer/moe/experts.py +++ b/colossalai/moe/experts.py @@ -1,15 +1,14 @@ import math -from copy import deepcopy +from contextlib import nullcontext import torch -import torch.distributed as dist import torch.nn as nn -from colossalai.context import ParallelMode, seed -from colossalai.context.moe_context import MOE_CONTEXT -from colossalai.nn.layer.moe._operation import MoeInGradScaler, MoeOutGradScaler -from colossalai.nn.layer.moe.utils import get_activation -from colossalai.tensor.moe_tensor.api import get_dp_group, get_ep_group, get_ep_size, set_moe_tensor_info +from colossalai.moe._operation import MoeInGradScaler, MoeOutGradScaler +from colossalai.moe.manager import MOE_MANAGER +from colossalai.moe.utils import get_activation +from colossalai.shardformer.layer.utils import Randomizer +from colossalai.tensor.moe_tensor.api import get_ep_size, set_moe_tensor_info class BaseMLPExperts(nn.Module): @@ -35,13 +34,13 @@ def __init__( # get expert parallel info if expert_parallel is not None: - self.num_local_experts, self.moe_info = MOE_CONTEXT.get_info( + self.num_local_experts, self.moe_info = MOE_MANAGER.get_info( num_experts, use_tp=True if expert_parallel == "TP" else False) # get settings for different parallel if expert_parallel == "TP": - assert intermediate_size % MOE_CONTEXT.max_ep_size == 0, \ + assert intermediate_size % MOE_MANAGER.max_ep_size == 0, \ "intermediate_size should be divide by maximum expert parallel size" - intermediate_size = intermediate_size // MOE_CONTEXT.max_ep_size + intermediate_size = intermediate_size // MOE_MANAGER.max_ep_size num_experts = self.num_total_experts else: num_experts = self.num_local_experts @@ -57,14 +56,18 @@ def __init__( self.wi = nn.Parameter(torch.empty(num_experts, hidden_size, intermediate_size)) self.wo = nn.Parameter(torch.empty(num_experts, intermediate_size, hidden_size)) + # expert param should be different if expert_parallel is not None: - with seed(ParallelMode.TENSOR): - if gated: - nn.init.trunc_normal_(self.wi_gate, std=math.sqrt(0.1 / hidden_size)) - nn.init.trunc_normal_(self.wi_up, std=math.sqrt(0.1 / hidden_size)) - else: - nn.init.trunc_normal_(self.wi, std=math.sqrt(0.1 / hidden_size)) - nn.init.trunc_normal_(self.wo, std=math.sqrt(0.1 / intermediate_size)) + seed_ctx = Randomizer(MOE_MANAGER.seed).fork_rng(enable_cpu=True) + else: + seed_ctx = nullcontext() + with seed_ctx: + if gated: + nn.init.trunc_normal_(self.wi_gate, std=math.sqrt(0.1 / hidden_size)) + nn.init.trunc_normal_(self.wi_up, std=math.sqrt(0.1 / hidden_size)) + else: + nn.init.trunc_normal_(self.wi, std=math.sqrt(0.1 / hidden_size)) + nn.init.trunc_normal_(self.wo, std=math.sqrt(0.1 / intermediate_size)) self.act = get_activation(activation) self.drop = nn.Dropout(p=drop_rate) @@ -88,10 +91,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: # inputs [g, e, c, h] else: x = torch.bmm(x, self.wi) x = self.act(x) - - if self.expert_parallel is not None: - with seed(ParallelMode.TENSOR): - x = self.drop(x) + x = self.drop(x) x = torch.bmm(x, self.wo) x = x.reshape(inshape) @@ -143,7 +143,7 @@ def get_expert_class(name: str) -> BaseMLPExperts: def build_ffn_experts(num_experts: int, d_model: int, d_ff: int, activation=None, drop_rate: float = 0): - mep_size = MOE_CONTEXT.max_ep_size + mep_size = MOE_MANAGER.max_ep_size if num_experts % mep_size == 0 or mep_size % num_experts == 0: return EPMLPExperts(num_experts, d_model, d_ff, activation, drop_rate) elif d_ff % mep_size == 0: diff --git a/colossalai/nn/layer/moe/layers.py b/colossalai/moe/layers.py similarity index 76% rename from colossalai/nn/layer/moe/layers.py rename to colossalai/moe/layers.py index 3f65bcde8b29..ace81b543273 100644 --- a/colossalai/nn/layer/moe/layers.py +++ b/colossalai/moe/layers.py @@ -5,18 +5,11 @@ import torch.nn as nn import torch.nn.functional as F -from colossalai.context.moe_context import MOE_CONTEXT -from colossalai.nn.layer.moe._operation import ( - COL_MOE_KERNEL_FLAG, - AllGather, - AllToAll, - MoeCombine, - MoeDispatch, - ReduceScatter, -) -from colossalai.nn.layer.moe.experts import BaseMLPExperts, get_expert_class -from colossalai.nn.layer.moe.routers import MoeRouter, get_router_cls -from colossalai.nn.layer.moe.utils import get_noise_generator +from colossalai.moe._operation import COL_MOE_KERNEL_FLAG, AllGather, AllToAll, MoeCombine, MoeDispatch, ReduceScatter +from colossalai.moe.experts import BaseMLPExperts, get_expert_class +from colossalai.moe.manager import MOE_MANAGER +from colossalai.moe.routers import MoeRouter, get_router_cls +from colossalai.moe.utils import get_noise_generator from colossalai.tensor.moe_tensor.api import get_ep_group, get_ep_size @@ -65,7 +58,7 @@ def __init__(self, super().__init__() self.hidden_size = hidden_size self.num_experts = num_experts - self.use_kernel = True if COL_MOE_KERNEL_FLAG and MOE_CONTEXT.use_kernel_optim else False + self.use_kernel = True if COL_MOE_KERNEL_FLAG and MOE_MANAGER.use_kernel_optim else False self.expert_parallel = expert_parallel assert expert_parallel in ["EP", "TP", None], f"Unsupported expert parallel type {expert_parallel}" @@ -156,45 +149,3 @@ def _tp_process(self, dispatch_data: torch.Tensor) -> torch.Tensor: expert_out = self.experts(expert_in) expert_out = ReduceScatter.apply(expert_out, self.ep_group) return expert_out - - -class MoeModule(SparseMLP): - """ - For other dependency - """ - - def __init__(self, - num_experts: int, - top_k: int = 1, - capacity_factor_train: float = 1.25, - capacity_factor_eval: float = 2.0, - min_capacity: int = 4, - noisy_policy: Optional[str] = None, - drop_tks: bool = True, - expert_parallel: str = "EP", - hidden_size: int = 2048, - intermediate_size: int = 2048, - activation: str = None): - super().__init__(num_experts, top_k, capacity_factor_train, capacity_factor_eval, min_capacity, noisy_policy, - drop_tks, expert_parallel, hidden_size, intermediate_size, activation) - - -class MoeLayer(SparseMLP): - """ - For other dependency - """ - - def __init__(self, - num_experts: int, - top_k: int = 1, - capacity_factor_train: float = 1.25, - capacity_factor_eval: float = 2.0, - min_capacity: int = 4, - noisy_policy: Optional[str] = None, - drop_tks: bool = True, - expert_parallel: str = "EP", - hidden_size: int = 2048, - intermediate_size: int = 2048, - activation: str = None): - super().__init__(num_experts, top_k, capacity_factor_train, capacity_factor_eval, min_capacity, noisy_policy, - drop_tks, expert_parallel, hidden_size, intermediate_size, activation) diff --git a/colossalai/nn/loss/loss_moe.py b/colossalai/moe/loss.py similarity index 90% rename from colossalai/nn/loss/loss_moe.py rename to colossalai/moe/loss.py index a8b18a3e37ee..75624510b452 100644 --- a/colossalai/nn/loss/loss_moe.py +++ b/colossalai/moe/loss.py @@ -1,80 +1,78 @@ -import torch.nn as nn -from colossalai.registry import LOSSES -from torch.nn.modules.loss import _Loss -from colossalai.context.moe_context import MOE_CONTEXT - - -@LOSSES.register_module -class MoeCrossEntropyLoss(_Loss): - r"""torch.nn.CrossEntropyLoss added with auxiliary loss. - - Args: - input (:class:`torch.tensor`): Predicted unnormalized scores (often referred to as logits). - target (:class:`torch.tensor`): Ground truth class indices or class probabilities. - aux_weight (float, optional): Weight of auxiliary loss in total loss.Defaults 0.01. - - The ``args`` and ``kwargs`` should include parameters below: - :: - - weight (Tensor, optional) - size_average (bool, optional) - ignore_index (int, optional) - reduce (bool, optional) - reduction (str, optional) - label_smoothing (float, optional) - - More details about ``args``, ``kwargs`` and ``torch.nn.functional.cross_entropy`` could be found in - `Cross_entropy `_. - """ - - def __init__(self, aux_weight: float = 0.01, *args, **kwargs): - super().__init__() - self.loss = nn.CrossEntropyLoss(*args, **kwargs) - self.aux_weight = aux_weight - - def forward(self, *args): - """ - The ``args`` should at least include parameters below: - :: - - input (:class:`torch.tensor`): Predicted unnormalized scores (often referred to as logits). - target (:class:`torch.tensor`): Ground truth class indices or class probabilities. - - More details about ``args``, ``kwargs`` and ``torch.nn.functional.cross_entropy`` could be found in - `Cross_entropy `_. - """ - main_loss = self.loss(*args) - aux_loss = MOE_CONTEXT.get_loss() - return main_loss + self.aux_weight * aux_loss - - -@LOSSES.register_module -class MoeLoss(_Loss): - """A wrapper class for any loss module to add with auxiliary loss. - - Args: - aux_weight (float): Weight of auxiliary loss in total loss. - loss_fn (``Callable``): Loss function. - args (list): Args in loss function. - kwargs (dict): Kwargs in loss function - """ - - def __init__(self, aux_weight: float, loss_fn, *args, **kwargs): - super().__init__() - self.loss_fn = loss_fn(*args, **kwargs) - self.aux_weight = aux_weight - - def forward(self, *args, **kwargs): - """ - The ``args`` and ``kwargs`` should at least include parameters below: - :: - - input (:class:`torch.tensor`): Predicted unnormalized scores (often referred to as logits). - target (:class:`torch.tensor`): Ground truth class indices or class probabilities. - - Note: - The ``args`` and ``kwargs`` may include different parameters varying with different loss function. - """ - main_loss = self.loss_fn(*args, **kwargs) - aux_loss = MOE_CONTEXT.get_loss() - return main_loss + self.aux_weight * aux_loss +import torch.nn as nn +from torch.nn.modules.loss import _Loss + +from colossalai.moe.manager import MOE_MANAGER + + +class MoeCrossEntropyLoss(_Loss): + r"""torch.nn.CrossEntropyLoss added with auxiliary loss. + + Args: + input (:class:`torch.tensor`): Predicted unnormalized scores (often referred to as logits). + target (:class:`torch.tensor`): Ground truth class indices or class probabilities. + aux_weight (float, optional): Weight of auxiliary loss in total loss.Defaults 0.01. + + The ``args`` and ``kwargs`` should include parameters below: + :: + + weight (Tensor, optional) + size_average (bool, optional) + ignore_index (int, optional) + reduce (bool, optional) + reduction (str, optional) + label_smoothing (float, optional) + + More details about ``args``, ``kwargs`` and ``torch.nn.functional.cross_entropy`` could be found in + `Cross_entropy `_. + """ + + def __init__(self, aux_weight: float = 0.01, *args, **kwargs): + super().__init__() + self.loss = nn.CrossEntropyLoss(*args, **kwargs) + self.aux_weight = aux_weight + + def forward(self, *args): + """ + The ``args`` should at least include parameters below: + :: + + input (:class:`torch.tensor`): Predicted unnormalized scores (often referred to as logits). + target (:class:`torch.tensor`): Ground truth class indices or class probabilities. + + More details about ``args``, ``kwargs`` and ``torch.nn.functional.cross_entropy`` could be found in + `Cross_entropy `_. + """ + main_loss = self.loss(*args) + aux_loss = MOE_MANAGER.get_loss() + return main_loss + self.aux_weight * aux_loss + + +class MoeLoss(_Loss): + """A wrapper class for any loss module to add with auxiliary loss. + + Args: + aux_weight (float): Weight of auxiliary loss in total loss. + loss_fn (``Callable``): Loss function. + args (list): Args in loss function. + kwargs (dict): Kwargs in loss function + """ + + def __init__(self, aux_weight: float, loss_fn, *args, **kwargs): + super().__init__() + self.loss_fn = loss_fn(*args, **kwargs) + self.aux_weight = aux_weight + + def forward(self, *args, **kwargs): + """ + The ``args`` and ``kwargs`` should at least include parameters below: + :: + + input (:class:`torch.tensor`): Predicted unnormalized scores (often referred to as logits). + target (:class:`torch.tensor`): Ground truth class indices or class probabilities. + + Note: + The ``args`` and ``kwargs`` may include different parameters varying with different loss function. + """ + main_loss = self.loss_fn(*args, **kwargs) + aux_loss = MOE_MANAGER.get_loss() + return main_loss + self.aux_weight * aux_loss diff --git a/colossalai/moe/manager.py b/colossalai/moe/manager.py new file mode 100644 index 000000000000..3dc27c6cb0f0 --- /dev/null +++ b/colossalai/moe/manager.py @@ -0,0 +1,115 @@ +from typing import Tuple + +import torch +import torch.distributed as dist + +from colossalai.context.singleton_meta import SingletonMeta +from colossalai.tensor.moe_tensor.api import get_moe_info +from colossalai.tensor.moe_tensor.moe_info import MoeParallelInfo + + +class MoeManager(metaclass=SingletonMeta): + """MoE manager. This class manages different + parallel groups in MoE context and MoE loss in training. + """ + + def __init__(self): + self.world_size = None + # Users may want to set maximum expert parallel size smaller than the world size + # since very low bandwidth across nodes may constrain the performance of MoE + # When we have a maximum expert parallel size, we have a minimum data parallel size naturally + self.max_ep_size = None + self.min_dp_size = None + self.router_aux_loss = [] + self.router_z_loss = [] + self.parallel = None + self.seed = None + self.use_kernel_optim = True + + self.has_setup = False + self._parallel_info_dict = dict() + + @property + def parallel_info_dict(self): + return self._parallel_info_dict + + @property + def is_initialized(self): + return self.has_setup + + def setup(self, seed: int, use_kernel_optim: bool = True, max_ep_size: int = 8, parallel: bool = None): + assert not self.is_initialized, "MoE distributed context shouldn't be set up again" + assert torch.cuda.is_available(), "MoE requires to enable CUDA first" + + self.world_size = dist.get_world_size() + self.seed = seed + dist.get_rank() + self.max_ep_size = min(max_ep_size, dist.get_world_size()) + self.min_dp_size = self.world_size // self.max_ep_size + self.parallel = parallel + + # Enabling kernel optimization may raise error in some cases + # Users can close kernel optimization manually + self.use_kernel_optim = use_kernel_optim + + self.has_setup = True + + def get_info(self, num_experts: int, use_tp: bool = False) -> Tuple[int, MoeParallelInfo]: + """Calculate the Data Parallel Group and Expert Parallel Group. + + Parameters + ---------- + num_experts : int + The number experts + + Returns + ------- + int, MoeParallelInfo + number of local experts, the MoeParallelInfo of the current ep_size + """ + + gt_flag = num_experts % self.max_ep_size == 0 # check whether num_experts is greater + lt_flag = self.max_ep_size % num_experts == 0 # check whether num_experts is less + + assert gt_flag or lt_flag, "Automatic experts placement dose not not support expert number" \ + " is not a multiple of ep size or vice versa." + + # If the number of experts is greater than maximum expert parallel size. a.k.a ep_size, + # there are multiple experts in each GPU and each GPU has different experts + # So it's data parallel size is 1 + # Otherwise, there is only one expert in each GPU + # The data parallel size should be calculated + dp_size = 1 if gt_flag else self.max_ep_size // num_experts + ep_size = self.max_ep_size // dp_size + + # Calculate the number of experts for each GPU + if use_tp: + num_local_experts = num_experts + else: + num_local_experts = 1 if lt_flag else num_experts // self.max_ep_size + + # Don't forget to multiply minimum data parallel size + dp_size *= self.min_dp_size + if not (ep_size in self.parallel_info_dict): + self.parallel_info_dict[ep_size] = get_moe_info(ep_size, dp_size) + + return num_local_experts, self.parallel_info_dict[ep_size] + + def set_kernel_not_use(self): + self.use_kernel_optim = False + + def reset_loss(self): + self.router_aux_loss, self.router_z_loss = [], [] + + def add_loss(self, aux_loss: float = 0., z_loss: float = 0.): + self.router_aux_loss.append(aux_loss) + self.router_z_loss.append(z_loss) + + def get_loss(self): + cur_loss = self.router_aux_loss, self.router_z_loss + return cur_loss + + def get_parallel(self): + return self.parallel + + +MOE_MANAGER = MoeManager() diff --git a/colossalai/nn/layer/moe/routers.py b/colossalai/moe/routers.py similarity index 98% rename from colossalai/nn/layer/moe/routers.py rename to colossalai/moe/routers.py index 9332302a096a..dd9243421667 100644 --- a/colossalai/nn/layer/moe/routers.py +++ b/colossalai/moe/routers.py @@ -8,8 +8,8 @@ import torch.nn.functional as F from torch.distributed import ProcessGroup -from colossalai.context import MOE_CONTEXT -from colossalai.nn.layer.moe._operation import moe_cumsum +from colossalai.moe._operation import moe_cumsum +from colossalai.moe.manager import MOE_MANAGER from colossalai.utils import get_current_device @@ -66,7 +66,7 @@ def set_z_loss(self, router_logits: torch.Tensor): def pop_router_loss(self) -> torch.Tensor: assert self._aux_loss is not None - MOE_CONTEXT.add_loss(self._aux_loss, self._z_loss) + MOE_MANAGER.add_loss(self._aux_loss, self._z_loss) self._aux_loss = None self._z_loss = None diff --git a/colossalai/nn/layer/moe/utils.py b/colossalai/moe/utils.py similarity index 66% rename from colossalai/nn/layer/moe/utils.py rename to colossalai/moe/utils.py index 5b3542c80595..58c1665a4d63 100644 --- a/colossalai/nn/layer/moe/utils.py +++ b/colossalai/moe/utils.py @@ -1,10 +1,13 @@ import contextlib -from typing import Callable +from typing import Callable, Dict, List import torch +import torch.distributed as dist +import torch.nn as nn import torch.nn.functional as F -from colossalai.context.moe_context import MOE_CONTEXT +from colossalai.moe.manager import MOE_MANAGER +from colossalai.tensor.moe_tensor.api import get_dp_group, get_dp_group_ranks, get_ep_size, is_moe_tensor from colossalai.utils import get_current_device @@ -119,3 +122,45 @@ def _skip_init(x, *args, **kwargs): for fn, fn_saved in zip(init_fn_list, fn_saved): fn = fn_saved return + + +def get_moe_epsize_param_dict(model: nn.Module) -> Dict[int, List[nn.Parameter]]: + """Returns a parameter dictionary, the key of which is the expert parallel + size of every parameter. Since the parameters in data parallelism is replicated + in each GPU, we set their ep_size to 1. + + Args: + model (:class:`torch.nn.Module`): A pyTorch `nn.Module` from which we get dict. + """ + epsize_param_dict = dict() + for param in model.parameters(): + if not is_moe_tensor(param): + ep_size = 1 # set ep_size to 1 for dp parameters + else: + ep_size = get_ep_size(param) + if ep_size not in epsize_param_dict: + epsize_param_dict[ep_size] = [] + epsize_param_dict[ep_size].append(param) + + return epsize_param_dict + + +def sync_moe_model_param(model: nn.Module): + """Make sure model parameters are consistent in MoE parallel context. + + Args: + model (:class:`torch.nn.Module`): A pyTorch model on whose parameters you check the consistency. + """ + param_dict = get_moe_epsize_param_dict(model) + + # synchronize the parameters whose dp_group is the whole world + if 1 in param_dict: + for param in param_dict[1]: + dist.broadcast(param, src=0) + + for ep_size in param_dict: + # When ep_size = world_size, communication is not needed + if ep_size != 1 and ep_size != MOE_MANAGER.world_size: + for param in param_dict[ep_size]: + src_rank = get_dp_group_ranks(param)[0] + dist.broadcast(param, src=src_rank, group=get_dp_group(param)) diff --git a/colossalai/nn/layer/__init__.py b/colossalai/nn/layer/__init__.py index b705632f8040..09c6615ea2ad 100644 --- a/colossalai/nn/layer/__init__.py +++ b/colossalai/nn/layer/__init__.py @@ -1,10 +1,10 @@ from .colossalai_layer import * +from .moe import * from .parallel_1d import * from .parallel_2d import * from .parallel_2p5d import * from .parallel_3d import * from .parallel_sequence import * -from .moe import * from .utils import * from .vanilla import * from .wrapper import * diff --git a/colossalai/nn/layer/moe/__init__.py b/colossalai/nn/layer/moe/__init__.py index 52f529814eba..5280acf8dee7 100644 --- a/colossalai/nn/layer/moe/__init__.py +++ b/colossalai/nn/layer/moe/__init__.py @@ -1,10 +1,12 @@ -from .checkpoint import MoeCheckpintIO -from .experts import EPMLPExperts, TPMLPExperts, build_ffn_experts -from .layers import MoeLayer, MoeModule, SparseMLP -from .routers import MoeRouter, Top1Router, Top2Router -from .utils import NormalNoiseGenerator, UniformNoiseGenerator - -__all__ = [ - 'EPMLPExperts', 'TPMLPExperts', 'Top1Router', 'Top2Router', 'MoeModule', 'MoeLayer', 'NormalNoiseGenerator', - 'UniformNoiseGenerator', 'SparseMLP', 'MoeRouter', 'MoeCheckpintIO', 'build_ffn_experts' -] +MoeModule = None +MoeLayer = None +build_ffn_experts = None +EPMLPExperts = None +TPMLPExperts = None +Top1Router = None +Top2Router = None +NormalNoiseGenerator = None +UniformNoiseGenerator = None +SparseMLP = None +MoeRouter = None +MoeCheckpintIO = None diff --git a/colossalai/nn/loss/__init__.py b/colossalai/nn/loss/__init__.py index 373e4ec9468b..3c1a16d44c3f 100644 --- a/colossalai/nn/loss/__init__.py +++ b/colossalai/nn/loss/__init__.py @@ -1,14 +1,14 @@ -from colossalai.global_variables import tensor_parallel_env as env -from colossalai.nn.layer.utils import get_tensor_parallel_mode from torch import nn from torch.nn.modules.loss import * from torch.nn.modules.loss import _Loss +from colossalai.global_variables import tensor_parallel_env as env +from colossalai.nn.layer.utils import get_tensor_parallel_mode + from .loss_1d import VocabParallelCrossEntropyLoss1D from .loss_2d import CrossEntropyLoss2D, VocabParallelCrossEntropyLoss2D from .loss_2p5d import CrossEntropyLoss2p5D, VocabParallelCrossEntropyLoss2p5D from .loss_3d import CrossEntropyLoss3D, VocabParallelCrossEntropyLoss3D -from .loss_moe import MoeCrossEntropyLoss, MoeLoss _parallel_cross_entropy = { '2d': CrossEntropyLoss2D, diff --git a/colossalai/tensor/moe_tensor/api.py b/colossalai/tensor/moe_tensor/api.py index b9b6d338438e..442b3c0f4958 100644 --- a/colossalai/tensor/moe_tensor/api.py +++ b/colossalai/tensor/moe_tensor/api.py @@ -107,3 +107,29 @@ def get_dp_rank(tensor: torch.Tensor) -> int: int: The data parallel rank of the given tensor. """ return dist.get_rank(get_dp_group(tensor)) + + +def get_ep_group_ranks(tensor: torch.Tensor) -> int: + """ + Get the expert parallel group ranks of the given tensor. + + Args: + tensor (torch.Tensor): The tensor to be checked. + + Returns: + int: The expert parallel group ranks of the given tensor. + """ + return tensor.moe_info.ep_group_ranks + + +def get_dp_group_ranks(tensor: torch.Tensor) -> int: + """ + Get the data parallel group ranks of the given tensor. + + Args: + tensor (torch.Tensor): The tensor to be checked. + + Returns: + int: The data parallel group ranks of the given tensor. + """ + return tensor.moe_info.dp_group_ranks diff --git a/colossalai/tensor/moe_tensor/moe_info.py b/colossalai/tensor/moe_tensor/moe_info.py index 89f79f162b5b..ca7f163b9c24 100644 --- a/colossalai/tensor/moe_tensor/moe_info.py +++ b/colossalai/tensor/moe_tensor/moe_info.py @@ -12,4 +12,6 @@ def __init__(self, ep_size: int, dp_size: int): self.ep_size = ep_size self.pg = ProcessGroupMesh(self.dp_size, self.ep_size) self.ep_group = self.pg.get_group_along_axis(self.ep_axis) + self.ep_group_ranks = self.pg.get_ranks_in_group(self.ep_group) self.dp_group = self.pg.get_group_along_axis(self.dp_axis) + self.dp_group_ranks = self.pg.get_ranks_in_group(self.dp_group) diff --git a/colossalai/utils/moe.py b/colossalai/utils/moe.py deleted file mode 100644 index 86d04c11958b..000000000000 --- a/colossalai/utils/moe.py +++ /dev/null @@ -1,52 +0,0 @@ -import torch.nn as nn -import torch.distributed as dist -from colossalai.core import global_context as gpc -from colossalai.context.moe_context import MOE_CONTEXT -from colossalai.context import ParallelMode -from .common import is_using_ddp -from typing import Dict, List - - -def get_moe_epsize_param_dict(model: nn.Module) -> Dict[int, List[nn.Parameter]]: - """Returns a parameter dictionary, the key of which is the expert parallel - size of every parameter. Since the parameters in data parallelism is replicated - in each GPU, we set their ep_size to 1. - - Args: - model (:class:`torch.nn.Module`): A pyTorch `nn.Module` from which we get dict. - """ - epsize_param_dict = dict() - for param in model.parameters(): - if not hasattr(param, 'moe_info'): - ep_size = 1 # set ep_size to 1 for dp parameters - else: - ep_size = param.moe_info.ep_size - if ep_size not in epsize_param_dict: - epsize_param_dict[ep_size] = [] - epsize_param_dict[ep_size].append(param) - - return epsize_param_dict - - -def sync_moe_model_param(model: nn.Module): - """Make sure model parameters are consistent in MoE parallel context. - - Args: - model (:class:`torch.nn.Module`): A pyTorch model on whose parameters you check the consistency. - """ - if is_using_ddp(): - - param_dict = get_moe_epsize_param_dict(model) - - # synchronize the parameters whose dp_group is the whole world - if 1 in param_dict: - src_rank = gpc.get_ranks_in_group(ParallelMode.DATA)[0] - for param in param_dict[1]: - dist.broadcast(param, src=src_rank, group=gpc.get_group(ParallelMode.DATA)) - - for ep_size in param_dict: - # When ep_size = world_size, communication is not needed - if ep_size != 1 and ep_size != MOE_CONTEXT.world_size: - src_rank = dist.get_rank(MOE_CONTEXT.parallel_info_dict[ep_size].ep_group) - for param in param_dict[ep_size]: - dist.broadcast(param, src=src_rank, group=param.moe_info.dp_group) diff --git a/examples/language/openmoe/model/modeling_openmoe.py b/examples/language/openmoe/model/modeling_openmoe.py index 1ea9d48523c3..ec7e1e8941f7 100644 --- a/examples/language/openmoe/model/modeling_openmoe.py +++ b/examples/language/openmoe/model/modeling_openmoe.py @@ -36,8 +36,8 @@ replace_return_docstrings, ) -from colossalai.context import MOE_CONTEXT -from colossalai.nn.layer.moe.layers import SparseMLP +from colossalai.moe.layers import SparseMLP +from colossalai.moe.manager import MOE_MANAGER logger = logging.get_logger(__name__) @@ -455,7 +455,7 @@ def __init__(self, config: LlamaConfig, moe: bool): min_capacity=config.min_capacity, noisy_policy=config.noisy_policy, drop_tks=config.drop_tks, - expert_parallel=MOE_CONTEXT.get_parallel() if MOE_CONTEXT.is_initialized else config.expert_parallel, + expert_parallel=MOE_MANAGER.get_parallel() if MOE_MANAGER.is_initialized else config.expert_parallel, hidden_size=config.hidden_size, intermediate_size=config.intermediate_size, activation=config.hidden_act, @@ -891,7 +891,7 @@ def forward( "Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you." ```""" # reset moe loss - MOE_CONTEXT.reset_loss() + MOE_MANAGER.reset_loss() output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_hidden_states = (output_hidden_states @@ -1014,7 +1014,7 @@ def _reorder_cache(past_key_values, beam_idx): return reordered_past def _calculate_router_loss(self): - aux_loss, z_loss = MOE_CONTEXT.get_loss() + aux_loss, z_loss = MOE_MANAGER.get_loss() assert len(aux_loss) == len(z_loss) == self.config.num_hidden_layers // self.config.moe_layer_interval aux_loss = self.config.router_aux_loss_factor * sum(aux_loss) / len(aux_loss) z_loss = self.config.router_z_loss_factor * sum(z_loss) / len(z_loss) diff --git a/examples/language/openmoe/train.py b/examples/language/openmoe/train.py index 67dd387a3950..132f17a9ba0f 100644 --- a/examples/language/openmoe/train.py +++ b/examples/language/openmoe/train.py @@ -15,10 +15,10 @@ from colossalai.booster import Booster from colossalai.booster.plugin import LowLevelZeroPlugin from colossalai.cluster import DistCoordinator -from colossalai.context import MOE_CONTEXT from colossalai.logging import disable_existing_loggers, get_dist_logger -from colossalai.nn.layer.moe import MoeCheckpintIO -from colossalai.nn.layer.moe.utils import skip_init +from colossalai.moe import MoeCheckpintIO +from colossalai.moe.manager import MOE_MANAGER +from colossalai.moe.utils import skip_init from colossalai.utils import get_current_device @@ -95,7 +95,7 @@ def main(): coordinator = DistCoordinator() # Set up moe - MOE_CONTEXT.setup(seed=42, parallel="EP") + MOE_MANAGER.setup(seed=42, parallel="EP") # Manage loggers disable_existing_loggers() diff --git a/tests/test_moe/moe_utils.py b/tests/test_moe/moe_utils.py index 7f9bfe376632..3371c35fd295 100644 --- a/tests/test_moe/moe_utils.py +++ b/tests/test_moe/moe_utils.py @@ -2,17 +2,14 @@ import torch.distributed as dist import torch.nn as nn -from colossalai.context import MOE_CONTEXT -from colossalai.context.moe_context import MOE_CONTEXT -from colossalai.context.parallel_mode import ParallelMode -from colossalai.core import global_context as gpc from colossalai.engine.gradient_handler._base_gradient_handler import BaseGradientHandler from colossalai.engine.gradient_handler.utils import bucket_allreduce +from colossalai.moe import SparseMLP +from colossalai.moe.manager import MOE_MANAGER +from colossalai.moe.utils import get_moe_epsize_param_dict from colossalai.nn import CheckpointModule -from colossalai.nn.layer import SparseMLP from colossalai.registry import GRADIENT_HANDLER from colossalai.tensor.moe_tensor.api import get_ep_group, get_ep_rank, get_ep_size, is_moe_tensor -from colossalai.utils.moe import get_moe_epsize_param_dict class MoeModel(nn.Module): @@ -39,7 +36,7 @@ def _forward(self, x): self.test_transform = TestSubModule() def forward(self, x): - MOE_CONTEXT.reset_loss() + MOE_MANAGER.reset_loss() x = self.test_embed(x) x = self.test_transform(x) @@ -68,21 +65,19 @@ def handle_gradient(self): Then running an all-reduce operation for all parameters in experts across moe model parallel group """ - global_data = gpc.data_parallel_size - - if global_data > 1: + if dist.get_world_size() > 1: epsize_param_dict = get_moe_epsize_param_dict(self._model) # epsize is 1, indicating the params are replicated among processes in data parallelism # use the ParallelMode.DATA to get data parallel group # reduce gradients for all parameters in data parallelism if 1 in epsize_param_dict: - bucket_allreduce(param_list=epsize_param_dict[1], group=gpc.get_group(ParallelMode.DATA)) + bucket_allreduce(param_list=epsize_param_dict[1]) for ep_size in epsize_param_dict: - if ep_size != 1 and ep_size != MOE_CONTEXT.world_size: + if ep_size != 1 and ep_size != MOE_MANAGER.world_size: bucket_allreduce(param_list=epsize_param_dict[ep_size], - group=MOE_CONTEXT.parallel_info_dict[ep_size].dp_group) + group=MOE_MANAGER.parallel_info_dict[ep_size].dp_group) def sync_tp_from_ep(tp_model: SparseMLP, ep_model: SparseMLP, assert_grad_flag: bool = False) -> None: @@ -160,3 +155,17 @@ def sync_local_from_ep(local_model: SparseMLP, ep_model: SparseMLP, assert_grad_ assert torch.allclose(local_param.grad, all_grad) else: local_param.data.copy_(all_param.data) + + +def assert_not_equal_in_group(tensor, process_group=None): + # all gather tensors from different ranks + world_size = dist.get_world_size(process_group) + tensor_list = [torch.empty_like(tensor) for _ in range(world_size)] + dist.all_gather(tensor_list, tensor, group=process_group) + + # check if they are equal one by one + for i in range(world_size - 1): + a = tensor_list[i] + b = tensor_list[i + 1] + assert not torch.allclose( + a, b), f'expected tensors on rank {i} and {i + 1} to be equal but they are not, {a} vs {b}' diff --git a/tests/test_moe/test_grad_handler.py b/tests/test_moe/test_grad_handler.py index fd9f30ecb473..a588e6cd7148 100644 --- a/tests/test_moe/test_grad_handler.py +++ b/tests/test_moe/test_grad_handler.py @@ -4,21 +4,21 @@ import torch.nn as nn import colossalai -from colossalai.context.moe_context import MOE_CONTEXT -from colossalai.nn.layer.moe import SparseMLP +from colossalai.moe import SparseMLP +from colossalai.moe.manager import MOE_MANAGER +from colossalai.moe.utils import sync_moe_model_param from colossalai.testing import assert_equal_in_group, rerun_if_address_is_in_use, spawn from colossalai.utils import get_current_device -from colossalai.utils.moe import sync_moe_model_param -from tests.test_moe.moe_utils import MoeGradientHandler +from tests.test_moe.moe_utils import MoeGradientHandler, assert_not_equal_in_group BATCH_SIZE = 4 -DIM = 16 +DIM = 4 def run_test(rank, world_size, port): colossalai.launch(config=dict(), rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') - MOE_CONTEXT.setup(42) # MOE initialization + MOE_MANAGER.setup(42) # MOE initialization num_experts_list = [1, 2, 4] layer_list = [] for num_experts in num_experts_list: @@ -32,13 +32,22 @@ def run_test(rank, world_size, port): model = nn.ModuleList(layer_list) model = model.to(get_current_device()) + dist_dict = MOE_MANAGER.parallel_info_dict + assert_not_equal_in_group(layer_list[0].experts.wi.data, dist_dict[1].dp_group) + assert_not_equal_in_group(layer_list[0].experts.wo.data, dist_dict[1].dp_group) + assert_not_equal_in_group(layer_list[1].experts.wi.data, dist_dict[2].dp_group) + assert_not_equal_in_group(layer_list[1].experts.wo.data, dist_dict[2].dp_group) + assert_not_equal_in_group(layer_list[2].experts.wi.data, dist_dict[4].dp_group) + assert_not_equal_in_group(layer_list[2].experts.wo.data, dist_dict[4].dp_group) + sync_moe_model_param(model) - dist_dict = MOE_CONTEXT.parallel_info_dict assert_equal_in_group(layer_list[0].experts.wi.data, dist_dict[1].dp_group) assert_equal_in_group(layer_list[0].experts.wo.data, dist_dict[1].dp_group) assert_equal_in_group(layer_list[1].experts.wi.data, dist_dict[2].dp_group) assert_equal_in_group(layer_list[1].experts.wo.data, dist_dict[2].dp_group) + assert_equal_in_group(layer_list[2].experts.wi.data, dist_dict[4].dp_group) + assert_equal_in_group(layer_list[2].experts.wo.data, dist_dict[4].dp_group) # MoE model synchronization passed grad_handler = MoeGradientHandler(model, 0) @@ -48,7 +57,7 @@ def run_test(rank, world_size, port): data = torch.randn(BATCH_SIZE, DIM, device=get_current_device()) grad = torch.randn_like(data) - MOE_CONTEXT.reset_loss() + MOE_MANAGER.reset_loss() for layer in layer_list: data = layer(data) data.backward(grad) @@ -58,6 +67,8 @@ def run_test(rank, world_size, port): assert_equal_in_group(layer_list[0].experts.wo.grad, dist_dict[1].dp_group) assert_equal_in_group(layer_list[1].experts.wi.grad, dist_dict[2].dp_group) assert_equal_in_group(layer_list[1].experts.wo.grad, dist_dict[2].dp_group) + assert_equal_in_group(layer_list[2].experts.wi.grad, dist_dict[4].dp_group) + assert_equal_in_group(layer_list[2].experts.wo.grad, dist_dict[4].dp_group) # MoE grad handler test passed diff --git a/tests/test_moe/test_kernel.py b/tests/test_moe/test_kernel.py index 46846206f7d1..0074a698fd96 100644 --- a/tests/test_moe/test_kernel.py +++ b/tests/test_moe/test_kernel.py @@ -1,15 +1,14 @@ import pytest import torch +import torch.distributed as dist import colossalai -from colossalai.context import ParallelMode -from colossalai.context.moe_context import MOE_CONTEXT -from colossalai.core import global_context as gpc -from colossalai.nn.layer.moe import SparseMLP +from colossalai.moe import SparseMLP +from colossalai.moe.manager import MOE_MANAGER from colossalai.testing import rerun_if_address_is_in_use, spawn from colossalai.utils import get_current_device -BATCH_SIZE = 16 +BATCH_SIZE = 4 NUM_EXPERTS = 4 @@ -22,10 +21,10 @@ def run_routing(rank, world_size, port, rs=2, hidden_size=128, data_type=torch.f torch.backends.cuda.matmul.allow_tf32 = False colossalai.launch(config=dict(), rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') - local_rank = gpc.get_local_rank(ParallelMode.GLOBAL) + local_rank = dist.get_rank() - MOE_CONTEXT.setup(42) # MOE environment initialization - MOE_CONTEXT.reset_loss() + MOE_MANAGER.setup(42) # MOE environment initialization + MOE_MANAGER.reset_loss() torch.manual_seed(rs + local_rank) # set each process has different random seed # get randomized data diff --git a/tests/test_moe/test_moe_checkpoint.py b/tests/test_moe/test_moe_checkpoint.py index fc51c2217233..20eb0969ca24 100644 --- a/tests/test_moe/test_moe_checkpoint.py +++ b/tests/test_moe/test_moe_checkpoint.py @@ -5,8 +5,8 @@ import torch.distributed as dist import colossalai -from colossalai.context import MOE_CONTEXT -from colossalai.nn.layer.moe import MoeCheckpintIO +from colossalai.moe import MoeCheckpintIO +from colossalai.moe.manager import MOE_MANAGER from colossalai.testing import rerun_if_address_is_in_use, spawn from colossalai.utils import get_current_device from tests.test_moe.moe_utils import MoeModel @@ -32,7 +32,7 @@ def exam_moe_checkpoint(): def _run_dist(rank, world_size, port): colossalai.launch(config=dict(), rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') - MOE_CONTEXT.setup(seed=42) + MOE_MANAGER.setup(seed=42) exam_moe_checkpoint() diff --git a/tests/test_moe/test_moe_ep_tp.py b/tests/test_moe/test_moe_ep_tp.py index cb261912e0f6..253fe6a7c094 100644 --- a/tests/test_moe/test_moe_ep_tp.py +++ b/tests/test_moe/test_moe_ep_tp.py @@ -3,11 +3,11 @@ import torch.distributed as dist import colossalai -from colossalai.context.moe_context import MOE_CONTEXT -from colossalai.nn.layer.moe import SparseMLP +from colossalai.moe import SparseMLP +from colossalai.moe.manager import MOE_MANAGER +from colossalai.moe.utils import sync_moe_model_param from colossalai.testing import assert_equal_in_group, rerun_if_address_is_in_use, spawn from colossalai.utils import get_current_device -from colossalai.utils.moe import sync_moe_model_param from tests.test_moe.moe_utils import MoeGradientHandler, sync_tp_from_ep BATCH_SIZE = 4 @@ -16,7 +16,7 @@ def run_test(rank, world_size, port): colossalai.launch(config=dict(), rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') - MOE_CONTEXT.setup(42) # MOE initialization + MOE_MANAGER.setup(42) # MOE initialization ep_model = SparseMLP(num_experts=4, expert_parallel="EP", hidden_size=DIM, intermediate_size=DIM) tp_model = SparseMLP(num_experts=4, expert_parallel="TP", hidden_size=DIM, intermediate_size=DIM) @@ -25,7 +25,7 @@ def run_test(rank, world_size, port): # sync ep param sync_moe_model_param(ep_model) - dist_dict = MOE_CONTEXT.parallel_info_dict + dist_dict = MOE_MANAGER.parallel_info_dict assert_equal_in_group(ep_model.experts.wi.data, dist_dict[2].dp_group) assert_equal_in_group(ep_model.experts.wo.data, dist_dict[2].dp_group) grad_handler = MoeGradientHandler(ep_model) @@ -38,9 +38,9 @@ def run_test(rank, world_size, port): ep_data = tp_data.detach()[2 * rank:2 * (rank + 1)] out_tp = tp_model(tp_data) - MOE_CONTEXT.reset_loss() + MOE_MANAGER.reset_loss() out_ep = ep_model(ep_data) - MOE_CONTEXT.reset_loss() + MOE_MANAGER.reset_loss() assert torch.allclose(out_ep, out_tp[2 * rank:2 * (rank + 1)]) out_tp.mean().backward() diff --git a/tests/test_moe/test_moe_group.py b/tests/test_moe/test_moe_group.py index fd87a9a3135d..f5d54ba290aa 100644 --- a/tests/test_moe/test_moe_group.py +++ b/tests/test_moe/test_moe_group.py @@ -3,11 +3,11 @@ import torch.nn as nn import colossalai -from colossalai.context.moe_context import MOE_CONTEXT -from colossalai.nn.layer.moe import EPMLPExperts, TPMLPExperts +from colossalai.moe import EPMLPExperts, TPMLPExperts +from colossalai.moe.manager import MOE_MANAGER +from colossalai.moe.utils import sync_moe_model_param from colossalai.testing import assert_equal_in_group, rerun_if_address_is_in_use, spawn from colossalai.utils import get_current_device -from colossalai.utils.moe import sync_moe_model_param HIDDEN_SIZE = 4 INTERMEDIATE_SIZE = 8 @@ -31,7 +31,7 @@ def run_moe_init(expert_cls): assert exp2.num_local_experts == 4 assert exp3.num_local_experts == 8 - parallel_info_dict = MOE_CONTEXT.parallel_info_dict + parallel_info_dict = MOE_MANAGER.parallel_info_dict rank = dist.get_rank() # group creation assert @@ -59,7 +59,7 @@ def run_moe_init(expert_cls): def _run_test(rank, world_size, port, expert_cls): colossalai.launch(config=dict(), rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') - MOE_CONTEXT.setup(seed=42) + MOE_MANAGER.setup(seed=42) run_moe_init(expert_cls) diff --git a/tests/test_moe/test_moe_local.py b/tests/test_moe/test_moe_local.py index e41a0d821a10..872b65c2d1f1 100644 --- a/tests/test_moe/test_moe_local.py +++ b/tests/test_moe/test_moe_local.py @@ -3,11 +3,11 @@ import torch.distributed as dist import colossalai -from colossalai.context.moe_context import MOE_CONTEXT -from colossalai.nn.layer.moe import SparseMLP +from colossalai.moe import SparseMLP +from colossalai.moe.manager import MOE_MANAGER +from colossalai.moe.utils import sync_moe_model_param from colossalai.testing import assert_equal_in_group, rerun_if_address_is_in_use, spawn from colossalai.utils import get_current_device -from colossalai.utils.moe import sync_moe_model_param from tests.test_moe.moe_utils import MoeGradientHandler, sync_local_from_ep BATCH_SIZE = 4 @@ -16,7 +16,7 @@ def run_test(rank, world_size, port): colossalai.launch(config=dict(), rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') - MOE_CONTEXT.setup(42) # MOE initialization + MOE_MANAGER.setup(42) # MOE initialization ep_model = SparseMLP(num_experts=4, expert_parallel="EP", hidden_size=DIM, intermediate_size=DIM) local_model = SparseMLP(num_experts=4, expert_parallel=None, hidden_size=DIM, intermediate_size=DIM) @@ -25,7 +25,7 @@ def run_test(rank, world_size, port): # sync ep param sync_moe_model_param(ep_model) - dist_dict = MOE_CONTEXT.parallel_info_dict + dist_dict = MOE_MANAGER.parallel_info_dict assert_equal_in_group(ep_model.experts.wi.data, dist_dict[2].dp_group) assert_equal_in_group(ep_model.experts.wo.data, dist_dict[2].dp_group) grad_handler = MoeGradientHandler(ep_model) @@ -38,9 +38,9 @@ def run_test(rank, world_size, port): ep_data = tp_data.detach()[2 * rank:2 * (rank + 1)] out_tp = local_model(tp_data) - MOE_CONTEXT.reset_loss() + MOE_MANAGER.reset_loss() out_ep = ep_model(ep_data) - MOE_CONTEXT.reset_loss() + MOE_MANAGER.reset_loss() assert torch.allclose(out_ep, out_tp[2 * rank:2 * (rank + 1)]) out_tp.mean().backward() @@ -54,10 +54,11 @@ def run_test(rank, world_size, port): @pytest.mark.dist +@pytest.mark.parametrize("world_size", [2]) @rerun_if_address_is_in_use() -def test_moe_ep_tp(): - spawn(run_test, 2) +def test_moe_local(world_size): + spawn(run_test, world_size) if __name__ == '__main__': - test_moe_ep_tp() + test_moe_local() diff --git a/tests/test_moe/test_moe_zero_fwd_bwd.py b/tests/test_moe/test_moe_zero_fwd_bwd.py index f1f888203746..2b2afa4623b5 100644 --- a/tests/test_moe/test_moe_zero_fwd_bwd.py +++ b/tests/test_moe/test_moe_zero_fwd_bwd.py @@ -5,8 +5,7 @@ from colossalai.booster import Booster from colossalai.booster.plugin import LowLevelZeroPlugin from colossalai.booster.plugin.low_level_zero_plugin import LowLevelZeroModel -from colossalai.context import MOE_CONTEXT -from colossalai.nn import MoeLoss +from colossalai.moe.manager import MOE_MANAGER from colossalai.testing import rerun_if_address_is_in_use, spawn from colossalai.testing.random import seed_all from tests.test_moe.moe_utils import MoeGradientHandler, MoeModel @@ -87,7 +86,7 @@ def run_zero_test(local_rank, world_size, stage=1): def run_dist(rank, world_size, port): colossalai.launch(config=dict(), rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') - MOE_CONTEXT.setup(seed=42) + MOE_MANAGER.setup(seed=42) seed_all(42 + rank) run_zero_test(rank, world_size, stage=1) run_zero_test(rank, world_size, stage=2) diff --git a/tests/test_moe/test_moe_zero_optim.py b/tests/test_moe/test_moe_zero_optim.py index 229ee528b4fc..38a5cfbfd66e 100644 --- a/tests/test_moe/test_moe_zero_optim.py +++ b/tests/test_moe/test_moe_zero_optim.py @@ -5,8 +5,7 @@ from colossalai.booster import Booster from colossalai.booster.plugin import LowLevelZeroPlugin from colossalai.booster.plugin.low_level_zero_plugin import LowLevelZeroModel -from colossalai.context import MOE_CONTEXT -from colossalai.nn import MoeLoss +from colossalai.moe.manager import MOE_MANAGER from colossalai.testing import rerun_if_address_is_in_use, spawn from tests.test_moe.moe_utils import MoeGradientHandler, MoeModel @@ -76,7 +75,7 @@ def run_zero_optim_test(local_rank, world_size, stage=1): def run_dist(rank, world_size, port): colossalai.launch(config=dict(), rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') - MOE_CONTEXT.setup(seed=42) + MOE_MANAGER.setup(seed=42) run_zero_optim_test(rank, world_size, stage=1) run_zero_optim_test(rank, world_size, stage=2)