Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 19 additions & 7 deletions colossalai/moe/routers.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ def get_capacity(self, logits_shape):
capacity += capacity % 2
capacity = max(capacity, self.min_capacity)
assert capacity > 0
return capacity
return int(capacity)

def set_aux_loss(self,
router_probs: torch.Tensor,
Expand Down Expand Up @@ -299,15 +299,27 @@ def forward(self,

return probs, mask, dest_idx, num_experts * capacity
else:
# >>> original code
# weight1 = mask1 * probs.type_as(inputs)
# weight2 = mask2 * probs.type_as(inputs)
# rank1_sc = F.one_hot(rank1, num_classes=capacity)
# rank2_sc = F.one_hot(rank2, num_classes=capacity)

# cb_weight1 = weight1.unsqueeze(2) * rank1_sc.unsqueeze(1)
# cb_weight2 = weight2.unsqueeze(2) * rank2_sc.unsqueeze(1)
# cb_weight = cb_weight1 + cb_weight2
# sec_mask = cb_weight.bool()

weight1 = mask1 * probs.type_as(inputs)
weight2 = mask2 * probs.type_as(inputs)
rank1_sc = F.one_hot(rank1, num_classes=capacity)
rank2_sc = F.one_hot(rank2, num_classes=capacity)

cb_weight1 = weight1.unsqueeze(2) * rank1_sc.unsqueeze(1)
cb_weight2 = weight2.unsqueeze(2) * rank2_sc.unsqueeze(1)
cb_weight = cb_weight1 + cb_weight2
sec_mask = cb_weight.bool()
cb_weight = torch.zeros(inputs.shape + (capacity, ), device=inputs.device)
sec_mask = torch.zeros_like(cb_weight, dtype=torch.bool)
indices = torch.arange(0, inputs.shape[0], device=inputs.device)
cb_weight[indices, top1_idx[indices], rank1[indices]] += weight1[indices, top1_idx[indices]]
cb_weight[indices, top2_idx[indices], rank2[indices]] += weight2[indices, top2_idx[indices]]
sec_mask[indices, top1_idx[indices], rank1[indices]] |= mask1.bool()[indices, top1_idx[indices]]
sec_mask[indices, top2_idx[indices], rank2[indices]] |= mask2.bool()[indices, top2_idx[indices]]

return cb_weight, sec_mask

Expand Down
196 changes: 196 additions & 0 deletions examples/language/openmoe/benchmark/benchmark_train.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,196 @@
import colossalai
import datasets
import torch
import transformers
from colossalai import get_default_parser
from colossalai.booster import Booster
from colossalai.booster.plugin import LowLevelZeroPlugin
from colossalai.cluster import DistCoordinator
from colossalai.logging import disable_existing_loggers, get_dist_logger
from colossalai.moe.manager import MOE_MANAGER
from colossalai.moe.utils import skip_init
from colossalai.utils import get_current_device
from model.modeling_openmoe import OpenMoeForCausalLM
from torch.utils.data import Dataset
from tqdm import tqdm
from transformers import Adafactor
from transformers.models.llama import LlamaConfig
from utils import SimpleTimer, print_model_numel


class RandomDataset(Dataset):

def __init__(self,
num_samples: int = 1000,
max_length: int = 2048,
vocab_size: int = 32000):
self.num_samples = num_samples
self.max_length = max_length
self.input_ids = torch.randint(0, vocab_size,
(num_samples, max_length),
device=get_current_device())
self.attention_mask = torch.ones_like(self.input_ids,
device=get_current_device())

def __len__(self):
return self.num_samples

def __getitem__(self, idx):
return {
'input_ids': self.input_ids[idx],
'attention_mask': self.attention_mask[idx],
'labels': self.input_ids[idx]
}


def parse_args():
parser = get_default_parser()
# TODO: add model_name
# parser.add_argument("--model_name", type=str, default="base", choices=["base", "8b"],
# help="Path to pretrained model or model identifier from huggingface.co/models.")
parser.add_argument("--num_epoch", type=int, default=1, help="Number of epochs.")
parser.add_argument("--batch_size", type=int, default=4, help="Batch size (per dp group) for the training dataloader.")
parser.add_argument("--seed", type=int, default=42, help="A seed for reproducible training.")
parser.add_argument("--num_samples", type=int, default=1000, help="Number of samples in the dataset.")

args = parser.parse_args()
return args


def main():
args = parse_args()

MDOEL_CONFIG = {
"architectures": [
"OpenMoeForCausalLM"
],
"capacity_factor_eval": 2.0,
"capacity_factor_train": 1.25,
"drop_tks": True,
"dropout_rate": 0.0,
"expert_parallel": None,
"gated": True,
"head_dim": 64,
"hidden_act": "swiglu",
"hidden_size": 768,
"intermediate_size": 2048,
"label_smoothing": 0.0,
"layer_norm_epsilon": 1e-06,
"min_capacity": 4,
"moe_layer_interval": 4,
"noisy_policy": None,
"num_attention_heads": 12,
"num_experts": 16,
"num_hidden_layers": 12,
"num_key_value_heads": 12,
"pretraining_tp": 1,
"rope_scaling": None,
"router_aux_loss_factor": 0.01,
"router_z_loss_factor": 0.0001,
"topk": 2,
"torch_dtype": "float32",
"vocab_size": 256384,
"z_loss_factor": 0.0001
}
OPTIM_CONFIG = {
"decay_rate": -0.8,
"weight_decay": 0.01,
}

# update config from args
for k in MDOEL_CONFIG:
if hasattr(args, k):
MDOEL_CONFIG[k] = getattr(args, k)

# Launch ColossalAI
colossalai.launch_from_torch(config={}, seed=args.seed)
coordinator = DistCoordinator()

# Set up moe
MOE_MANAGER.setup(seed=42, parallel="EP")

# Manage loggers
disable_existing_loggers()
logger = get_dist_logger()
if coordinator.is_master():
datasets.utils.logging.set_verbosity_warning()
transformers.utils.logging.set_verbosity_info()
else:
datasets.utils.logging.set_verbosity_error()
transformers.utils.logging.set_verbosity_error()

# Build OpenMoe model
config = LlamaConfig()
for k, v in MDOEL_CONFIG.items():
setattr(config, k, v)

with skip_init():
model = OpenMoeForCausalLM(config)

logger.info(f"Finish init model with config:\n{config}", ranks=[0])
model_param = sum(p.numel() for p in model.parameters() if p.requires_grad)
logger.info(f"Model param count: {model_param/1e6:.2f}M", ranks=[0])

# Enable gradient checkpointing
model.gradient_checkpointing_enable()

# Set plugin
plugin = LowLevelZeroPlugin(initial_scale=2**5, stage=2)
logger.info(f"Set plugin as {plugin}", ranks=[0])

# Prepare tokenizer and dataloader
dataset = RandomDataset(num_samples=args.num_samples)
dataloader = plugin.prepare_dataloader(dataset,
batch_size=args.batch_size,
shuffle=True,
drop_last=True)

# Set optimizer
optimizer = Adafactor(model.parameters(),
decay_rate=OPTIM_CONFIG["decay_rate"],
weight_decay=OPTIM_CONFIG["weight_decay"])

# Set booster
booster = Booster(plugin=plugin)
model, optimizer, _, dataloader, _ = booster.boost(model=model,
optimizer=optimizer,
dataloader=dataloader)

# Start benchmark
model.train()
logger.info(f"Start benchmark", ranks=[0])

timer = SimpleTimer()
for epoch in range(args.num_epoch):
for batch in tqdm(dataloader,
desc=f'Epoch [{epoch + 1}]',
disable=not coordinator.is_master()):
timer.start("train_step")

# Forward
timer.start("forward")
outputs = model(use_cache=False, chunk_head=True, **batch)
loss = outputs['loss']
torch.cuda.synchronize()
timer.stop("forward")

# Backward
timer.start("backward")
booster.backward(loss, optimizer)
torch.cuda.synchronize()
timer.stop("backward")

# Optimizer step
timer.start("optimizer_step")
optimizer.step()
optimizer.zero_grad()
torch.cuda.synchronize()
timer.stop("optimizer_step")

timer.stop("train_step")

logger.info(f"Benchmark result:\n{repr(timer)}", ranks=[0])


if __name__ == "__main__":
main()
34 changes: 34 additions & 0 deletions examples/language/openmoe/benchmark/benchmark_train.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
#!/bin/bash

set -xue

BENCHMARK_DIR=benchmark
NUM_GPU=2

set_n_least_used_CUDA_VISIBLE_DEVICES() {
local n=${1:-"9999"}
echo "GPU Memory Usage:"
local FIRST_N_GPU_IDS=$(nvidia-smi --query-gpu=memory.used --format=csv |
tail -n +2 |
nl -v 0 |
tee /dev/tty |
sort -g -k 2 |
awk '{print $1}' |
head -n $n)
export CUDA_VISIBLE_DEVICES=$(echo $FIRST_N_GPU_IDS | sed 's/ /,/g')
echo "Now CUDA_VISIBLE_DEVICES is set to:"
echo "CUDA_VISIBLE_DEVICES=$CUDA_VISIBLE_DEVICES"
}

set_n_least_used_CUDA_VISIBLE_DEVICES $NUM_GPU

# HACK: make model importable
example_dir=$(dirname $(realpath $(dirname $0)))
if [ -z ${PYTHONPATH+x} ]; then
export PYTHONPATH=$example_dir
else
export PYTHONPATH=$example_dir:$PYTHONPATH
fi

torchrun --standalone --nproc_per_node $NUM_GPU \
$example_dir/$BENCHMARK_DIR/benchmark_train.py
61 changes: 61 additions & 0 deletions examples/language/openmoe/benchmark/utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
import dataclasses
import time
from typing import Dict

import torch.distributed as dist
import torch.nn as nn
from colossalai.logging import DistributedLogger


def print_model_numel(logger: DistributedLogger,
model: nn.Module) -> None:
B = 1024**3
M = 1024**2
K = 1024
outputs = "Model param count: "
model_param = sum(p.numel() for p in model.parameters() if p.requires_grad)
if model_param >= B:
outputs += f'{model_param / B:.2f} B\n'
elif model_param >= M:
outputs += f'{model_param / M:.2f} M\n'
elif model_param >= K:
outputs += f'{model_param / K:.2f} K\n'
else:
outputs += f'{model_param}\n'
logger.info(outputs, ranks=[0])


@dataclasses.dataclass
class TimingItem():
last_time: float = 0.0
total_time: float = 0.0
count: float = 0

def __str__(self) -> str:
return f"average time: {self.total_time/self.count * 1000:.2f} ms"


class SimpleTimer():
def __init__(self, warmup: int = 20) -> None:
self.timing_items: Dict[str, TimingItem] = {}
self.warmup = warmup

def start(self, name: str):
if name not in self.timing_items:
self.timing_items[name] = TimingItem()
self.timing_items[name].last_time = time.time()

def stop(self, name: str):
assert name in self.timing_items
timing_item = self.timing_items[name]
timing_item.total_time += time.time() - timing_item.last_time
timing_item.count += 1
if timing_item.count > self.warmup:
timing_item.count = 0
timing_item.total_time = 0.0

def __repr__(self) -> str:
result = "[Timer]:\n"
for name, timing_item in self.timing_items.items():
result += f" {name}: {timing_item}\n"
return result
Loading