Skip to content
79 changes: 53 additions & 26 deletions colossalai/booster/plugin/moe_hybrid_parallel_plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,19 +21,18 @@
get_param_info,
init_pipeline_optimizer,
)
from colossalai.checkpoint_io import MoECheckpointIO
from colossalai.cluster import ProcessGroupMesh
from colossalai.interface import ModelWrapper, OptimizerWrapper
from colossalai.moe import MoECheckpointIO
from colossalai.pipeline.schedule import OneForwardOneBackwardSchedule
from colossalai.pipeline.stage_manager import PipelineStageManager
from colossalai.shardformer import ShardConfig
from colossalai.shardformer.policies.base_policy import Policy
from colossalai.zero.low_level import LowLevelZeroOptimizer
from colossalai.tensor.moe_tensor.api import is_moe_tensor
from colossalai.zero.low_level import LowLevelOptStrategy, LowLevelZeroOptimizer, MoeZeroStrategy

PP_AXIS, DP_AXIS, EP_AXIS, TP_AXIS = 0, 1, 2, 3


class HybridParallelZeroOptimizer(LowLevelZeroOptimizer):
class MoeHybridParallelZeroOptimizer(LowLevelZeroOptimizer):
def __init__(
self,
optimizer: Optimizer,
Expand Down Expand Up @@ -68,8 +67,39 @@ def __init__(
self.pp_pg = pp_process_group
if use_pipeline:
init_pipeline_optimizer(optimizer, model)

assert (
len(optimizer.param_groups) == 1
), "Currently only one parameter group is supported, and we will support multiple groups later."
zero_params = list(filter(lambda x: not is_moe_tensor(x), model.parameters()))
Comment thread
Hz188 marked this conversation as resolved.
moe_params = list(filter(lambda x: is_moe_tensor(x), model.parameters()))

optimizer.param_groups.clear()
optimizer.add_param_group({"params": zero_params})
optimizer.add_param_group({"params": moe_params})
strategies = [
LowLevelOptStrategy(
param_group=optimizer.param_groups[0],
process_group=dp_process_group,
reduce_bucket_size=reduce_bucket_size,
communication_dtype=communication_dtype,
overlap_communication=overlap_communication,
partition_grad=partition_grad,
cpu_offload=cpu_offload,
),
MoeZeroStrategy(
param_group=optimizer.param_groups[1],
process_group=moe_extra_dp_process_group,
reduce_bucket_size=reduce_bucket_size,
communication_dtype=communication_dtype,
overlap_communication=overlap_communication,
partition_grad=partition_grad,
cpu_offload=cpu_offload,
),
]
super().__init__(
optimizer=optimizer,
group_strategies=strategies,
initial_scale=initial_scale,
min_scale=min_scale,
growth_factor=growth_factor,
Expand All @@ -79,14 +109,7 @@ def __init__(
max_scale=max_scale,
clip_grad_norm=clip_grad_norm,
verbose=verbose,
reduce_bucket_size=reduce_bucket_size,
communication_dtype=communication_dtype,
overlap_communication=overlap_communication,
partition_grad=partition_grad,
cpu_offload=cpu_offload,
dp_process_group=dp_process_group,
forced_dtype=forced_dtype,
moe_extra_dp_process_group=moe_extra_dp_process_group,
)


Expand Down Expand Up @@ -185,7 +208,6 @@ def __init__(
custom_policy: Policy = None,
checkpoint_io: Optional[MoECheckpointIO] = None,
) -> None:
global DP_AXIS, EP_AXIS
world_size = dist.get_world_size()
assert tp_size == 1, "Tensor parallel is not supported in MoE yet"
assert (
Expand Down Expand Up @@ -224,28 +246,30 @@ def __init__(
self.moe_dp_size = self.dp_size // self.ep_size
self.use_ep_inside = use_ep_inside
if self.use_ep_inside:
self.pp_axis, self.dp_axis, self.ep_axis, self.tp_axis = 0, 1, 2, 3
self.pg_mesh = ProcessGroupMesh(self.pp_size, self.moe_dp_size, ep_size, tp_size)
self.moe_dp_group = self.pg_mesh.get_group_along_axis(DP_AXIS)
self.ep_group = self.pg_mesh.get_group_along_axis(EP_AXIS)
self.moe_dp_group = self.pg_mesh.get_group_along_axis(self.dp_axis)
self.ep_group = self.pg_mesh.get_group_along_axis(self.ep_axis)
if dist.get_rank() == 0:
print(f"MoE Parallel: pp {self.pp_size}, outer_dp {self.moe_dp_size}, inner_ep {ep_size}, tp {tp_size}")
else:
warnings.warn("Using ep outside dp (cross-node) is strongly discouraged due to communication costs.")
self.pp_axis, self.dp_axis, self.ep_axis, self.tp_axis = 0, 2, 1, 3
self.pg_mesh = ProcessGroupMesh(self.pp_size, ep_size, self.moe_dp_size, tp_size)
EP_AXIS = 1
DP_AXIS = 2
self.moe_dp_group = self.pg_mesh.get_group_along_axis(DP_AXIS)
self.ep_group = self.pg_mesh.get_group_along_axis(EP_AXIS)
self.moe_dp_group = self.pg_mesh.get_group_along_axis(self.dp_axis)
self.ep_group = self.pg_mesh.get_group_along_axis(self.ep_axis)
if dist.get_rank() == 0:
print(f"MoE Parallel: pp {self.pp_size}, outer_ep {ep_size}, inner_dp {self.moe_dp_size}, tp {tp_size}")
if dist.get_rank() == 0:
print(f"Non-MoE Parameter Parallel: pp {self.pp_size}, dp {self.dp_size}, tp {tp_size}")

self.tp_group = self.pg_mesh.get_group_along_axis(TP_AXIS) # TODO: support custom tp size for mixtral lm head
self.global_dp_group = self.pg_mesh.get_group_along_axis((DP_AXIS, EP_AXIS))
self.pp_group = self.pg_mesh.get_group_along_axis(PP_AXIS)
self.tp_group = self.pg_mesh.get_group_along_axis(
self.tp_axis
) # TODO: support custom tp size for mixtral lm head
self.global_dp_group = self.pg_mesh.get_group_along_axis((self.dp_axis, self.ep_axis))
self.pp_group = self.pg_mesh.get_group_along_axis(self.pp_axis)
# TODO: Currently moe only support partially sequence parallel
self.sp_group = self.pg_mesh.get_group_along_axis(TP_AXIS)
self.sp_group = self.pg_mesh.get_group_along_axis(self.tp_axis)

self.custom_policy = custom_policy
self.stage_manager = None
Expand All @@ -257,7 +281,7 @@ def __init__(
num_microbatches is not None or microbatch_size is not None
), "num_microbatches or microbatch_size must be specified when using pipeline parallelism"
assert self.zero_stage <= 1, "zero stage must be 0 or 1 when using pipeline parallelism"
self.stage_manager = PipelineStageManager(self.pg_mesh, PP_AXIS)
self.stage_manager = PipelineStageManager(self.pg_mesh, self.pp_axis)
self.schedule = OneForwardOneBackwardSchedule(
self.stage_manager, num_microbatches=num_microbatches, microbatch_size=microbatch_size
)
Expand Down Expand Up @@ -329,7 +353,10 @@ def prepare_dataloader(
"""
_kwargs = kwargs.copy()
sampler = DistributedSampler(
dataset, num_replicas=self.pg_mesh.size(DP_AXIS), rank=self.pg_mesh.coordinate(DP_AXIS), shuffle=shuffle
dataset,
num_replicas=self.pg_mesh.size(self.dp_axis),
rank=self.pg_mesh.coordinate(self.dp_axis),
shuffle=shuffle,
)

# Deterministic dataloader
Expand Down Expand Up @@ -409,7 +436,7 @@ def configure(
else:
assert self.dp_size > 1, "Please use Zero when data parallel size is greater than 1."
assert self.precision != "fp32", "Please set precision to 'fp16' or 'bf16' when using ZeRO."
optimizer = HybridParallelZeroOptimizer(
optimizer = MoeHybridParallelZeroOptimizer(
optimizer,
model,
use_pipeline=self.enable_pipeline_parallelism,
Expand Down
9 changes: 8 additions & 1 deletion colossalai/checkpoint_io/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,5 +2,12 @@
from .general_checkpoint_io import GeneralCheckpointIO
from .hybrid_parallel_checkpoint_io import HybridParallelCheckpointIO
from .index_file import CheckpointIndexFile
from .moe_checkpoint import MoECheckpointIO

__all__ = ["CheckpointIO", "CheckpointIndexFile", "GeneralCheckpointIO", "HybridParallelCheckpointIO"]
__all__ = [
"CheckpointIO",
"CheckpointIndexFile",
"GeneralCheckpointIO",
"HybridParallelCheckpointIO",
"MoECheckpointIO",
]
2 changes: 0 additions & 2 deletions colossalai/moe/__init__.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,5 @@
from .checkpoint import MoECheckpointIO
from .manager import MOE_MANAGER

__all__ = [
"MoECheckpointIO",
"MOE_MANAGER",
]
3 changes: 2 additions & 1 deletion colossalai/zero/low_level/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from .low_level_optim import LowLevelZeroOptimizer
from .low_level_strategy import LowLevelOptStrategy, LowLevelOptStrategyBase, MoeZeroStrategy

__all__ = ["LowLevelZeroOptimizer"]
__all__ = ["LowLevelZeroOptimizer", "LowLevelOptStrategy", "MoeZeroStrategy", "LowLevelOptStrategyBase"]
15 changes: 8 additions & 7 deletions colossalai/zero/low_level/low_level_optim.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,11 +86,11 @@ def __init__(
elif len(self.optim.param_groups) > 1 and group_strategies is None:
raise ValueError("group_strategies must be provided when the optimizer has multiple param groups")

self.masterparam2strategy: Dict[torch.nn.Parameter, LowLevelOptStrategyBase] = {}
self.workingparam2strategy: Dict[torch.nn.Parameter, LowLevelOptStrategyBase] = {}
for grp, strategy in zip(self.optim.param_groups, group_strategies):
assert grp["params"] is strategy.param_group["params"], "param groups should be in the same order"
for param in strategy.working_param_group:
self.masterparam2strategy[param] = strategy
self.workingparam2strategy[param] = strategy
self._group_strategies = group_strategies

# initialize mixed precision mixin
Expand Down Expand Up @@ -139,9 +139,9 @@ def backward(self, loss, retain_graph=False):

# another way of doing this is to reassign tensor.grad, however this won't apply for zero-2
# since the shape doesn't match
def get_param_grad(self, master_param):
strategy = self.masterparam2strategy[master_param]
return strategy.get_param_grad(master_param)
def get_param_grad(self, working_param):
strategy = self.workingparam2strategy[working_param]
return strategy.get_param_grad(working_param)

def _unscale_and_clip_grads(self, grad_groups_flat, total_norm):
# compute combined scale factor for this group
Expand Down Expand Up @@ -265,8 +265,9 @@ def update_master_params(self, model: nn.Module) -> None:
Args:
model (nn.Module): The model to update master params
"""
for master_param in model.parameters():
strategy = self.masterparam2strategy[master_param]
for working_param in model.parameters():
strategy = self.workingparam2strategy[working_param]
master_param = strategy.working2master(working_param=working_param)
strategy.update_master_param(master_param)

def get_working_to_master_map(self) -> Dict[int, torch.Tensor]:
Expand Down
1 change: 1 addition & 0 deletions colossalai/zero/low_level/low_level_strategy.py
Original file line number Diff line number Diff line change
Expand Up @@ -304,6 +304,7 @@ def state_dict(self, optim: torch.optim.Optimizer) -> Dict:
def update_master_param(self, master_param):
working_param = self.master2working(master_param)
padding_size = self.get_param_padding_size(working_param)
working_param = working_param.data.view(-1)
if padding_size > 0:
working_param = torch.nn.functional.pad(working_param, [0, padding_size])
master_param.copy_(working_param.chunk(self._world_size)[self._local_rank])
Expand Down
2 changes: 1 addition & 1 deletion tests/test_moe/test_moe_checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
import colossalai
from colossalai.booster import Booster
from colossalai.booster.plugin.moe_hybrid_parallel_plugin import MoeHybridParallelPlugin
from colossalai.moe import MoECheckpointIO
from colossalai.checkpoint_io import MoECheckpointIO
from colossalai.tensor.moe_tensor.api import is_moe_tensor
from colossalai.testing.utils import spawn

Expand Down