From 7f42527d30d5f4b79807fbcf8b96674282a2c13a Mon Sep 17 00:00:00 2001 From: Xuanlei Zhao Date: Wed, 18 Oct 2023 11:25:26 +0800 Subject: [PATCH 01/17] doc --- colossalai/moe/layers.py | 47 ++++++++++++---------------------- colossalai/moe/load_balance.py | 5 ++++ 2 files changed, 21 insertions(+), 31 deletions(-) diff --git a/colossalai/moe/layers.py b/colossalai/moe/layers.py index 9846cd432b53..8965ab1b2398 100644 --- a/colossalai/moe/layers.py +++ b/colossalai/moe/layers.py @@ -147,6 +147,7 @@ def forward(self, inputs: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: with torch.no_grad(): # TODO: optimize computation expert_load = torch.topk(gate_output, k=self.topk, dim=-1)[1] + # TODO: bincount introduces synchronize, fix it expert_load = torch.bincount(expert_load.view(-1)) self.load_balancer.update_load(expert_load) @@ -189,10 +190,7 @@ def _local_process(self, expert_in: torch.Tensor) -> torch.Tensor: expert_out = self.experts(expert_in) return expert_out - def _ep_process(self, - dispatch_data: torch.Tensor, - overlap: bool = True - ) -> torch.Tensor: + def _ep_process(self, dispatch_data: torch.Tensor, overlap: bool = True) -> torch.Tensor: """ Expert Parallel @@ -210,6 +208,7 @@ def _ep_process(self, return expert_output else: + @dataclasses.dataclass class Capsule(): data: torch.Tensor @@ -238,24 +237,17 @@ class Capsule(): # all2all last output if _expert_out is not None: - expert_out = Capsule( - *AllToAll.apply(_expert_out.data, self.ep_group, True), - ) + expert_out = Capsule(*AllToAll.apply(_expert_out.data, self.ep_group, True),) _expert_out = None # all2all next input if 0 <= i < NUM_CHUNK: - _expert_in = Capsule( - *AllToAll.apply(chunk_data[i].contiguous(), self.ep_group, True) - ) + _expert_in = Capsule(*AllToAll.apply(chunk_data[i].contiguous(), self.ep_group, True)) # compute if expert_in is not None: expert_in.handle.wait() - _expert_out = Capsule( - data=self.experts(expert_in.data), - handle=None - ) + _expert_out = Capsule(data=self.experts(expert_in.data), handle=None) expert_in = None if _expert_in is not None: @@ -264,10 +256,7 @@ class Capsule(): return output - def _tp_process(self, - dispatch_data: torch.Tensor, - overlap: bool = True - ) -> torch.Tensor: + def _tp_process(self, dispatch_data: torch.Tensor, overlap: bool = True) -> torch.Tensor: """ without overlap: | C | @@ -291,6 +280,7 @@ def _tp_process(self, expert_out = ReduceScatter.apply(expert_out, self.ep_group, False)[0] return expert_out else: + @dataclasses.dataclass class Capsule(): data: torch.Tensor @@ -307,7 +297,7 @@ class Capsule(): output = torch.empty_like(dispatch_data) def get_chunk_slice(idx: int, chunk_size: int) -> Tuple[slice]: - return (slice(idx * chunk_size, (idx + 1) * chunk_size), ) + return (slice(idx * chunk_size, (idx + 1) * chunk_size),) _expert_in, expert_in, _expert_out, expert_out = None, None, None, None @@ -319,26 +309,21 @@ def get_chunk_slice(idx: int, chunk_size: int) -> Tuple[slice]: # reduce scatter last output if _expert_out is not None: - expert_out = Capsule( - *ReduceScatter.apply(_expert_out.data, self.ep_group, True), - indices=_expert_out.indices - ) + expert_out = Capsule(*ReduceScatter.apply(_expert_out.data, self.ep_group, True), + indices=_expert_out.indices) _expert_out = None # all gather next input if 0 <= i < NUM_CHUNK: - _expert_in = Capsule( - *AllGather.apply(chunk_data[i].contiguous(), self.ep_group, True), - indices=get_chunk_slice(i, chunk_size) - ) + _expert_in = Capsule(*AllGather.apply(chunk_data[i].contiguous(), self.ep_group, True), + indices=get_chunk_slice(i, chunk_size)) # compute if expert_in is not None: expert_in.handle.wait() - _expert_out = Capsule( - self.experts(expert_in.data, expert_in.indices), - handle=None, indices=expert_in.indices - ) + _expert_out = Capsule(self.experts(expert_in.data, expert_in.indices), + handle=None, + indices=expert_in.indices) expert_in = None if _expert_in is not None: diff --git a/colossalai/moe/load_balance.py b/colossalai/moe/load_balance.py index b2fb672329c2..9ac6d5f3c0bd 100644 --- a/colossalai/moe/load_balance.py +++ b/colossalai/moe/load_balance.py @@ -423,6 +423,11 @@ def balance_load(self, optim: LowLevelZeroOptimizer) -> None: load = self._load_to_list(load) # search balance swap_list = self._search_balance(load) + if dist.get_rank() == 0: + if len(swap_list) > 0: + print(f"Apply swap...") + else: + print(f"Invalid swap, continue...") # swap expert and gate self._swap_moe_param(swap_list, optim) # clear load From e646d3de130c60f9dca46eee25e5eee3796dbc92 Mon Sep 17 00:00:00 2001 From: Xuanlei Zhao Date: Wed, 18 Oct 2023 11:27:23 +0800 Subject: [PATCH 02/17] update script --- examples/language/openmoe/benchmark/benchmark_cai.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/examples/language/openmoe/benchmark/benchmark_cai.py b/examples/language/openmoe/benchmark/benchmark_cai.py index e1acba5c88b0..680ff53992fa 100644 --- a/examples/language/openmoe/benchmark/benchmark_cai.py +++ b/examples/language/openmoe/benchmark/benchmark_cai.py @@ -155,7 +155,7 @@ def main(): mgr_dict = {"seed": 42, "use_kernel_optim": args.use_kernel, "enable_load_balance": args.load_balance} if args.plugin == "zero": dp_size = dist.get_world_size() - plugin = LowLevelZeroPlugin(initial_scale=2**5, stage=2) + plugin = LowLevelZeroPlugin(initial_scale=2**5, stage=2, verbose=True, precision="bf16") MOE_MANAGER.setup( parallel=None, **mgr_dict, @@ -169,6 +169,7 @@ def main(): ) MOE_MANAGER.setup( parallel="EP", + max_ep_size=dp_size, **mgr_dict, ) elif args.plugin == "ep_zero": From 7fe20d3fd4169148d713799af161a751fb49e472 Mon Sep 17 00:00:00 2001 From: Xuanlei Zhao Date: Wed, 18 Oct 2023 11:37:02 +0800 Subject: [PATCH 03/17] update experts --- colossalai/moe/experts.py | 21 ++++++++++++++++++++- 1 file changed, 20 insertions(+), 1 deletion(-) diff --git a/colossalai/moe/experts.py b/colossalai/moe/experts.py index 076f160adb79..b8d418da15e8 100644 --- a/colossalai/moe/experts.py +++ b/colossalai/moe/experts.py @@ -97,7 +97,12 @@ def reset_parameters(self): torch.nn.init.normal_(self.wi, std=math.sqrt(0.1 / self.hidden_size)) torch.nn.init.normal_(self.wo, std=math.sqrt(0.1 / self.intermediate_size)) - def forward(self, x: torch.Tensor, param_slice: Tuple[slice] = (slice(None),)) -> torch.Tensor: + def forward( + self, + x: torch.Tensor, + param_slice: Tuple[slice] = (slice(None),), + use_sparse: bool = True, + ) -> torch.Tensor: """ Args: x (torch.Tensor): The input tensor of shape (num_groups, num_experts, capacity, hidden_size) @@ -114,6 +119,16 @@ def forward(self, x: torch.Tensor, param_slice: Tuple[slice] = (slice(None),)) - inshape = x.shape x = x.reshape(e, -1, h) + if use_sparse: + seq_len = x.shape[1] + with torch.no_grad(): + mask = x[:, :, 0] != 0.0 + mask = torch.sum(mask, dim=-1) + x_list = [] + for i in range(e): + x_list.append(x[i, :mask[i]]) + x = x_list + if self.gated: x_gate = [torch.mm(x[i], self.wi_gate[param_slice][i]) for i in range(e)] x_up = [torch.mm(x[i], self.wi_up[param_slice][i]) for i in range(e)] @@ -127,6 +142,10 @@ def forward(self, x: torch.Tensor, param_slice: Tuple[slice] = (slice(None),)) - x = [self.drop(x[i]) for i in range(e)] x = [torch.mm(x[i], self.wo[param_slice][i]) for i in range(e)] + if use_sparse: + for i in range(e): + x[i] = torch.nn.functional.pad(x[i], (0, 0, 0, seq_len - x[i].shape[0]), mode="constant", value=0) + x = torch.cat([x[i].unsqueeze(0) for i in range(e)], dim=0) x = x.reshape(inshape) x = x.transpose(0, 1).contiguous() From ee9fe2c0d2fc2ea81eb8f7be69c11574ac7eda58 Mon Sep 17 00:00:00 2001 From: Xuanlei Zhao Date: Wed, 18 Oct 2023 11:50:13 +0800 Subject: [PATCH 04/17] update optim in fsdp --- examples/language/openmoe/benchmark/benchmark_fsdp.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/examples/language/openmoe/benchmark/benchmark_fsdp.py b/examples/language/openmoe/benchmark/benchmark_fsdp.py index 0edf102d640c..531e18313798 100644 --- a/examples/language/openmoe/benchmark/benchmark_fsdp.py +++ b/examples/language/openmoe/benchmark/benchmark_fsdp.py @@ -12,7 +12,6 @@ from torch.distributed.fsdp.wrap import transformer_auto_wrap_policy from torch.utils.data import Dataset from torch.utils.data.distributed import DistributedSampler -from transformers import Adafactor from transformers.models.llama import LlamaConfig from utils import PerformanceEvaluator, get_model_numel @@ -80,7 +79,7 @@ def fsdp_main(rank, world_size, args): auto_wrap_policy=auto_wrap_policy, device_id=torch.cuda.current_device(), ) - optimizer = Adafactor(model.parameters()) + optimizer = torch.optim.Adam(model.parameters(), weight_decay=0.01, lr=1e-5) model.train() model_numel = get_model_numel(model) From 115078e6d7d3bf2b278817d7f39ebd5151203492 Mon Sep 17 00:00:00 2001 From: Xuanlei Zhao Date: Wed, 18 Oct 2023 11:52:30 +0800 Subject: [PATCH 05/17] update kernel in sparse --- colossalai/moe/experts.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/colossalai/moe/experts.py b/colossalai/moe/experts.py index b8d418da15e8..40f3d59f42bd 100644 --- a/colossalai/moe/experts.py +++ b/colossalai/moe/experts.py @@ -119,7 +119,7 @@ def forward( inshape = x.shape x = x.reshape(e, -1, h) - if use_sparse: + if self.use_kernel and use_sparse: seq_len = x.shape[1] with torch.no_grad(): mask = x[:, :, 0] != 0.0 @@ -142,7 +142,7 @@ def forward( x = [self.drop(x[i]) for i in range(e)] x = [torch.mm(x[i], self.wo[param_slice][i]) for i in range(e)] - if use_sparse: + if self.use_kernel and use_sparse: for i in range(e): x[i] = torch.nn.functional.pad(x[i], (0, 0, 0, seq_len - x[i].shape[0]), mode="constant", value=0) From 2378d466eef94eeaf25cd4b4f43b79db83801c19 Mon Sep 17 00:00:00 2001 From: Xuanlei Zhao Date: Wed, 18 Oct 2023 20:25:39 +0800 Subject: [PATCH 06/17] empty cache --- colossalai/moe/layers.py | 1 + 1 file changed, 1 insertion(+) diff --git a/colossalai/moe/layers.py b/colossalai/moe/layers.py index 8965ab1b2398..103074f8792f 100644 --- a/colossalai/moe/layers.py +++ b/colossalai/moe/layers.py @@ -345,4 +345,5 @@ def _apply_recursive(module: nn.Module): sub_module.load_balancer.balance_load(optim) _apply_recursive(sub_module) + torch.cuda.empty_cache() _apply_recursive(model) From a5deb238226d198cfcf570aa0f392425b76a0586 Mon Sep 17 00:00:00 2001 From: Xuanlei Zhao Date: Wed, 18 Oct 2023 22:23:09 +0800 Subject: [PATCH 07/17] update script --- colossalai/moe/layers.py | 14 +- colossalai/moe/manager.py | 4 +- colossalai/zero/low_level/low_level_optim.py | 176 ++++++++++++------ .../openmoe/benchmark/benchmark_cai.py | 54 +++--- .../openmoe/benchmark/benchmark_cai.sh | 67 +++++-- 5 files changed, 209 insertions(+), 106 deletions(-) diff --git a/colossalai/moe/layers.py b/colossalai/moe/layers.py index 103074f8792f..7510e95c4163 100644 --- a/colossalai/moe/layers.py +++ b/colossalai/moe/layers.py @@ -66,6 +66,7 @@ def __init__( self.use_kernel = MOE_MANAGER.use_kernel_optim self.expert_parallel = MOE_MANAGER.get_parallel() self.gated = gated + self.overlap = MOE_MANAGER.overlap_alltoall assert self.expert_parallel in [ "EP", "TP", @@ -164,9 +165,9 @@ def forward(self, inputs: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: # expert_output: (num_groups, num_experts, capacity, hidden_size) if self.expert_parallel == "EP": - expert_output = self._ep_process(dispatch_data) + expert_output = self._ep_process(dispatch_data, overlap=self.overlap) elif self.expert_parallel == "TP": - expert_output = self._tp_process(dispatch_data) + expert_output = self._tp_process(dispatch_data, overlap=self.overlap) elif self.expert_parallel is None: expert_output = self._local_process(dispatch_data) else: @@ -190,7 +191,7 @@ def _local_process(self, expert_in: torch.Tensor) -> torch.Tensor: expert_out = self.experts(expert_in) return expert_out - def _ep_process(self, dispatch_data: torch.Tensor, overlap: bool = True) -> torch.Tensor: + def _ep_process(self, dispatch_data: torch.Tensor, overlap: bool = False) -> torch.Tensor: """ Expert Parallel @@ -214,7 +215,7 @@ class Capsule(): data: torch.Tensor handle: Any = None - NUM_CHUNK = 2 + NUM_CHUNK = 4 NUM_STAGES = 4 assert dispatch_data.shape[1] % NUM_CHUNK == 0, \ @@ -256,7 +257,7 @@ class Capsule(): return output - def _tp_process(self, dispatch_data: torch.Tensor, overlap: bool = True) -> torch.Tensor: + def _tp_process(self, dispatch_data: torch.Tensor, overlap: bool = False) -> torch.Tensor: """ without overlap: | C | @@ -287,7 +288,7 @@ class Capsule(): handle: Any indices: Tuple - NUM_CHUNK = 2 + NUM_CHUNK = 4 NUM_STAGES = 4 assert dispatch_data.shape[0] % NUM_CHUNK == 0, \ @@ -347,3 +348,4 @@ def _apply_recursive(module: nn.Module): torch.cuda.empty_cache() _apply_recursive(model) + torch.cuda.empty_cache() diff --git a/colossalai/moe/manager.py b/colossalai/moe/manager.py index e3659ef43fbd..ea09d4d6e037 100644 --- a/colossalai/moe/manager.py +++ b/colossalai/moe/manager.py @@ -34,6 +34,7 @@ def __init__(self): self.tolerance = None self.beam_width = None self.group_swap_factor = None + self.overlap_alltoall = None self.has_setup = False self._parallel_info_dict = dict() @@ -61,6 +62,7 @@ def setup( tolerance: float = 0.1, beam_width: int = 8, group_swap_factor: float = 0.4, + overlap_alltoall: bool = False, ) -> None: """ Setup MoE distributed context. @@ -109,7 +111,7 @@ def setup( self.tolerance = tolerance self.beam_width = beam_width self.group_swap_factor = group_swap_factor - + self.overlap_alltoall = overlap_alltoall self.has_setup = True def get_info(self, num_experts: int, use_tp: bool = False) -> Tuple[int, MoeParallelInfo]: diff --git a/colossalai/zero/low_level/low_level_optim.py b/colossalai/zero/low_level/low_level_optim.py index f08ebea58589..59c302103729 100644 --- a/colossalai/zero/low_level/low_level_optim.py +++ b/colossalai/zero/low_level/low_level_optim.py @@ -7,7 +7,7 @@ import torch import torch.distributed as dist import torch.nn as nn -from torch._utils import _flatten_dense_tensors +from torch._utils import _flatten_dense_tensors, _unflatten_dense_tensors from torch.distributed import ProcessGroup from torch.optim import Optimizer @@ -281,6 +281,40 @@ def _run_reduction(self): if self.extra_dp_pg is None: flat_grads = self._bucket_store.get_flatten_grad() flat_grads /= self._world_size + else: + # record moe and non moe param + moe_list = [] + for param in self._bucket_store._param_list: + moe_list.append(is_moe_tensor(param)) + + # divide them into different groups + moe_grad_list = [] + non_moe_grad_list = [] + for grad_list in self._bucket_store._grad_in_bucket.values(): + non_moe_cur_grad = [] + moe_cur_grad = [] + for i in range(len(grad_list)): + if moe_list[i] == True: + moe_cur_grad.append(grad_list[i]) + else: + non_moe_cur_grad.append(grad_list[i]) + if len(moe_cur_grad) > 0: + moe_grad_list.append(moe_cur_grad) + if len(non_moe_cur_grad) > 0: + non_moe_grad_list.append(non_moe_cur_grad) + + if len(non_moe_grad_list) > 0: + non_moe_flat_grads = [] + for grad_list in non_moe_grad_list: + non_moe_flat_grads.append(_flatten_dense_tensors(grad_list)) + non_moe_flat_grads = _flatten_dense_tensors(non_moe_flat_grads) + non_moe_flat_grads /= self._world_size + + if len(moe_grad_list) > 0: + moe_flat_grads = [] + for grad_list in moe_grad_list: + moe_flat_grads.append(_flatten_dense_tensors(grad_list)) + moe_flat_grads = _flatten_dense_tensors(moe_flat_grads) # ready to add other tensors to bucket self._bucket_store.reset_num_elements_in_bucket() @@ -290,6 +324,11 @@ def _run_reduction(self): # in case of the memory being reused in the default stream if self.extra_dp_pg is None: flat_grads.record_stream(stream) + else: + if len(non_moe_grad_list) > 0: + non_moe_flat_grads.record_stream(stream) + if len(moe_grad_list) > 0: + moe_flat_grads.record_stream(stream) # waiting for ops in the default stream finishing stream.wait_stream(torch.cuda.current_stream()) else: @@ -324,64 +363,70 @@ def _run_reduction(self): # sync extra zero group else: - # record moe and non moe param - moe_list = [] - for param in self._bucket_store._param_list: - moe_list.append(is_moe_tensor(param)) - - # divide them into different groups - moe_grad_list = [] - non_moe_grad_list = [] - for grad_list in self._bucket_store._grad_in_bucket.values(): - non_moe_cur_grad = [] - moe_cur_grad = [] - for i in range(len(grad_list)): - if moe_list[i] == True: - moe_cur_grad.append(grad_list[i]) - else: - non_moe_cur_grad.append(grad_list[i]) - if len(moe_cur_grad) > 0: - moe_grad_list.append(moe_cur_grad) - if len(non_moe_cur_grad) > 0: - non_moe_grad_list.append(non_moe_cur_grad) - # sync non moe param in global dp group if len(non_moe_grad_list) > 0: - flat_grads = [] - for grad_list in non_moe_grad_list: - flat_grads.append(_flatten_dense_tensors(grad_list)) - flat_grads = _flatten_dense_tensors(flat_grads) - flat_grads /= self._world_size - dist.all_reduce(flat_grads, group=self.dp_pg) - flat_grads_per_rank = flat_grads.split(flat_grads.numel() // self._world_size) + dist.all_reduce(non_moe_flat_grads, group=self.dp_pg) + flat_grads_per_rank = non_moe_flat_grads.split(non_moe_flat_grads.numel() // + self._world_size) self._sync_unpartitioned_grad(non_moe_grad_list, flat_grads_per_rank, group_id) # sync moe param only in zero group if len(moe_grad_list) > 0: - flat_grads = [] - for grad_list in moe_grad_list: - flat_grads.append(_flatten_dense_tensors(grad_list)) - flat_grads = _flatten_dense_tensors(flat_grads) - dist.all_reduce(flat_grads, group=self.extra_dp_pg) - flat_grads_per_rank = flat_grads.split(flat_grads.numel() // self._world_size) + dist.all_reduce(moe_flat_grads, group=self.extra_dp_pg) + flat_grads_per_rank = moe_flat_grads.split(moe_flat_grads.numel() // self._world_size) self._sync_unpartitioned_grad(moe_grad_list, flat_grads_per_rank, group_id) else: - flat_grads_list = list(flat_grads.split(len(flat_grads) // self._world_size)) - recieved_grad = torch.zeros_like(flat_grads_list[0]) - dist.reduce_scatter(recieved_grad, flat_grads_list, group=self.dp_pg) - - if recieved_grad.dtype != grad_dtype: - recieved_grad = recieved_grad.to(grad_dtype) - - grad_in_bucket_current_rank = self._bucket_store.get_grad()[self._local_rank] - sync_tensor(recieved_grad, grad_in_bucket_current_rank) - for grad in grad_in_bucket_current_rank: - param_id = self._bucket_store.get_param_id_of_grad(grad) - if len(self._grad_store.get_partitioned_gradients_by_param_id(group_id, param_id)) < 1: - self._grad_store.append_gradients_by_param_id(grad, group_id, param_id) - else: - self._grad_store.add_gradients_by_param_id(grad, 0, group_id, param_id) + if self.extra_dp_pg is None: + flat_grads_list = list(flat_grads.split(len(flat_grads) // self._world_size)) + recieved_grad = torch.zeros_like(flat_grads_list[0]) + dist.reduce_scatter(recieved_grad, flat_grads_list, group=self.dp_pg) + + if recieved_grad.dtype != grad_dtype: + recieved_grad = recieved_grad.to(grad_dtype) + + grad_in_bucket_current_rank = self._bucket_store.get_grad()[self._local_rank] + sync_tensor(recieved_grad, grad_in_bucket_current_rank) + for grad in grad_in_bucket_current_rank: + param_id = self._bucket_store.get_param_id_of_grad(grad) + if len(self._grad_store.get_partitioned_gradients_by_param_id(group_id, param_id)) < 1: + self._grad_store.append_gradients_by_param_id(grad, group_id, param_id) + else: + self._grad_store.add_gradients_by_param_id(grad, 0, group_id, param_id) + else: + if len(non_moe_grad_list) > 0: + flat_grads_list = list(non_moe_flat_grads.split( + len(non_moe_flat_grads) // self._world_size)) + recieved_grad = torch.zeros_like(flat_grads_list[0]) + dist.reduce_scatter(recieved_grad, flat_grads_list, group=self.dp_pg) + + grad_in_bucket_current_rank = self._bucket_store.get_grad()[self._local_rank] + sync_tensor(recieved_grad, grad_in_bucket_current_rank) + for grad in grad_in_bucket_current_rank: + param_id = self._bucket_store.get_param_id_of_grad(grad) + if len(self._grad_store.get_partitioned_gradients_by_param_id(group_id, param_id)) < 1: + self._grad_store.append_gradients_by_param_id(grad, group_id, param_id) + else: + self._grad_store.add_gradients_by_param_id(grad, 0, group_id, param_id) + + if len(moe_grad_list) > 0: + flat_grads_list = list(moe_flat_grads.split(len(moe_flat_grads) // self.extra_dp_pg_size)) + recieved_grad = torch.zeros_like(flat_grads_list[0]) + dist.reduce_scatter(recieved_grad, flat_grads_list, group=self.extra_dp_pg) + + param_slice = self._world_size // self.extra_dp_pg_size + recieved_grad = list(recieved_grad.split(len(recieved_grad) // param_slice)) + grad_in_bucket_current_rank = self._bucket_store.get_grad()[self._local_rank] + for split_recieved_grad in recieved_grad: + split_recieved_grad = _unflatten_dense_tensors(split_recieved_grad, + grad_in_bucket_current_rank) + for grad in grad_in_bucket_current_rank: + param_id = self._bucket_store.get_param_id_of_grad(grad) + if len(self._grad_store.get_partitioned_gradients_by_param_id( + group_id, param_id)) < param_slice: + self._grad_store.append_gradients_by_param_id(grad, group_id, param_id) + else: + self._grad_store.add_gradients_by_param_id(grad, 0, group_id, param_id) self._bucket_store.reset() @@ -512,18 +557,19 @@ def step(self, closure=None): # moe hybrid zero if self.extra_dp_pg is not None and is_moe_tensor(working_param): real_working_params[group_id].append(working_param) - param_slice = self._world_size // self.extra_dp_pg_size - grad = grads[self.extra_dp_pg_rank * param_slice:(self.extra_dp_pg_rank + 1) * param_slice] - grad = flatten(grad).to(splited_param.dtype).to(splited_param.device) - splited_param.grad = grad - grad_partition_groups.append(grad) - real_master_params[group_id].append(splited_param) + if self._partition_grads: + grad = grads + else: + param_slice = self._world_size // self.extra_dp_pg_size + grad = grads[self.extra_dp_pg_rank * param_slice:(self.extra_dp_pg_rank + 1) * param_slice] + grad = flatten(grad) else: real_working_params[group_id].append(working_param) - grad = grads[grad_index].to(splited_param.dtype).to(splited_param.device) - splited_param.grad = grad - grad_partition_groups.append(grad) - real_master_params[group_id].append(splited_param) + grad = grads[grad_index] + grad = grad.to(splited_param.dtype).to(splited_param.device) + splited_param.grad = grad + grad_partition_groups.append(grad) + real_master_params[group_id].append(splited_param) # compute norm working_grads = self._grad_store.get_working_grads_by_group_id(group_id) @@ -539,13 +585,21 @@ def step(self, closure=None): global_norm = calculate_global_norm_from_list(norm_list=norm_groups) self._unscale_and_clip_grads(grad_partition_groups, global_norm) + # TODO: we should store master param for ep + if len(self.param_groups) > len(self._working_param_groups): + for param in self.param_groups[-1]['params']: + param.data = param.data.to(torch.float32) + param.grad = param.grad.to(torch.float32) + # update the parameters self.optim.step() - # release the moe grad + # TODO: release the moe grad. we should store master param if len(self.param_groups) > len(self._working_param_groups): + dtype = real_working_params[0][0].dtype for param in self.param_groups[-1]['params']: param.grad = None + param.data = param.data.to(dtype) # release the grad grad_partition_groups = [] diff --git a/examples/language/openmoe/benchmark/benchmark_cai.py b/examples/language/openmoe/benchmark/benchmark_cai.py index 680ff53992fa..249ea424cc4c 100644 --- a/examples/language/openmoe/benchmark/benchmark_cai.py +++ b/examples/language/openmoe/benchmark/benchmark_cai.py @@ -15,12 +15,12 @@ import colossalai from colossalai import get_default_parser from colossalai.booster import Booster -from colossalai.booster.plugin import LowLevelZeroPlugin from colossalai.booster.plugin.moe_hybrid_parallel_plugin import MoeHybridParallelPlugin from colossalai.cluster import DistCoordinator from colossalai.moe.layers import apply_load_balance from colossalai.moe.manager import MOE_MANAGER from colossalai.moe.utils import skip_init +from colossalai.nn.optimizer import HybridAdam from colossalai.utils import get_current_device @@ -118,7 +118,7 @@ def parse_args(): parser.add_argument("--pp_size", type=int, default=2, help="pp size") parser.add_argument("--dp_size", type=int, default=1, help="dp size") parser.add_argument("--ep_size", type=int, default=2, help="ep size") - parser.add_argument("--zero_stage", type=int, default=1, help="zero stage in hybrid plugin") + parser.add_argument("--zero_stage", type=int, default=2, help="zero stage in hybrid plugin") parser.add_argument("--microbatch_size", type=int, default=1, help="microbatch size") parser.add_argument("--extra_dp_size", type=int, default=1) # kernel @@ -132,6 +132,9 @@ def parse_args(): parser.add_argument("--active", type=int, default=20) # load balance parser.add_argument("--load_balance", action="store_true") + + # overlap + parser.add_argument("--overlap_alltoall", action="store_true") args = parser.parse_args() return args @@ -150,12 +153,21 @@ def main(): "custom_policy": OpenMoeForCausalLMPolicy(), "enable_fused_normalization": args.use_kernel, "enable_jit_fused": args.use_kernel, - "precision": "bf16" + "precision": "bf16", + "zero_stage": args.zero_stage, + } + mgr_dict = { + "seed": 42, + "use_kernel_optim": args.use_kernel, + "enable_load_balance": args.load_balance, + "overlap_alltoall": args.overlap_alltoall } - mgr_dict = {"seed": 42, "use_kernel_optim": args.use_kernel, "enable_load_balance": args.load_balance} if args.plugin == "zero": dp_size = dist.get_world_size() - plugin = LowLevelZeroPlugin(initial_scale=2**5, stage=2, verbose=True, precision="bf16") + plugin = MoeHybridParallelPlugin( + pp_size=1, + **hybrid_dict, + ) MOE_MANAGER.setup( parallel=None, **mgr_dict, @@ -164,7 +176,6 @@ def main(): dp_size = dist.get_world_size() plugin = MoeHybridParallelPlugin( pp_size=1, - zero_stage=2, **hybrid_dict, ) MOE_MANAGER.setup( @@ -177,23 +188,6 @@ def main(): use_ep_inside = False plugin = MoeHybridParallelPlugin( pp_size=1, - zero_stage=1, - extra_dp_size=args.extra_dp_size, - use_ep_inside=use_ep_inside, - **hybrid_dict, - ) - MOE_MANAGER.setup( - parallel="EP", - max_ep_size=dp_size // args.extra_dp_size, - use_ep_inside=use_ep_inside, - **mgr_dict, - ) - elif args.plugin == "zero_ep": - dp_size = dist.get_world_size() - use_ep_inside = True - plugin = MoeHybridParallelPlugin( - pp_size=1, - zero_stage=1, extra_dp_size=args.extra_dp_size, use_ep_inside=use_ep_inside, **hybrid_dict, @@ -248,7 +242,7 @@ def main(): dataloader = plugin.prepare_dataloader(dataset, batch_size=args.batch_size) # Set optimizer - optimizer = torch.optim.Adam(model.parameters(), weight_decay=0.01, lr=1e-5) + optimizer = HybridAdam(model.parameters(), weight_decay=0.01, lr=1e-5) model_numel = get_model_numel(model) performance_evaluator = PerformanceEvaluator( @@ -260,8 +254,8 @@ def main(): # Set booster booster = Booster(plugin=plugin, **booster_kwargs) - model, optimizer, _, dataloader, _ = booster.boost(model=model, optimizer=optimizer, dataloader=dataloader) load_ckpt(repo_name, model, booster) + model, optimizer, _, dataloader, _ = booster.boost(model=model, optimizer=optimizer, dataloader=dataloader) use_pipeline = (isinstance(booster.plugin, MoeHybridParallelPlugin) and booster.plugin.pp_size > 1) is_pp_last_stage = use_pipeline and booster.plugin.stage_manager.is_last_stage() coordinator.print_on_master(f"Finish init booster") @@ -272,6 +266,13 @@ def main(): train_dataloader_iter = iter(dataloader) total_len = len(train_dataloader_iter) - 1 exmaple_data = next(train_dataloader_iter) + # from torch.profiler import ProfilerActivity + # with torch.profiler.profile( + # activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA], + # schedule=torch.profiler.schedule(wait=0, warmup=args.warmup, active=1, repeat=1), + # on_trace_ready=torch.profiler.tensorboard_trace_handler("./log"), + # with_stack=True, + # ) as prof: with tqdm(range(total_len), disable=not coordinator.is_master()) as pbar: for step in pbar: performance_evaluator.on_step_start(step) @@ -303,8 +304,9 @@ def main(): optimizer.zero_grad() performance_evaluator.on_step_end(exmaple_data["input_ids"]) if (step == args.warmup // 2) and args.load_balance: - apply_load_balance(model, optimizer) coordinator.print_on_master(f"Apply load balance") + apply_load_balance(model, optimizer) + # prof.step() performance_evaluator.on_fit_end() diff --git a/examples/language/openmoe/benchmark/benchmark_cai.sh b/examples/language/openmoe/benchmark/benchmark_cai.sh index ec4490faa55d..8ac1ae1d86a6 100755 --- a/examples/language/openmoe/benchmark/benchmark_cai.sh +++ b/examples/language/openmoe/benchmark/benchmark_cai.sh @@ -2,10 +2,10 @@ set -xue -NUM_GPU=4 +NUM_GPU=8 MODEL="8b" SEQ_LENGTH=2048 -WARMUP=8 +WARMUP=20 ACTIVE=4 # HACK: make model importable @@ -16,18 +16,22 @@ else export PYTHONPATH=$example_dir:$PYTHONPATH fi -# zero + +# ep +echo -e "\n\n EP \n\n" torchrun --standalone --nproc_per_node $NUM_GPU \ $example_dir/benchmark/benchmark_cai.py \ --model_name $MODEL \ - --batch_size 4 \ + --batch_size 12 \ --seq_length $SEQ_LENGTH \ --warmup $WARMUP \ --active $ACTIVE \ - --plugin zero \ - --use_kernel + --plugin ep \ + --use_kernel \ + --zero_stage 2 \ + --load_balance -# ep +echo -e "\n\n EP + Overlap \n\n" torchrun --standalone --nproc_per_node $NUM_GPU \ $example_dir/benchmark/benchmark_cai.py \ --model_name $MODEL \ @@ -36,9 +40,43 @@ torchrun --standalone --nproc_per_node $NUM_GPU \ --warmup $WARMUP \ --active $ACTIVE \ --plugin ep \ - --use_kernel + --use_kernel \ + --zero_stage 2 \ + --load_balance \ + --overlap_alltoall + # ep_zero +echo -e "\n\n EP-ZERO \n\n" +torchrun --standalone --nproc_per_node $NUM_GPU \ + $example_dir/benchmark/benchmark_cai.py \ + --model_name $MODEL \ + --batch_size 16 \ + --seq_length $SEQ_LENGTH \ + --warmup $WARMUP \ + --active $ACTIVE \ + --plugin ep_zero \ + --use_kernel \ + --extra_dp_size 2 \ + --zero_stage 1 \ + --load_balance + +echo -e "\n\n EP-ZERO + Overlap \n\n" +torchrun --standalone --nproc_per_node $NUM_GPU \ + $example_dir/benchmark/benchmark_cai.py \ + --model_name $MODEL \ + --batch_size 16 \ + --seq_length $SEQ_LENGTH \ + --warmup $WARMUP \ + --active $ACTIVE \ + --plugin ep_zero \ + --use_kernel \ + --extra_dp_size 2 \ + --zero_stage 1 \ + --load_balance \ + --overlap_alltoall + +echo -e "\n\n EP-ZERO-2 \n\n" torchrun --standalone --nproc_per_node $NUM_GPU \ $example_dir/benchmark/benchmark_cai.py \ --model_name $MODEL \ @@ -48,9 +86,11 @@ torchrun --standalone --nproc_per_node $NUM_GPU \ --active $ACTIVE \ --plugin ep_zero \ --use_kernel \ - --extra_dp_size 2 + --extra_dp_size 2 \ + --zero_stage 2 \ + --load_balance -# zero_ep +echo -e "\n\n EP-ZERO-2 \n\n" torchrun --standalone --nproc_per_node $NUM_GPU \ $example_dir/benchmark/benchmark_cai.py \ --model_name $MODEL \ @@ -58,9 +98,12 @@ torchrun --standalone --nproc_per_node $NUM_GPU \ --seq_length $SEQ_LENGTH \ --warmup $WARMUP \ --active $ACTIVE \ - --plugin zero_ep \ + --plugin ep_zero \ --use_kernel \ - --extra_dp_size 2 + --extra_dp_size 2 \ + --zero_stage 2 \ + --load_balance \ + --overlap_alltoall # hybrid torchrun --standalone --nproc_per_node $NUM_GPU \ From 4c701aa67f36160f59174b16b18e4cf728b7a31a Mon Sep 17 00:00:00 2001 From: Xuanlei Zhao Date: Wed, 18 Oct 2023 22:31:15 +0800 Subject: [PATCH 08/17] update bench --- examples/language/openmoe/benchmark/benchmark_cai.py | 8 -------- 1 file changed, 8 deletions(-) diff --git a/examples/language/openmoe/benchmark/benchmark_cai.py b/examples/language/openmoe/benchmark/benchmark_cai.py index 249ea424cc4c..f07151253fbc 100644 --- a/examples/language/openmoe/benchmark/benchmark_cai.py +++ b/examples/language/openmoe/benchmark/benchmark_cai.py @@ -266,13 +266,6 @@ def main(): train_dataloader_iter = iter(dataloader) total_len = len(train_dataloader_iter) - 1 exmaple_data = next(train_dataloader_iter) - # from torch.profiler import ProfilerActivity - # with torch.profiler.profile( - # activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA], - # schedule=torch.profiler.schedule(wait=0, warmup=args.warmup, active=1, repeat=1), - # on_trace_ready=torch.profiler.tensorboard_trace_handler("./log"), - # with_stack=True, - # ) as prof: with tqdm(range(total_len), disable=not coordinator.is_master()) as pbar: for step in pbar: performance_evaluator.on_step_start(step) @@ -306,7 +299,6 @@ def main(): if (step == args.warmup // 2) and args.load_balance: coordinator.print_on_master(f"Apply load balance") apply_load_balance(model, optimizer) - # prof.step() performance_evaluator.on_fit_end() From fc2c2b4cee021bb61c7fbdf6305f3dd3770f10fe Mon Sep 17 00:00:00 2001 From: Xuanlei Zhao Date: Wed, 18 Oct 2023 22:32:22 +0800 Subject: [PATCH 09/17] update script --- examples/language/openmoe/benchmark/benchmark_cai.sh | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/language/openmoe/benchmark/benchmark_cai.sh b/examples/language/openmoe/benchmark/benchmark_cai.sh index 8ac1ae1d86a6..b198ddd095fa 100755 --- a/examples/language/openmoe/benchmark/benchmark_cai.sh +++ b/examples/language/openmoe/benchmark/benchmark_cai.sh @@ -90,7 +90,7 @@ torchrun --standalone --nproc_per_node $NUM_GPU \ --zero_stage 2 \ --load_balance -echo -e "\n\n EP-ZERO-2 \n\n" +echo -e "\n\n EP-ZERO-2 + Overlap \n\n" torchrun --standalone --nproc_per_node $NUM_GPU \ $example_dir/benchmark/benchmark_cai.py \ --model_name $MODEL \ From fb3c966850c2c5d018749d86aa619f50f4875e70 Mon Sep 17 00:00:00 2001 From: Xuanlei Zhao Date: Wed, 18 Oct 2023 22:42:04 +0800 Subject: [PATCH 10/17] remove epzero2 --- .../openmoe/benchmark/benchmark_cai.sh | 29 ------------------- 1 file changed, 29 deletions(-) diff --git a/examples/language/openmoe/benchmark/benchmark_cai.sh b/examples/language/openmoe/benchmark/benchmark_cai.sh index b198ddd095fa..fda468e7fd93 100755 --- a/examples/language/openmoe/benchmark/benchmark_cai.sh +++ b/examples/language/openmoe/benchmark/benchmark_cai.sh @@ -76,35 +76,6 @@ torchrun --standalone --nproc_per_node $NUM_GPU \ --load_balance \ --overlap_alltoall -echo -e "\n\n EP-ZERO-2 \n\n" -torchrun --standalone --nproc_per_node $NUM_GPU \ - $example_dir/benchmark/benchmark_cai.py \ - --model_name $MODEL \ - --batch_size 12 \ - --seq_length $SEQ_LENGTH \ - --warmup $WARMUP \ - --active $ACTIVE \ - --plugin ep_zero \ - --use_kernel \ - --extra_dp_size 2 \ - --zero_stage 2 \ - --load_balance - -echo -e "\n\n EP-ZERO-2 + Overlap \n\n" -torchrun --standalone --nproc_per_node $NUM_GPU \ - $example_dir/benchmark/benchmark_cai.py \ - --model_name $MODEL \ - --batch_size 12 \ - --seq_length $SEQ_LENGTH \ - --warmup $WARMUP \ - --active $ACTIVE \ - --plugin ep_zero \ - --use_kernel \ - --extra_dp_size 2 \ - --zero_stage 2 \ - --load_balance \ - --overlap_alltoall - # hybrid torchrun --standalone --nproc_per_node $NUM_GPU \ $example_dir/benchmark/benchmark_cai.py \ From 26e29eb1ec38ff9750ecb91e125efba4b197ca66 Mon Sep 17 00:00:00 2001 From: Xuanlei Zhao Date: Thu, 19 Oct 2023 14:00:40 +0800 Subject: [PATCH 11/17] fix --- .../openmoe/benchmark/benchmark_cai.sh | 23 ++++--------------- .../openmoe/benchmark/benchmark_fsdp.sh | 4 ++-- 2 files changed, 6 insertions(+), 21 deletions(-) diff --git a/examples/language/openmoe/benchmark/benchmark_cai.sh b/examples/language/openmoe/benchmark/benchmark_cai.sh index fda468e7fd93..f269e260d8db 100755 --- a/examples/language/openmoe/benchmark/benchmark_cai.sh +++ b/examples/language/openmoe/benchmark/benchmark_cai.sh @@ -18,32 +18,16 @@ fi # ep -echo -e "\n\n EP \n\n" +echo -e "\n\n Naive EP \n\n" torchrun --standalone --nproc_per_node $NUM_GPU \ $example_dir/benchmark/benchmark_cai.py \ --model_name $MODEL \ - --batch_size 12 \ + --batch_size 8 \ --seq_length $SEQ_LENGTH \ --warmup $WARMUP \ --active $ACTIVE \ --plugin ep \ - --use_kernel \ - --zero_stage 2 \ - --load_balance - -echo -e "\n\n EP + Overlap \n\n" -torchrun --standalone --nproc_per_node $NUM_GPU \ - $example_dir/benchmark/benchmark_cai.py \ - --model_name $MODEL \ - --batch_size 12 \ - --seq_length $SEQ_LENGTH \ - --warmup $WARMUP \ - --active $ACTIVE \ - --plugin ep \ - --use_kernel \ - --zero_stage 2 \ - --load_balance \ - --overlap_alltoall + --zero_stage 2 # ep_zero @@ -76,6 +60,7 @@ torchrun --standalone --nproc_per_node $NUM_GPU \ --load_balance \ --overlap_alltoall + # hybrid torchrun --standalone --nproc_per_node $NUM_GPU \ $example_dir/benchmark/benchmark_cai.py \ diff --git a/examples/language/openmoe/benchmark/benchmark_fsdp.sh b/examples/language/openmoe/benchmark/benchmark_fsdp.sh index e1eb2a9c6053..0380ee1ade20 100755 --- a/examples/language/openmoe/benchmark/benchmark_fsdp.sh +++ b/examples/language/openmoe/benchmark/benchmark_fsdp.sh @@ -6,8 +6,8 @@ NUM_GPU=8 MODEL="8b" BATCH_SIZE=1 SEQ_LENGTH=2048 -WARMUP=6 -ACTIVE=3 +WARMUP=8 +ACTIVE=4 # HACK: make model importable example_dir=$(dirname $(realpath $(dirname $0))) From 5129a79477c2dc60a9c692f740b5011590ab76c0 Mon Sep 17 00:00:00 2001 From: Xuanlei Zhao Date: Thu, 19 Oct 2023 14:02:48 +0800 Subject: [PATCH 12/17] update print --- colossalai/moe/load_balance.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/colossalai/moe/load_balance.py b/colossalai/moe/load_balance.py index 9ac6d5f3c0bd..932a1e8e4647 100644 --- a/colossalai/moe/load_balance.py +++ b/colossalai/moe/load_balance.py @@ -425,9 +425,9 @@ def balance_load(self, optim: LowLevelZeroOptimizer) -> None: swap_list = self._search_balance(load) if dist.get_rank() == 0: if len(swap_list) > 0: - print(f"Apply swap...") + print(f"[Load Balance] Applying expert swap...") else: - print(f"Invalid swap, continue...") + print(f"[Load Balance] Invalid swap, skip...") # swap expert and gate self._swap_moe_param(swap_list, optim) # clear load From 0a60c961213ad78824dceafb7add1f443cb49646 Mon Sep 17 00:00:00 2001 From: Xuanlei Zhao Date: Thu, 19 Oct 2023 16:48:44 +0800 Subject: [PATCH 13/17] update test script --- tests/test_moe/test_moe_ep_tp.py | 31 ++++++++----------------------- 1 file changed, 8 insertions(+), 23 deletions(-) diff --git a/tests/test_moe/test_moe_ep_tp.py b/tests/test_moe/test_moe_ep_tp.py index 51fd135483b6..11d0664fd580 100644 --- a/tests/test_moe/test_moe_ep_tp.py +++ b/tests/test_moe/test_moe_ep_tp.py @@ -11,32 +11,20 @@ from tests.test_moe.moe_utils import MoeGradientHandler, sync_local_from_ep, sync_tp_from_ep -def run_test(rank: int, - world_size: int, - port: int, - num_experts: int, - batch_size: int, - dim: int, - seed: int): +def run_test(rank: int, world_size: int, port: int, num_experts: int, batch_size: int, dim: int, seed: int): assert batch_size % world_size == 0 colossalai.launch(config=dict(), rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') MOE_MANAGER.__init__() MOE_MANAGER.setup(seed, parallel=None) - local_model = SparseMLP(num_experts=num_experts, - hidden_size=dim, - intermediate_size=dim * 2) + local_model = SparseMLP(num_experts=num_experts, hidden_size=dim, intermediate_size=dim * 2) MOE_MANAGER.__init__() MOE_MANAGER.setup(seed, parallel="EP") - ep_model = SparseMLP(num_experts=num_experts, - hidden_size=dim, - intermediate_size=dim * 2) + ep_model = SparseMLP(num_experts=num_experts, hidden_size=dim, intermediate_size=dim * 2) MOE_MANAGER.__init__() MOE_MANAGER.setup(seed, parallel="TP") - tp_model = SparseMLP(num_experts=num_experts, - hidden_size=dim, - intermediate_size=dim * 2) + tp_model = SparseMLP(num_experts=num_experts, hidden_size=dim, intermediate_size=dim * 2) ep_model = ep_model.to(get_current_device()) tp_model = tp_model.to(get_current_device()) local_model = local_model.to(get_current_device()) @@ -81,14 +69,11 @@ def run_test(rank: int, @pytest.mark.dist @pytest.mark.parametrize("num_experts", [4, 8]) -@pytest.mark.parametrize("batch_size", [4, 8]) -@pytest.mark.parametrize("dim", [16, 256]) -@pytest.mark.parametrize("seed", [42, 78]) +@pytest.mark.parametrize("batch_size", [4]) +@pytest.mark.parametrize("dim", [32]) +@pytest.mark.parametrize("seed", [42]) @rerun_if_address_is_in_use() -def test_moe_ep_tp(num_experts: int, - batch_size: int, - dim: int, - seed: int): +def test_moe_ep_tp(num_experts: int, batch_size: int, dim: int, seed: int): spawn(run_test, 2, num_experts=num_experts, batch_size=batch_size, dim=dim, seed=seed) From 8123c3b2db1f77c3a82d14f29dbf877d8ac47bfb Mon Sep 17 00:00:00 2001 From: Xuanlei Zhao Date: Thu, 19 Oct 2023 19:06:18 +0800 Subject: [PATCH 14/17] update script --- colossalai/moe/__init__.py | 15 ++- colossalai/moe/experts.py | 97 ++------------ colossalai/moe/layers.py | 126 ++++++++++-------- colossalai/moe/load_balance.py | 6 +- .../openmoe/model/modeling_openmoe.py | 16 +-- tests/test_moe/test_grad_handler.py | 23 +++- tests/test_moe/test_kernel.py | 8 +- tests/test_moe/test_moe_group.py | 57 ++++---- 8 files changed, 152 insertions(+), 196 deletions(-) diff --git a/colossalai/moe/__init__.py b/colossalai/moe/__init__.py index 1614987538c1..f32e89dfad3f 100644 --- a/colossalai/moe/__init__.py +++ b/colossalai/moe/__init__.py @@ -1,12 +1,17 @@ from .checkpoint import MoeCheckpintIO -from .experts import EPMLPExperts, TPMLPExperts, build_ffn_experts +from .experts import MLPExperts from .layers import SparseMLP from .routers import MoeRouter, Top1Router, Top2Router, TopKRouter from .utils import NormalNoiseGenerator, UniformNoiseGenerator __all__ = [ - 'EPMLPExperts', 'TPMLPExperts', 'build_ffn_experts', - 'MoeRouter', 'Top1Router', 'Top2Router', 'TopKRouter', - 'NormalNoiseGenerator', 'UniformNoiseGenerator', - 'SparseMLP', 'MoeCheckpintIO' + "MLPExperts", + "MoeRouter", + "Top1Router", + "Top2Router", + "TopKRouter", + "NormalNoiseGenerator", + "UniformNoiseGenerator", + "SparseMLP", + "MoeCheckpintIO", ] diff --git a/colossalai/moe/experts.py b/colossalai/moe/experts.py index 40f3d59f42bd..3471b2876e9b 100644 --- a/colossalai/moe/experts.py +++ b/colossalai/moe/experts.py @@ -15,18 +15,19 @@ from colossalai.kernel.triton.llama_act_combine_kernel import LlamaActCombine -class BaseMLPExperts(nn.Module): +class MLPExperts(nn.Module): """ SparseMLP is a multi-layer perceptron with sparse expert parallel layers. Args: num_experts (int): The number of experts - forward: hidden_size --> intermediate_size --> hidden_size - hidden_size (int): The hidden size of MLP - intermediate_size (int): The intermediate size of MLP - expert_parallel (str, optional): The parallelism of experts. Now we have 'EP' and 'TP'. + hidden_size (int): The hidden size of MLP + intermediate_size (int): The intermediate size of MLP + expert_parallel (str, optional): The parallelism of experts. Now we have None, EP and TP. activation (optional): The activation function of MLP drop_rate (float, optional): The drop rate of MLP + gated (bool, optional): Whether to use gated MLP + use_kernel (bool, optional): Whether to use kernel optimization """ def __init__( @@ -36,9 +37,9 @@ def __init__( intermediate_size: int, expert_parallel: Optional[str] = None, activation: Optional[Callable] = None, - drop_rate: float = 0, - gated: bool = False, - use_kernel: bool = False, + drop_rate: Optional[float] = 0, + gated: Optional[bool] = False, + use_kernel: Optional[bool] = False, ): super().__init__() assert expert_parallel in ["EP", "TP", None] @@ -104,6 +105,8 @@ def forward( use_sparse: bool = True, ) -> torch.Tensor: """ + forward: hidden_size --> intermediate_size --> hidden_size + Args: x (torch.Tensor): The input tensor of shape (num_groups, num_experts, capacity, hidden_size) @@ -151,81 +154,3 @@ def forward( x = x.transpose(0, 1).contiguous() x = MoeOutGradScaler.apply(x, self.ep_size) return x - - -class EPMLPExperts(BaseMLPExperts): - """ - Use expert parallelism to split each expert evenly, which can deploy experts in - """ - - def __init__( - self, - num_experts: int, - hidden_size: int, - intermediate_size: int, - activation=None, - drop_rate: float = 0, - gated: bool = False, - use_kernel: bool = False, - ): - # TODO: This class can be aborted - super().__init__( - num_experts, - hidden_size, - intermediate_size, - "EP", - activation, - drop_rate, - gated, - use_kernel, - ) - - -class TPMLPExperts(BaseMLPExperts): - """Use tensor parallelism to split each expert evenly, which can deploy experts in - case that the number of experts can't be divide by maximum expert parallel size or - maximum expert parallel size can't be divide by the number of experts. - """ - - def __init__( - self, - num_experts: int, - hidden_size: int, - intermediate_size: int, - activation: str = None, - drop_rate: float = 0, - gated: bool = False, - use_kernel: bool = False, - ): - # TODO: This class can be aborted - super().__init__( - num_experts, - hidden_size, - intermediate_size, - "TP", - activation, - drop_rate, - gated, - use_kernel, - ) - - -def get_expert_class(name: str) -> BaseMLPExperts: - if name == "TP": - return TPMLPExperts - elif name == "EP": - return EPMLPExperts - elif name is None: - return BaseMLPExperts - else: - raise ValueError(f"Unknown expert class name: {name}") - - -def build_ffn_experts(num_experts: int, d_model: int, d_ff: int, activation=None, drop_rate: float = 0): - mep_size = MOE_MANAGER.max_ep_size - if num_experts % mep_size == 0 or mep_size % num_experts == 0: - return EPMLPExperts(num_experts, d_model, d_ff, activation, drop_rate) - elif d_ff % mep_size == 0: - return TPMLPExperts(num_experts, d_model, d_ff, activation, drop_rate) - else: - raise NotImplementedError(f"Can not build {num_experts} experts in {mep_size} GPUS.") diff --git a/colossalai/moe/layers.py b/colossalai/moe/layers.py index 7510e95c4163..bd2cefbe9ab8 100644 --- a/colossalai/moe/layers.py +++ b/colossalai/moe/layers.py @@ -8,7 +8,7 @@ import torch.nn.functional as F from colossalai.moe._operation import AllGather, AllToAll, MoeCombine, MoeDispatch, ReduceScatter -from colossalai.moe.experts import BaseMLPExperts, get_expert_class +from colossalai.moe.experts import MLPExperts from colossalai.moe.load_balance import LoadBalancer from colossalai.moe.manager import MOE_MANAGER from colossalai.moe.routers import MoeRouter, get_router_cls @@ -48,51 +48,59 @@ class SparseMLP(nn.Module): def __init__( self, num_experts: int, - top_k: int = 1, - capacity_factor_train: float = 1.25, - capacity_factor_eval: float = 2.0, - min_capacity: int = 4, - noisy_policy: Optional[str] = None, - drop_tks: bool = True, - hidden_size: int = 2048, - intermediate_size: int = 2048, - activation: str = None, - gated: bool = False, + hidden_size: int, + intermediate_size: int, + router_top_k: int = 1, + router_capacity_factor_train: Optional[float] = 1.25, + router_capacity_factor_eval: Optional[float] = 2.0, + router_min_capacity: Optional[int] = 4, + router_noisy_policy: Optional[str] = None, + router_drop_tks: Optional[bool] = True, + mlp_activation: Optional[str] = None, + mlp_gated: Optional[bool] = False, + enable_load_balance: Optional[bool] = False, + load_balance_tolerance: Optional[float] = 0.1, + load_balance_beam_width: Optional[int] = 8, + load_balance_group_swap_factor: Optional[float] = 0.4, + enable_kernel: Optional[bool] = False, + enable_comm_overlap: Optional[bool] = False, ): super().__init__() self.hidden_size = hidden_size self.intermediate_size = intermediate_size self.num_experts = num_experts - self.use_kernel = MOE_MANAGER.use_kernel_optim + self.gated = mlp_gated + self.enable_kernel = enable_kernel + self.enable_comm_overlap = enable_comm_overlap self.expert_parallel = MOE_MANAGER.get_parallel() - self.gated = gated - self.overlap = MOE_MANAGER.overlap_alltoall - assert self.expert_parallel in [ - "EP", - "TP", - None, - ], f"Unsupported expert parallel type {self.expert_parallel}" # moe router - noisy_func = get_noise_generator(noisy_policy, num_experts) - router_cls = get_router_cls(top_k) - self.topk = top_k + noisy_func = get_noise_generator(router_noisy_policy, num_experts) + router_cls = get_router_cls(router_top_k) + self.topk = router_top_k self.router: MoeRouter = router_cls( - capacity_factor_train=capacity_factor_train, - capacity_factor_eval=capacity_factor_eval, - min_capacity=min_capacity, + capacity_factor_train=router_capacity_factor_train, + capacity_factor_eval=router_capacity_factor_eval, + min_capacity=router_min_capacity, noisy_func=noisy_func, - drop_tks=drop_tks, + drop_tks=router_drop_tks, ) + # gate + self.gate_weight = torch.nn.Parameter(torch.empty(num_experts, self.hidden_size)) + # moe experts - expert_cls = get_expert_class(self.expert_parallel) - self.experts: BaseMLPExperts = expert_cls(num_experts=num_experts, - hidden_size=hidden_size, - intermediate_size=intermediate_size, - activation=activation, - gated=gated, - use_kernel=self.use_kernel) + self.experts = MLPExperts( + num_experts=self.num_experts, + expert_parallel=self.expert_parallel, + hidden_size=self.hidden_size, + intermediate_size=self.intermediate_size, + activation=mlp_activation, + gated=mlp_gated, + use_kernel=self.enable_kernel, + ) + + # get parallel settings if self.expert_parallel is not None: self.ep_group = get_ep_group(self.experts) self.ep_size = get_ep_size(self.experts) @@ -102,11 +110,8 @@ def __init__( self.dp_group = None self.num_local_experts = self.experts.num_local_experts - # gate - self.gate_weight = torch.nn.Parameter(torch.empty(num_experts, self.hidden_size)) - # load balance - self.enable_load_balance = MOE_MANAGER.load_balance + self.enable_load_balance = enable_load_balance if self.enable_load_balance == True: self.load_balancer = LoadBalancer( experts=self.experts, @@ -115,9 +120,9 @@ def __init__( expert_num=self.num_experts, ep_group=self.ep_group, dp_group=self.dp_group, - tolerance=MOE_MANAGER.tolerance, - beam_width=MOE_MANAGER.beam_width, - group_swap_factor=MOE_MANAGER.group_swap_factor, + tolerance=load_balance_tolerance, + beam_width=load_balance_beam_width, + group_swap_factor=load_balance_group_swap_factor, ) # init param @@ -153,10 +158,10 @@ def forward(self, inputs: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: self.load_balancer.update_load(expert_load) # the result from the router - route_result_list = self.router(inputs=gate_output, use_kernel=self.use_kernel, ep_group=self.ep_group) + route_result_list = self.router(inputs=gate_output, use_kernel=self.enable_kernel, ep_group=self.ep_group) # dispatch_data: (num_experts, capacity, hidden_size) - if self.use_kernel: + if self.enable_kernel: dispatch_data = MoeDispatch.apply(tokens, *route_result_list[1:]) dispatch_data = dispatch_data.reshape(self.num_experts, -1, self.hidden_size) else: @@ -165,16 +170,16 @@ def forward(self, inputs: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: # expert_output: (num_groups, num_experts, capacity, hidden_size) if self.expert_parallel == "EP": - expert_output = self._ep_process(dispatch_data, overlap=self.overlap) + expert_output = self._ep_process(dispatch_data, overlap=self.enable_comm_overlap) elif self.expert_parallel == "TP": - expert_output = self._tp_process(dispatch_data, overlap=self.overlap) + expert_output = self._tp_process(dispatch_data, overlap=self.enable_comm_overlap) elif self.expert_parallel is None: expert_output = self._local_process(dispatch_data) else: raise NotImplementedError("This kind of communication has not been implemented yet.\n" "Please use Experts build function.") - if self.use_kernel: + if self.enable_kernel: expert_output = expert_output.reshape(-1, self.hidden_size) ans = MoeCombine.apply(expert_output, *route_result_list) else: @@ -211,15 +216,14 @@ def _ep_process(self, dispatch_data: torch.Tensor, overlap: bool = False) -> tor else: @dataclasses.dataclass - class Capsule(): + class Capsule: data: torch.Tensor handle: Any = None NUM_CHUNK = 4 NUM_STAGES = 4 - assert dispatch_data.shape[1] % NUM_CHUNK == 0, \ - "arbitrary chunk num is not supported yet" + assert (dispatch_data.shape[1] % NUM_CHUNK == 0), "arbitrary chunk num is not supported yet" chunk_size = dispatch_data.shape[1] // NUM_CHUNK input_shape = (self.ep_size, self.num_local_experts, -1, self.hidden_size) dispatch_data = dispatch_data.reshape(*input_shape) @@ -283,7 +287,7 @@ def _tp_process(self, dispatch_data: torch.Tensor, overlap: bool = False) -> tor else: @dataclasses.dataclass - class Capsule(): + class Capsule: data: torch.Tensor handle: Any indices: Tuple @@ -291,8 +295,8 @@ class Capsule(): NUM_CHUNK = 4 NUM_STAGES = 4 - assert dispatch_data.shape[0] % NUM_CHUNK == 0, \ - "arbitrary chunk num is not supported yet, please use chunk num that can divide num_experts" + assert (dispatch_data.shape[0] % NUM_CHUNK == 0 + ), "arbitrary chunk num is not supported yet, please use chunk num that can divide num_experts" chunk_size = dispatch_data.shape[0] // NUM_CHUNK chunk_data = torch.split(dispatch_data, chunk_size, dim=0) output = torch.empty_like(dispatch_data) @@ -310,21 +314,27 @@ def get_chunk_slice(idx: int, chunk_size: int) -> Tuple[slice]: # reduce scatter last output if _expert_out is not None: - expert_out = Capsule(*ReduceScatter.apply(_expert_out.data, self.ep_group, True), - indices=_expert_out.indices) + expert_out = Capsule( + *ReduceScatter.apply(_expert_out.data, self.ep_group, True), + indices=_expert_out.indices, + ) _expert_out = None # all gather next input if 0 <= i < NUM_CHUNK: - _expert_in = Capsule(*AllGather.apply(chunk_data[i].contiguous(), self.ep_group, True), - indices=get_chunk_slice(i, chunk_size)) + _expert_in = Capsule( + *AllGather.apply(chunk_data[i].contiguous(), self.ep_group, True), + indices=get_chunk_slice(i, chunk_size), + ) # compute if expert_in is not None: expert_in.handle.wait() - _expert_out = Capsule(self.experts(expert_in.data, expert_in.indices), - handle=None, - indices=expert_in.indices) + _expert_out = Capsule( + self.experts(expert_in.data, expert_in.indices), + handle=None, + indices=expert_in.indices, + ) expert_in = None if _expert_in is not None: diff --git a/colossalai/moe/load_balance.py b/colossalai/moe/load_balance.py index 932a1e8e4647..4a3d0fe4d096 100644 --- a/colossalai/moe/load_balance.py +++ b/colossalai/moe/load_balance.py @@ -7,7 +7,7 @@ from torch.distributed import ProcessGroup from colossalai.cluster import ProcessGroupMesh -from colossalai.moe.experts import BaseMLPExperts +from colossalai.moe.experts import MLPExperts from colossalai.moe.manager import MOE_MANAGER from colossalai.zero.low_level import LowLevelZeroOptimizer @@ -16,7 +16,7 @@ class LoadBalancer: def __init__( self, - experts: BaseMLPExperts, + experts: MLPExperts, gate: nn.Parameter, local_expert_num: int, expert_num: int, @@ -26,7 +26,7 @@ def __init__( beam_width: Optional[int] = 8, group_swap_factor: Optional[float] = 0.4, ) -> None: - self.experts: BaseMLPExperts = experts + self.experts: MLPExperts = experts self.gate: nn.Parameter = gate self.moe_ep_group: ProcessGroup = ep_group self.moe_ep_ranks = MOE_MANAGER.parallel_info_dict[dist.get_world_size(self.moe_ep_group)].ep_group_ranks diff --git a/examples/language/openmoe/model/modeling_openmoe.py b/examples/language/openmoe/model/modeling_openmoe.py index 357c0f22a783..f4dba898d478 100644 --- a/examples/language/openmoe/model/modeling_openmoe.py +++ b/examples/language/openmoe/model/modeling_openmoe.py @@ -361,16 +361,16 @@ def __init__(self, config: LlamaConfig, moe: bool): self.post_attention_layernorm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps) if self.moe: self.mlp = SparseMLP(num_experts=config.num_experts, - top_k=config.topk, - capacity_factor_train=config.capacity_factor_train, - capacity_factor_eval=config.capacity_factor_eval, - min_capacity=config.min_capacity, - noisy_policy=config.noisy_policy, - drop_tks=config.drop_tks, + router_top_k=config.topk, + router_capacity_factor_train=config.capacity_factor_train, + router_capacity_factor_eval=config.capacity_factor_eval, + router_min_capacity=config.min_capacity, + router_noisy_policy=config.noisy_policy, + router_drop_tks=config.drop_tks, hidden_size=config.hidden_size, intermediate_size=config.intermediate_size, - activation=config.hidden_act, - gated=config.gated) + mlp_activation=config.hidden_act, + mlp_gated=config.gated) self.pre_extra_mlp_layernorm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.extra_mlp = OpenMoeMLP(config) else: diff --git a/tests/test_moe/test_grad_handler.py b/tests/test_moe/test_grad_handler.py index 13e142aadd7a..28ee618e1ba7 100644 --- a/tests/test_moe/test_grad_handler.py +++ b/tests/test_moe/test_grad_handler.py @@ -16,17 +16,26 @@ def run_test(rank, world_size, port): - colossalai.launch(config=dict(), rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') + colossalai.launch( + config=dict(), + rank=rank, + world_size=world_size, + host="localhost", + port=port, + backend="nccl", + ) MOE_MANAGER.setup(42, parallel="EP") # MOE initialization num_experts_list = [1, 2, 4] layer_list = [] for num_experts in num_experts_list: - moe_layer = SparseMLP(hidden_size=DIM, - intermediate_size=DIM * 4, - num_experts=num_experts, - top_k=1, - noisy_policy="Jitter") + moe_layer = SparseMLP( + hidden_size=DIM, + intermediate_size=DIM * 4, + num_experts=num_experts, + router_top_k=1, + router_noisy_policy="Jitter", + ) layer_list.append(moe_layer) model = nn.ModuleList(layer_list) @@ -77,5 +86,5 @@ def test_grad_handler(): spawn(run_test, 4) -if __name__ == '__main__': +if __name__ == "__main__": test_grad_handler() diff --git a/tests/test_moe/test_kernel.py b/tests/test_moe/test_kernel.py index db40110d8d9e..4b9885ab0691 100644 --- a/tests/test_moe/test_kernel.py +++ b/tests/test_moe/test_kernel.py @@ -33,14 +33,14 @@ def run_routing(rank, world_size, port, rs=2, hidden_size=128, data_type=torch.f layer = SparseMLP(hidden_size=hidden_size, intermediate_size=hidden_size * 2, num_experts=NUM_EXPERTS, - top_k=topk, - capacity_factor_train=1.0) + router_top_k=topk, + router_capacity_factor_train=1.0) layer = layer.to(get_current_device()) if data_type == torch.float16: layer = layer.half() # use matrix multiplication instead of COL_MOE_KERNEL in MOE dispatch and combine - layer.use_kernel = False + layer.enable_kernel = False old_out = layer(tokens) ech = old_out.shape grad = torch.randn(ech, device=get_current_device()) @@ -54,7 +54,7 @@ def run_routing(rank, world_size, port, rs=2, hidden_size=128, data_type=torch.f tokens.grad.zero_() layer.gate_weight.grad.zero_() - layer.use_kernel = True + layer.enable_kernel = True new_out = layer(tokens) # get outputs through colossal kernel if data_type == torch.float32: diff --git a/tests/test_moe/test_moe_group.py b/tests/test_moe/test_moe_group.py index e111ea6bb18d..3cd5acc0d953 100644 --- a/tests/test_moe/test_moe_group.py +++ b/tests/test_moe/test_moe_group.py @@ -3,7 +3,7 @@ import torch.nn as nn import colossalai -from colossalai.moe import EPMLPExperts, TPMLPExperts +from colossalai.moe.experts import MLPExperts from colossalai.moe.manager import MOE_MANAGER from colossalai.moe.utils import sync_moe_model_param from colossalai.testing import assert_equal_in_group, rerun_if_address_is_in_use, spawn @@ -13,38 +13,39 @@ INTERMEDIATE_SIZE = 8 -def run_moe_init(expert_cls): - expert_args = dict(hidden_size=HIDDEN_SIZE, intermediate_size=INTERMEDIATE_SIZE) - exp0 = expert_cls(1, **expert_args) - exp1 = expert_cls(2, **expert_args) - exp2 = expert_cls(4, **expert_args) - exp3 = expert_cls(8, **expert_args) +def run_moe_init(expert_parallel): + MOE_MANAGER.__init__() + MOE_MANAGER.setup(seed=42, parallel=expert_parallel) + expert_args = dict( + hidden_size=HIDDEN_SIZE, + intermediate_size=INTERMEDIATE_SIZE, + expert_parallel=expert_parallel, + ) + exp0 = MLPExperts(1, **expert_args) + exp1 = MLPExperts(2, **expert_args) + exp2 = MLPExperts(4, **expert_args) - if expert_cls is EPMLPExperts: + if expert_parallel == "EP": assert exp0.num_local_experts == 1 assert exp1.num_local_experts == 1 - assert exp2.num_local_experts == 1 - assert exp3.num_local_experts == 2 + assert exp2.num_local_experts == 2 else: assert exp0.num_local_experts == 1 assert exp1.num_local_experts == 2 assert exp2.num_local_experts == 4 - assert exp3.num_local_experts == 8 parallel_info_dict = MOE_MANAGER.parallel_info_dict rank = dist.get_rank() # group creation assert - assert len(parallel_info_dict) == 3 - assert dist.get_rank(parallel_info_dict[4].ep_group) == rank + assert len(parallel_info_dict) == 2 assert dist.get_rank(parallel_info_dict[2].ep_group) == rank % 2 assert dist.get_rank(parallel_info_dict[1].ep_group) == 0 - assert dist.get_rank(parallel_info_dict[4].dp_group) == 0 assert dist.get_rank(parallel_info_dict[2].dp_group) == rank // 2 assert dist.get_rank(parallel_info_dict[1].dp_group) == rank - model = nn.ModuleList([exp0, exp1, exp2, exp3]) + model = nn.ModuleList([exp0, exp1, exp2]) model = model.to(get_current_device()) sync_moe_model_param(model) @@ -57,19 +58,25 @@ def run_moe_init(expert_cls): assert_equal_in_group(exp1.wo.data, parallel_info_dict[2].dp_group) -def _run_test(rank, world_size, port, expert_cls): - colossalai.launch(config=dict(), rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') - MOE_MANAGER.setup(seed=42, parallel="EP") - run_moe_init(expert_cls) +def _run_test(rank, world_size, port, expert_parallel): + colossalai.launch( + config=dict(), + rank=rank, + world_size=world_size, + host="localhost", + port=port, + backend="nccl", + ) + run_moe_init(expert_parallel) @pytest.mark.dist -@pytest.mark.parametrize("expert_cls", [EPMLPExperts, TPMLPExperts]) +@pytest.mark.parametrize("expert_parallel", ["EP", "TP"]) @rerun_if_address_is_in_use() -def test_moe_initialization(expert_cls): - spawn(_run_test, 4, expert_cls=expert_cls) +def test_moe_initialization(expert_parallel): + spawn(_run_test, 2, expert_parallel=expert_parallel) -if __name__ == '__main__': - test_moe_initialization(EPMLPExperts) - test_moe_initialization(TPMLPExperts) +if __name__ == "__main__": + test_moe_initialization("EP") + test_moe_initialization("TP") From 82afebc06de4709f35903b2c0b730436209746fd Mon Sep 17 00:00:00 2001 From: Xuanlei Zhao Date: Fri, 20 Oct 2023 10:05:34 +0800 Subject: [PATCH 15/17] update manager --- colossalai/moe/manager.py | 87 +++++++++++++-------------------------- 1 file changed, 29 insertions(+), 58 deletions(-) diff --git a/colossalai/moe/manager.py b/colossalai/moe/manager.py index ea09d4d6e037..1b61965b83fe 100644 --- a/colossalai/moe/manager.py +++ b/colossalai/moe/manager.py @@ -14,30 +14,29 @@ class MoeManager(metaclass=SingletonMeta): """ def __init__(self): - self.world_size = None - # Users may want to set maximum expert parallel size smaller than the world size - # since very low bandwidth across nodes may constrain the performance of MoE - # When we have a maximum expert parallel size, we have a minimum data parallel size naturally - self.max_ep_size = None - self.min_dp_size = None - self.router_aux_loss = [] - self.router_z_loss = [] self.parallel = None self.seed = None self.mode = None - self.use_kernel_optim = False self.use_ep_inside = None + self.world_size = None + self._parallel_info_dict = dict() + + # router + self.router_aux_loss = [] + self.router_z_loss = [] + + # fixed mode self.pp_size = None + self.dp_size = None + self.ep_size = None - # load balance param - self.load_balance = None - self.tolerance = None - self.beam_width = None - self.group_swap_factor = None - self.overlap_alltoall = None + # dynamic mode + # Users may want to set maximum expert parallel size smaller than the world size + # since very low bandwidth across nodes may constrain the performance of MoE + # When we have a maximum expert parallel size, we have a minimum data parallel size naturally + self.max_ep_size = None self.has_setup = False - self._parallel_info_dict = dict() @property def parallel_info_dict(self): @@ -50,7 +49,6 @@ def is_initialized(self): def setup( self, seed: int, - use_kernel_optim: bool = False, parallel: str = None, mode: str = "dynamic", max_ep_size: int = 8, @@ -58,11 +56,6 @@ def setup( fixed_ep_size: int = 0, fixed_pp_size: int = 0, use_ep_inside: bool = True, - enable_load_balance: bool = False, - tolerance: float = 0.1, - beam_width: int = 8, - group_swap_factor: float = 0.4, - overlap_alltoall: bool = False, ) -> None: """ Setup MoE distributed context. @@ -80,38 +73,28 @@ def setup( fixed_pp_size (int, optional): Fixed pp size in fixed mode. Defaults to 0. use_ep_inside (bool, optional): Use ep inside dp if True, dp inside ep if Fasle. Defaults to True. """ - assert not self.is_initialized, "MoE distributed context shouldn't be set up again" + assert (not self.is_initialized), "MoE distributed context shouldn't be set up again" assert torch.cuda.is_available(), "MoE requires to enable CUDA first" - self.world_size = dist.get_world_size() self.seed = seed + dist.get_rank() self.parallel = parallel self.use_ep_inside = use_ep_inside + self.world_size = dist.get_world_size() # init by mode self.mode = mode assert self.mode in ["fixed", "dynamic"], "mode should be fixed or dynamic" if self.mode == "dynamic": - self.max_ep_size = min(max_ep_size, dist.get_world_size()) - self.min_dp_size = self.world_size // self.max_ep_size + self.max_ep_size = min(max_ep_size, self.world_size) else: - assert fixed_dp_size > 0 and fixed_ep_size > 0 and fixed_pp_size > 0, "dp_size, ep_size and pp_size should be greater than 0" - assert isinstance(fixed_dp_size, int) and isinstance(fixed_ep_size, int) and isinstance( - fixed_pp_size, int), "dp_size, ep_size and pp_size should be int" + assert (fixed_dp_size > 0 and fixed_ep_size > 0 + and fixed_pp_size > 0), "dp_size, ep_size and pp_size should be greater than 0" + assert (isinstance(fixed_dp_size, int) and isinstance(fixed_ep_size, int) + and isinstance(fixed_pp_size, int)), "dp_size, ep_size and pp_size should be int" self.ep_size = fixed_ep_size self.dp_size = fixed_dp_size self.pp_size = fixed_pp_size - # Enabling kernel optimization may raise error in some cases - # Users can close kernel optimization manually - self.use_kernel_optim = use_kernel_optim - - # update load balance - self.load_balance = enable_load_balance - self.tolerance = tolerance - self.beam_width = beam_width - self.group_swap_factor = group_swap_factor - self.overlap_alltoall = overlap_alltoall self.has_setup = True def get_info(self, num_experts: int, use_tp: bool = False) -> Tuple[int, MoeParallelInfo]: @@ -129,21 +112,12 @@ def get_info(self, num_experts: int, use_tp: bool = False) -> Tuple[int, MoePara """ if self.mode == "dynamic": - gt_flag = num_experts % self.max_ep_size == 0 # check whether num_experts is greater - lt_flag = self.max_ep_size % num_experts == 0 # check whether num_experts is less - - assert gt_flag or lt_flag, "Automatic experts placement dose not not support expert number" \ - " is not a multiple of ep size or vice versa." - - # If the number of experts is greater than maximum expert parallel size. a.k.a ep_size, - # there are multiple experts in each GPU and each GPU has different experts - # So it's data parallel size is 1 - # Otherwise, there is only one expert in each GPU - # The data parallel size should be calculated - dp_size = 1 if gt_flag else self.max_ep_size // num_experts - ep_size = self.max_ep_size // dp_size - # Don't forget to multiply minimum data parallel size - dp_size *= self.min_dp_size + gt_flag = (num_experts % self.max_ep_size == 0) # check whether num_experts is greater + lt_flag = (self.max_ep_size % num_experts == 0) # check whether num_experts is less + assert gt_flag or lt_flag, ("Automatic experts placement dose not not support expert number" + " is not a multiple of ep size or vice versa.") + dp_size = 1 if gt_flag else self.world_size // num_experts + ep_size = self.world_size // dp_size pp_size = 1 else: dp_size = self.dp_size @@ -169,13 +143,10 @@ def get_info(self, num_experts: int, use_tp: bool = False) -> Tuple[int, MoePara return num_local_experts, self.parallel_info_dict[ep_size] - def set_kernel_not_use(self): - self.use_kernel_optim = False - def reset_loss(self): self.router_aux_loss, self.router_z_loss = [], [] - def add_loss(self, aux_loss: float = 0., z_loss: float = 0.): + def add_loss(self, aux_loss: float = 0.0, z_loss: float = 0.0): self.router_aux_loss.append(aux_loss) self.router_z_loss.append(z_loss) From 1fb95a594339675f8f2876fc6b623db3ec814569 Mon Sep 17 00:00:00 2001 From: Xuanlei Zhao Date: Fri, 20 Oct 2023 10:07:23 +0800 Subject: [PATCH 16/17] update host --- examples/language/openmoe/benchmark/hostfile.txt | 2 ++ 1 file changed, 2 insertions(+) create mode 100644 examples/language/openmoe/benchmark/hostfile.txt diff --git a/examples/language/openmoe/benchmark/hostfile.txt b/examples/language/openmoe/benchmark/hostfile.txt new file mode 100644 index 000000000000..994b3e2cfc4f --- /dev/null +++ b/examples/language/openmoe/benchmark/hostfile.txt @@ -0,0 +1,2 @@ +host1 +host2 From f46abb141160145a548c6e0ab3d0c5b49393b6ad Mon Sep 17 00:00:00 2001 From: Xuanlei Zhao Date: Fri, 20 Oct 2023 13:14:46 +0800 Subject: [PATCH 17/17] update script --- colossalai/moe/_operation.py | 4 +- colossalai/moe/manager.py | 3 +- colossalai/moe/routers.py | 83 +++++++------------ colossalai/moe/utils.py | 7 +- .../openmoe/benchmark/benchmark_cai.py | 31 +++++-- examples/language/openmoe/infer.py | 49 ++++++++++- .../openmoe/model/modeling_openmoe.py | 30 ++++--- examples/language/openmoe/train.py | 31 +++++-- tests/test_moe/moe_utils.py | 7 +- tests/test_moe/test_moe_checkpoint.py | 14 ++-- tests/test_moe/test_moe_load_balance.py | 12 +-- 11 files changed, 170 insertions(+), 101 deletions(-) diff --git a/colossalai/moe/_operation.py b/colossalai/moe/_operation.py index 9b4988b345d9..14e0935b72e4 100644 --- a/colossalai/moe/_operation.py +++ b/colossalai/moe/_operation.py @@ -227,10 +227,10 @@ def backward(ctx, tokens_grad): return d_expert, d_logits, None, None, None -def moe_cumsum(inputs: Tensor): +def moe_cumsum(inputs: Tensor, use_kernel: bool = False): dim0 = inputs.size(0) flag = (dim0 <= 1024) or (dim0 <= 2048 and dim0 % 2 == 0) or (dim0 % 4 == 0) - if flag and MOE_MANAGER.use_kernel_optim: + if flag and use_kernel: if MOE_KERNEL is None: load_moe() return MOE_KERNEL.cumsum_sub_one(inputs) diff --git a/colossalai/moe/manager.py b/colossalai/moe/manager.py index 1b61965b83fe..f237ea134638 100644 --- a/colossalai/moe/manager.py +++ b/colossalai/moe/manager.py @@ -117,7 +117,8 @@ def get_info(self, num_experts: int, use_tp: bool = False) -> Tuple[int, MoePara assert gt_flag or lt_flag, ("Automatic experts placement dose not not support expert number" " is not a multiple of ep size or vice versa.") dp_size = 1 if gt_flag else self.world_size // num_experts - ep_size = self.world_size // dp_size + ep_size = min(self.world_size // dp_size, self.max_ep_size) + dp_size = self.world_size // ep_size pp_size = 1 else: dp_size = self.dp_size diff --git a/colossalai/moe/routers.py b/colossalai/moe/routers.py index 1ac66f7bb78f..7960a74d4539 100644 --- a/colossalai/moe/routers.py +++ b/colossalai/moe/routers.py @@ -30,7 +30,8 @@ def __init__(self, capacity_factor_eval: float, min_capacity: int, noisy_func: Optional[Callable] = None, - drop_tks: bool = True): + drop_tks: bool = True, + use_kernel: bool = False): super().__init__() self.k_value = k_value self.capacity_factor_train = capacity_factor_train @@ -40,6 +41,7 @@ def __init__(self, self.drop_tks = drop_tks self._aux_loss = None self._z_loss = None + self.use_kernel = use_kernel def get_capacity(self, logits_shape): capacity_factor = self.capacity_factor_train if self.training else self.capacity_factor_eval @@ -49,11 +51,7 @@ def get_capacity(self, logits_shape): assert capacity > 0 return int(capacity) - def set_aux_loss(self, - router_probs: torch.Tensor, - expert_indices: torch.Tensor, - num_experts: int - ) -> None: + def set_aux_loss(self, router_probs: torch.Tensor, expert_indices: torch.Tensor, num_experts: int) -> None: """Computes auxiliary load balancing loss as in Switch Transformer. See Switch Transformer (https://arxiv.org/abs/2101.03961). This function @@ -81,8 +79,7 @@ def set_aux_loss(self, tokens_per_group_and_expert = torch.mean(expert_mask.float(), dim=-2) router_prob_per_group_and_expert = torch.mean(router_probs.float(), dim=-2) - aux_loss = num_experts**2 * torch.mean( - tokens_per_group_and_expert * router_prob_per_group_and_expert) + aux_loss = num_experts**2 * torch.mean(tokens_per_group_and_expert * router_prob_per_group_and_expert) self._aux_loss = aux_loss def set_z_loss(self, router_logits: torch.Tensor): @@ -101,8 +98,7 @@ def set_z_loss(self, router_logits: torch.Tensor): assert router_logits.dim() == 3, "router_logits must be 3D tensor" num_groups, tokens_per_group, _ = router_logits.shape log_z = torch.logsumexp(router_logits, dim=-1) - z_loss = torch.sum(log_z**2, dtype=torch.float32 - ) / (num_groups * tokens_per_group) + z_loss = torch.sum(log_z**2, dtype=torch.float32) / (num_groups * tokens_per_group) self._z_loss = z_loss def pop_router_loss(self) -> torch.Tensor: @@ -113,8 +109,8 @@ def pop_router_loss(self) -> torch.Tensor: class Top1Router(MoeRouter): - """Top1 router that returns the dispatch mask (batch_size * seq_len, num_experts, capacity) - and combine weight (batch_size * seq_len, num_experts, capacity) for routing usage. More detailed + """Top1 router that returns the dispatch mask (batch_size * seq_len, num_experts, capacity) + and combine weight (batch_size * seq_len, num_experts, capacity) for routing usage. More detailed function can be found in the paper about Switch Transformer of Google. Args: @@ -142,22 +138,17 @@ def __init__(self, self.select_policy = select_policy assert select_policy in {"first", "random"} if select_policy == "random": - self.uniform = torch.distributions.uniform.Uniform( - low=torch.tensor(0.0, device=get_current_device()), - high=torch.tensor(1.0, device=get_current_device()) - ).rsample - - def forward(self, - inputs: torch.Tensor, - use_kernel: bool = False, - ep_group: Optional[ProcessGroup] = None - ) -> Tuple: + self.uniform = torch.distributions.uniform.Uniform(low=torch.tensor(0.0, device=get_current_device()), + high=torch.tensor(1.0, + device=get_current_device())).rsample + + def forward(self, inputs: torch.Tensor, use_kernel: bool = False, ep_group: Optional[ProcessGroup] = None) -> Tuple: """ Args: inputs (torch.Tensor): The input tensor of shape (batch_size * seq_len, num_experts). Returns: - 1. use_kernel is False: + 1. use_kernel is False: The combine weight tensor of shape (batch_size * seq_len, num_experts, capacity). The dispatch mask tensor of shape (batch_size * seq_len, num_experts, capacity). 2. use_kernel is True: @@ -188,9 +179,9 @@ def forward(self, rand_mask = mask * self.uniform(mask.shape) _, dispatch_idx = torch.topk(rand_mask, k=capacity, dim=0) mask = mask * torch.zeros_like(mask).scatter_(0, dispatch_idx, 1) - ranks = moe_cumsum(mask) + ranks = moe_cumsum(mask, use_kernel=self.use_kernel) elif self.select_policy == "first": - ranks = moe_cumsum(mask) + ranks = moe_cumsum(mask, use_kernel=self.use_kernel) mask = mask * torch.lt(ranks, capacity) else: raise NotImplementedError("Not support such select policy yet.") @@ -211,8 +202,8 @@ def forward(self, class Top2Router(MoeRouter): - """Top2 router that returns the dispatch mask (batch_size * seq_len, num_experts, capacity) - and combine weight (batch_size * seq_len, num_experts, capacity) for routing usage. More detailed + """Top2 router that returns the dispatch mask (batch_size * seq_len, num_experts, capacity) + and combine weight (batch_size * seq_len, num_experts, capacity) for routing usage. More detailed function can be found in the paper about ViT-MoE. Args: @@ -236,17 +227,13 @@ def __init__(self, noisy_func=noisy_func, drop_tks=drop_tks) - def forward(self, - inputs: torch.Tensor, - use_kernel: bool = False, - ep_group: Optional[ProcessGroup] = None - ) -> Tuple: + def forward(self, inputs: torch.Tensor, use_kernel: bool = False, ep_group: Optional[ProcessGroup] = None) -> Tuple: """ Args: inputs (torch.Tensor): The input tensor of shape (batch_size * seq_len, num_experts). Returns: - 1. use_kernel is False: + 1. use_kernel is False: The combine weight tensor of shape (batch_size * seq_len, num_experts, capacity). The dispatch mask tensor of shape (batch_size * seq_len, num_experts, capacity). 2. use_kernel is True: @@ -280,8 +267,8 @@ def forward(self, dist.all_reduce(max_num, op=dist.ReduceOp.MAX, group=ep_group) capacity = max_num.item() - rank1 = moe_cumsum(mask1) # rank1: [s, e] - rank2 = moe_cumsum(mask2) + rank1 = moe_cumsum(mask1, use_kernel=self.use_kernel) # rank1: [s, e] + rank2 = moe_cumsum(mask2, use_kernel=self.use_kernel) rank2 += torch.sum(mask1, dim=-2, keepdim=True) mask1 *= torch.lt(rank1, capacity) @@ -313,7 +300,7 @@ def forward(self, weight1 = mask1 * probs.type_as(inputs) weight2 = mask2 * probs.type_as(inputs) - cb_weight = torch.zeros(inputs.shape + (capacity, ), device=inputs.device) + cb_weight = torch.zeros(inputs.shape + (capacity,), device=inputs.device) sec_mask = torch.zeros_like(cb_weight, dtype=torch.bool) indices = torch.arange(0, inputs.shape[0], device=inputs.device) cb_weight[indices, top1_idx[indices], rank1[indices]] += weight1[indices, top1_idx[indices]] @@ -348,17 +335,14 @@ def __init__(self, min_capacity: int = 4, noisy_func: Optional[Callable] = None, drop_tks: bool = True): - super().__init__(num_selected_experts, - capacity_factor_train, - capacity_factor_eval, - min_capacity, - noisy_func, + super().__init__(num_selected_experts, capacity_factor_train, capacity_factor_eval, min_capacity, noisy_func, drop_tks) - def forward(self, - router_probs: torch.Tensor, - expert_capacity: int, - ) -> Tuple: + def forward( + self, + router_probs: torch.Tensor, + expert_capacity: int, + ) -> Tuple: """Computes masks for the top-k experts per token. Args: @@ -418,17 +402,12 @@ def forward(self, # The combine array will be used for combining expert outputs, scaled by the # router probabilities. Shape: [num_groups, tokens_per_group, num_experts, # expert_capacity]. - combine_array = torch.einsum( - '...te,...tec->...tec', - router_probs, - dispatch_mask) + combine_array = torch.einsum('...te,...tec->...tec', router_probs, dispatch_mask) return combine_array, dispatch_mask -def get_router_cls(top_k: int, - grouped: bool = False - ) -> MoeRouter: +def get_router_cls(top_k: int, grouped: bool = False) -> MoeRouter: if not grouped: if top_k == 1: return Top1Router diff --git a/colossalai/moe/utils.py b/colossalai/moe/utils.py index e3bc6d3cac9a..0938e4206fda 100644 --- a/colossalai/moe/utils.py +++ b/colossalai/moe/utils.py @@ -1,5 +1,5 @@ import contextlib -from typing import Callable, Dict, List +from typing import Any, Callable, Dict, List import torch import torch.distributed as dist @@ -170,3 +170,8 @@ def sync_moe_model_param(model: nn.Module): for param in param_dict[ep_size]: src_rank = get_dp_group_ranks(param)[0] dist.broadcast(param, src=src_rank, group=get_dp_group(param)) + + +def set_moe_args(config: Any, args: dict): + for k, v in args.items(): + setattr(config, k, v) diff --git a/examples/language/openmoe/benchmark/benchmark_cai.py b/examples/language/openmoe/benchmark/benchmark_cai.py index f07151253fbc..2f6bfa0f89a2 100644 --- a/examples/language/openmoe/benchmark/benchmark_cai.py +++ b/examples/language/openmoe/benchmark/benchmark_cai.py @@ -19,7 +19,7 @@ from colossalai.cluster import DistCoordinator from colossalai.moe.layers import apply_load_balance from colossalai.moe.manager import MOE_MANAGER -from colossalai.moe.utils import skip_init +from colossalai.moe.utils import set_moe_args, skip_init from colossalai.nn.optimizer import HybridAdam from colossalai.utils import get_current_device @@ -158,9 +158,6 @@ def main(): } mgr_dict = { "seed": 42, - "use_kernel_optim": args.use_kernel, - "enable_load_balance": args.load_balance, - "overlap_alltoall": args.overlap_alltoall } if args.plugin == "zero": dp_size = dist.get_world_size() @@ -221,10 +218,28 @@ def main(): # Build OpenMoe model repo_name = "hpcaitech/openmoe-" + args.model_name config = LlamaConfig.from_pretrained(repo_name) - setattr(config, "router_aux_loss_factor", 0.1) - setattr(config, "router_z_loss_factor", 0.1) - setattr(config, "label_smoothing", 0.1) - setattr(config, "z_loss_factor", 0.1) + moe_args = { + "num_experts": config.num_experts, + "moe_layer_interval": config.moe_layer_interval, + "router_topk": 2, + "router_capacity_factor_train": 1.25, + "router_capacity_factor_eval": 2.0, + "router_min_capacity": 4, + "router_noisy_policy": None, + "router_drop_tks": True, + "router_aux_loss_factor": 0.01, + "router_z_loss_factor": 0.01, + "mlp_gated": True, + "label_smoothing": 0.001, + "z_loss_factor": 0.01, + "enable_load_balance": args.load_balance, + "load_balance_tolerance": 0.1, + "load_balance_beam_width": 8, + "load_balance_group_swap_factor": 0.4, + "enable_kernel": args.use_kernel, + "enable_comm_overlap": args.overlap_alltoall, + } + set_moe_args(config, moe_args) with skip_init(): model = OpenMoeForCausalLM(config) coordinator.print_on_master(f"Finish init model with config:\n{config}") diff --git a/examples/language/openmoe/infer.py b/examples/language/openmoe/infer.py index f59772189827..1ad1456b9c56 100644 --- a/examples/language/openmoe/infer.py +++ b/examples/language/openmoe/infer.py @@ -5,6 +5,8 @@ from transformers import T5Tokenizer from transformers.models.llama import LlamaConfig +from colossalai.moe.utils import set_moe_args + def parse_args(): parser = ArgumentParser() @@ -17,9 +19,54 @@ def inference(args): tokenizer = T5Tokenizer.from_pretrained("google/umt5-small") if args.model == "test": config = LlamaConfig.from_pretrained("hpcaitech/openmoe-base") + moe_args = { + "num_experts": config.num_experts, + "moe_layer_interval": config.moe_layer_interval, + "router_topk": 2, + "router_capacity_factor_train": 1.25, + "router_capacity_factor_eval": 2.0, + "router_min_capacity": 4, + "router_noisy_policy": None, + "router_drop_tks": True, + "router_aux_loss_factor": 0.01, + "router_z_loss_factor": 0.01, + "mlp_gated": True, + "label_smoothing": 0.001, + "z_loss_factor": 0.01, + "enable_load_balance": False, + "load_balance_tolerance": 0.1, + "load_balance_beam_width": 8, + "load_balance_group_swap_factor": 0.4, + "enable_kernel": False, + "enable_comm_overlap": False, + } + set_moe_args(config, moe_args) model = OpenMoeForCausalLM(config) else: - model = OpenMoeForCausalLM.from_pretrained(f"hpcaitech/openmoe-{args.model}") + config = LlamaConfig.from_pretrained(f"hpcaitech/openmoe-{args.model}") + moe_args = { + "num_experts": config.num_experts, + "moe_layer_interval": config.moe_layer_interval, + "router_topk": 2, + "router_capacity_factor_train": 1.25, + "router_capacity_factor_eval": 2.0, + "router_min_capacity": 4, + "router_noisy_policy": None, + "router_drop_tks": True, + "router_aux_loss_factor": 0.01, + "router_z_loss_factor": 0.01, + "mlp_gated": True, + "label_smoothing": 0.001, + "z_loss_factor": 0.01, + "enable_load_balance": False, + "load_balance_tolerance": 0.1, + "load_balance_beam_width": 8, + "load_balance_group_swap_factor": 0.4, + "enable_kernel": False, + "enable_comm_overlap": False, + } + set_moe_args(config, moe_args) + model = OpenMoeForCausalLM.from_pretrained(f"hpcaitech/openmoe-{args.model}", config=config) model = model.eval().half() model = model.to(torch.cuda.current_device()) diff --git a/examples/language/openmoe/model/modeling_openmoe.py b/examples/language/openmoe/model/modeling_openmoe.py index f4dba898d478..6f9b668e4597 100644 --- a/examples/language/openmoe/model/modeling_openmoe.py +++ b/examples/language/openmoe/model/modeling_openmoe.py @@ -39,6 +39,7 @@ from colossalai.kernel.triton.llama_act_combine_kernel import HAS_TRITON from colossalai.moe.layers import SparseMLP from colossalai.moe.manager import MOE_MANAGER +from colossalai.moe.utils import get_activation if HAS_TRITON: from colossalai.kernel.triton.llama_act_combine_kernel import LlamaActCombine @@ -166,7 +167,7 @@ def SwiGLU(x): class OpenMoeMLP(nn.Module): - def __init__(self, config): + def __init__(self, config: LlamaConfig): super().__init__() self.pretraining_tp = config.pretraining_tp self.hidden_size = config.hidden_size @@ -174,8 +175,9 @@ def __init__(self, config): self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size * 2, bias=False) self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False) self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False) - self.act_fn = SwiGLU - self.use_kernel = True if MOE_MANAGER.use_kernel_optim else False + self.hidden_act = config.hidden_act + self.act_fn = get_activation(self.hidden_act) + self.use_kernel = config.enable_kernel def forward(self, x): if self.pretraining_tp > 1: @@ -191,7 +193,7 @@ def forward(self, x): down_proj = [F.linear(intermediate_states[i], down_proj_slices[i]) for i in range(self.pretraining_tp)] down_proj = sum(down_proj) else: - if HAS_TRITON and self.use_kernel: + if HAS_TRITON and self.use_kernel and self.hidden_act == "swiglu": down_proj = self.down_proj(LlamaActCombine.apply(self.gate_proj(x), self.up_proj(x))) else: down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x)) @@ -361,16 +363,22 @@ def __init__(self, config: LlamaConfig, moe: bool): self.post_attention_layernorm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps) if self.moe: self.mlp = SparseMLP(num_experts=config.num_experts, - router_top_k=config.topk, - router_capacity_factor_train=config.capacity_factor_train, - router_capacity_factor_eval=config.capacity_factor_eval, - router_min_capacity=config.min_capacity, - router_noisy_policy=config.noisy_policy, - router_drop_tks=config.drop_tks, hidden_size=config.hidden_size, intermediate_size=config.intermediate_size, + router_top_k=config.router_topk, + router_capacity_factor_train=config.router_capacity_factor_train, + router_capacity_factor_eval=config.router_capacity_factor_eval, + router_min_capacity=config.router_min_capacity, + router_noisy_policy=config.router_noisy_policy, + router_drop_tks=config.router_drop_tks, mlp_activation=config.hidden_act, - mlp_gated=config.gated) + mlp_gated=config.mlp_gated, + enable_load_balance=config.enable_load_balance, + load_balance_tolerance=config.load_balance_tolerance, + load_balance_beam_width=config.load_balance_beam_width, + load_balance_group_swap_factor=config.load_balance_group_swap_factor, + enable_kernel=config.enable_kernel, + enable_comm_overlap=config.enable_comm_overlap) self.pre_extra_mlp_layernorm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.extra_mlp = OpenMoeMLP(config) else: diff --git a/examples/language/openmoe/train.py b/examples/language/openmoe/train.py index 6f239104328c..ec9ec21b55dc 100644 --- a/examples/language/openmoe/train.py +++ b/examples/language/openmoe/train.py @@ -19,7 +19,7 @@ from colossalai.logging import disable_existing_loggers, get_dist_logger from colossalai.moe import MoeCheckpintIO from colossalai.moe.manager import MOE_MANAGER -from colossalai.moe.utils import skip_init +from colossalai.moe.utils import set_moe_args, skip_init from colossalai.utils import get_current_device @@ -157,7 +157,6 @@ def main(): MOE_MANAGER.setup( seed=42, parallel="EP", - use_kernel_optim=args.use_kernel if not test_mode else False, ) elif args.plugin == "zero2_ep": plugin = MoeHybridParallelPlugin( @@ -171,7 +170,6 @@ def main(): MOE_MANAGER.setup( seed=42, parallel="EP", - use_kernel_optim=args.use_kernel if not test_mode else False, ) elif args.plugin == "hybrid": plugin = MoeHybridParallelPlugin( @@ -190,7 +188,6 @@ def main(): fixed_dp_size=args.dp_size, fixed_ep_size=args.ep_size, fixed_pp_size=args.pp_size, - use_kernel_optim=args.use_kernel, ) else: raise ValueError(f"Invalid plugin {args.plugin}") @@ -205,10 +202,28 @@ def main(): else: repo_name = "hpcaitech/openmoe-" + args.model_name config = LlamaConfig.from_pretrained(repo_name) - setattr(config, "router_aux_loss_factor", args.router_aux_loss_factor) - setattr(config, "router_z_loss_factor", args.router_z_loss_factor) - setattr(config, "label_smoothing", args.label_smoothing) - setattr(config, "z_loss_factor", args.z_loss_factor) + moe_args = { + "num_experts": config.num_experts, + "moe_layer_interval": config.moe_layer_interval, + "router_topk": 2, + "router_capacity_factor_train": 1.25, + "router_capacity_factor_eval": 2.0, + "router_min_capacity": 4, + "router_noisy_policy": None, + "router_drop_tks": True, + "router_aux_loss_factor": 0.01, + "router_z_loss_factor": 0.01, + "mlp_gated": True, + "label_smoothing": 0.001, + "z_loss_factor": 0.01, + "enable_load_balance": False, + "load_balance_tolerance": 0.1, + "load_balance_beam_width": 8, + "load_balance_group_swap_factor": 0.4, + "enable_kernel": False, + "enable_comm_overlap": False, + } + set_moe_args(config, moe_args) with skip_init(): model = OpenMoeForCausalLM(config) logger.info(f"Finish init model with config:\n{config}", ranks=[0]) diff --git a/tests/test_moe/moe_utils.py b/tests/test_moe/moe_utils.py index 934061ae4417..2e116de2db7d 100644 --- a/tests/test_moe/moe_utils.py +++ b/tests/test_moe/moe_utils.py @@ -14,13 +14,16 @@ class MoeModel(nn.Module): - def __init__(self, checkpoint: bool = False): + def __init__(self, checkpoint: bool = False, enable_load_balance: bool = False): class TestSubModule(CheckpointModule): def __init__(self): super().__init__(checkpoint) - self.moe = SparseMLP(num_experts=8, hidden_size=16, intermediate_size=32) + self.moe = SparseMLP(num_experts=8, + hidden_size=16, + intermediate_size=32, + enable_load_balance=enable_load_balance) self.proj = nn.Linear(16, 4) def _forward(self, x): diff --git a/tests/test_moe/test_moe_checkpoint.py b/tests/test_moe/test_moe_checkpoint.py index 09af499185db..40aae12f016a 100644 --- a/tests/test_moe/test_moe_checkpoint.py +++ b/tests/test_moe/test_moe_checkpoint.py @@ -16,7 +16,6 @@ sys.path.append(os.path.join(os.path.dirname(os.path.dirname(os.path.dirname(__file__))), "examples/language/openmoe")) -# TODO: better way to import them OpenMoeForCausalLM = importlib.import_module("model.modeling_openmoe").OpenMoeForCausalLM OpenMoeForCausalLMPolicy = importlib.import_module("model.openmoe_policy").OpenMoeForCausalLMPolicy @@ -37,22 +36,27 @@ def get_config(): "head_dim": 4, "num_attention_heads": 4, "dropout_rate": 0.0, - "layer_norm_epsilon": 1e-06, "hidden_act": "swiglu", "num_experts": 16, - "topk": 2, "capacity_factor_train": 1.25, "capacity_factor_eval": 2.0, "min_capacity": 4, "noisy_policy": None, "drop_tks": True, - "expert_parallel": None, - "gated": True, "moe_layer_interval": 4, "router_aux_loss_factor": 0.1, "router_z_loss_factor": 0.1, "label_smoothing": 0.1, "z_loss_factor": 0.1, + "mlp_gated": True, + "label_smoothing": 0.001, + "z_loss_factor": 0.01, + "enable_load_balance": False, + "load_balance_tolerance": 0.1, + "load_balance_beam_width": 8, + "load_balance_group_swap_factor": 0.4, + "enable_kernel": False, + "enable_comm_overlap": False, } for key, value in settings.items(): setattr(config, key, value) diff --git a/tests/test_moe/test_moe_load_balance.py b/tests/test_moe/test_moe_load_balance.py index b4eea04bc85a..5126c61ae92f 100644 --- a/tests/test_moe/test_moe_load_balance.py +++ b/tests/test_moe/test_moe_load_balance.py @@ -47,12 +47,8 @@ def run_zero_optim_test(local_rank, world_size, stage=1): MOE_MANAGER.setup( seed=42, parallel="EP", - enable_load_balance=True, - tolerance=0.1, - beam_width=8, - group_swap_factor=0.4, ) - zero_model = MoeModel(checkpoint=True) + zero_model = MoeModel(checkpoint=True, enable_load_balance=True) zero_optimizer = torch.optim.Adam(zero_model.parameters()) plugin = LowLevelZeroPlugin(stage=stage, precision="bf16", verbose=True) booster = Booster(plugin=plugin) @@ -118,12 +114,8 @@ def run_hybrid_zero_optim_test(local_rank, world_size, stage=1): max_ep_size=2, use_ep_inside=False, parallel="EP", - enable_load_balance=True, - tolerance=0.1, - beam_width=8, - group_swap_factor=0.4, ) - zero_model = MoeModel(checkpoint=True) + zero_model = MoeModel(checkpoint=True, enable_load_balance=True) extra_dp_group = MOE_MANAGER.parallel_info_dict[2].dp_group ep_rank = dist.get_rank(MOE_MANAGER.parallel_info_dict[2].ep_group) ep_size = MOE_MANAGER.parallel_info_dict[2].ep_size