From ca0f523ac14706f958cf5ce84b694b2710d93553 Mon Sep 17 00:00:00 2001 From: CWHer Date: Wed, 13 Sep 2023 18:03:13 +0800 Subject: [PATCH 1/6] feat: add benchmark train --- .../openmoe/benchmark/benchmark_train.py | 191 ++++++++++++++++++ .../openmoe/benchmark/benchmark_train.sh | 34 ++++ examples/language/openmoe/benchmark/utils.py | 61 ++++++ 3 files changed, 286 insertions(+) create mode 100644 examples/language/openmoe/benchmark/benchmark_train.py create mode 100755 examples/language/openmoe/benchmark/benchmark_train.sh create mode 100644 examples/language/openmoe/benchmark/utils.py diff --git a/examples/language/openmoe/benchmark/benchmark_train.py b/examples/language/openmoe/benchmark/benchmark_train.py new file mode 100644 index 000000000000..cb3ecf371197 --- /dev/null +++ b/examples/language/openmoe/benchmark/benchmark_train.py @@ -0,0 +1,191 @@ +import colossalai +import datasets +import torch +import transformers +from colossalai import get_default_parser +from colossalai.booster import Booster +from colossalai.booster.plugin import LowLevelZeroPlugin +from colossalai.cluster import DistCoordinator +from colossalai.logging import disable_existing_loggers, get_dist_logger +from colossalai.moe.manager import MOE_MANAGER +from colossalai.moe.utils import skip_init +from colossalai.utils import get_current_device +from model.modeling_openmoe import OpenMoeForCausalLM +from torch.utils.data import Dataset +from tqdm import tqdm +from transformers import Adafactor +from transformers.models.llama import LlamaConfig +from utils import SimpleTimer, print_model_numel + + +class RandomDataset(Dataset): + + def __init__(self, + num_samples: int = 1000, + max_length: int = 2048, + vocab_size: int = 32000): + self.num_samples = num_samples + self.max_length = max_length + self.input_ids = torch.randint(0, vocab_size, + (num_samples, max_length), + device=get_current_device()) + self.attention_mask = torch.ones_like(self.input_ids, + device=get_current_device()) + + def __len__(self): + return self.num_samples + + def __getitem__(self, idx): + return { + 'input_ids': self.input_ids[idx], + 'attention_mask': self.attention_mask[idx], + 'labels': self.input_ids[idx] + } + + +def parse_args(): + parser = get_default_parser() + # TODO: add model_name + # parser.add_argument("--model_name", type=str, default="base", choices=["base", "8b"], + # help="Path to pretrained model or model identifier from huggingface.co/models.") + parser.add_argument("--num_epoch", type=int, default=1, help="Number of epochs.") + parser.add_argument("--batch_size", type=int, default=4, help="Batch size (per dp group) for the training dataloader.") + parser.add_argument("--seed", type=int, default=42, help="A seed for reproducible training.") + parser.add_argument("--num_samples", type=int, default=200, help="Number of samples in the dataset.") + + args = parser.parse_args() + return args + + +def main(): + args = parse_args() + + LOSS_CONFIG = { + "architectures": [ + "OpenMoeForCausalLM" + ], + "capacity_factor_eval": 2.0, + "capacity_factor_train": 1.25, + "drop_tks": True, + "dropout_rate": 0.0, + "expert_parallel": None, + "gated": True, + "head_dim": 64, + "hidden_act": "swiglu", + "hidden_size": 768, + "intermediate_size": 2048, + "label_smoothing": 0.0, + "layer_norm_epsilon": 1e-06, + "min_capacity": 4, + "moe_layer_interval": 4, + "noisy_policy": None, + "num_attention_heads": 12, + "num_experts": 16, + "num_hidden_layers": 12, + "num_key_value_heads": 12, + "pretraining_tp": 1, + "rope_scaling": None, + "router_aux_loss_factor": 0.01, + "router_z_loss_factor": 0.0001, + "topk": 2, + "torch_dtype": "float32", + "vocab_size": 256384, + "z_loss_factor": 0.0001 + } + OPTIM_CONFIG = { + "decay_rate": -0.8, + "weight_decay": 0.01, + } + + # Launch ColossalAI + colossalai.launch_from_torch(config={}, seed=args.seed) + coordinator = DistCoordinator() + + # Set up moe + MOE_MANAGER.setup(seed=42, parallel="EP") + + # Manage loggers + disable_existing_loggers() + logger = get_dist_logger() + if coordinator.is_master(): + datasets.utils.logging.set_verbosity_warning() + transformers.utils.logging.set_verbosity_info() + else: + datasets.utils.logging.set_verbosity_error() + transformers.utils.logging.set_verbosity_error() + + # Build OpenMoe model + config = LlamaConfig() + for k, v in LOSS_CONFIG.items(): + setattr(config, k, v) + + with skip_init(): + model = OpenMoeForCausalLM(config) + + logger.info(f"Finish init model with config:\n{config}", ranks=[0]) + model_param = sum(p.numel() for p in model.parameters() if p.requires_grad) + logger.info(f"Model param count: {model_param/1e6:.2f}M", ranks=[0]) + + # Enable gradient checkpointing + model.gradient_checkpointing_enable() + + # Set plugin + plugin = LowLevelZeroPlugin(initial_scale=2**5, stage=2) + logger.info(f"Set plugin as {plugin}", ranks=[0]) + + # Prepare tokenizer and dataloader + dataset = RandomDataset(num_samples=args.num_samples) + dataloader = plugin.prepare_dataloader(dataset, + batch_size=args.batch_size, + shuffle=True, + drop_last=True) + + # Set optimizer + optimizer = Adafactor(model.parameters(), + decay_rate=OPTIM_CONFIG["decay_rate"], + weight_decay=OPTIM_CONFIG["weight_decay"]) + + # Set booster + booster = Booster(plugin=plugin) + model, optimizer, _, dataloader, _ = booster.boost(model=model, + optimizer=optimizer, + dataloader=dataloader) + + # Start benchmark + model.train() + logger.info(f"Start benchmark", ranks=[0]) + + timer = SimpleTimer() + for epoch in range(args.num_epoch): + for batch in tqdm(dataloader, + desc=f'Epoch [{epoch + 1}]', + disable=not coordinator.is_master()): + timer.start("train_step") + + # Forward + timer.start("forward") + outputs = model(use_cache=False, chunk_head=True, **batch) + loss = outputs['loss'] + torch.cuda.synchronize() + timer.stop("forward") + + # Backward + timer.start("backward") + booster.backward(loss, optimizer) + torch.cuda.synchronize() + timer.stop("backward") + + # Optimizer step + timer.start("optimizer_step") + optimizer.step() + optimizer.zero_grad() + torch.cuda.synchronize() + timer.stop("optimizer_step") + + timer.stop("train_step") + + logger.info(f"Benchmark result:\n{repr(timer)}", ranks=[0]) + + +if __name__ == "__main__": + main() diff --git a/examples/language/openmoe/benchmark/benchmark_train.sh b/examples/language/openmoe/benchmark/benchmark_train.sh new file mode 100755 index 000000000000..0496a31a7479 --- /dev/null +++ b/examples/language/openmoe/benchmark/benchmark_train.sh @@ -0,0 +1,34 @@ +#!/bin/bash + +set -xue + +BENCHMARK_DIR=benchmark +NUM_GPU=2 + +set_n_least_used_CUDA_VISIBLE_DEVICES() { + local n=${1:-"9999"} + echo "GPU Memory Usage:" + local FIRST_N_GPU_IDS=$(nvidia-smi --query-gpu=memory.used --format=csv | + tail -n +2 | + nl -v 0 | + tee /dev/tty | + sort -g -k 2 | + awk '{print $1}' | + head -n $n) + export CUDA_VISIBLE_DEVICES=$(echo $FIRST_N_GPU_IDS | sed 's/ /,/g') + echo "Now CUDA_VISIBLE_DEVICES is set to:" + echo "CUDA_VISIBLE_DEVICES=$CUDA_VISIBLE_DEVICES" +} + +set_n_least_used_CUDA_VISIBLE_DEVICES $NUM_GPU + +# HACK: make model importable +example_dir=$(dirname $(realpath $(dirname $0))) +if [ -z ${PYTHONPATH+x} ]; then + export PYTHONPATH=$example_dir +else + export PYTHONPATH=$example_dir:$PYTHONPATH +fi + +torchrun --standalone --nproc_per_node $NUM_GPU \ + $example_dir/$BENCHMARK_DIR/benchmark_train.py diff --git a/examples/language/openmoe/benchmark/utils.py b/examples/language/openmoe/benchmark/utils.py new file mode 100644 index 000000000000..a0414a4e015c --- /dev/null +++ b/examples/language/openmoe/benchmark/utils.py @@ -0,0 +1,61 @@ +import dataclasses +import time +from typing import Dict + +import torch.distributed as dist +import torch.nn as nn +from colossalai.logging import DistributedLogger + + +def print_model_numel(logger: DistributedLogger, + model: nn.Module) -> None: + B = 1024**3 + M = 1024**2 + K = 1024 + outputs = "Model param count: " + model_param = sum(p.numel() for p in model.parameters() if p.requires_grad) + if model_param >= B: + outputs += f'{model_param / B:.2f} B\n' + elif model_param >= M: + outputs += f'{model_param / M:.2f} M\n' + elif model_param >= K: + outputs += f'{model_param / K:.2f} K\n' + else: + outputs += f'{model_param}\n' + logger.info(outputs, ranks=[0]) + + +@dataclasses.dataclass +class TimingItem(): + last_time: float = 0.0 + total_time: float = 0.0 + count: float = 0 + + def __str__(self) -> str: + return f"average time: {self.total_time/self.count * 1000:.2f} ms" + + +class SimpleTimer(): + def __init__(self, warmup: int = 10) -> None: + self.timing_items: Dict[str, TimingItem] = {} + self.warmup = warmup + + def start(self, name: str): + if name not in self.timing_items: + self.timing_items[name] = TimingItem() + self.timing_items[name].last_time = time.time() + + def stop(self, name: str): + assert name in self.timing_items + timing_item = self.timing_items[name] + timing_item.total_time += time.time() - timing_item.last_time + timing_item.count += 1 + if timing_item.count > self.warmup: + timing_item.count = 0 + timing_item.total_time = 0.0 + + def __repr__(self) -> str: + result = "[Timer]:\n" + for name, timing_item in self.timing_items.items(): + result += f" {name}: {timing_item}\n" + return result From 6bd4c7ecab14270c8cfa86c21d220a81b7f9f483 Mon Sep 17 00:00:00 2001 From: CWHer Date: Thu, 14 Sep 2023 11:59:54 +0800 Subject: [PATCH 2/6] perf: use flash_attn --- .../openmoe/model/modeling_openmoe.py | 49 ++++++++++++------- 1 file changed, 31 insertions(+), 18 deletions(-) diff --git a/examples/language/openmoe/model/modeling_openmoe.py b/examples/language/openmoe/model/modeling_openmoe.py index 6ccbf64a60e4..41b74c06fe30 100644 --- a/examples/language/openmoe/model/modeling_openmoe.py +++ b/examples/language/openmoe/model/modeling_openmoe.py @@ -349,6 +349,7 @@ def forward( past_key_value: Optional[Tuple[torch.Tensor]] = None, output_attentions: bool = False, use_cache: bool = False, + use_kernel: bool = True, ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: bsz, q_len, _ = hidden_states.size() @@ -407,24 +408,36 @@ def forward( key_states = repeat_kv(key_states, self.num_key_value_groups) value_states = repeat_kv(value_states, self.num_key_value_groups) - attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) - - if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len): - raise ValueError(f"Attention weights should be of size {(bsz, self.num_heads, q_len, kv_seq_len)}, but is" - f" {attn_weights.size()}") - - if attention_mask is not None: - if attention_mask.size() != (bsz, 1, q_len, kv_seq_len): - raise ValueError( - f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}") - if self.training: - attention_mask = attention_mask.clone().detach() - attention_mask[:, :, :, 0] = 0 - attn_weights = attn_weights + attention_mask - - # upcast attention to fp32 - attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype) - attn_output = torch.matmul(attn_weights, value_states) + if use_kernel: + from flash_attn import flash_attn_func + query_states = query_states.transpose(1, 2) + key_states = key_states.transpose(1, 2) + value_states = value_states.transpose(1, 2) + attn_output = flash_attn_func(query_states, + key_states, + value_states, + softmax_scale=1.0, + causal=True) + attn_output = attn_output.transpose(1, 2).contiguous() + else: + attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) + + if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len): + raise ValueError(f"Attention weights should be of size {(bsz, self.num_heads, q_len, kv_seq_len)}, but is" + f" {attn_weights.size()}") + + if attention_mask is not None: + if attention_mask.size() != (bsz, 1, q_len, kv_seq_len): + raise ValueError( + f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}") + if self.training: + attention_mask = attention_mask.clone().detach() + attention_mask[:, :, :, 0] = 0 + attn_weights = attn_weights + attention_mask + + # upcast attention to fp32 + attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype) + attn_output = torch.matmul(attn_weights, value_states) if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim): raise ValueError(f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is" From e17a29fb3c0b2c9b4cc2c69d6f220d83f394f2e6 Mon Sep 17 00:00:00 2001 From: CWHer Date: Thu, 14 Sep 2023 14:15:27 +0800 Subject: [PATCH 3/6] fix: modify benchmark config --- examples/language/openmoe/benchmark/benchmark_train.py | 2 +- examples/language/openmoe/benchmark/utils.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/examples/language/openmoe/benchmark/benchmark_train.py b/examples/language/openmoe/benchmark/benchmark_train.py index cb3ecf371197..47f94feea869 100644 --- a/examples/language/openmoe/benchmark/benchmark_train.py +++ b/examples/language/openmoe/benchmark/benchmark_train.py @@ -51,7 +51,7 @@ def parse_args(): parser.add_argument("--num_epoch", type=int, default=1, help="Number of epochs.") parser.add_argument("--batch_size", type=int, default=4, help="Batch size (per dp group) for the training dataloader.") parser.add_argument("--seed", type=int, default=42, help="A seed for reproducible training.") - parser.add_argument("--num_samples", type=int, default=200, help="Number of samples in the dataset.") + parser.add_argument("--num_samples", type=int, default=1000, help="Number of samples in the dataset.") args = parser.parse_args() return args diff --git a/examples/language/openmoe/benchmark/utils.py b/examples/language/openmoe/benchmark/utils.py index a0414a4e015c..d2edee64451c 100644 --- a/examples/language/openmoe/benchmark/utils.py +++ b/examples/language/openmoe/benchmark/utils.py @@ -36,7 +36,7 @@ def __str__(self) -> str: class SimpleTimer(): - def __init__(self, warmup: int = 10) -> None: + def __init__(self, warmup: int = 20) -> None: self.timing_items: Dict[str, TimingItem] = {} self.warmup = warmup From fdc84dad6c22548652ee9a0f61b3d372256af86d Mon Sep 17 00:00:00 2001 From: CWHer Date: Thu, 14 Sep 2023 17:39:27 +0800 Subject: [PATCH 4/6] fix: check flash attn installation --- .../openmoe/model/modeling_openmoe.py | 25 +++++++++---------- 1 file changed, 12 insertions(+), 13 deletions(-) diff --git a/examples/language/openmoe/model/modeling_openmoe.py b/examples/language/openmoe/model/modeling_openmoe.py index 41b74c06fe30..4775a3ebea0d 100644 --- a/examples/language/openmoe/model/modeling_openmoe.py +++ b/examples/language/openmoe/model/modeling_openmoe.py @@ -24,24 +24,23 @@ import torch import torch.nn.functional as F import torch.utils.checkpoint +from colossalai.kernel.cuda_native.mha.flash_attn_2 import HAS_FLASH_ATTN +from colossalai.kernel.triton.llama_act_combine_kernel import HAS_TRITON +from colossalai.moe.layers import SparseMLP +from colossalai.moe.manager import MOE_MANAGER from torch import nn -from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast +from transformers.modeling_outputs import (BaseModelOutputWithPast, + CausalLMOutputWithPast) from transformers.modeling_utils import PreTrainedModel from transformers.models.llama import LlamaConfig from transformers.models.t5.modeling_t5 import T5LayerNorm -from transformers.utils import ( - add_start_docstrings, - add_start_docstrings_to_model_forward, - logging, - replace_return_docstrings, -) - -from colossalai.kernel.triton.llama_act_combine_kernel import HAS_TRITON -from colossalai.moe.layers import SparseMLP -from colossalai.moe.manager import MOE_MANAGER +from transformers.utils import (add_start_docstrings, + add_start_docstrings_to_model_forward, logging, + replace_return_docstrings) if HAS_TRITON: - from colossalai.kernel.triton.llama_act_combine_kernel import LlamaActCombine + from colossalai.kernel.triton.llama_act_combine_kernel import \ + LlamaActCombine logger = logging.get_logger(__name__) @@ -408,7 +407,7 @@ def forward( key_states = repeat_kv(key_states, self.num_key_value_groups) value_states = repeat_kv(value_states, self.num_key_value_groups) - if use_kernel: + if HAS_FLASH_ATTN and use_kernel: from flash_attn import flash_attn_func query_states = query_states.transpose(1, 2) key_states = key_states.transpose(1, 2) From 98cdb225b7d68a535c00e22068cc89148ee9b945 Mon Sep 17 00:00:00 2001 From: CWHer Date: Thu, 14 Sep 2023 17:47:26 +0800 Subject: [PATCH 5/6] fix: update config with args --- examples/language/openmoe/benchmark/benchmark_train.py | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) diff --git a/examples/language/openmoe/benchmark/benchmark_train.py b/examples/language/openmoe/benchmark/benchmark_train.py index 47f94feea869..373516c56f84 100644 --- a/examples/language/openmoe/benchmark/benchmark_train.py +++ b/examples/language/openmoe/benchmark/benchmark_train.py @@ -60,7 +60,7 @@ def parse_args(): def main(): args = parse_args() - LOSS_CONFIG = { + MDOEL_CONFIG = { "architectures": [ "OpenMoeForCausalLM" ], @@ -97,6 +97,11 @@ def main(): "weight_decay": 0.01, } + # update config from args + for k in MDOEL_CONFIG: + if hasattr(args, k): + MDOEL_CONFIG[k] = getattr(args, k) + # Launch ColossalAI colossalai.launch_from_torch(config={}, seed=args.seed) coordinator = DistCoordinator() @@ -116,7 +121,7 @@ def main(): # Build OpenMoe model config = LlamaConfig() - for k, v in LOSS_CONFIG.items(): + for k, v in MDOEL_CONFIG.items(): setattr(config, k, v) with skip_init(): From 28f0f527bb9a1838784575dbf3187e8cdf0d83ae Mon Sep 17 00:00:00 2001 From: CWHer Date: Mon, 18 Sep 2023 10:31:49 +0800 Subject: [PATCH 6/6] perf: optimize top2 router --- colossalai/moe/routers.py | 26 +++++++++++++++++++------- 1 file changed, 19 insertions(+), 7 deletions(-) diff --git a/colossalai/moe/routers.py b/colossalai/moe/routers.py index 6fa89a416203..1ac66f7bb78f 100644 --- a/colossalai/moe/routers.py +++ b/colossalai/moe/routers.py @@ -47,7 +47,7 @@ def get_capacity(self, logits_shape): capacity += capacity % 2 capacity = max(capacity, self.min_capacity) assert capacity > 0 - return capacity + return int(capacity) def set_aux_loss(self, router_probs: torch.Tensor, @@ -299,15 +299,27 @@ def forward(self, return probs, mask, dest_idx, num_experts * capacity else: + # >>> original code + # weight1 = mask1 * probs.type_as(inputs) + # weight2 = mask2 * probs.type_as(inputs) + # rank1_sc = F.one_hot(rank1, num_classes=capacity) + # rank2_sc = F.one_hot(rank2, num_classes=capacity) + + # cb_weight1 = weight1.unsqueeze(2) * rank1_sc.unsqueeze(1) + # cb_weight2 = weight2.unsqueeze(2) * rank2_sc.unsqueeze(1) + # cb_weight = cb_weight1 + cb_weight2 + # sec_mask = cb_weight.bool() + weight1 = mask1 * probs.type_as(inputs) weight2 = mask2 * probs.type_as(inputs) - rank1_sc = F.one_hot(rank1, num_classes=capacity) - rank2_sc = F.one_hot(rank2, num_classes=capacity) - cb_weight1 = weight1.unsqueeze(2) * rank1_sc.unsqueeze(1) - cb_weight2 = weight2.unsqueeze(2) * rank2_sc.unsqueeze(1) - cb_weight = cb_weight1 + cb_weight2 - sec_mask = cb_weight.bool() + cb_weight = torch.zeros(inputs.shape + (capacity, ), device=inputs.device) + sec_mask = torch.zeros_like(cb_weight, dtype=torch.bool) + indices = torch.arange(0, inputs.shape[0], device=inputs.device) + cb_weight[indices, top1_idx[indices], rank1[indices]] += weight1[indices, top1_idx[indices]] + cb_weight[indices, top2_idx[indices], rank2[indices]] += weight2[indices, top2_idx[indices]] + sec_mask[indices, top1_idx[indices], rank1[indices]] |= mask1.bool()[indices, top1_idx[indices]] + sec_mask[indices, top2_idx[indices], rank2[indices]] |= mask2.bool()[indices, top2_idx[indices]] return cb_weight, sec_mask