diff --git a/colossalai/booster/plugin/moe_hybrid_parallel_plugin.py b/colossalai/booster/plugin/moe_hybrid_parallel_plugin.py index 94deb6befeb5..8ba68270e514 100644 --- a/colossalai/booster/plugin/moe_hybrid_parallel_plugin.py +++ b/colossalai/booster/plugin/moe_hybrid_parallel_plugin.py @@ -21,19 +21,18 @@ get_param_info, init_pipeline_optimizer, ) +from colossalai.checkpoint_io import MoECheckpointIO from colossalai.cluster import ProcessGroupMesh from colossalai.interface import ModelWrapper, OptimizerWrapper -from colossalai.moe import MoECheckpointIO from colossalai.pipeline.schedule import OneForwardOneBackwardSchedule from colossalai.pipeline.stage_manager import PipelineStageManager from colossalai.shardformer import ShardConfig from colossalai.shardformer.policies.base_policy import Policy -from colossalai.zero.low_level import LowLevelZeroOptimizer +from colossalai.tensor.moe_tensor.api import is_moe_tensor +from colossalai.zero.low_level import LowLevelOptStrategy, LowLevelZeroOptimizer, MoeZeroStrategy -PP_AXIS, DP_AXIS, EP_AXIS, TP_AXIS = 0, 1, 2, 3 - -class HybridParallelZeroOptimizer(LowLevelZeroOptimizer): +class MoeHybridParallelZeroOptimizer(LowLevelZeroOptimizer): def __init__( self, optimizer: Optimizer, @@ -68,8 +67,39 @@ def __init__( self.pp_pg = pp_process_group if use_pipeline: init_pipeline_optimizer(optimizer, model) + + assert ( + len(optimizer.param_groups) == 1 + ), "Currently only one parameter group is supported, and we will support multiple groups later." + zero_params = list(filter(lambda x: not is_moe_tensor(x), model.parameters())) + moe_params = list(filter(lambda x: is_moe_tensor(x), model.parameters())) + + optimizer.param_groups.clear() + optimizer.add_param_group({"params": zero_params}) + optimizer.add_param_group({"params": moe_params}) + strategies = [ + LowLevelOptStrategy( + param_group=optimizer.param_groups[0], + process_group=dp_process_group, + reduce_bucket_size=reduce_bucket_size, + communication_dtype=communication_dtype, + overlap_communication=overlap_communication, + partition_grad=partition_grad, + cpu_offload=cpu_offload, + ), + MoeZeroStrategy( + param_group=optimizer.param_groups[1], + process_group=moe_extra_dp_process_group, + reduce_bucket_size=reduce_bucket_size, + communication_dtype=communication_dtype, + overlap_communication=overlap_communication, + partition_grad=partition_grad, + cpu_offload=cpu_offload, + ), + ] super().__init__( optimizer=optimizer, + group_strategies=strategies, initial_scale=initial_scale, min_scale=min_scale, growth_factor=growth_factor, @@ -79,14 +109,7 @@ def __init__( max_scale=max_scale, clip_grad_norm=clip_grad_norm, verbose=verbose, - reduce_bucket_size=reduce_bucket_size, - communication_dtype=communication_dtype, - overlap_communication=overlap_communication, - partition_grad=partition_grad, - cpu_offload=cpu_offload, - dp_process_group=dp_process_group, forced_dtype=forced_dtype, - moe_extra_dp_process_group=moe_extra_dp_process_group, ) @@ -185,7 +208,6 @@ def __init__( custom_policy: Policy = None, checkpoint_io: Optional[MoECheckpointIO] = None, ) -> None: - global DP_AXIS, EP_AXIS world_size = dist.get_world_size() assert tp_size == 1, "Tensor parallel is not supported in MoE yet" assert ( @@ -224,28 +246,30 @@ def __init__( self.moe_dp_size = self.dp_size // self.ep_size self.use_ep_inside = use_ep_inside if self.use_ep_inside: + self.pp_axis, self.dp_axis, self.ep_axis, self.tp_axis = 0, 1, 2, 3 self.pg_mesh = ProcessGroupMesh(self.pp_size, self.moe_dp_size, ep_size, tp_size) - self.moe_dp_group = self.pg_mesh.get_group_along_axis(DP_AXIS) - self.ep_group = self.pg_mesh.get_group_along_axis(EP_AXIS) + self.moe_dp_group = self.pg_mesh.get_group_along_axis(self.dp_axis) + self.ep_group = self.pg_mesh.get_group_along_axis(self.ep_axis) if dist.get_rank() == 0: print(f"MoE Parallel: pp {self.pp_size}, outer_dp {self.moe_dp_size}, inner_ep {ep_size}, tp {tp_size}") else: warnings.warn("Using ep outside dp (cross-node) is strongly discouraged due to communication costs.") + self.pp_axis, self.dp_axis, self.ep_axis, self.tp_axis = 0, 2, 1, 3 self.pg_mesh = ProcessGroupMesh(self.pp_size, ep_size, self.moe_dp_size, tp_size) - EP_AXIS = 1 - DP_AXIS = 2 - self.moe_dp_group = self.pg_mesh.get_group_along_axis(DP_AXIS) - self.ep_group = self.pg_mesh.get_group_along_axis(EP_AXIS) + self.moe_dp_group = self.pg_mesh.get_group_along_axis(self.dp_axis) + self.ep_group = self.pg_mesh.get_group_along_axis(self.ep_axis) if dist.get_rank() == 0: print(f"MoE Parallel: pp {self.pp_size}, outer_ep {ep_size}, inner_dp {self.moe_dp_size}, tp {tp_size}") if dist.get_rank() == 0: print(f"Non-MoE Parameter Parallel: pp {self.pp_size}, dp {self.dp_size}, tp {tp_size}") - self.tp_group = self.pg_mesh.get_group_along_axis(TP_AXIS) # TODO: support custom tp size for mixtral lm head - self.global_dp_group = self.pg_mesh.get_group_along_axis((DP_AXIS, EP_AXIS)) - self.pp_group = self.pg_mesh.get_group_along_axis(PP_AXIS) + self.tp_group = self.pg_mesh.get_group_along_axis( + self.tp_axis + ) # TODO: support custom tp size for mixtral lm head + self.global_dp_group = self.pg_mesh.get_group_along_axis((self.dp_axis, self.ep_axis)) + self.pp_group = self.pg_mesh.get_group_along_axis(self.pp_axis) # TODO: Currently moe only support partially sequence parallel - self.sp_group = self.pg_mesh.get_group_along_axis(TP_AXIS) + self.sp_group = self.pg_mesh.get_group_along_axis(self.tp_axis) self.custom_policy = custom_policy self.stage_manager = None @@ -257,7 +281,7 @@ def __init__( num_microbatches is not None or microbatch_size is not None ), "num_microbatches or microbatch_size must be specified when using pipeline parallelism" assert self.zero_stage <= 1, "zero stage must be 0 or 1 when using pipeline parallelism" - self.stage_manager = PipelineStageManager(self.pg_mesh, PP_AXIS) + self.stage_manager = PipelineStageManager(self.pg_mesh, self.pp_axis) self.schedule = OneForwardOneBackwardSchedule( self.stage_manager, num_microbatches=num_microbatches, microbatch_size=microbatch_size ) @@ -329,7 +353,10 @@ def prepare_dataloader( """ _kwargs = kwargs.copy() sampler = DistributedSampler( - dataset, num_replicas=self.pg_mesh.size(DP_AXIS), rank=self.pg_mesh.coordinate(DP_AXIS), shuffle=shuffle + dataset, + num_replicas=self.pg_mesh.size(self.dp_axis), + rank=self.pg_mesh.coordinate(self.dp_axis), + shuffle=shuffle, ) # Deterministic dataloader @@ -409,7 +436,7 @@ def configure( else: assert self.dp_size > 1, "Please use Zero when data parallel size is greater than 1." assert self.precision != "fp32", "Please set precision to 'fp16' or 'bf16' when using ZeRO." - optimizer = HybridParallelZeroOptimizer( + optimizer = MoeHybridParallelZeroOptimizer( optimizer, model, use_pipeline=self.enable_pipeline_parallelism, diff --git a/colossalai/checkpoint_io/__init__.py b/colossalai/checkpoint_io/__init__.py index 19b61730bded..ef37534fe01a 100644 --- a/colossalai/checkpoint_io/__init__.py +++ b/colossalai/checkpoint_io/__init__.py @@ -2,5 +2,12 @@ from .general_checkpoint_io import GeneralCheckpointIO from .hybrid_parallel_checkpoint_io import HybridParallelCheckpointIO from .index_file import CheckpointIndexFile +from .moe_checkpoint import MoECheckpointIO -__all__ = ["CheckpointIO", "CheckpointIndexFile", "GeneralCheckpointIO", "HybridParallelCheckpointIO"] +__all__ = [ + "CheckpointIO", + "CheckpointIndexFile", + "GeneralCheckpointIO", + "HybridParallelCheckpointIO", + "MoECheckpointIO", +] diff --git a/colossalai/moe/checkpoint.py b/colossalai/checkpoint_io/moe_checkpoint.py similarity index 100% rename from colossalai/moe/checkpoint.py rename to colossalai/checkpoint_io/moe_checkpoint.py diff --git a/colossalai/moe/__init__.py b/colossalai/moe/__init__.py index 2708764d89bd..0623d19efd5f 100644 --- a/colossalai/moe/__init__.py +++ b/colossalai/moe/__init__.py @@ -1,7 +1,5 @@ -from .checkpoint import MoECheckpointIO from .manager import MOE_MANAGER __all__ = [ - "MoECheckpointIO", "MOE_MANAGER", ] diff --git a/colossalai/zero/low_level/__init__.py b/colossalai/zero/low_level/__init__.py index 270a6a6a4786..7e4702dfd38c 100644 --- a/colossalai/zero/low_level/__init__.py +++ b/colossalai/zero/low_level/__init__.py @@ -1,3 +1,4 @@ from .low_level_optim import LowLevelZeroOptimizer +from .low_level_strategy import LowLevelOptStrategy, LowLevelOptStrategyBase, MoeZeroStrategy -__all__ = ["LowLevelZeroOptimizer"] +__all__ = ["LowLevelZeroOptimizer", "LowLevelOptStrategy", "MoeZeroStrategy", "LowLevelOptStrategyBase"] diff --git a/colossalai/zero/low_level/low_level_optim.py b/colossalai/zero/low_level/low_level_optim.py index 29903cb09219..bcbc7561dcd6 100644 --- a/colossalai/zero/low_level/low_level_optim.py +++ b/colossalai/zero/low_level/low_level_optim.py @@ -86,11 +86,11 @@ def __init__( elif len(self.optim.param_groups) > 1 and group_strategies is None: raise ValueError("group_strategies must be provided when the optimizer has multiple param groups") - self.masterparam2strategy: Dict[torch.nn.Parameter, LowLevelOptStrategyBase] = {} + self.workingparam2strategy: Dict[torch.nn.Parameter, LowLevelOptStrategyBase] = {} for grp, strategy in zip(self.optim.param_groups, group_strategies): assert grp["params"] is strategy.param_group["params"], "param groups should be in the same order" for param in strategy.working_param_group: - self.masterparam2strategy[param] = strategy + self.workingparam2strategy[param] = strategy self._group_strategies = group_strategies # initialize mixed precision mixin @@ -139,9 +139,9 @@ def backward(self, loss, retain_graph=False): # another way of doing this is to reassign tensor.grad, however this won't apply for zero-2 # since the shape doesn't match - def get_param_grad(self, master_param): - strategy = self.masterparam2strategy[master_param] - return strategy.get_param_grad(master_param) + def get_param_grad(self, working_param): + strategy = self.workingparam2strategy[working_param] + return strategy.get_param_grad(working_param) def _unscale_and_clip_grads(self, grad_groups_flat, total_norm): # compute combined scale factor for this group @@ -265,8 +265,9 @@ def update_master_params(self, model: nn.Module) -> None: Args: model (nn.Module): The model to update master params """ - for master_param in model.parameters(): - strategy = self.masterparam2strategy[master_param] + for working_param in model.parameters(): + strategy = self.workingparam2strategy[working_param] + master_param = strategy.working2master(working_param=working_param) strategy.update_master_param(master_param) def get_working_to_master_map(self) -> Dict[int, torch.Tensor]: diff --git a/colossalai/zero/low_level/low_level_strategy.py b/colossalai/zero/low_level/low_level_strategy.py index d469e859d833..359e608d334b 100644 --- a/colossalai/zero/low_level/low_level_strategy.py +++ b/colossalai/zero/low_level/low_level_strategy.py @@ -304,6 +304,7 @@ def state_dict(self, optim: torch.optim.Optimizer) -> Dict: def update_master_param(self, master_param): working_param = self.master2working(master_param) padding_size = self.get_param_padding_size(working_param) + working_param = working_param.data.view(-1) if padding_size > 0: working_param = torch.nn.functional.pad(working_param, [0, padding_size]) master_param.copy_(working_param.chunk(self._world_size)[self._local_rank]) diff --git a/tests/test_moe/test_moe_checkpoint.py b/tests/test_moe/test_moe_checkpoint.py index 3a3930fbc622..86f2d2909475 100644 --- a/tests/test_moe/test_moe_checkpoint.py +++ b/tests/test_moe/test_moe_checkpoint.py @@ -11,7 +11,7 @@ import colossalai from colossalai.booster import Booster from colossalai.booster.plugin.moe_hybrid_parallel_plugin import MoeHybridParallelPlugin -from colossalai.moe import MoECheckpointIO +from colossalai.checkpoint_io import MoECheckpointIO from colossalai.tensor.moe_tensor.api import is_moe_tensor from colossalai.testing.utils import spawn