Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
48 changes: 13 additions & 35 deletions colossalai/context/moe_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,29 +3,9 @@
import torch
import torch.distributed as dist

from colossalai.context.parallel_mode import ParallelMode
from colossalai.context.singleton_meta import SingletonMeta
from colossalai.tensor import ProcessGroup


def _check_sanity():
from colossalai.core import global_context as gpc
if gpc.tensor_parallel_size > 1 or gpc.pipeline_parallel_size > 1:
raise NotImplementedError("Moe is not compatible with tensor or "
"pipeline parallel at present.")


class MoeParallelInfo:
"""Moe parallelism information, storing parallel sizes and groups.
"""

def __init__(self, ep_size: int, dp_size: int):
_check_sanity()
self.ep_size = ep_size
self.dp_size = dp_size
self.pg = ProcessGroup(tp_degree=ep_size, dp_degree=dp_size)
self.ep_group = self.pg.tp_process_group()
self.dp_group = self.pg.dp_process_group()
from colossalai.tensor.moe_tensor.api import get_moe_info
from colossalai.tensor.moe_tensor.moe_info import MoeParallelInfo


class MoeContext(metaclass=SingletonMeta):
Expand All @@ -34,12 +14,12 @@ class MoeContext(metaclass=SingletonMeta):
"""

def __init__(self):
self.world_size = 1
self.world_size = None
# Users may want to set maximum expert parallel size smaller than the world size
# since very low bandwidth across nodes may constrain the performance of MoE
# When we have a maximum expert parallel size, we have a minimum data parallel size naturally
self.max_ep_size = 1
self.min_dp_size = 1
self.max_ep_size = None
self.min_dp_size = None
self.aux_loss = None
self.use_kernel_optim = True

Expand All @@ -54,17 +34,12 @@ def parallel_info_dict(self):
def is_initialized(self):
return self.has_setup

def setup(self, seed: int, use_kernel_optim: bool = True):
def setup(self, seed: int, use_kernel_optim: bool = True, max_ep_size: int = 8):
assert not self.is_initialized, "MoE distributed context shouldn't be set up again"
_check_sanity()
assert torch.cuda.is_available(), "MoE requires to enable CUDA first"

self.world_size = dist.get_world_size()

from colossalai.core import global_context as gpc
self.max_ep_size = gpc.config.get('max_ep_size', self.world_size)
assert self.world_size % self.max_ep_size == 0, \
"Maximum expert parallel size must be a factor of the number of GPUs"
self.max_ep_size = min(max_ep_size, dist.get_world_size())
self.min_dp_size = self.world_size // self.max_ep_size

# Enabling kernel optimization may raise error in some cases
Expand All @@ -75,7 +50,7 @@ def setup(self, seed: int, use_kernel_optim: bool = True):
moe_set_seed(seed)
self.has_setup = True

def get_info(self, num_experts: int) -> Tuple[int, MoeParallelInfo]:
def get_info(self, num_experts: int, use_tp: bool = False) -> Tuple[int, MoeParallelInfo]:
"""Calculate the Data Parallel Group and Expert Parallel Group.

Parameters
Expand Down Expand Up @@ -104,12 +79,15 @@ def get_info(self, num_experts: int) -> Tuple[int, MoeParallelInfo]:
ep_size = self.max_ep_size // dp_size

# Calculate the number of experts for each GPU
num_local_experts = 1 if lt_flag else num_experts // self.max_ep_size
if use_tp:
num_local_experts = num_experts
else:
num_local_experts = 1 if lt_flag else num_experts // self.max_ep_size

# Don't forget to multiply minimum data parallel size
dp_size *= self.min_dp_size
if not (ep_size in self.parallel_info_dict):
self.parallel_info_dict[ep_size] = MoeParallelInfo(ep_size, dp_size)
self.parallel_info_dict[ep_size] = get_moe_info(ep_size, dp_size)

return num_local_experts, self.parallel_info_dict[ep_size]

Expand Down
10 changes: 5 additions & 5 deletions colossalai/nn/layer/moe/__init__.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
from .checkpoint import load_moe_model, save_moe_model
from .experts import Experts, FFNExperts, TPExperts
from .layers import MoeLayer, MoeModule
from .checkpoint import MoeCheckpintIO
from .experts import EPMLPExperts, TPMLPExperts
from .layers import MoeLayer, MoeModule, SparseMLP
from .routers import MoeRouter, Top1Router, Top2Router
from .utils import NormalNoiseGenerator, UniformNoiseGenerator, build_ffn_experts

__all__ = [
'Experts', 'FFNExperts', 'TPExperts', 'Top1Router', 'Top2Router', 'MoeLayer', 'NormalNoiseGenerator',
'UniformNoiseGenerator', 'build_ffn_experts', 'MoeModule', 'MoeRouter', 'save_moe_model', 'load_moe_model'
'EPMLPExperts', 'TPMLPExperts', 'Top1Router', 'Top2Router', 'MoeLayer', 'MoeModule', 'NormalNoiseGenerator',
'UniformNoiseGenerator', 'build_ffn_experts', 'SparseMLP', 'MoeRouter', 'MoeCheckpintIO'
]
77 changes: 49 additions & 28 deletions colossalai/nn/layer/moe/checkpoint.py
Original file line number Diff line number Diff line change
@@ -1,40 +1,61 @@
from pathlib import Path
from typing import Optional

import torch
import torch.distributed as dist
import torch.nn as nn
from torch.optim import Optimizer

from colossalai.checkpoint_io import CheckpointIO
from colossalai.tensor.moe_tensor.api import get_ep_group


class MoeCheckpintIO(CheckpointIO):

def __init__(self) -> None:
super().__init__()

def load_unsharded_model(self, model: nn.Module, checkpoint: str, strict: bool):
state_dict = torch.load(checkpoint)
for name, param in model.named_parameters():
if '.experts.' in name:
ep_rank = dist.get_rank(get_ep_group(param))
ep_size = dist.get_world_size(get_ep_group(param))
for rank in range(ep_size):
new_name = name.replace('.experts.', f'.experts.{rank}.')
if rank == ep_rank:
state_dict[name] = state_dict.pop(new_name)
else:
state_dict.pop(new_name)

from .experts import MoeExperts
model.load_state_dict(state_dict, strict=strict)

def save_unsharded_model(self, model: nn.Module, checkpoint: str, gather_dtensor: bool, use_safetensors: bool):
state_dict = model.state_dict()
if dist.get_rank() == 0:
torch.save(state_dict, checkpoint)
Comment thread
ver217 marked this conversation as resolved.
dist.barrier()

def save_moe_model(model: nn.Module, save_path: str):
state_dict = model.state_dict()
if dist.get_rank() == 0:
torch.save(state_dict, save_path)
dist.barrier()
def load_sharded_model(self, model: nn.Module, index_file_path: str, strict: bool):
raise NotImplementedError()

def save_sharded_model(self, model: nn.Module, checkpoint: str, gather_dtensor: bool, prefix: Optional[str],
size_per_shard: int, use_safetensors: bool):
raise NotImplementedError()

def load_moe_model(model: nn.Module, load_path: str):
state_dict = torch.load(load_path)
# ========================================================
# Abstract methods for optimizer loading/saving implementation
# ========================================================

for prefix, module in model.named_modules():
if prefix.endswith('.moe_layer.experts'):
# this module should be an Experts instance
assert isinstance(module, MoeExperts)
def load_sharded_optimizer(self, optimizer: Optimizer, index_file_path: str, prefix: str):
raise NotImplementedError()

ep_rank = dist.get_rank(module.dist_info.ep_group)
num_local = module.num_local_experts
for i in range(num_local):
expert_id = ep_rank * num_local + i
for name, _ in module.experts[i].named_parameters():
cur_key = f'{prefix}.experts.{i}.{name}'
param_key = f'{prefix}.experts.{expert_id}.{name}'
load_param = state_dict[param_key]
state_dict[cur_key] = load_param
def load_unsharded_optimizer(self, optimizer: Optimizer, checkpoint: Path):
raise NotImplementedError()

for name, _ in module.experts[0].named_parameters():
pop_pre = f'{prefix}.experts.'
pop_suf = f'.{name}'
for i in range(num_local, module.num_total_experts):
pop_key = f'{pop_pre}{i}{pop_suf}'
state_dict.pop(pop_key)
def save_sharded_optimizer(self, optimizer: Optimizer, checkpoint: Path, gather_dtensor: bool, prefix: str,
size_per_shard: int):
raise NotImplementedError()

model.load_state_dict(state_dict)
def save_unsharded_optimizer(self, optimizer: Optimizer, checkpoint: Path, gather_dtensor: bool):
raise NotImplementedError()
Loading