From 8306a33703cb7100b0aca7db3ccc4ed8ee2090e1 Mon Sep 17 00:00:00 2001 From: oahzxl Date: Tue, 3 Oct 2023 11:59:12 +0800 Subject: [PATCH 1/6] overlap comm --- colossalai/moe/_operation.py | 64 ++++++++++++--- colossalai/moe/experts.py | 72 +++++++++++------ colossalai/moe/layers.py | 146 +++++++++++++++++++++++++++-------- 3 files changed, 217 insertions(+), 65 deletions(-) diff --git a/colossalai/moe/_operation.py b/colossalai/moe/_operation.py index f483f2d85989..1300dc6dafb1 100644 --- a/colossalai/moe/_operation.py +++ b/colossalai/moe/_operation.py @@ -8,6 +8,8 @@ from colossalai.moe.manager import MOE_MANAGER MOE_KERNEL = None +WROLD_HANDLE_ALLGATHER = None +WROLD_HANDLE_REDUCESCATTER = None def load_moe(): @@ -20,9 +22,15 @@ def load_moe(): class AllGather(torch.autograd.Function): @staticmethod - def forward(ctx: Any, inputs: Tensor, group: Optional[ProcessGroup] = None) -> Tensor: + def forward( + ctx: Any, + inputs: Tensor, + group: Optional[ProcessGroup] = None, + overlap: bool = False, + ) -> Tensor: if ctx is not None: ctx.comm_grp = group + ctx.overlap = overlap comm_size = dist.get_world_size(group) if comm_size == 1: @@ -31,20 +39,41 @@ def forward(ctx: Any, inputs: Tensor, group: Optional[ProcessGroup] = None) -> T buffer_shape = (comm_size,) + inputs.shape outputs = torch.empty(buffer_shape, dtype=inputs.dtype, device=inputs.device) buffer_list = list(torch.chunk(outputs, comm_size, dim=0)) - dist.all_gather(buffer_list, inputs, group=group) - return outputs + if not overlap: + dist.all_gather(buffer_list, inputs, group=group) + return outputs, None + else: + handle = dist.all_gather(buffer_list, inputs, group=group, async_op=True) + if ctx is None and overlap: + global WROLD_HANDLE_ALLGATHER + WROLD_HANDLE_ALLGATHER = handle + return outputs, handle @staticmethod - def backward(ctx: Any, grad_outputs: Tensor) -> Tuple[Tensor, None]: - return ReduceScatter.forward(None, grad_outputs, ctx.comm_grp), None + def backward(ctx: Any, *grad_outputs) -> Tuple[Tensor, None]: + global WROLD_HANDLE_REDUCESCATTER + if WROLD_HANDLE_REDUCESCATTER is not None: + WROLD_HANDLE_REDUCESCATTER.wait() + WROLD_HANDLE_REDUCESCATTER = None + return ( + ReduceScatter.forward(None, grad_outputs[0], ctx.comm_grp, ctx.overlap)[0], + None, + None, + ) class ReduceScatter(torch.autograd.Function): @staticmethod - def forward(ctx: Any, inputs: Tensor, group: Optional[ProcessGroup] = None) -> Tensor: + def forward( + ctx: Any, + inputs: Tensor, + group: Optional[ProcessGroup] = None, + overlap: bool = False, + ) -> Tensor: if ctx is not None: ctx.comm_grp = group + ctx.overlap = overlap comm_size = dist.get_world_size(group) if comm_size == 1: @@ -56,12 +85,27 @@ def forward(ctx: Any, inputs: Tensor, group: Optional[ProcessGroup] = None) -> T output_shape = inputs.shape[1:] outputs = torch.empty(output_shape, dtype=inputs.dtype, device=inputs.device) buffer_list = list(torch.chunk(inputs, comm_size, dim=0)) - dist.reduce_scatter(outputs, buffer_list, group=group) - return outputs + if not overlap: + dist.reduce_scatter(outputs, buffer_list, group=group) + return outputs, None + else: + handle = dist.reduce_scatter(outputs, buffer_list, group=group, async_op=True) + if ctx is None and overlap: + global WROLD_HANDLE_REDUCESCATTER + WROLD_HANDLE_REDUCESCATTER = handle + return outputs, handle @staticmethod - def backward(ctx: Any, grad_outputs: Tensor) -> Tuple[Tensor, None]: - return AllGather.forward(None, grad_outputs, ctx.comm_grp), None + def backward(ctx: Any, *grad_outputs) -> Tuple[Tensor, None]: + global WROLD_HANDLE_ALLGATHER + if WROLD_HANDLE_ALLGATHER is not None: + WROLD_HANDLE_ALLGATHER.wait() + WROLD_HANDLE_ALLGATHER = None + return ( + AllGather.forward(None, grad_outputs[0], ctx.comm_grp, ctx.overlap)[0], + None, + None, + ) class AllToAll(torch.autograd.Function): diff --git a/colossalai/moe/experts.py b/colossalai/moe/experts.py index 4535d8ab9a85..30476d20a5f7 100644 --- a/colossalai/moe/experts.py +++ b/colossalai/moe/experts.py @@ -1,6 +1,6 @@ import math from contextlib import nullcontext -from typing import Callable, Optional +from typing import Callable, Optional, Tuple import torch import torch.nn as nn @@ -52,8 +52,9 @@ def __init__( num_experts, use_tp=True if expert_parallel == "TP" else False) # get settings for different parallel if expert_parallel == "TP": - assert intermediate_size % MOE_MANAGER.max_ep_size == 0, \ - "intermediate_size should be divide by maximum expert parallel size" + assert ( + intermediate_size % + MOE_MANAGER.max_ep_size == 0), "intermediate_size should be divide by maximum expert parallel size" intermediate_size = intermediate_size // MOE_MANAGER.max_ep_size num_experts = self.num_total_experts else: @@ -91,7 +92,7 @@ def __init__( for param in self.parameters(): set_moe_tensor_info(param, self.moe_info) - def forward(self, x: torch.Tensor) -> torch.Tensor: + def forward(self, x: torch.Tensor, param_slice: Tuple[slice] = (slice(None),)) -> torch.Tensor: """ Args: x (torch.Tensor): The input tensor of shape (num_groups, num_experts, capacity, hidden_size) @@ -110,14 +111,17 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: if self.gated: if HAS_TRITON and self.act_name == "swiglu": - x = LlamaActCombine.apply(torch.bmm(x, self.wi_gate), torch.bmm(x, self.wi_up)) + x = LlamaActCombine.apply( + torch.bmm(x, self.wi_gate[param_slice]), + torch.bmm(x, self.wi_up[param_slice]), + ) else: - x = self.act(torch.bmm(x, self.wi_gate)) * torch.bmm(x, self.wi_up) + x = self.act(torch.bmm(x, self.wi_gate[param_slice])) * torch.bmm(x, self.wi_up[param_slice]) else: - x = torch.bmm(x, self.wi) + x = torch.bmm(x, self.wi[param_slice]) x = self.act(x) x = self.drop(x) - x = torch.bmm(x, self.wo) + x = torch.bmm(x, self.wo[param_slice]) x = x.reshape(inshape) x = x.transpose(0, 1).contiguous() @@ -130,14 +134,24 @@ class EPMLPExperts(BaseMLPExperts): Use expert parallelism to split each expert evenly, which can deploy experts in """ - def __init__(self, - num_experts: int, - hidden_size: int, - intermediate_size: int, - activation=None, - drop_rate: float = 0, - gated: bool = False): - super().__init__(num_experts, hidden_size, intermediate_size, "EP", activation, drop_rate, gated) + def __init__( + self, + num_experts: int, + hidden_size: int, + intermediate_size: int, + activation=None, + drop_rate: float = 0, + gated: bool = False, + ): + super().__init__( + num_experts, + hidden_size, + intermediate_size, + "EP", + activation, + drop_rate, + gated, + ) class TPMLPExperts(BaseMLPExperts): @@ -146,14 +160,24 @@ class TPMLPExperts(BaseMLPExperts): maximum expert parallel size can't be divide by the number of experts. """ - def __init__(self, - num_experts: int, - hidden_size: int, - intermediate_size: int, - activation: str = None, - drop_rate: float = 0, - gated: bool = False): - super().__init__(num_experts, hidden_size, intermediate_size, "TP", activation, drop_rate, gated) + def __init__( + self, + num_experts: int, + hidden_size: int, + intermediate_size: int, + activation: str = None, + drop_rate: float = 0, + gated: bool = False, + ): + super().__init__( + num_experts, + hidden_size, + intermediate_size, + "TP", + activation, + drop_rate, + gated, + ) def get_expert_class(name: str) -> BaseMLPExperts: diff --git a/colossalai/moe/layers.py b/colossalai/moe/layers.py index a78bfe0a3d74..c2cf627aceae 100644 --- a/colossalai/moe/layers.py +++ b/colossalai/moe/layers.py @@ -42,42 +42,52 @@ class SparseMLP(nn.Module): https://arxiv.org/abs/2201.05596 """ - def __init__(self, - num_experts: int, - top_k: int = 1, - capacity_factor_train: float = 1.25, - capacity_factor_eval: float = 2.0, - min_capacity: int = 4, - noisy_policy: Optional[str] = None, - drop_tks: bool = True, - expert_parallel: str = "EP", - hidden_size: int = 2048, - intermediate_size: int = 2048, - activation: str = None, - gated: bool = False): + def __init__( + self, + num_experts: int, + top_k: int = 1, + capacity_factor_train: float = 1.25, + capacity_factor_eval: float = 2.0, + min_capacity: int = 4, + noisy_policy: Optional[str] = None, + drop_tks: bool = True, + expert_parallel: str = "EP", + hidden_size: int = 2048, + intermediate_size: int = 2048, + activation: str = None, + gated: bool = False, + ): super().__init__() self.hidden_size = hidden_size self.num_experts = num_experts self.use_kernel = MOE_MANAGER.use_kernel_optim self.expert_parallel = expert_parallel - assert expert_parallel in ["EP", "TP", None], f"Unsupported expert parallel type {expert_parallel}" + assert expert_parallel in [ + "EP", + "TP", + None, + ], f"Unsupported expert parallel type {expert_parallel}" # moe router noisy_func = get_noise_generator(noisy_policy, num_experts) router_cls = get_router_cls(top_k) - self.router: MoeRouter = router_cls(capacity_factor_train=capacity_factor_train, - capacity_factor_eval=capacity_factor_eval, - min_capacity=min_capacity, - noisy_func=noisy_func, - drop_tks=drop_tks) + self.router: MoeRouter = router_cls( + capacity_factor_train=capacity_factor_train, + capacity_factor_eval=capacity_factor_eval, + min_capacity=min_capacity, + noisy_func=noisy_func, + drop_tks=drop_tks, + ) # moe experts expert_cls = get_expert_class(expert_parallel) - self.experts: BaseMLPExperts = expert_cls(num_experts=num_experts, - hidden_size=hidden_size, - intermediate_size=intermediate_size, - activation=activation, - gated=gated) + self.experts: BaseMLPExperts = expert_cls( + num_experts=num_experts, + hidden_size=hidden_size, + intermediate_size=intermediate_size, + activation=activation, + gated=gated, + ) if expert_parallel is not None: self.ep_group = get_ep_group(self.experts) self.ep_size = get_ep_size(self.experts) @@ -88,9 +98,7 @@ def __init__(self, self.gate_weight = torch.nn.Parameter(torch.empty(num_experts, self.hidden_size)) nn.init.trunc_normal_(self.gate_weight, std=math.sqrt(0.1 / self.hidden_size)) - def forward(self, - inputs: torch.Tensor) \ - -> Tuple[torch.Tensor, torch.Tensor]: + def forward(self, inputs: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: """ Args: inputs (torch.Tensor): The input tensor of shape (batch_size, seq_len, hidden_size) @@ -146,6 +154,15 @@ def _local_process(self, expert_in: torch.Tensor) -> torch.Tensor: return expert_out def _ep_process(self, dispatch_data: torch.Tensor) -> torch.Tensor: + """ + Expert Parallel + + Args: + dispatch_data (torch.Tensor): (num_experts, capacity, hidden_size) + + Returns: + torch.Tensor: (num_experts, capacity, hidden_size) + """ expert_input = AllToAll.apply(dispatch_data, self.ep_group) input_shape = expert_input.shape expert_input = expert_input.reshape(self.ep_size, self.num_local_experts, -1, self.hidden_size) @@ -155,7 +172,74 @@ def _ep_process(self, dispatch_data: torch.Tensor) -> torch.Tensor: return expert_output def _tp_process(self, dispatch_data: torch.Tensor) -> torch.Tensor: - expert_in = AllGather.apply(dispatch_data, self.ep_group) - expert_out = self.experts(expert_in) - expert_out = ReduceScatter.apply(expert_out, self.ep_group) - return expert_out + """ + TP with overlap. + + origin: + | C | + | A | | R | + + overlap: + | C1 || C2 || C3 || C4 | + | A1 || A2 | | R1 | A3 || R2 | A4 || R3 | | R4 | + + C is computation, A is all gather, R is reduce scatter. + + Args: + dispatch_data (torch.Tensor): (num_experts, capacity, hidden_size) + + Returns: + torch.Tensor: (num_experts, capacity, hidden_size) + """ + chunk_num = 4 + chunk_size = dispatch_data.shape[0] // chunk_num + out = torch.empty_like(dispatch_data) + in_data = None + in_handle = None + out_data = None + out_handle = None + + # backward compatibility for async op + torch.cuda.synchronize() + + def get_chunk_slice(idx: int, gap: int) -> Tuple[slice]: + return (slice(idx * gap, (idx + 1) * gap),) + + for i in range(chunk_num): + cur_chunk_slice = get_chunk_slice(i, chunk_size) + + # if first, all gather + if i == 0: + d = dispatch_data[cur_chunk_slice].contiguous() + expert_in, _ = AllGather.apply(d, self.ep_group) + else: + expert_in = in_data + + # async communication while compute + if i != 0: + # reduce scatter last out + out_data, out_handle = ReduceScatter.apply(out_data, self.ep_group, True) + if i != chunk_num - 1: + # all gather next in + next_d = dispatch_data[get_chunk_slice(i + 1, chunk_size)].contiguous() + in_data, in_handle = AllGather.apply(next_d, self.ep_group, True) + + # compute + expert_out = self.experts(expert_in, cur_chunk_slice) + + # sync handle + if i != 0: + out_handle.wait() + out[get_chunk_slice(i - 1, chunk_size)] = out_data + if i != chunk_num - 1: + in_handle.wait() + out_data = expert_out + + # store out for last loop + if i == chunk_num - 1: + out_data, _ = ReduceScatter.apply(out_data, self.ep_group) + out[cur_chunk_slice] = out_data + + # sync for async op + torch.cuda.synchronize() + return out From a1dd69375f7d2b98e5dd10a9bdb98e79c0f7a8c2 Mon Sep 17 00:00:00 2001 From: oahzxl Date: Tue, 3 Oct 2023 13:13:17 +0800 Subject: [PATCH 2/6] fix typo --- colossalai/moe/_operation.py | 28 ++++++++++++++-------------- 1 file changed, 14 insertions(+), 14 deletions(-) diff --git a/colossalai/moe/_operation.py b/colossalai/moe/_operation.py index 1300dc6dafb1..7594ae6c0a19 100644 --- a/colossalai/moe/_operation.py +++ b/colossalai/moe/_operation.py @@ -8,8 +8,8 @@ from colossalai.moe.manager import MOE_MANAGER MOE_KERNEL = None -WROLD_HANDLE_ALLGATHER = None -WROLD_HANDLE_REDUCESCATTER = None +WORLD_HANDLE_ALLGATHER = None +WORLD_HANDLE_REDUCESCATTER = None def load_moe(): @@ -45,16 +45,16 @@ def forward( else: handle = dist.all_gather(buffer_list, inputs, group=group, async_op=True) if ctx is None and overlap: - global WROLD_HANDLE_ALLGATHER - WROLD_HANDLE_ALLGATHER = handle + global WORLD_HANDLE_ALLGATHER + WORLD_HANDLE_ALLGATHER = handle return outputs, handle @staticmethod def backward(ctx: Any, *grad_outputs) -> Tuple[Tensor, None]: - global WROLD_HANDLE_REDUCESCATTER - if WROLD_HANDLE_REDUCESCATTER is not None: - WROLD_HANDLE_REDUCESCATTER.wait() - WROLD_HANDLE_REDUCESCATTER = None + global WORLD_HANDLE_REDUCESCATTER + if WORLD_HANDLE_REDUCESCATTER is not None: + WORLD_HANDLE_REDUCESCATTER.wait() + WORLD_HANDLE_REDUCESCATTER = None return ( ReduceScatter.forward(None, grad_outputs[0], ctx.comm_grp, ctx.overlap)[0], None, @@ -91,16 +91,16 @@ def forward( else: handle = dist.reduce_scatter(outputs, buffer_list, group=group, async_op=True) if ctx is None and overlap: - global WROLD_HANDLE_REDUCESCATTER - WROLD_HANDLE_REDUCESCATTER = handle + global WORLD_HANDLE_REDUCESCATTER + WORLD_HANDLE_REDUCESCATTER = handle return outputs, handle @staticmethod def backward(ctx: Any, *grad_outputs) -> Tuple[Tensor, None]: - global WROLD_HANDLE_ALLGATHER - if WROLD_HANDLE_ALLGATHER is not None: - WROLD_HANDLE_ALLGATHER.wait() - WROLD_HANDLE_ALLGATHER = None + global WORLD_HANDLE_ALLGATHER + if WORLD_HANDLE_ALLGATHER is not None: + WORLD_HANDLE_ALLGATHER.wait() + WORLD_HANDLE_ALLGATHER = None return ( AllGather.forward(None, grad_outputs[0], ctx.comm_grp, ctx.overlap)[0], None, From a2be80faa18663b5a8c297361bf46c108aa1492a Mon Sep 17 00:00:00 2001 From: oahzxl Date: Tue, 3 Oct 2023 15:09:07 +0800 Subject: [PATCH 3/6] update bench script --- .../openmoe/benchmark/benchmark_cai.py | 36 ++++++++++++++++++- 1 file changed, 35 insertions(+), 1 deletion(-) diff --git a/examples/language/openmoe/benchmark/benchmark_cai.py b/examples/language/openmoe/benchmark/benchmark_cai.py index d7dbd58ed0ca..8d328b6e23a3 100644 --- a/examples/language/openmoe/benchmark/benchmark_cai.py +++ b/examples/language/openmoe/benchmark/benchmark_cai.py @@ -1,7 +1,10 @@ +import os + import datasets import torch import torch.distributed as dist import transformers +from huggingface_hub import snapshot_download from model.modeling_openmoe import OpenMoeForCausalLM from model.openmoe_policy import OpenMoeForCausalLMPolicy from torch.utils.data import Dataset @@ -18,6 +21,7 @@ from colossalai.cluster import DistCoordinator from colossalai.logging import disable_existing_loggers, get_dist_logger from colossalai.moe.manager import MOE_MANAGER +from colossalai.moe.utils import skip_init from colossalai.utils import get_current_device @@ -25,6 +29,19 @@ def move_to_cuda(batch, device): return {k: v.to(device) for k, v in batch.items()} +def load_ckpt(repo_name: str, model: OpenMoeForCausalLM, booster: Booster): + ckpt_path = snapshot_download(repo_name) + # single ckpt + if os.path.exists(os.path.join(ckpt_path, "pytorch_model.bin")): + ckpt_path = os.path.join(ckpt_path, "pytorch_model.bin") + # shard ckpt + elif os.path.exists(os.path.join(ckpt_path, "pytorch_model.bin.index.json")): + ckpt_path = os.path.join(ckpt_path, "pytorch_model.bin.index.json") + else: + raise ValueError(f"Invalid checkpoint path: {ckpt_path}") + booster.load_model(model, ckpt_path) + + class RandomDataset(Dataset): def __init__(self, num_samples: int = 1000, max_length: int = 2048, vocab_size: int = 32000): @@ -135,6 +152,21 @@ def main(): parallel="EP", use_kernel_optim=args.use_kernel, ) + elif args.plugin == "zero2_tp": + dp_size = dist.get_world_size() + plugin = MoeHybridParallelPlugin( + tp_size=1, + pp_size=1, + zero_stage=2, + custom_policy=OpenMoeForCausalLMPolicy(), + enable_fused_normalization=args.use_kernel, + enable_jit_fused=args.use_kernel, + ) + MOE_MANAGER.setup( + seed=42, + parallel="TP", + use_kernel_optim=args.use_kernel, + ) elif args.plugin == "hybrid": dp_size = dist.get_world_size() // args.pp_size plugin = MoeHybridParallelPlugin( @@ -166,7 +198,8 @@ def main(): setattr(config, "router_z_loss_factor", 0.1) setattr(config, "label_smoothing", 0.1) setattr(config, "z_loss_factor", 0.1) - model = OpenMoeForCausalLM(config) + with skip_init(): + model = OpenMoeForCausalLM(config) logger.info(f"Finish init model with config:\n{config}", ranks=[0]) # Enable gradient checkpointing @@ -193,6 +226,7 @@ def main(): # Set booster booster = Booster(plugin=plugin, **booster_kwargs) model, optimizer, _, dataloader, _ = booster.boost(model=model, optimizer=optimizer, dataloader=dataloader) + load_ckpt(repo_name, model, booster) use_pipeline = (isinstance(booster.plugin, MoeHybridParallelPlugin) and booster.plugin.pp_size > 1) is_pp_last_stage = use_pipeline and booster.plugin.stage_manager.is_last_stage() logger.info(f"Finish init booster", ranks=[0]) From 28193c4683ea21dde962597887af69fad7d3f81a Mon Sep 17 00:00:00 2001 From: oahzxl Date: Tue, 3 Oct 2023 15:31:16 +0800 Subject: [PATCH 4/6] add option --- examples/language/openmoe/benchmark/benchmark_cai.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/language/openmoe/benchmark/benchmark_cai.py b/examples/language/openmoe/benchmark/benchmark_cai.py index 8d328b6e23a3..ee14ace40c96 100644 --- a/examples/language/openmoe/benchmark/benchmark_cai.py +++ b/examples/language/openmoe/benchmark/benchmark_cai.py @@ -89,7 +89,7 @@ def parse_args(): type=str, default="hybrid", help="parallel plugin", - choices=["zero2", "zero2_ep", "hybrid"], + choices=["zero2", "zero2_ep", "hybrid", "zero2_tp"], ) # hybrid plugin parser.add_argument("--pp_size", type=int, default=2, help="pp size") From 702fb70a834569e296c38066d1ece729ddca3b11 Mon Sep 17 00:00:00 2001 From: oahzxl Date: Tue, 3 Oct 2023 15:32:46 +0800 Subject: [PATCH 5/6] update script --- .../openmoe/benchmark/benchmark_cai.sh | 41 ++++++++++++------- 1 file changed, 26 insertions(+), 15 deletions(-) diff --git a/examples/language/openmoe/benchmark/benchmark_cai.sh b/examples/language/openmoe/benchmark/benchmark_cai.sh index 620bd4901ccd..5db65a216461 100755 --- a/examples/language/openmoe/benchmark/benchmark_cai.sh +++ b/examples/language/openmoe/benchmark/benchmark_cai.sh @@ -5,8 +5,8 @@ set -xue NUM_GPU=8 MODEL="8b" SEQ_LENGTH=2048 -WARMUP=5 -ACTIVE=5 +WARMUP=8 +ACTIVE=4 # HACK: make model importable example_dir=$(dirname $(realpath $(dirname $0))) @@ -16,40 +16,51 @@ else export PYTHONPATH=$example_dir:$PYTHONPATH fi -# hybrid +# zero2 torchrun --standalone --nproc_per_node $NUM_GPU \ $example_dir/benchmark/benchmark_cai.py \ --model_name $MODEL \ - --batch_size 512 \ + --batch_size 4 \ --seq_length $SEQ_LENGTH \ --warmup $WARMUP \ --active $ACTIVE \ - --use_kernel \ - --plugin hybrid \ - --pp_size 2 \ - --dp_size 1 \ - --ep_size 4 \ - --zero_stage 1 \ - --microbatch_size 32 + --plugin zero2 \ + --use_kernel -# zero2 +# zero2_tp torchrun --standalone --nproc_per_node $NUM_GPU \ $example_dir/benchmark/benchmark_cai.py \ --model_name $MODEL \ - --batch_size 8 \ + --batch_size 12 \ --seq_length $SEQ_LENGTH \ --warmup $WARMUP \ --active $ACTIVE \ - --plugin zero2 \ + --plugin zero2_tp \ --use_kernel # zero2_ep torchrun --standalone --nproc_per_node $NUM_GPU \ $example_dir/benchmark/benchmark_cai.py \ --model_name $MODEL \ - --batch_size 16 \ + --batch_size 12 \ --seq_length $SEQ_LENGTH \ --warmup $WARMUP \ --active $ACTIVE \ --plugin zero2_ep \ --use_kernel + +# hybrid +torchrun --standalone --nproc_per_node $NUM_GPU \ + $example_dir/benchmark/benchmark_cai.py \ + --model_name $MODEL \ + --batch_size 512 \ + --seq_length $SEQ_LENGTH \ + --warmup $WARMUP \ + --active $ACTIVE \ + --use_kernel \ + --plugin hybrid \ + --pp_size 2 \ + --dp_size 1 \ + --ep_size 4 \ + --zero_stage 1 \ + --microbatch_size 32 From a46c7d039b4d88577267f46ea4f1a478b5923441 Mon Sep 17 00:00:00 2001 From: oahzxl Date: Tue, 3 Oct 2023 18:15:49 +0800 Subject: [PATCH 6/6] update bench --- colossalai/moe/experts.py | 8 +-- colossalai/moe/utils.py | 56 ++++++++++--------- .../openmoe/benchmark/benchmark_cai.py | 21 ++----- 3 files changed, 39 insertions(+), 46 deletions(-) diff --git a/colossalai/moe/experts.py b/colossalai/moe/experts.py index 30476d20a5f7..e05ea59b3d28 100644 --- a/colossalai/moe/experts.py +++ b/colossalai/moe/experts.py @@ -78,11 +78,11 @@ def __init__( seed_ctx = nullcontext() with seed_ctx: if gated: - nn.init.trunc_normal_(self.wi_gate, std=math.sqrt(0.1 / hidden_size)) - nn.init.trunc_normal_(self.wi_up, std=math.sqrt(0.1 / hidden_size)) + torch.nn.init.trunc_normal_(self.wi_gate, std=math.sqrt(0.1 / hidden_size)) + torch.nn.init.trunc_normal_(self.wi_up, std=math.sqrt(0.1 / hidden_size)) else: - nn.init.trunc_normal_(self.wi, std=math.sqrt(0.1 / hidden_size)) - nn.init.trunc_normal_(self.wo, std=math.sqrt(0.1 / intermediate_size)) + torch.nn.init.trunc_normal_(self.wi, std=math.sqrt(0.1 / hidden_size)) + torch.nn.init.trunc_normal_(self.wo, std=math.sqrt(0.1 / intermediate_size)) self.act_name = activation self.act = get_activation(activation) diff --git a/colossalai/moe/utils.py b/colossalai/moe/utils.py index 58c1665a4d63..e3bc6d3cac9a 100644 --- a/colossalai/moe/utils.py +++ b/colossalai/moe/utils.py @@ -28,9 +28,10 @@ class NormalNoiseGenerator: """ def __init__(self, num_experts: int): - self.normal = torch.distributions.normal.Normal(loc=torch.tensor(0.0, device=get_current_device()), - scale=torch.tensor(1.0 / num_experts**2, - device=get_current_device())).rsample + self.normal = torch.distributions.normal.Normal( + loc=torch.tensor(0.0, device=get_current_device()), + scale=torch.tensor(1.0 / num_experts**2, device=get_current_device()), + ).rsample def __call__(self, inputs: torch.Tensor): noisy = self.normal(inputs.shape) @@ -49,9 +50,10 @@ class UniformNoiseGenerator: """ def __init__(self, eps: float = 1e-2): - self.uniform = torch.distributions.uniform.Uniform(low=torch.tensor(1.0 - eps, device=get_current_device()), - high=torch.tensor(1.0 + eps, - device=get_current_device())).rsample + self.uniform = torch.distributions.uniform.Uniform( + low=torch.tensor(1.0 - eps, device=get_current_device()), + high=torch.tensor(1.0 + eps, device=get_current_device()), + ).rsample def __call__(self, inputs: torch.Tensor): noisy = self.uniform(inputs.shape) @@ -65,9 +67,9 @@ def autocast_softmax(logit: torch.Tensor, dim: int): def get_noise_generator(noise_type: str, num_experts: int) -> Callable: if noise_type is None: return None - elif noise_type == 'Jitter': + elif noise_type == "Jitter": noisy_func = UniformNoiseGenerator() - elif noise_type == 'Gaussian': + elif noise_type == "Gaussian": noisy_func = NormalNoiseGenerator(num_experts) else: raise NotImplementedError("Unsupported input noisy policy") @@ -75,11 +77,11 @@ def get_noise_generator(noise_type: str, num_experts: int) -> Callable: def get_activation(act: str) -> Callable: - if act is None or act == 'relu': + if act is None or act == "relu": return torch.nn.ReLU() - elif act == 'gelu': + elif act == "gelu": return torch.nn.GELU() - elif act == 'swiglu': + elif act == "swiglu": return SwiGLU else: raise NotImplementedError("Unsupported activation function") @@ -103,24 +105,28 @@ def skip_init(): skip param random init """ - def _skip_init(x, *args, **kwargs): - return x + def _skip_init(*args, **kwargs): + pass - # __enter__ - fn_saved = [] - init_fn_list = [ - torch.nn.init.constant_, torch.nn.init.uniform_, torch.nn.init.normal_, torch.nn.init.xavier_uniform_, - torch.nn.init.xavier_normal_, torch.nn.init.kaiming_uniform_, torch.nn.init.kaiming_normal_ - ] - for fn in init_fn_list: - fn_saved.append(fn) - fn = _skip_init + init_func = { + "constant_": torch.nn.init.constant_, + "uniform_": torch.nn.init.uniform_, + "normal_": torch.nn.init.normal_, + "kaiming_uniform_": torch.nn.init.kaiming_uniform_, + "kaiming_normal_": torch.nn.init.kaiming_normal_, + "xavier_normal_": torch.nn.init.xavier_normal_, + "xavier_uniform_": torch.nn.init.xavier_uniform_, + "trunc_normal_": torch.nn.init.trunc_normal_, + } + + for method_name, original_init in init_func.items(): + setattr(torch.nn.init, method_name, _skip_init) yield - # __exit__ - for fn, fn_saved in zip(init_fn_list, fn_saved): - fn = fn_saved + for method_name, original_init in init_func.items(): + setattr(torch.nn.init, method_name, original_init) + return diff --git a/examples/language/openmoe/benchmark/benchmark_cai.py b/examples/language/openmoe/benchmark/benchmark_cai.py index ee14ace40c96..5ff0843caaea 100644 --- a/examples/language/openmoe/benchmark/benchmark_cai.py +++ b/examples/language/openmoe/benchmark/benchmark_cai.py @@ -1,9 +1,7 @@ import os -import datasets import torch import torch.distributed as dist -import transformers from huggingface_hub import snapshot_download from model.modeling_openmoe import OpenMoeForCausalLM from model.openmoe_policy import OpenMoeForCausalLMPolicy @@ -19,7 +17,6 @@ from colossalai.booster.plugin import LowLevelZeroPlugin from colossalai.booster.plugin.moe_hybrid_parallel_plugin import MoeHybridParallelPlugin from colossalai.cluster import DistCoordinator -from colossalai.logging import disable_existing_loggers, get_dist_logger from colossalai.moe.manager import MOE_MANAGER from colossalai.moe.utils import skip_init from colossalai.utils import get_current_device @@ -117,16 +114,6 @@ def main(): colossalai.launch_from_torch(config={}, seed=args.seed) coordinator = DistCoordinator() - # Manage loggers - disable_existing_loggers() - logger = get_dist_logger() - if coordinator.is_master(): - datasets.utils.logging.set_verbosity_warning() - transformers.utils.logging.set_verbosity_info() - else: - datasets.utils.logging.set_verbosity_error() - transformers.utils.logging.set_verbosity_error() - # Set plugin booster_kwargs = {} if args.plugin == "zero2": @@ -189,7 +176,7 @@ def main(): ) else: raise ValueError(f"Invalid plugin {args.plugin}") - logger.info(f"Set plugin as {plugin}", ranks=[0]) + coordinator.print_on_master(f"Set plugin as {plugin}") # Build OpenMoe model repo_name = "hpcaitech/openmoe-" + args.model_name @@ -200,7 +187,7 @@ def main(): setattr(config, "z_loss_factor", 0.1) with skip_init(): model = OpenMoeForCausalLM(config) - logger.info(f"Finish init model with config:\n{config}", ranks=[0]) + coordinator.print_on_master(f"Finish init model with config:\n{config}") # Enable gradient checkpointing model.gradient_checkpointing_enable() @@ -229,10 +216,10 @@ def main(): load_ckpt(repo_name, model, booster) use_pipeline = (isinstance(booster.plugin, MoeHybridParallelPlugin) and booster.plugin.pp_size > 1) is_pp_last_stage = use_pipeline and booster.plugin.stage_manager.is_last_stage() - logger.info(f"Finish init booster", ranks=[0]) + coordinator.print_on_master(f"Finish init booster") # Start finetuning - logger.info(f"Start finetuning", ranks=[0]) + coordinator.print_on_master(f"Start finetuning") model.train() train_dataloader_iter = iter(dataloader) total_len = len(train_dataloader_iter) - 1