Skip to content
13 changes: 9 additions & 4 deletions colossalai/booster/plugin/hybrid_parallel_plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
from colossalai.pipeline.schedule import OneForwardOneBackwardSchedule
from colossalai.pipeline.stage_manager import PipelineStageManager
from colossalai.shardformer import ShardConfig, ShardFormer
from colossalai.shardformer.policies.base_policy import Policy
from colossalai.zero.low_level import LowLevelZeroOptimizer

from .pp_plugin_base import PipelinePluginBase
Expand All @@ -38,13 +39,15 @@ def _convert_floating_point(x, dtype: torch.dtype = torch.float16):
class HybridParallelModule(ModelWrapper):

def __init__(self, module: Module, precision: str, shard_config: ShardConfig, dp_group: ProcessGroup, use_ddp: bool,
ddp_config: dict) -> None:
ddp_config: dict, custom_policy: Policy) -> None:

self.stage_manager = shard_config.pipeline_stage_manager
self.dp_group = dp_group

shardformer = ShardFormer(shard_config)
module, self.shared_params = shardformer.optimize(module)
if custom_policy is not None:
assert isinstance(custom_policy, object)
module, self.shared_params = shardformer.optimize(module, policy=custom_policy)

# setting process groups for shared parameters
self.shared_param_process_groups = []
Expand Down Expand Up @@ -302,7 +305,8 @@ def __init__(self,
zero_bucket_size_in_m: int = 12,
cpu_offload: bool = False,
communication_dtype: Optional[torch.dtype] = None,
overlap_communication: bool = True) -> None:
overlap_communication: bool = True,
custom_policy: Policy = None) -> None:

super().__init__()
assert dist.get_world_size() % (
Expand All @@ -326,6 +330,7 @@ def __init__(self,
self.pg_mesh = ProcessGroupMesh(self.dp_size, self.pp_size, self.tp_size)
self.stage_manager = None
self.schedule = None
self.custom_policy = custom_policy
assert zero_stage in (0, 1, 2)
if self.pp_size > 1:
assert num_microbatches is not None or microbatch_size is not None, 'num_microbatches or microbatch_size must be specified when using pipeline parallelism'
Expand Down Expand Up @@ -405,7 +410,7 @@ def configure(
if not isinstance(model, ModelWrapper):
use_ddp = self.dp_size > 1 and self.pp_size == 1 and self.zero_stage == 0
model = HybridParallelModule(model, self.precision, self.shard_config, self.dp_group, use_ddp,
self.ddp_config)
self.ddp_config, self.custom_policy)
if optimizer is not None and not isinstance(optimizer, OptimizerWrapper):
if self.zero_stage == 0:
if self.precision in ['fp16', 'bf16']:
Expand Down
173 changes: 173 additions & 0 deletions colossalai/booster/plugin/moe_hybrid_parallel_plugin.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,173 @@
from typing import Optional

import torch
import torch.distributed as dist

from colossalai.booster.plugin.hybrid_parallel_plugin import HybridParallelPlugin
from colossalai.cluster import ProcessGroupMesh
from colossalai.pipeline.schedule import OneForwardOneBackwardSchedule
from colossalai.pipeline.stage_manager import PipelineStageManager
from colossalai.shardformer import ShardConfig
from colossalai.shardformer.policies.base_policy import Policy

PP_AXIS, DP_AXIS, TP_AXIS = 0, 1, 2


class MoeHybridParallelPlugin(HybridParallelPlugin):
"""
Plugin for Moe Hybrid Parallel Training.
Tensor parallel, pipeline parallel and data parallel(DDP/ZeRO) can be picked and combined in this plugin.
The size of tp and pp should be passed in by user, then the size of dp is automatically calculated from dp_size = world_size / (tp_size * pp_size).

Example:
>>> from colossalai.booster import Booster
>>> from colossalai.booster.plugin import HybridParallelPlugin

>>> model, train_dataset, optimizer, criterion = ...
>>> plugin = HybridParallelPlugin(tp_size=2, pp_size=2)

>>> train_dataloader = plugin.prepare_dataloader(train_dataset, batch_size=8)
>>> booster = Booster(plugin=plugin)
>>> model, optimizer, criterion, train_dataloader, _ = booster.boost(model, optimizer, criterion, train_dataloader)

Args:
tp_size (int): The size of tensor parallelism. Tensor parallelism will not be used when tp_size is set to 1.
pp_size (int): The number of pipeline stages in pipeline parallelism. Pipeline parallelism will not be used when pp_size is set to 1.
precision (str, optional): Specifies the precision of parameters during training.
Auto-mixied precision will be used when this argument is set to 'fp16' or 'bf16', otherwise model is trained with 'fp32'.
Defaults to 'fp16'.
zero_stage (int, optional): The stage of ZeRO for data parallelism. Can only be choosed from [0, 1, 2].
When set to 0, ZeRO will not be used. Defaults to 0.
enable_all_optimization (bool, optional): Whether to switch on all the optimizations supported by Shardformer.
Currently all the optimization methods include fused normalization, flash attention and JIT.
Defaults to False.
enable_fused_normalization (bool, optional): Whether to switch on fused normalization in Shardformer. Defaults to False.
enable_flash_attention (bool, optional): Whether to switch on flash attention in Shardformer. Defaults to False.
enable_jit_fused (bool, optional): Whether to switch on JIT in Shardformer. Default to False.
enable_sequence_parallelism (bool): Whether to turn on sequence parallelism in Shardformer. Defaults to False.
enable_sequence_overlap (bool): Whether to turn on sequence overlap in Shardformer. Defaults to False.
num_microbatches (int, optional): Number of microbatches when using pipeline parallelism. Defaults to None.
microbatch_size (int, optional): Microbatch size when using pipeline parallelism.
Either ``num_microbatches`` or ``microbatch_size`` should be provided if using pipeline.
If ``num_microbatches`` is provided, this will be ignored. Defaults to None.
initial_scale (float, optional): The initial loss scale of AMP. Defaults to 2**16.
min_scale (float, optional): The minimum loss scale of AMP. Defaults to 1.
growth_factor (float, optional): The multiplication factor for increasing loss scale when using AMP. Defaults to 2.
backoff_factor (float, optional): The multiplication factor for decreasing loss scale when using AMP. Defaults to 0.5.
growth_interval (int, optional): The number of steps to increase loss scale when no overflow occurs when using AMP. Defaults to 1000.
hysteresis (int, optional): The number of overflows before decreasing loss scale when using AMP. Defaults to 2.
max_scale (float, optional): The maximum loss scale of AMP. Defaults to 2**32.
max_norm (float, optional): Maximum norm for gradient clipping. Defaults to 0.
broadcast_buffers (bool, optional): Whether to broadcast buffers in the beginning of training when using DDP. Defaults to True.
ddp_bucket_cap_mb (int, optional): The bucket size in MB when using DDP. Defaults to 25.
find_unused_parameters (bool, optional): Whether to find unused parameters when using DDP. Defaults to False.
check_reduction (bool, optional): Whether to check reduction when using DDP. Defaults to False.
gradient_as_bucket_view (bool, optional): Whether to use gradient as bucket view when using DDP. Defaults to False.
static_graph (bool, optional): Whether to use static graph when using DDP. Defaults to False.
zero_bucket_size_in_m (int, optional): Gradient reduce bucket size in million elements when using ZeRO. Defaults to 12.
cpu_offload (bool, optional): Whether to open cpu_offload when using ZeRO. Defaults to False.
communication_dtype (torch.dtype, optional): Communication dtype when using ZeRO. If not specified, the dtype of param will be used. Defaults to None.
overlap_communication (bool, optional): Whether to overlap communication and computation when using ZeRO. Defaults to True.
"""

def __init__(self,
tp_size: int,
pp_size: int,
precision: str = 'fp16',
zero_stage: int = 0,
enable_all_optimization: bool = False,
enable_fused_normalization: bool = False,
enable_flash_attention: bool = False,
enable_jit_fused: bool = False,
enable_sequence_parallelism: bool = False,
enable_sequence_overlap: bool = False,
num_microbatches: Optional[int] = None,
microbatch_size: Optional[int] = None,
initial_scale: float = 2**16,
min_scale: float = 1,
growth_factor: float = 2,
backoff_factor: float = 0.5,
growth_interval: int = 1000,
hysteresis: int = 2,
max_scale: float = 2**32,
max_norm: float = 0,
broadcast_buffers: bool = True,
ddp_bucket_cap_mb: int = 25,
find_unused_parameters: bool = False,
check_reduction: bool = False,
gradient_as_bucket_view: bool = False,
static_graph: bool = False,
zero_bucket_size_in_m: int = 12,
cpu_offload: bool = False,
communication_dtype: Optional[torch.dtype] = None,
overlap_communication: bool = True,
custom_policy: Policy = None) -> None:

super().__init__()
assert dist.get_world_size() % (
tp_size * pp_size
) == 0, f'world size {dist.get_world_size()} is not divisible by tp_size {tp_size} * pp_size {pp_size}'

if enable_sequence_parallelism:
assert tp_size > 1, 'Sequence parallelism must be enabled when using tensor parallelism'

self.tp_size = tp_size
self.pp_size = pp_size
self.dp_size = dist.get_world_size() // (tp_size * pp_size)
self.precision = precision
self.zero_stage = zero_stage
self.cpu_offload = cpu_offload
self.enable_all_optimization = enable_all_optimization
self.enable_fused_normalization = enable_fused_normalization
self.enable_flash_attention = enable_flash_attention
self.enable_jit_fused = enable_jit_fused
self.enable_sequence_parallelism = enable_sequence_parallelism
# we change pg mesh to (pp, dp, tp) for better moe performance
self.pg_mesh = ProcessGroupMesh(self.pp_size, self.dp_size, self.tp_size)
self.stage_manager = None
self.schedule = None
self.custom_policy = custom_policy
assert zero_stage in (0, 1, 2)
if self.pp_size > 1:
assert num_microbatches is not None or microbatch_size is not None, 'num_microbatches or microbatch_size must be specified when using pipeline parallelism'
assert self.zero_stage <= 1, 'zero stage must be 0 or 1 when using pipeline parallelism'
self.stage_manager = PipelineStageManager(self.pg_mesh, PP_AXIS)
self.schedule = OneForwardOneBackwardSchedule(self.stage_manager,
num_microbatches=num_microbatches,
microbatch_size=microbatch_size)
self.tp_group = self.pg_mesh.get_group_along_axis(TP_AXIS)
self.dp_group = self.pg_mesh.get_group_along_axis(DP_AXIS)
self.pp_group = self.pg_mesh.get_group_along_axis(PP_AXIS)
self.shard_config = ShardConfig(tensor_parallel_process_group=self.tp_group,
pipeline_stage_manager=self.stage_manager,
enable_tensor_parallelism=self.tp_size > 1,
enable_all_optimization=self.enable_all_optimization,
enable_fused_normalization=self.enable_fused_normalization,
enable_flash_attention=self.enable_flash_attention,
enable_jit_fused=self.enable_jit_fused,
enable_sequence_parallelism=enable_sequence_parallelism,
enable_sequence_overlap=enable_sequence_overlap)
self.amp_config = dict(
initial_scale=initial_scale,
growth_factor=growth_factor,
backoff_factor=backoff_factor,
growth_interval=growth_interval,
hysteresis=hysteresis,
min_scale=min_scale,
max_scale=max_scale,
)

self.ddp_config = dict(broadcast_buffers=broadcast_buffers,
bucket_cap_mb=ddp_bucket_cap_mb,
find_unused_parameters=find_unused_parameters,
check_reduction=check_reduction,
gradient_as_bucket_view=gradient_as_bucket_view,
static_graph=static_graph)

self.zero_config = dict(reduce_bucket_size=zero_bucket_size_in_m * 1024 * 1024,
communication_dtype=communication_dtype,
overlap_communication=overlap_communication,
cpu_offload=cpu_offload,
partition_grad=(self.zero_stage == 2))

self.max_norm = max_norm
91 changes: 70 additions & 21 deletions colossalai/moe/manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,9 @@ def __init__(self):
self.router_z_loss = []
self.parallel = None
self.seed = None
self.use_kernel_optim = True
self.mode = None
self.use_kernel_optim = False
self.use_ep_inside = None

self.has_setup = False
self._parallel_info_dict = dict()
Expand All @@ -37,15 +39,53 @@ def parallel_info_dict(self):
def is_initialized(self):
return self.has_setup

def setup(self, seed: int, use_kernel_optim: bool = True, max_ep_size: int = 8, parallel: bool = None):
def setup(self,
seed: int,
use_kernel_optim: bool = True,
parallel: bool = None,
mode: str = "dynamic",
max_ep_size: int = 8,
fixed_dp_size: int = 0,
fixed_ep_size: int = 0,
fixed_pp_size: int = 0,
use_ep_inside: bool = True) -> None:
"""
Setup MoE distributed context.

Args:
seed (int): Random seed. Defaults to 42.
use_kernel_optim (bool, optional): Use cuda kernel. Defaults to True.
parallel (bool, optional): Parallel mode, should be EP, TP or None. Defaults to None.
mode (str, optional): Should be "fixed" or "dynamic". Defaults to "dynamic".
In fixed mode, the ep size and dp size is fixed.
In dynamic mode, the ep size and dp size will be changed according to num experts.
max_ep_size (int, optional): Max ep size in dynamic mode. Defaults to 8.
fixed_dp_size (int, optional): Fixed dp size in fixed mode. Defaults to 0.
fixed_ep_size (int, optional): Fixed ep size in fixed mode. Defaults to 0.
fixed_pp_size (int, optional): Fixed pp size in fixed mode. Defaults to 0.
use_ep_inside (bool, optional): Use ep inside dp if True, dp inside ep if Fasle. Defaults to True.
"""
assert not self.is_initialized, "MoE distributed context shouldn't be set up again"
assert torch.cuda.is_available(), "MoE requires to enable CUDA first"

self.world_size = dist.get_world_size()
self.seed = seed + dist.get_rank()
self.max_ep_size = min(max_ep_size, dist.get_world_size())
self.min_dp_size = self.world_size // self.max_ep_size
self.parallel = parallel
self.use_ep_inside = use_ep_inside

# init by mode
self.mode = mode
assert self.mode in ["fixed", "dynamic"], "mode should be fixed or dynamic"
if self.mode == "dynamic":
self.max_ep_size = min(max_ep_size, dist.get_world_size())
self.min_dp_size = self.world_size // self.max_ep_size
else:
assert fixed_dp_size > 0 and fixed_ep_size > 0 and fixed_pp_size > 0, "dp_size, ep_size and pp_size should be greater than 0"
assert isinstance(fixed_dp_size, int) and isinstance(fixed_ep_size, int) and isinstance(
fixed_pp_size, int), "dp_size, ep_size and pp_size should be int"
self.ep_size = fixed_ep_size
self.dp_size = fixed_dp_size
self.pp_size = fixed_pp_size

# Enabling kernel optimization may raise error in some cases
# Users can close kernel optimization manually
Expand All @@ -67,30 +107,39 @@ def get_info(self, num_experts: int, use_tp: bool = False) -> Tuple[int, MoePara
number of local experts, the MoeParallelInfo of the current ep_size
"""

gt_flag = num_experts % self.max_ep_size == 0 # check whether num_experts is greater
lt_flag = self.max_ep_size % num_experts == 0 # check whether num_experts is less

assert gt_flag or lt_flag, "Automatic experts placement dose not not support expert number" \
" is not a multiple of ep size or vice versa."

# If the number of experts is greater than maximum expert parallel size. a.k.a ep_size,
# there are multiple experts in each GPU and each GPU has different experts
# So it's data parallel size is 1
# Otherwise, there is only one expert in each GPU
# The data parallel size should be calculated
dp_size = 1 if gt_flag else self.max_ep_size // num_experts
ep_size = self.max_ep_size // dp_size
if self.mode == "dynamic":
gt_flag = num_experts % self.max_ep_size == 0 # check whether num_experts is greater
lt_flag = self.max_ep_size % num_experts == 0 # check whether num_experts is less

assert gt_flag or lt_flag, "Automatic experts placement dose not not support expert number" \
" is not a multiple of ep size or vice versa."

# If the number of experts is greater than maximum expert parallel size. a.k.a ep_size,
# there are multiple experts in each GPU and each GPU has different experts
# So it's data parallel size is 1
# Otherwise, there is only one expert in each GPU
# The data parallel size should be calculated
dp_size = 1 if gt_flag else self.max_ep_size // num_experts
ep_size = self.max_ep_size // dp_size
# Don't forget to multiply minimum data parallel size
dp_size *= self.min_dp_size
pp_size = 1
else:
dp_size = self.dp_size
ep_size = self.ep_size
pp_size = self.pp_size

# Calculate the number of experts for each GPU
if use_tp:
num_local_experts = num_experts
else:
num_local_experts = 1 if lt_flag else num_experts // self.max_ep_size
if self.mode == "dynamic":
num_local_experts = 1 if lt_flag else num_experts // self.max_ep_size
else:
num_local_experts = num_experts // ep_size

# Don't forget to multiply minimum data parallel size
dp_size *= self.min_dp_size
if not (ep_size in self.parallel_info_dict):
self.parallel_info_dict[ep_size] = get_moe_info(ep_size, dp_size)
self.parallel_info_dict[ep_size] = get_moe_info(ep_size, dp_size, pp_size, ep_inside=self.use_ep_inside)

return num_local_experts, self.parallel_info_dict[ep_size]

Expand Down
Loading