Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
120 changes: 118 additions & 2 deletions colossalai/booster/plugin/moe_hybrid_parallel_plugin.py
Original file line number Diff line number Diff line change
@@ -1,23 +1,71 @@
import random
from typing import Optional
from typing import Callable, Optional, OrderedDict, Tuple

import numpy as np
import torch
import torch.distributed as dist
from torch.distributed import ProcessGroup
from torch.nn import Module
from torch.optim import Optimizer
from torch.optim.lr_scheduler import _LRScheduler as LRScheduler
from torch.utils.data import DataLoader
from torch.utils.data.distributed import DistributedSampler

from colossalai.booster.plugin.hybrid_parallel_plugin import HybridParallelPlugin
from colossalai.booster.plugin.hybrid_parallel_plugin import (
HybridParallelAMPOptimizer,
HybridParallelModule,
HybridParallelNaiveOptimizer,
HybridParallelPlugin,
get_param_info,
init_pipeline_optimizer,
)
from colossalai.cluster import ProcessGroupMesh
from colossalai.interface import ModelWrapper, OptimizerWrapper
from colossalai.moe import MoeCheckpintIO
from colossalai.pipeline.schedule import OneForwardOneBackwardSchedule
from colossalai.pipeline.stage_manager import PipelineStageManager
from colossalai.shardformer import ShardConfig
from colossalai.shardformer.policies.base_policy import Policy
from colossalai.zero.low_level import LowLevelZeroOptimizer

PP_AXIS, DP_AXIS, TP_AXIS = 0, 1, 2


class HybridParallelZeroOptimizer(LowLevelZeroOptimizer):

def __init__(
self,
optimizer: Optimizer,
model: Module,
use_pipeline: bool,
param_info: OrderedDict,
initial_scale: int = 2**16, # grad scaler config
min_scale: int = 1,
growth_factor: float = 2.,
backoff_factor: float = .5,
growth_interval: int = 2000,
hysteresis: int = 2,
max_scale: int = 2**24,
clip_grad_norm: float = 0.0, # grad clipping
verbose: bool = False,
reduce_bucket_size: int = 1024 * 1024, # communication
communication_dtype: Optional[torch.dtype] = None,
overlap_communication: bool = True,
partition_grad: bool = False, # stage 2 flag
cpu_offload: bool = False, # cpu offload
dp_process_group: Optional[ProcessGroup] = None, # the dp pg for comm
tp_process_group: Optional[ProcessGroup] = None, # if using tp
forced_dtype: Optional[torch.dtype] = None,
extra_dp_process_group: Optional[ProcessGroup] = None):
self.param_info = param_info
if use_pipeline:
init_pipeline_optimizer(optimizer, model)
super().__init__(optimizer, initial_scale, min_scale, growth_factor, backoff_factor, growth_interval,
hysteresis, max_scale, clip_grad_norm, verbose, reduce_bucket_size, communication_dtype,
overlap_communication, partition_grad, cpu_offload, dp_process_group, tp_process_group,
forced_dtype, extra_dp_process_group)


class MoeHybridParallelPlugin(HybridParallelPlugin):
"""
Plugin for Moe Hybrid Parallel Training.
Expand Down Expand Up @@ -78,6 +126,7 @@ class MoeHybridParallelPlugin(HybridParallelPlugin):
def __init__(self,
tp_size: int,
pp_size: int,
extra_dp_size: int = 1,
precision: str = 'fp16',
zero_stage: int = 0,
enable_all_optimization: bool = False,
Expand Down Expand Up @@ -106,6 +155,7 @@ def __init__(self,
cpu_offload: bool = False,
communication_dtype: Optional[torch.dtype] = None,
overlap_communication: bool = True,
use_ep_inside: bool = True,
custom_policy: Policy = None) -> None:

super().__init__(tp_size=tp_size,
Expand All @@ -132,6 +182,23 @@ def __init__(self,
self.enable_sequence_parallelism = enable_sequence_parallelism
# we change pg mesh to (pp, dp, tp) for better moe performance
self.pg_mesh = ProcessGroupMesh(self.pp_size, self.dp_size, self.tp_size)

# sync moe in outer dp group, and sync other param in global dp group
if extra_dp_size > 1:
ep_size = self.dp_size // extra_dp_size
if use_ep_inside:
self.pg_mesh_moe = ProcessGroupMesh(self.pp_size, extra_dp_size, ep_size)
self.extra_dp_group = self.pg_mesh_moe.get_group_along_axis(1)
if dist.get_rank() == 0:
print(f"Zero Parallel: pp {self.pp_size}, outer_dp {extra_dp_size}, inner_dp {ep_size}")
else:
self.pg_mesh_moe = ProcessGroupMesh(self.pp_size, ep_size, extra_dp_size)
self.extra_dp_group = self.pg_mesh_moe.get_group_along_axis(2)
if dist.get_rank() == 0:
print(f"Zero Parallel: pp {self.pp_size}, outer_dp {ep_size}, inner_dp {extra_dp_size}")
else:
self.extra_dp_group = None

self.stage_manager = None
self.schedule = None
self.custom_policy = custom_policy
Expand Down Expand Up @@ -235,3 +302,52 @@ def seed_worker(worker_id):
def get_checkpoint_io(self) -> MoeCheckpintIO:
self.checkpoint_io = MoeCheckpintIO(self.dp_group, self.pp_group, self.tp_group, self.zero_stage)
return self.checkpoint_io

def configure(
self,
model: Module,
optimizer: Optional[Optimizer] = None,
criterion: Optional[Callable] = None,
dataloader: Optional[DataLoader] = None,
lr_scheduler: Optional[LRScheduler] = None,
) -> Tuple[Module, OptimizerWrapper, Callable, DataLoader, LRScheduler]:
param_info = get_param_info(optimizer)
if not isinstance(model, ModelWrapper):
use_ddp = self.dp_size > 1 and self.pp_size == 1 and self.zero_stage == 0
model = HybridParallelModule(model, self.precision, self.shard_config, self.dp_group, use_ddp,
self.ddp_config, self.custom_policy)
if optimizer is not None and not isinstance(optimizer, OptimizerWrapper):
if self.zero_stage == 0:
if self.precision in ['fp16', 'bf16']:
optimizer = HybridParallelAMPOptimizer(optimizer,
model,
use_pipeline=self.enable_pipeline_parallelism,
param_info=param_info,
precision=self.precision,
max_norm=self.max_norm,
**self.amp_config)
self.checkpoint_io.link_master_and_working_param(optimizer.working_to_master_map,
optimizer.master_to_working_map)
else:
optimizer = HybridParallelNaiveOptimizer(optimizer,
model,
use_pipeline=self.enable_pipeline_parallelism,
param_info=param_info)
else:
assert self.dp_size > 1, "Please use Zero when data parallel size is greater than 1."
assert self.precision != 'fp32', "Please set precision to 'fp16' or 'bf16' when using ZeRO."
optimizer = HybridParallelZeroOptimizer(optimizer,
model,
use_pipeline=self.enable_pipeline_parallelism,
param_info=param_info,
dp_process_group=self.dp_group,
tp_process_group=self.tp_group,
extra_dp_process_group=self.extra_dp_group,
verbose=True,
clip_grad_norm=self.max_norm,
**self.zero_config,
**self.amp_config)
self.checkpoint_io.link_master_and_working_param(optimizer._param_store.working_to_master_param,
optimizer._param_store.master_to_working_param)

return model, optimizer, criterion, dataloader, lr_scheduler
50 changes: 31 additions & 19 deletions colossalai/moe/experts.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,27 +39,28 @@ def __init__(
activation: Optional[Callable] = None,
drop_rate: float = 0,
gated: bool = False,
use_kernel: bool = False,
):
super().__init__()
assert expert_parallel in ["EP", "TP", None]
self.expert_parallel = expert_parallel
self.num_total_experts = num_experts
self.gated = gated
self.use_kernel = use_kernel
self.hidden_size = hidden_size
self.intermediate_size = intermediate_size

# get expert parallel info
if expert_parallel is not None:
self.num_local_experts, self.moe_info = MOE_MANAGER.get_info(
num_experts, use_tp=True if expert_parallel == "TP" else False)
# get settings for different parallel
self.ep_size = get_ep_size(self)
if expert_parallel == "TP":
assert (
intermediate_size %
MOE_MANAGER.max_ep_size == 0), "intermediate_size should be divide by maximum expert parallel size"
intermediate_size = intermediate_size // MOE_MANAGER.max_ep_size
intermediate_size = intermediate_size // self.ep_size
num_experts = self.num_total_experts
else:
num_experts = self.num_local_experts
self.ep_size = get_ep_size(self)
else:
self.num_local_experts = self.num_total_experts
self.ep_size = 1
Expand All @@ -71,19 +72,6 @@ def __init__(
self.wi = nn.Parameter(torch.empty(num_experts, hidden_size, intermediate_size))
self.wo = nn.Parameter(torch.empty(num_experts, intermediate_size, hidden_size))

# expert param should be different
if expert_parallel is not None:
seed_ctx = Randomizer(MOE_MANAGER.seed).fork_rng(enable_cpu=True)
else:
seed_ctx = nullcontext()
with seed_ctx:
if gated:
torch.nn.init.trunc_normal_(self.wi_gate, std=math.sqrt(0.1 / hidden_size))
torch.nn.init.trunc_normal_(self.wi_up, std=math.sqrt(0.1 / hidden_size))
else:
torch.nn.init.trunc_normal_(self.wi, std=math.sqrt(0.1 / hidden_size))
torch.nn.init.trunc_normal_(self.wo, std=math.sqrt(0.1 / intermediate_size))

self.act_name = activation
self.act = get_activation(activation)
self.drop = nn.Dropout(p=drop_rate)
Expand All @@ -92,6 +80,24 @@ def __init__(
for param in self.parameters():
set_moe_tensor_info(param, self.moe_info)

# init param
self.reset_parameters()

@torch.no_grad()
def reset_parameters(self):
# expert param should be different
if self.expert_parallel is not None:
seed_ctx = Randomizer(MOE_MANAGER.seed).fork_rng(enable_cpu=True)
else:
seed_ctx = Randomizer(42).fork_rng(enable_cpu=True)
with seed_ctx:
if self.gated:
torch.nn.init.normal_(self.wi_gate, std=math.sqrt(0.1 / self.hidden_size))
torch.nn.init.normal_(self.wi_up, std=math.sqrt(0.1 / self.hidden_size))
else:
torch.nn.init.normal_(self.wi, std=math.sqrt(0.1 / self.hidden_size))
torch.nn.init.normal_(self.wo, std=math.sqrt(0.1 / self.intermediate_size))

def forward(self, x: torch.Tensor, param_slice: Tuple[slice] = (slice(None),)) -> torch.Tensor:
"""
Args:
Expand All @@ -110,7 +116,7 @@ def forward(self, x: torch.Tensor, param_slice: Tuple[slice] = (slice(None),)) -
x = x.reshape(e, -1, h)

if self.gated:
if HAS_TRITON and self.act_name == "swiglu":
if self.use_kernel and HAS_TRITON and self.act_name == "swiglu":
x = LlamaActCombine.apply(
torch.bmm(x, self.wi_gate[param_slice]),
torch.bmm(x, self.wi_up[param_slice]),
Expand Down Expand Up @@ -142,7 +148,9 @@ def __init__(
activation=None,
drop_rate: float = 0,
gated: bool = False,
use_kernel: bool = False,
):
# TODO: This class can be aborted
super().__init__(
num_experts,
hidden_size,
Expand All @@ -151,6 +159,7 @@ def __init__(
activation,
drop_rate,
gated,
use_kernel,
)


Expand All @@ -168,7 +177,9 @@ def __init__(
activation: str = None,
drop_rate: float = 0,
gated: bool = False,
use_kernel: bool = False,
):
# TODO: This class can be aborted
super().__init__(
num_experts,
hidden_size,
Expand All @@ -177,6 +188,7 @@ def __init__(
activation,
drop_rate,
gated,
use_kernel,
)


Expand Down
44 changes: 29 additions & 15 deletions colossalai/moe/layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,22 +51,23 @@ def __init__(
min_capacity: int = 4,
noisy_policy: Optional[str] = None,
drop_tks: bool = True,
expert_parallel: str = "EP",
hidden_size: int = 2048,
intermediate_size: int = 2048,
activation: str = None,
gated: bool = False,
):
super().__init__()
self.hidden_size = hidden_size
self.intermediate_size = intermediate_size
self.num_experts = num_experts
self.use_kernel = MOE_MANAGER.use_kernel_optim
self.expert_parallel = expert_parallel
assert expert_parallel in [
self.expert_parallel = MOE_MANAGER.get_parallel()
self.gated = gated
assert self.expert_parallel in [
"EP",
"TP",
None,
], f"Unsupported expert parallel type {expert_parallel}"
], f"Unsupported expert parallel type {self.expert_parallel}"

# moe router
noisy_func = get_noise_generator(noisy_policy, num_experts)
Expand All @@ -80,23 +81,29 @@ def __init__(
)

# moe experts
expert_cls = get_expert_class(expert_parallel)
self.experts: BaseMLPExperts = expert_cls(
num_experts=num_experts,
hidden_size=hidden_size,
intermediate_size=intermediate_size,
activation=activation,
gated=gated,
)
if expert_parallel is not None:
expert_cls = get_expert_class(self.expert_parallel)
self.experts: BaseMLPExperts = expert_cls(num_experts=num_experts,
hidden_size=hidden_size,
intermediate_size=intermediate_size,
activation=activation,
gated=gated,
use_kernel=self.use_kernel)
if self.expert_parallel is not None:
self.ep_group = get_ep_group(self.experts)
self.ep_size = get_ep_size(self.experts)
else:
self.ep_group = None
self.num_local_experts = self.experts.num_local_experts

# gate
self.gate_weight = torch.nn.Parameter(torch.empty(num_experts, self.hidden_size))
nn.init.trunc_normal_(self.gate_weight, std=math.sqrt(0.1 / self.hidden_size))

# init param
self.reset_parameters()

@torch.no_grad()
def reset_parameters(self):
torch.nn.init.normal_(self.gate_weight, std=math.sqrt(0.1 / self.hidden_size))

def forward(self, inputs: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Expand Down Expand Up @@ -171,7 +178,7 @@ def _ep_process(self, dispatch_data: torch.Tensor) -> torch.Tensor:
expert_output = AllToAll.apply(expert_output, self.ep_group)
return expert_output

def _tp_process(self, dispatch_data: torch.Tensor) -> torch.Tensor:
def _tp_process(self, dispatch_data: torch.Tensor, use_overlap: bool = False) -> torch.Tensor:
"""
TP with overlap.

Expand All @@ -191,6 +198,13 @@ def _tp_process(self, dispatch_data: torch.Tensor) -> torch.Tensor:
Returns:
torch.Tensor: (num_experts, capacity, hidden_size)
"""
if use_overlap == False:
expert_in, _ = AllGather.apply(dispatch_data, self.ep_group)
expert_out = self.experts(expert_in)
expert_out, _ = ReduceScatter.apply(expert_out, self.ep_group)
return expert_out

# TODO: there is accuracy problem in overlap
chunk_num = 4
chunk_size = dispatch_data.shape[0] // chunk_num
out = torch.empty_like(dispatch_data)
Expand Down
Loading