From 3a9a227db87bc3cdd3051b97b2fd8956b382ebf3 Mon Sep 17 00:00:00 2001 From: Xuanlei Zhao Date: Mon, 16 Oct 2023 18:15:08 +0800 Subject: [PATCH] update bench --- colossalai/moe/_operation.py | 40 +++++++++----- .../openmoe/benchmark/benchmark_cai.py | 54 ++++++++++++++++--- 2 files changed, 75 insertions(+), 19 deletions(-) diff --git a/colossalai/moe/_operation.py b/colossalai/moe/_operation.py index 7594ae6c0a19..a932b96597b6 100644 --- a/colossalai/moe/_operation.py +++ b/colossalai/moe/_operation.py @@ -3,6 +3,7 @@ import torch import torch.distributed as dist from torch import Tensor +from torch.cuda.amp import custom_bwd, custom_fwd from torch.distributed import ProcessGroup from colossalai.moe.manager import MOE_MANAGER @@ -133,32 +134,43 @@ def backward(ctx: Any, *grad_outputs: Tensor) -> Tuple[Tensor, None]: class MoeDispatch(torch.autograd.Function): @staticmethod + @custom_fwd def forward(ctx, tokens, mask, dest_idx, ec): s = tokens.size(0) h = tokens.size(1) + dtype = tokens.dtype if MOE_KERNEL is None: load_moe() - + if tokens.dtype != torch.float32: + tokens = tokens.to(torch.float32) expert_input = MOE_KERNEL.dispatch_forward(s, ec, h, tokens, mask, dest_idx) - + if expert_input.dtype != dtype: + expert_input = expert_input.to(dtype) ctx.save_for_backward(mask, dest_idx) ctx.s = s ctx.h = h ctx.ec = ec + ctx.dtype = dtype return expert_input @staticmethod + @custom_bwd def backward(ctx, output_grad): mask, dest_idx = ctx.saved_tensors + if output_grad.dtype != torch.float32: + output_grad = output_grad.to(torch.float32) d_tokens = MOE_KERNEL.dispatch_backward(ctx.s, ctx.ec, ctx.h, output_grad, mask, dest_idx) + if d_tokens.dtype != ctx.dtype: + d_tokens = d_tokens.to(ctx.dtype) return d_tokens, None, None, None class MoeCombine(torch.autograd.Function): @staticmethod + @custom_fwd def forward(ctx, expert_tokens, logits, mask, dest_idx, ec): assert logits.dtype == torch.float32 @@ -166,32 +178,36 @@ def forward(ctx, expert_tokens, logits, mask, dest_idx, ec): e = logits.size(1) c = ec // e h = expert_tokens.size(-1) + dtype = expert_tokens.dtype - fp16_flag = expert_tokens.dtype == torch.float16 - cb_input = expert_tokens.to(torch.float32) if fp16_flag else expert_tokens + if expert_tokens.dtype != torch.float32: + expert_tokens = expert_tokens.to(torch.float32) if MOE_KERNEL is None: load_moe() - ctokens = MOE_KERNEL.combine_forward(s, e, c, h, cb_input, logits, mask, dest_idx) - output = ctokens.to(torch.float16) if fp16_flag else ctokens + output = MOE_KERNEL.combine_forward(s, e, c, h, expert_tokens, logits, mask, dest_idx) + if output.dtype != dtype: + output = output.to(dtype) ctx.save_for_backward(expert_tokens, logits, mask, dest_idx) ctx.s = s ctx.e = e ctx.c = c ctx.h = h - ctx.fp16_flag = fp16_flag + ctx.dtype = dtype return output @staticmethod + @custom_bwd def backward(ctx, tokens_grad): expert_tokens, logits, mask, dest_idx = ctx.saved_tensors + if tokens_grad.dtype != torch.float32: + tokens_grad = tokens_grad.to(torch.float32) - cb_grad = (tokens_grad.to(torch.float32) if tokens_grad.dtype is torch.float16 else tokens_grad) - cb_input = expert_tokens.to(torch.float32) if ctx.fp16_flag else expert_tokens - d_expert, d_logits = MOE_KERNEL.combine_backward(ctx.s, ctx.e, ctx.c, ctx.h, cb_grad, cb_input, logits, mask, - dest_idx) - d_expert = d_expert.to(torch.float16) if ctx.fp16_flag else d_expert + d_expert, d_logits = MOE_KERNEL.combine_backward(ctx.s, ctx.e, ctx.c, ctx.h, tokens_grad, expert_tokens, logits, + mask, dest_idx) + if d_expert.dtype != ctx.dtype: + d_expert = d_expert.to(ctx.dtype) return d_expert, d_logits, None, None, None diff --git a/examples/language/openmoe/benchmark/benchmark_cai.py b/examples/language/openmoe/benchmark/benchmark_cai.py index 830ff9df0ec6..e1acba5c88b0 100644 --- a/examples/language/openmoe/benchmark/benchmark_cai.py +++ b/examples/language/openmoe/benchmark/benchmark_cai.py @@ -1,3 +1,4 @@ +import json import os import torch @@ -7,7 +8,7 @@ from model.openmoe_policy import OpenMoeForCausalLMPolicy from torch.utils.data import Dataset from tqdm import tqdm -from transformers import Adafactor +from transformers import T5Tokenizer from transformers.models.llama import LlamaConfig from utils import PerformanceEvaluator, get_model_numel @@ -17,6 +18,7 @@ from colossalai.booster.plugin import LowLevelZeroPlugin from colossalai.booster.plugin.moe_hybrid_parallel_plugin import MoeHybridParallelPlugin from colossalai.cluster import DistCoordinator +from colossalai.moe.layers import apply_load_balance from colossalai.moe.manager import MOE_MANAGER from colossalai.moe.utils import skip_init from colossalai.utils import get_current_device @@ -41,11 +43,36 @@ def load_ckpt(repo_name: str, model: OpenMoeForCausalLM, booster: Booster): class RandomDataset(Dataset): - def __init__(self, num_samples: int = 1000, max_length: int = 2048, vocab_size: int = 256384): + def __init__(self, + num_samples: int = 1000, + max_length: int = 2048, + vocab_size: int = 256384, + tokenizer: T5Tokenizer = None): self.num_samples = num_samples self.max_length = max_length - self.input_ids = torch.randint(0, vocab_size, (num_samples, max_length), device=get_current_device()) - self.attention_mask = torch.ones_like(self.input_ids) + if os.path.exists("./mock_data.json"): + self.input_ids = [] + self.attention_mask = [] + with open("./mock_data.json", 'r') as f: + data = json.load(f) + for v in data.values(): + d = v["text"] + encode = tokenizer("" + d, + return_tensors="pt", + add_special_tokens=False, + max_length=max_length, + truncation=True, + padding="max_length") + self.input_ids.append(encode["input_ids"]) + self.attention_mask.append(encode["attention_mask"]) + self.input_ids = torch.cat(self.input_ids, dim=0).to(get_current_device()) + self.attention_mask = torch.cat(self.attention_mask, dim=0).to(get_current_device()) + repeat_times = num_samples // self.input_ids.shape[0] + 1 + self.input_ids = self.input_ids.repeat(repeat_times, 1)[:num_samples] + self.attention_mask = self.attention_mask.repeat(repeat_times, 1)[:num_samples] + else: + self.input_ids = torch.randint(0, vocab_size, (num_samples, max_length), device=get_current_device()) + self.attention_mask = torch.ones_like(self.input_ids) def __len__(self): return self.num_samples @@ -103,6 +130,8 @@ def parse_args(): # bench parser.add_argument("--warmup", type=int, default=20) parser.add_argument("--active", type=int, default=20) + # load balance + parser.add_argument("--load_balance", action="store_true") args = parser.parse_args() return args @@ -116,8 +145,14 @@ def main(): # Set plugin booster_kwargs = {} - hybrid_dict = {"tp_size": 1, "custom_policy": OpenMoeForCausalLMPolicy(), "enable_fused_normalization": args.use_kernel, "enable_jit_fused": args.use_kernel} - mgr_dict = {"seed": 42, "use_kernel_optim": args.use_kernel} + hybrid_dict = { + "tp_size": 1, + "custom_policy": OpenMoeForCausalLMPolicy(), + "enable_fused_normalization": args.use_kernel, + "enable_jit_fused": args.use_kernel, + "precision": "bf16" + } + mgr_dict = {"seed": 42, "use_kernel_optim": args.use_kernel, "enable_load_balance": args.load_balance} if args.plugin == "zero": dp_size = dist.get_world_size() plugin = LowLevelZeroPlugin(initial_scale=2**5, stage=2) @@ -203,14 +238,16 @@ def main(): model.gradient_checkpointing_enable() # Prepare tokenizer and dataloader + tokenizer = T5Tokenizer.from_pretrained("google/umt5-small") dataset = RandomDataset( num_samples=args.batch_size * (args.warmup + args.active + 1) * dp_size, max_length=args.seq_length, + tokenizer=tokenizer, ) dataloader = plugin.prepare_dataloader(dataset, batch_size=args.batch_size) # Set optimizer - optimizer = Adafactor(model.parameters(), weight_decay=0.01) + optimizer = torch.optim.Adam(model.parameters(), weight_decay=0.01, lr=1e-5) model_numel = get_model_numel(model) performance_evaluator = PerformanceEvaluator( @@ -264,6 +301,9 @@ def main(): optimizer.step() optimizer.zero_grad() performance_evaluator.on_step_end(exmaple_data["input_ids"]) + if (step == args.warmup // 2) and args.load_balance: + apply_load_balance(model, optimizer) + coordinator.print_on_master(f"Apply load balance") performance_evaluator.on_fit_end()