diff --git a/colossalai/booster/plugin/moe_hybrid_parallel_plugin.py b/colossalai/booster/plugin/moe_hybrid_parallel_plugin.py index fab6c2f0cb7b..1f3bb294a7ca 100644 --- a/colossalai/booster/plugin/moe_hybrid_parallel_plugin.py +++ b/colossalai/booster/plugin/moe_hybrid_parallel_plugin.py @@ -103,7 +103,10 @@ def __init__(self, overlap_communication: bool = True, custom_policy: Policy = None) -> None: - super().__init__() + super().__init__(tp_size=tp_size, + pp_size=pp_size, + num_microbatches=num_microbatches, + microbatch_size=microbatch_size) assert dist.get_world_size() % ( tp_size * pp_size ) == 0, f'world size {dist.get_world_size()} is not divisible by tp_size {tp_size} * pp_size {pp_size}' diff --git a/examples/language/openmoe/benchmark/benchmark_cai.py b/examples/language/openmoe/benchmark/benchmark_cai.py new file mode 100644 index 000000000000..7f36f8a88925 --- /dev/null +++ b/examples/language/openmoe/benchmark/benchmark_cai.py @@ -0,0 +1,232 @@ +import datasets +import torch +import torch.distributed as dist +import transformers +from model.modeling_openmoe import OpenMoeForCausalLM +from model.openmoe_policy import OpenMoeForCausalLMPolicy +from torch.utils.data import Dataset +from tqdm import tqdm +from transformers import Adafactor +from transformers.models.llama import LlamaConfig +from utils import PerformanceEvaluator, get_model_numel + +import colossalai +from colossalai import get_default_parser +from colossalai.booster import Booster +from colossalai.booster.plugin import LowLevelZeroPlugin +from colossalai.booster.plugin.moe_hybrid_parallel_plugin import MoeHybridParallelPlugin +from colossalai.cluster import DistCoordinator +from colossalai.logging import disable_existing_loggers, get_dist_logger +from colossalai.moe.manager import MOE_MANAGER +from colossalai.utils import get_current_device + + +def move_to_cuda(batch, device): + return {k: v.to(device) for k, v in batch.items()} + + +class RandomDataset(Dataset): + + def __init__(self, num_samples: int = 1000, max_length: int = 2048, vocab_size: int = 32000): + self.num_samples = num_samples + self.max_length = max_length + self.input_ids = torch.randint(0, vocab_size, (num_samples, max_length), device=get_current_device()) + self.attention_mask = torch.ones_like(self.input_ids) + + def __len__(self): + return self.num_samples + + def __getitem__(self, idx): + return { + "input_ids": self.input_ids[idx], + "attention_mask": self.attention_mask[idx], + "labels": self.input_ids[idx], + } + + +def parse_args(): + # basic settings + parser = get_default_parser() + parser.add_argument( + "--model_name", + type=str, + default="base", + choices=["base", "8b"], + help="Path to pretrained model or model identifier from huggingface.co/models.", + ) + parser.add_argument( + "--batch_size", + type=int, + default=4, + help="Batch size (per dp group) for the training dataloader.", + ) + parser.add_argument( + "--seq_length", + type=int, + default=2048, + help="sequence length for the training dataloader.", + ) + parser.add_argument("--seed", type=int, default=42, help="A seed for reproducible training.") + parser.add_argument( + "--plugin", + type=str, + default="hybrid", + help="parallel plugin", + choices=["zero1", "zero2", "hybrid"], + ) + # hybrid plugin + parser.add_argument("--pp_size", type=int, default=2, help="pp size") + parser.add_argument("--dp_size", type=int, default=1, help="dp size") + parser.add_argument("--ep_size", type=int, default=2, help="ep size") + parser.add_argument("--zero_stage", type=int, default=1, help="zero stage in hybrid plugin") + parser.add_argument("--microbatch_size", type=int, default=1, help="microbatch size") + # kernel + parser.add_argument( + "--use_kernel", + action="store_true", + help="Use kernel optim. Need to install flash attention, apex, triton to enable all kernel optimizations.", + ) + # bench + parser.add_argument("--warmup", type=int, default=20) + parser.add_argument("--active", type=int, default=20) + args = parser.parse_args() + return args + + +def main(): + args = parse_args() + + # Launch ColossalAI + colossalai.launch_from_torch(config={}, seed=args.seed) + coordinator = DistCoordinator() + + # Manage loggers + disable_existing_loggers() + logger = get_dist_logger() + if coordinator.is_master(): + datasets.utils.logging.set_verbosity_warning() + transformers.utils.logging.set_verbosity_info() + else: + datasets.utils.logging.set_verbosity_error() + transformers.utils.logging.set_verbosity_error() + + # Set plugin + booster_kwargs = {} + if args.plugin == "zero1": + dp_size = dist.get_world_size() + plugin = LowLevelZeroPlugin(initial_scale=2**5, stage=1) + MOE_MANAGER.setup( + seed=42, + parallel="EP", + use_kernel_optim=args.use_kernel, + ) + elif args.plugin == "zero2": + dp_size = dist.get_world_size() + plugin = LowLevelZeroPlugin(initial_scale=2**5, stage=2) + MOE_MANAGER.setup( + seed=42, + parallel="EP", + use_kernel_optim=args.use_kernel, + ) + elif args.plugin == "hybrid": + dp_size = dist.get_world_size() // args.pp_size + plugin = MoeHybridParallelPlugin( + tp_size=1, + pp_size=args.pp_size, + zero_stage=args.zero_stage, + microbatch_size=args.microbatch_size, + custom_policy=OpenMoeForCausalLMPolicy(), + enable_fused_normalization=args.use_kernel, + enable_jit_fused=args.use_kernel, + ) + MOE_MANAGER.setup( + seed=42, + parallel="EP", + mode="fixed", + fixed_dp_size=args.dp_size, + fixed_ep_size=args.ep_size, + fixed_pp_size=args.pp_size, + use_kernel_optim=args.use_kernel, + ) + else: + raise ValueError(f"Invalid plugin {args.plugin}") + logger.info(f"Set plugin as {plugin}", ranks=[0]) + + # Build OpenMoe model + repo_name = "hpcaitech/openmoe-" + args.model_name + config = LlamaConfig.from_pretrained(repo_name) + setattr(config, "router_aux_loss_factor", 0.1) + setattr(config, "router_z_loss_factor", 0.1) + setattr(config, "label_smoothing", 0.1) + setattr(config, "z_loss_factor", 0.1) + model = OpenMoeForCausalLM(config) + logger.info(f"Finish init model with config:\n{config}", ranks=[0]) + + # Enable gradient checkpointing + model.gradient_checkpointing_enable() + + # Prepare tokenizer and dataloader + dataset = RandomDataset( + num_samples=args.batch_size * (args.warmup + args.active + 1) * dp_size, + max_length=args.seq_length, + ) + dataloader = plugin.prepare_dataloader(dataset, batch_size=args.batch_size) + + # Set optimizer + optimizer = Adafactor(model.parameters(), weight_decay=0.01) + + model_numel = get_model_numel(model) + performance_evaluator = PerformanceEvaluator( + model_numel, + enable_grad_checkpoint=True, + ignore_steps=args.warmup, + dp_world_size=dp_size, + ) + + # Set booster + booster = Booster(plugin=plugin, **booster_kwargs) + model, optimizer, _, dataloader, _ = booster.boost(model=model, optimizer=optimizer, dataloader=dataloader) + use_pipeline = (isinstance(booster.plugin, MoeHybridParallelPlugin) and booster.plugin.pp_size > 1) + is_pp_last_stage = use_pipeline and booster.plugin.stage_manager.is_last_stage() + logger.info(f"Finish init booster", ranks=[0]) + + # Start finetuning + logger.info(f"Start finetuning", ranks=[0]) + model.train() + train_dataloader_iter = iter(dataloader) + total_len = len(train_dataloader_iter) - 1 + exmaple_data = next(train_dataloader_iter) + with tqdm(range(total_len), disable=not coordinator.is_master()) as pbar: + for step in pbar: + performance_evaluator.on_step_start(step) + if use_pipeline: + # Forward pass + outputs = booster.execute_pipeline( + train_dataloader_iter, + model, + lambda x, y: x.loss, + optimizer, + return_loss=True, + return_outputs=True, + ) + # Backward and optimize + if is_pp_last_stage: + loss = outputs["loss"] + pbar.set_postfix({"loss": loss.item()}) + else: + # Forward pass + data = move_to_cuda(data, torch.cuda.current_device()) + outputs = model(**data) + loss = outputs["loss"] + # Backward + booster.backward(loss, optimizer) + pbar.set_postfix({"loss": loss.item()}) + + optimizer.step() + optimizer.zero_grad() + performance_evaluator.on_step_end(exmaple_data["input_ids"]) + performance_evaluator.on_fit_end() + + +if __name__ == "__main__": + main() diff --git a/examples/language/openmoe/benchmark/benchmark_cai.sh b/examples/language/openmoe/benchmark/benchmark_cai.sh new file mode 100755 index 000000000000..24d0c1b23ab2 --- /dev/null +++ b/examples/language/openmoe/benchmark/benchmark_cai.sh @@ -0,0 +1,56 @@ +#!/bin/bash + +set -xue + +NUM_GPU=4 +MODEL="base" +BATCH_SIZE=1 +SEQ_LENGTH=2048 +WARMUP=10 +ACTIVE=10 + +# HACK: make model importable +example_dir=$(dirname $(realpath $(dirname $0))) +if [ -z ${PYTHONPATH+x} ]; then + export PYTHONPATH=$example_dir +else + export PYTHONPATH=$example_dir:$PYTHONPATH +fi + +# hybrid +torchrun --standalone --nproc_per_node $NUM_GPU \ + $example_dir/benchmark/benchmark_cai.py \ + --model_name $MODEL \ + --batch_size $BATCH_SIZE \ + --seq_length $SEQ_LENGTH \ + --warmup $WARMUP \ + --active $ACTIVE \ + --use_kernel \ + --plugin hybrid \ + --pp_size 2 \ + --dp_size 1 \ + --ep_size 2 \ + --zero_stage 1 \ + --microbatch_size 1 + +# zero1 +torchrun --standalone --nproc_per_node $NUM_GPU \ + $example_dir/benchmark/benchmark_cai.py \ + --model_name $MODEL \ + --batch_size $BATCH_SIZE \ + --seq_length $SEQ_LENGTH \ + --warmup $WARMUP \ + --active $ACTIVE \ + --plugin zero1 \ + --use_kernel + +# zero2 +torchrun --standalone --nproc_per_node $NUM_GPU \ + $example_dir/benchmark/benchmark_cai.py \ + --model_name $MODEL \ + --batch_size $BATCH_SIZE \ + --seq_length $SEQ_LENGTH \ + --warmup $WARMUP \ + --active $ACTIVE \ + --plugin zero2 \ + --use_kernel diff --git a/examples/language/openmoe/benchmark/benchmark_fsdp.py b/examples/language/openmoe/benchmark/benchmark_fsdp.py new file mode 100644 index 000000000000..cb231687ef39 --- /dev/null +++ b/examples/language/openmoe/benchmark/benchmark_fsdp.py @@ -0,0 +1,124 @@ +import argparse +import os + +import torch +import torch.distributed as dist +import torch.multiprocessing as mp +import tqdm +from model.modeling_openmoe import LlamaConfig, OpenMoeForCausalLM +from torch.distributed.fsdp import FullyShardedDataParallel as FSDP +from torch.distributed.fsdp.fully_sharded_data_parallel import MixedPrecision +from torch.utils.data import Dataset +from torch.utils.data.distributed import DistributedSampler +from transformers import Adafactor +from transformers.models.llama import LlamaConfig +from utils import PerformanceEvaluator, get_model_numel + +from colossalai.moe.manager import MOE_MANAGER + + +class RandomDataset(Dataset): + + def __init__(self, num_samples: int = 1000, max_length: int = 2048, vocab_size: int = 32000): + self.num_samples = num_samples + self.max_length = max_length + self.input_ids = torch.randint(0, vocab_size, (num_samples, max_length)) + self.attention_mask = torch.ones_like(self.input_ids) + + def __len__(self): + return self.num_samples + + def __getitem__(self, idx): + return { + "input_ids": self.input_ids[idx], + "attention_mask": self.attention_mask[idx], + "labels": self.input_ids[idx], + } + + +def fsdp_main(rank, world_size, args): + os.environ["MASTER_ADDR"] = "localhost" + os.environ["MASTER_PORT"] = "14501" + # initialize the process group + dist.init_process_group("nccl", rank=rank, world_size=world_size) + + MOE_MANAGER.setup(seed=42, parallel=None, use_kernel_optim=False) + + dp_size = dist.get_world_size() + dataset = RandomDataset(max_length=args.seq_length, + num_samples=args.batch_size * (args.warmup + args.active) * dp_size) + sampler = DistributedSampler(dataset, rank=rank, num_replicas=world_size, shuffle=False) + train_kwargs = {"batch_size": args.batch_size, "sampler": sampler} + train_loader = torch.utils.data.DataLoader(dataset, **train_kwargs) + torch.cuda.set_device(rank) + + config = LlamaConfig.from_pretrained("hpcaitech/openmoe-%s" % args.model_name) + setattr(config, "router_aux_loss_factor", 0.1) + setattr(config, "router_z_loss_factor", 0.1) + setattr(config, "label_smoothing", 0.1) + setattr(config, "z_loss_factor", 0.1) + model = OpenMoeForCausalLM(config).to(rank) + # 使用FSDP将model warp起来 + model = FSDP( + model, + mixed_precision=MixedPrecision( + param_dtype=torch.float16, + reduce_dtype=torch.float16, + buffer_dtype=torch.float16, + ), + ) + optimizer = Adafactor(model.parameters()) + model.train() + + model_numel = get_model_numel(model) + performance_evaluator = PerformanceEvaluator( + model_numel, + enable_grad_checkpoint=True, + ignore_steps=args.warmup, + dp_world_size=dist.get_world_size(), + ) + + for step, data in tqdm.tqdm(enumerate(train_loader), total=len(train_loader)): + performance_evaluator.on_step_start(step) + input_ids, attention_mask, labels = ( + data["input_ids"].cuda(), + data["attention_mask"].cuda(), + data["labels"].cuda(), + ) + + optimizer.zero_grad() + output = model( + input_ids=input_ids, + labels=labels, + attention_mask=attention_mask, + chunk_head=False, + ) + loss = output["loss"] + loss.backward() + optimizer.step() + performance_evaluator.on_step_end(input_ids) + + performance_evaluator.on_fit_end() + if dist.get_rank() == 0: + print(f"Max CUDA memory usage: {torch.cuda.max_memory_allocated()/1024**2:.2f} MB") + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument( + "--model_name", + type=str, + default="base", + choices=["base", "8b"], + help="base or 8b", + ) + parser.add_argument("--batch_size", type=int, default=1) + parser.add_argument("--seq_length", type=int, default=2048) + parser.add_argument("--warmup", type=int, default=20) + parser.add_argument("--active", type=int, default=20) + args = parser.parse_args() + + torch.manual_seed(42) + + WORLD_SIZE = torch.cuda.device_count() + mp.spawn(fsdp_main, args=(WORLD_SIZE, args), nprocs=WORLD_SIZE, join=True) diff --git a/examples/language/openmoe/benchmark/benchmark_fsdp.sh b/examples/language/openmoe/benchmark/benchmark_fsdp.sh new file mode 100755 index 000000000000..a4cb32019431 --- /dev/null +++ b/examples/language/openmoe/benchmark/benchmark_fsdp.sh @@ -0,0 +1,25 @@ +#!/bin/bash + +set -xue + +NUM_GPU=4 +MODEL="base" +BATCH_SIZE=1 +SEQ_LENGTH=2048 +WARMUP=10 +ACTIVE=10 + +# HACK: make model importable +example_dir=$(dirname $(realpath $(dirname $0))) +if [ -z ${PYTHONPATH+x} ]; then + export PYTHONPATH=$example_dir +else + export PYTHONPATH=$example_dir:$PYTHONPATH +fi + +python $example_dir/benchmark/benchmark_fsdp.py \ + --model_name $MODEL \ + --batch_size $BATCH_SIZE \ + --seq_length $SEQ_LENGTH \ + --warmup $WARMUP \ + --active $ACTIVE diff --git a/examples/language/openmoe/benchmark/benchmark_train.py b/examples/language/openmoe/benchmark/benchmark_train.py deleted file mode 100644 index 373516c56f84..000000000000 --- a/examples/language/openmoe/benchmark/benchmark_train.py +++ /dev/null @@ -1,196 +0,0 @@ -import colossalai -import datasets -import torch -import transformers -from colossalai import get_default_parser -from colossalai.booster import Booster -from colossalai.booster.plugin import LowLevelZeroPlugin -from colossalai.cluster import DistCoordinator -from colossalai.logging import disable_existing_loggers, get_dist_logger -from colossalai.moe.manager import MOE_MANAGER -from colossalai.moe.utils import skip_init -from colossalai.utils import get_current_device -from model.modeling_openmoe import OpenMoeForCausalLM -from torch.utils.data import Dataset -from tqdm import tqdm -from transformers import Adafactor -from transformers.models.llama import LlamaConfig -from utils import SimpleTimer, print_model_numel - - -class RandomDataset(Dataset): - - def __init__(self, - num_samples: int = 1000, - max_length: int = 2048, - vocab_size: int = 32000): - self.num_samples = num_samples - self.max_length = max_length - self.input_ids = torch.randint(0, vocab_size, - (num_samples, max_length), - device=get_current_device()) - self.attention_mask = torch.ones_like(self.input_ids, - device=get_current_device()) - - def __len__(self): - return self.num_samples - - def __getitem__(self, idx): - return { - 'input_ids': self.input_ids[idx], - 'attention_mask': self.attention_mask[idx], - 'labels': self.input_ids[idx] - } - - -def parse_args(): - parser = get_default_parser() - # TODO: add model_name - # parser.add_argument("--model_name", type=str, default="base", choices=["base", "8b"], - # help="Path to pretrained model or model identifier from huggingface.co/models.") - parser.add_argument("--num_epoch", type=int, default=1, help="Number of epochs.") - parser.add_argument("--batch_size", type=int, default=4, help="Batch size (per dp group) for the training dataloader.") - parser.add_argument("--seed", type=int, default=42, help="A seed for reproducible training.") - parser.add_argument("--num_samples", type=int, default=1000, help="Number of samples in the dataset.") - - args = parser.parse_args() - return args - - -def main(): - args = parse_args() - - MDOEL_CONFIG = { - "architectures": [ - "OpenMoeForCausalLM" - ], - "capacity_factor_eval": 2.0, - "capacity_factor_train": 1.25, - "drop_tks": True, - "dropout_rate": 0.0, - "expert_parallel": None, - "gated": True, - "head_dim": 64, - "hidden_act": "swiglu", - "hidden_size": 768, - "intermediate_size": 2048, - "label_smoothing": 0.0, - "layer_norm_epsilon": 1e-06, - "min_capacity": 4, - "moe_layer_interval": 4, - "noisy_policy": None, - "num_attention_heads": 12, - "num_experts": 16, - "num_hidden_layers": 12, - "num_key_value_heads": 12, - "pretraining_tp": 1, - "rope_scaling": None, - "router_aux_loss_factor": 0.01, - "router_z_loss_factor": 0.0001, - "topk": 2, - "torch_dtype": "float32", - "vocab_size": 256384, - "z_loss_factor": 0.0001 - } - OPTIM_CONFIG = { - "decay_rate": -0.8, - "weight_decay": 0.01, - } - - # update config from args - for k in MDOEL_CONFIG: - if hasattr(args, k): - MDOEL_CONFIG[k] = getattr(args, k) - - # Launch ColossalAI - colossalai.launch_from_torch(config={}, seed=args.seed) - coordinator = DistCoordinator() - - # Set up moe - MOE_MANAGER.setup(seed=42, parallel="EP") - - # Manage loggers - disable_existing_loggers() - logger = get_dist_logger() - if coordinator.is_master(): - datasets.utils.logging.set_verbosity_warning() - transformers.utils.logging.set_verbosity_info() - else: - datasets.utils.logging.set_verbosity_error() - transformers.utils.logging.set_verbosity_error() - - # Build OpenMoe model - config = LlamaConfig() - for k, v in MDOEL_CONFIG.items(): - setattr(config, k, v) - - with skip_init(): - model = OpenMoeForCausalLM(config) - - logger.info(f"Finish init model with config:\n{config}", ranks=[0]) - model_param = sum(p.numel() for p in model.parameters() if p.requires_grad) - logger.info(f"Model param count: {model_param/1e6:.2f}M", ranks=[0]) - - # Enable gradient checkpointing - model.gradient_checkpointing_enable() - - # Set plugin - plugin = LowLevelZeroPlugin(initial_scale=2**5, stage=2) - logger.info(f"Set plugin as {plugin}", ranks=[0]) - - # Prepare tokenizer and dataloader - dataset = RandomDataset(num_samples=args.num_samples) - dataloader = plugin.prepare_dataloader(dataset, - batch_size=args.batch_size, - shuffle=True, - drop_last=True) - - # Set optimizer - optimizer = Adafactor(model.parameters(), - decay_rate=OPTIM_CONFIG["decay_rate"], - weight_decay=OPTIM_CONFIG["weight_decay"]) - - # Set booster - booster = Booster(plugin=plugin) - model, optimizer, _, dataloader, _ = booster.boost(model=model, - optimizer=optimizer, - dataloader=dataloader) - - # Start benchmark - model.train() - logger.info(f"Start benchmark", ranks=[0]) - - timer = SimpleTimer() - for epoch in range(args.num_epoch): - for batch in tqdm(dataloader, - desc=f'Epoch [{epoch + 1}]', - disable=not coordinator.is_master()): - timer.start("train_step") - - # Forward - timer.start("forward") - outputs = model(use_cache=False, chunk_head=True, **batch) - loss = outputs['loss'] - torch.cuda.synchronize() - timer.stop("forward") - - # Backward - timer.start("backward") - booster.backward(loss, optimizer) - torch.cuda.synchronize() - timer.stop("backward") - - # Optimizer step - timer.start("optimizer_step") - optimizer.step() - optimizer.zero_grad() - torch.cuda.synchronize() - timer.stop("optimizer_step") - - timer.stop("train_step") - - logger.info(f"Benchmark result:\n{repr(timer)}", ranks=[0]) - - -if __name__ == "__main__": - main() diff --git a/examples/language/openmoe/benchmark/benchmark_train.sh b/examples/language/openmoe/benchmark/benchmark_train.sh deleted file mode 100755 index 0496a31a7479..000000000000 --- a/examples/language/openmoe/benchmark/benchmark_train.sh +++ /dev/null @@ -1,34 +0,0 @@ -#!/bin/bash - -set -xue - -BENCHMARK_DIR=benchmark -NUM_GPU=2 - -set_n_least_used_CUDA_VISIBLE_DEVICES() { - local n=${1:-"9999"} - echo "GPU Memory Usage:" - local FIRST_N_GPU_IDS=$(nvidia-smi --query-gpu=memory.used --format=csv | - tail -n +2 | - nl -v 0 | - tee /dev/tty | - sort -g -k 2 | - awk '{print $1}' | - head -n $n) - export CUDA_VISIBLE_DEVICES=$(echo $FIRST_N_GPU_IDS | sed 's/ /,/g') - echo "Now CUDA_VISIBLE_DEVICES is set to:" - echo "CUDA_VISIBLE_DEVICES=$CUDA_VISIBLE_DEVICES" -} - -set_n_least_used_CUDA_VISIBLE_DEVICES $NUM_GPU - -# HACK: make model importable -example_dir=$(dirname $(realpath $(dirname $0))) -if [ -z ${PYTHONPATH+x} ]; then - export PYTHONPATH=$example_dir -else - export PYTHONPATH=$example_dir:$PYTHONPATH -fi - -torchrun --standalone --nproc_per_node $NUM_GPU \ - $example_dir/$BENCHMARK_DIR/benchmark_train.py diff --git a/examples/language/openmoe/benchmark/utils.py b/examples/language/openmoe/benchmark/utils.py index d2edee64451c..7a0955bb028a 100644 --- a/examples/language/openmoe/benchmark/utils.py +++ b/examples/language/openmoe/benchmark/utils.py @@ -1,61 +1,126 @@ -import dataclasses -import time -from typing import Dict +from time import time +from typing import Optional +import torch import torch.distributed as dist import torch.nn as nn +from torch import Tensor + from colossalai.logging import DistributedLogger -def print_model_numel(logger: DistributedLogger, - model: nn.Module) -> None: +def print_model_numel(logger: DistributedLogger, model: nn.Module) -> None: B = 1024**3 M = 1024**2 K = 1024 outputs = "Model param count: " model_param = sum(p.numel() for p in model.parameters() if p.requires_grad) if model_param >= B: - outputs += f'{model_param / B:.2f} B\n' + outputs += f"{model_param / B:.2f} B\n" elif model_param >= M: - outputs += f'{model_param / M:.2f} M\n' + outputs += f"{model_param / M:.2f} M\n" elif model_param >= K: - outputs += f'{model_param / K:.2f} K\n' + outputs += f"{model_param / K:.2f} K\n" else: - outputs += f'{model_param}\n' + outputs += f"{model_param}\n" logger.info(outputs, ranks=[0]) -@dataclasses.dataclass -class TimingItem(): - last_time: float = 0.0 - total_time: float = 0.0 - count: float = 0 - - def __str__(self) -> str: - return f"average time: {self.total_time/self.count * 1000:.2f} ms" - - -class SimpleTimer(): - def __init__(self, warmup: int = 20) -> None: - self.timing_items: Dict[str, TimingItem] = {} - self.warmup = warmup - - def start(self, name: str): - if name not in self.timing_items: - self.timing_items[name] = TimingItem() - self.timing_items[name].last_time = time.time() - - def stop(self, name: str): - assert name in self.timing_items - timing_item = self.timing_items[name] - timing_item.total_time += time.time() - timing_item.last_time - timing_item.count += 1 - if timing_item.count > self.warmup: - timing_item.count = 0 - timing_item.total_time = 0.0 - - def __repr__(self) -> str: - result = "[Timer]:\n" - for name, timing_item in self.timing_items.items(): - result += f" {name}: {timing_item}\n" - return result +def get_model_numel(model: nn.Module) -> None: + model_param = sum(p.numel() for p in model.parameters() if p.requires_grad) + return model_param + + +def divide(x: float, y: float) -> float: + if y == 0: + return float("inf") + elif y == float("inf"): + return float("nan") + return x / y + + +@torch.no_grad() +def all_reduce_mean(x: float, world_size: int) -> float: + if world_size == 1: + return x + tensor = torch.tensor([x], device=torch.cuda.current_device()) + dist.all_reduce(tensor) + tensor = tensor / world_size + return tensor.item() + + +class Timer: + + def __init__(self) -> None: + self.start_time: Optional[float] = None + self.duration: float = 0.0 + + def start(self) -> None: + self.start_time = time() + + def end(self) -> None: + assert self.start_time is not None + self.duration += time() - self.start_time + self.start_time = None + + def reset(self) -> None: + self.duration = 0.0 + + +class PerformanceEvaluator: + """ + Callback for valuate the performance of the model. + Args: + actor_num_params: The number of parameters of the actor model. + critic_num_params: The number of parameters of the critic model. + initial_model_num_params: The number of parameters of the initial model. + reward_model_num_params: The number of parameters of the reward model. + enable_grad_checkpoint: Whether to enable gradient checkpointing. + ignore_episodes: The number of episodes to ignore when calculating the performance. + """ + + def __init__( + self, + model_numel: int, + enable_grad_checkpoint: bool = False, + ignore_steps: int = 0, + dp_world_size: Optional[int] = None, + ) -> None: + self.model_numel = model_numel + self.enable_grad_checkpoint = enable_grad_checkpoint + self.ignore_steps = ignore_steps + self.dp_world_size = dp_world_size + self.world_size = dist.get_world_size() + self.disable: bool = False + self.timer = Timer() + self.num_samples: int = 0 + self.flop: int = 0 + + def on_step_start(self, step: int) -> None: + self.disable = self.ignore_steps > 0 and step < self.ignore_steps + if self.disable: + return + torch.cuda.synchronize() + self.timer.start() + + def on_step_end(self, input_ids: Tensor, **kwargs) -> None: + if self.disable: + return + torch.cuda.synchronize() + self.timer.end() + + batch_size, seq_len = input_ids.shape + + self.num_samples += batch_size + self.flop += (batch_size * seq_len * self.model_numel * 2 * (3 + int(self.enable_grad_checkpoint))) + + def on_fit_end(self) -> None: + avg_duration = all_reduce_mean(self.timer.duration, self.world_size) + avg_throughput = self.num_samples * self.dp_world_size / (avg_duration + 1e-12) + mp_world_size = self.world_size // self.dp_world_size + avg_tflops_per_gpu = self.flop / 1e12 / (avg_duration + 1e-12) / mp_world_size + if dist.get_rank() == 0: + print( + f"num_samples: {self.num_samples}, dp_world_size: {self.dp_world_size}, flop: {self.flop}, avg_duration: {avg_duration}, " + f"avg_throughput: {avg_throughput}") + print(f"Throughput: {avg_throughput:.2f} samples/sec, TFLOPS per GPU: {avg_tflops_per_gpu:.2f}") diff --git a/examples/language/openmoe/model/modeling_openmoe.py b/examples/language/openmoe/model/modeling_openmoe.py index a774d4e9fd55..4d5ff19936b6 100644 --- a/examples/language/openmoe/model/modeling_openmoe.py +++ b/examples/language/openmoe/model/modeling_openmoe.py @@ -24,23 +24,24 @@ import torch import torch.nn.functional as F import torch.utils.checkpoint +from torch import nn +from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast +from transformers.modeling_utils import PreTrainedModel +from transformers.models.llama.modeling_llama import LlamaConfig, LlamaRMSNorm +from transformers.utils import ( + add_start_docstrings, + add_start_docstrings_to_model_forward, + logging, + replace_return_docstrings, +) + from colossalai.kernel.cuda_native.mha.flash_attn_2 import HAS_FLASH_ATTN from colossalai.kernel.triton.llama_act_combine_kernel import HAS_TRITON from colossalai.moe.layers import SparseMLP from colossalai.moe.manager import MOE_MANAGER -from torch import nn -from transformers.modeling_outputs import (BaseModelOutputWithPast, - CausalLMOutputWithPast) -from transformers.modeling_utils import PreTrainedModel -from transformers.models.llama import LlamaConfig -from transformers.models.t5.modeling_t5 import T5LayerNorm -from transformers.utils import (add_start_docstrings, - add_start_docstrings_to_model_forward, logging, - replace_return_docstrings) if HAS_TRITON: - from colossalai.kernel.triton.llama_act_combine_kernel import \ - LlamaActCombine + from colossalai.kernel.triton.llama_act_combine_kernel import LlamaActCombine logger = logging.get_logger(__name__) @@ -305,23 +306,21 @@ def forward( query_states = query_states.transpose(1, 2) key_states = key_states.transpose(1, 2) value_states = value_states.transpose(1, 2) - attn_output = flash_attn_func(query_states, - key_states, - value_states, - softmax_scale=1.0, - causal=True) + attn_output = flash_attn_func(query_states, key_states, value_states, softmax_scale=1.0, causal=True) attn_output = attn_output.transpose(1, 2).contiguous() else: attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len): - raise ValueError(f"Attention weights should be of size {(bsz, self.num_heads, q_len, kv_seq_len)}, but is" - f" {attn_weights.size()}") + raise ValueError( + f"Attention weights should be of size {(bsz, self.num_heads, q_len, kv_seq_len)}, but is" + f" {attn_weights.size()}") if attention_mask is not None: if attention_mask.size() != (bsz, 1, q_len, kv_seq_len): raise ValueError( - f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}") + f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}" + ) if self.training: attention_mask = attention_mask.clone().detach() attention_mask[:, :, :, 0] = 0 @@ -358,8 +357,8 @@ def __init__(self, config: LlamaConfig, moe: bool): self.hidden_size = config.hidden_size self.moe = moe self.self_attn = OpenMoeAttention(config=config) - self.input_layernorm = T5LayerNorm(config.hidden_size, eps=config.rms_norm_eps) - self.post_attention_layernorm = T5LayerNorm(config.hidden_size, eps=config.rms_norm_eps) + self.input_layernorm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.post_attention_layernorm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps) if self.moe: self.mlp = SparseMLP( num_experts=config.num_experts, @@ -374,7 +373,7 @@ def __init__(self, config: LlamaConfig, moe: bool): intermediate_size=config.intermediate_size, activation=config.hidden_act, gated=config.gated) - self.pre_extra_mlp_layernorm = T5LayerNorm(config.hidden_size, eps=config.rms_norm_eps) + self.pre_extra_mlp_layernorm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.extra_mlp = OpenMoeMLP(config) else: self.mlp = OpenMoeMLP(config) @@ -570,7 +569,7 @@ def __init__(self, config: LlamaConfig): OpenMoeDecoderLayer(config, moe=True if (i + 1) % config.moe_layer_interval == 0 else False) for i in range(config.num_hidden_layers) ]) - self.norm = T5LayerNorm(config.hidden_size, eps=config.rms_norm_eps) + self.norm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.gradient_checkpointing = False # Initialize weights and apply final processing diff --git a/examples/language/openmoe/train.py b/examples/language/openmoe/train.py index 2099bbde91f5..e276759043a9 100644 --- a/examples/language/openmoe/train.py +++ b/examples/language/openmoe/train.py @@ -67,6 +67,7 @@ def parse_args(): "--model_name", type=str, default="base", + choices=["base", "8b"], help="Path to pretrained model or model identifier from huggingface.co/models.", ) parser.add_argument( @@ -132,26 +133,6 @@ def main(): colossalai.launch_from_torch(config={}, seed=args.seed) coordinator = DistCoordinator() - # Set up moe - if args.plugin in ["zero1", "zero2"]: - MOE_MANAGER.setup( - seed=42, - parallel="EP", - use_kernel_optim=False if args.model_name == "test" else args.use_kernel, - ) - elif args.plugin == "hybrid": - assert (args.dp_size * args.ep_size * - args.pp_size == coordinator.world_size), "dp_size * ep_size * pp_size must equal to world_size" - MOE_MANAGER.setup( - seed=42, - parallel="EP", - mode="fixed", - fixed_dp_size=args.dp_size, - fixed_ep_size=args.ep_size, - fixed_pp_size=args.pp_size, - use_kernel_optim=False if args.model_name == "test" else args.use_kernel, - ) - # Manage loggers disable_existing_loggers() logger = get_dist_logger() @@ -162,32 +143,22 @@ def main(): datasets.utils.logging.set_verbosity_error() transformers.utils.logging.set_verbosity_error() - # Build OpenMoe model - repo_name = "hpcaitech/openmoe-" + args.model_name - if args.model_name == "test": - config = LlamaConfig.from_pretrained("hpcaitech/openmoe-base") - config.vocab_size = 32000 - else: - config = LlamaConfig.from_pretrained(repo_name) - setattr(config, "router_aux_loss_factor", args.router_aux_loss_factor) - setattr(config, "router_z_loss_factor", args.router_z_loss_factor) - setattr(config, "label_smoothing", args.label_smoothing) - setattr(config, "z_loss_factor", args.z_loss_factor) - with skip_init(): - model = OpenMoeForCausalLM(config) - if args.model_name != "test": - load_ckpt(repo_name, model) - logger.info(f"Finish init model with config:\n{config}", ranks=[0]) - - # Enable gradient checkpointing - model.gradient_checkpointing_enable() - # Set plugin booster_kwargs = {} if args.plugin == "zero1": plugin = LowLevelZeroPlugin(initial_scale=2**5, stage=1) + MOE_MANAGER.setup( + seed=42, + parallel="EP", + use_kernel_optim=args.use_kernel, + ) elif args.plugin == "zero2": plugin = LowLevelZeroPlugin(initial_scale=2**5, stage=2) + MOE_MANAGER.setup( + seed=42, + parallel="EP", + use_kernel_optim=args.use_kernel, + ) elif args.plugin == "hybrid": plugin = MoeHybridParallelPlugin( tp_size=1, @@ -198,13 +169,37 @@ def main(): enable_fused_normalization=args.use_kernel, enable_jit_fused=args.use_kernel, ) + MOE_MANAGER.setup( + seed=42, + parallel="EP", + mode="fixed", + fixed_dp_size=args.dp_size, + fixed_ep_size=args.ep_size, + fixed_pp_size=args.pp_size, + use_kernel_optim=args.use_kernel, + ) else: raise ValueError(f"Invalid plugin {args.plugin}") logger.info(f"Set plugin as {plugin}", ranks=[0]) + # Build OpenMoe model + repo_name = "hpcaitech/openmoe-" + args.model_name + config = LlamaConfig.from_pretrained(repo_name) + setattr(config, "router_aux_loss_factor", args.router_aux_loss_factor) + setattr(config, "router_z_loss_factor", args.router_z_loss_factor) + setattr(config, "label_smoothing", args.label_smoothing) + setattr(config, "z_loss_factor", args.z_loss_factor) + with skip_init(): + model = OpenMoeForCausalLM(config) + load_ckpt(repo_name, model) + logger.info(f"Finish init model with config:\n{config}", ranks=[0]) + + # Enable gradient checkpointing + model.gradient_checkpointing_enable() + # Prepare tokenizer and dataloader tokenizer = T5Tokenizer.from_pretrained("google/umt5-small") - dataset = RandomDataset(num_samples=1000 if args.model_name != "test" else 50) + dataset = RandomDataset(num_samples=1000) dataloader = plugin.prepare_dataloader(dataset, batch_size=args.batch_size, shuffle=True, drop_last=True) # Set optimizer @@ -228,9 +223,9 @@ def main(): desc=f"Epoch [{epoch + 1}/{args.num_epoch}]", disable=not coordinator.is_master(), ) as pbar: - # Forward pass for _ in pbar: if use_pipeline: + # Forward pass outputs = booster.execute_pipeline( train_dataloader_iter, model, @@ -244,6 +239,7 @@ def main(): loss = outputs["loss"] pbar.set_postfix({"loss": loss.item()}) else: + # Forward pass data = next(train_dataloader_iter) data = move_to_cuda(data, torch.cuda.current_device()) outputs = model(**data)