Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
40 changes: 28 additions & 12 deletions colossalai/moe/_operation.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import torch
import torch.distributed as dist
from torch import Tensor
from torch.cuda.amp import custom_bwd, custom_fwd
from torch.distributed import ProcessGroup

from colossalai.moe.manager import MOE_MANAGER
Expand Down Expand Up @@ -133,65 +134,80 @@ def backward(ctx: Any, *grad_outputs: Tensor) -> Tuple[Tensor, None]:
class MoeDispatch(torch.autograd.Function):

@staticmethod
@custom_fwd
def forward(ctx, tokens, mask, dest_idx, ec):
s = tokens.size(0)
h = tokens.size(1)
dtype = tokens.dtype

if MOE_KERNEL is None:
load_moe()

if tokens.dtype != torch.float32:
tokens = tokens.to(torch.float32)
expert_input = MOE_KERNEL.dispatch_forward(s, ec, h, tokens, mask, dest_idx)

if expert_input.dtype != dtype:
expert_input = expert_input.to(dtype)
ctx.save_for_backward(mask, dest_idx)
ctx.s = s
ctx.h = h
ctx.ec = ec
ctx.dtype = dtype

return expert_input

@staticmethod
@custom_bwd
def backward(ctx, output_grad):
mask, dest_idx = ctx.saved_tensors
if output_grad.dtype != torch.float32:
output_grad = output_grad.to(torch.float32)
d_tokens = MOE_KERNEL.dispatch_backward(ctx.s, ctx.ec, ctx.h, output_grad, mask, dest_idx)
if d_tokens.dtype != ctx.dtype:
d_tokens = d_tokens.to(ctx.dtype)
return d_tokens, None, None, None


class MoeCombine(torch.autograd.Function):

@staticmethod
@custom_fwd
def forward(ctx, expert_tokens, logits, mask, dest_idx, ec):
assert logits.dtype == torch.float32

s = logits.size(0)
e = logits.size(1)
c = ec // e
h = expert_tokens.size(-1)
dtype = expert_tokens.dtype

fp16_flag = expert_tokens.dtype == torch.float16
cb_input = expert_tokens.to(torch.float32) if fp16_flag else expert_tokens
if expert_tokens.dtype != torch.float32:
expert_tokens = expert_tokens.to(torch.float32)
if MOE_KERNEL is None:
load_moe()
ctokens = MOE_KERNEL.combine_forward(s, e, c, h, cb_input, logits, mask, dest_idx)
output = ctokens.to(torch.float16) if fp16_flag else ctokens
output = MOE_KERNEL.combine_forward(s, e, c, h, expert_tokens, logits, mask, dest_idx)
if output.dtype != dtype:
output = output.to(dtype)

ctx.save_for_backward(expert_tokens, logits, mask, dest_idx)
ctx.s = s
ctx.e = e
ctx.c = c
ctx.h = h
ctx.fp16_flag = fp16_flag
ctx.dtype = dtype

return output

@staticmethod
@custom_bwd
def backward(ctx, tokens_grad):
expert_tokens, logits, mask, dest_idx = ctx.saved_tensors
if tokens_grad.dtype != torch.float32:
tokens_grad = tokens_grad.to(torch.float32)

cb_grad = (tokens_grad.to(torch.float32) if tokens_grad.dtype is torch.float16 else tokens_grad)
cb_input = expert_tokens.to(torch.float32) if ctx.fp16_flag else expert_tokens
d_expert, d_logits = MOE_KERNEL.combine_backward(ctx.s, ctx.e, ctx.c, ctx.h, cb_grad, cb_input, logits, mask,
dest_idx)
d_expert = d_expert.to(torch.float16) if ctx.fp16_flag else d_expert
d_expert, d_logits = MOE_KERNEL.combine_backward(ctx.s, ctx.e, ctx.c, ctx.h, tokens_grad, expert_tokens, logits,
mask, dest_idx)
if d_expert.dtype != ctx.dtype:
d_expert = d_expert.to(ctx.dtype)

return d_expert, d_logits, None, None, None

Expand Down
54 changes: 47 additions & 7 deletions examples/language/openmoe/benchmark/benchmark_cai.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import json
import os

import torch
Expand All @@ -7,7 +8,7 @@
from model.openmoe_policy import OpenMoeForCausalLMPolicy
from torch.utils.data import Dataset
from tqdm import tqdm
from transformers import Adafactor
from transformers import T5Tokenizer
from transformers.models.llama import LlamaConfig
from utils import PerformanceEvaluator, get_model_numel

Expand All @@ -17,6 +18,7 @@
from colossalai.booster.plugin import LowLevelZeroPlugin
from colossalai.booster.plugin.moe_hybrid_parallel_plugin import MoeHybridParallelPlugin
from colossalai.cluster import DistCoordinator
from colossalai.moe.layers import apply_load_balance
from colossalai.moe.manager import MOE_MANAGER
from colossalai.moe.utils import skip_init
from colossalai.utils import get_current_device
Expand All @@ -41,11 +43,36 @@ def load_ckpt(repo_name: str, model: OpenMoeForCausalLM, booster: Booster):

class RandomDataset(Dataset):

def __init__(self, num_samples: int = 1000, max_length: int = 2048, vocab_size: int = 256384):
def __init__(self,
num_samples: int = 1000,
max_length: int = 2048,
vocab_size: int = 256384,
tokenizer: T5Tokenizer = None):
self.num_samples = num_samples
self.max_length = max_length
self.input_ids = torch.randint(0, vocab_size, (num_samples, max_length), device=get_current_device())
self.attention_mask = torch.ones_like(self.input_ids)
if os.path.exists("./mock_data.json"):
self.input_ids = []
self.attention_mask = []
with open("./mock_data.json", 'r') as f:
data = json.load(f)
for v in data.values():
d = v["text"]
encode = tokenizer("<pad>" + d,
return_tensors="pt",
add_special_tokens=False,
max_length=max_length,
truncation=True,
padding="max_length")
self.input_ids.append(encode["input_ids"])
self.attention_mask.append(encode["attention_mask"])
self.input_ids = torch.cat(self.input_ids, dim=0).to(get_current_device())
self.attention_mask = torch.cat(self.attention_mask, dim=0).to(get_current_device())
repeat_times = num_samples // self.input_ids.shape[0] + 1
self.input_ids = self.input_ids.repeat(repeat_times, 1)[:num_samples]
self.attention_mask = self.attention_mask.repeat(repeat_times, 1)[:num_samples]
else:
self.input_ids = torch.randint(0, vocab_size, (num_samples, max_length), device=get_current_device())
self.attention_mask = torch.ones_like(self.input_ids)

def __len__(self):
return self.num_samples
Expand Down Expand Up @@ -103,6 +130,8 @@ def parse_args():
# bench
parser.add_argument("--warmup", type=int, default=20)
parser.add_argument("--active", type=int, default=20)
# load balance
parser.add_argument("--load_balance", action="store_true")
args = parser.parse_args()
return args

Expand All @@ -116,8 +145,14 @@ def main():

# Set plugin
booster_kwargs = {}
hybrid_dict = {"tp_size": 1, "custom_policy": OpenMoeForCausalLMPolicy(), "enable_fused_normalization": args.use_kernel, "enable_jit_fused": args.use_kernel}
mgr_dict = {"seed": 42, "use_kernel_optim": args.use_kernel}
hybrid_dict = {
"tp_size": 1,
"custom_policy": OpenMoeForCausalLMPolicy(),
"enable_fused_normalization": args.use_kernel,
"enable_jit_fused": args.use_kernel,
"precision": "bf16"
}
mgr_dict = {"seed": 42, "use_kernel_optim": args.use_kernel, "enable_load_balance": args.load_balance}
if args.plugin == "zero":
dp_size = dist.get_world_size()
plugin = LowLevelZeroPlugin(initial_scale=2**5, stage=2)
Expand Down Expand Up @@ -203,14 +238,16 @@ def main():
model.gradient_checkpointing_enable()

# Prepare tokenizer and dataloader
tokenizer = T5Tokenizer.from_pretrained("google/umt5-small")
dataset = RandomDataset(
num_samples=args.batch_size * (args.warmup + args.active + 1) * dp_size,
max_length=args.seq_length,
tokenizer=tokenizer,
)
dataloader = plugin.prepare_dataloader(dataset, batch_size=args.batch_size)

# Set optimizer
optimizer = Adafactor(model.parameters(), weight_decay=0.01)
optimizer = torch.optim.Adam(model.parameters(), weight_decay=0.01, lr=1e-5)

model_numel = get_model_numel(model)
performance_evaluator = PerformanceEvaluator(
Expand Down Expand Up @@ -264,6 +301,9 @@ def main():
optimizer.step()
optimizer.zero_grad()
performance_evaluator.on_step_end(exmaple_data["input_ids"])
if (step == args.warmup // 2) and args.load_balance:
Comment thread
ver217 marked this conversation as resolved.
apply_load_balance(model, optimizer)
coordinator.print_on_master(f"Apply load balance")
performance_evaluator.on_fit_end()


Expand Down