Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 9 additions & 8 deletions colossalai/moe/checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,14 +48,15 @@ def pre_load_model(self, model: nn.Module, state_dict: dict) -> dict:
"""
for name, param in state_dict.items():
if ".experts." in name:
model_param = dict(model.named_parameters())[name]
if is_moe_tensor(model_param):
ep_rank = get_ep_rank(model_param)
ep_size = get_ep_size(model_param)
expert_num = param.shape[0] // ep_size
assert param.shape[0] % ep_size == 0
param = param[ep_rank * expert_num:(ep_rank + 1) * expert_num]
state_dict[name] = param
if name in dict(model.named_parameters()):
model_param = dict(model.named_parameters())[name]
if is_moe_tensor(model_param):
ep_rank = get_ep_rank(model_param)
ep_size = get_ep_size(model_param)
expert_num = param.shape[0] // ep_size
assert param.shape[0] % ep_size == 0
param = param[ep_rank * expert_num:(ep_rank + 1) * expert_num]
state_dict[name] = param
dist.barrier()
return state_dict

Expand Down
32 changes: 8 additions & 24 deletions examples/language/openmoe/benchmark/benchmark_cai.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import torch
import torch.distributed as dist
from huggingface_hub import snapshot_download
from model.modeling_openmoe import OpenMoeForCausalLM
from model.modeling_openmoe import OpenMoeForCausalLM, set_openmoe_args
from model.openmoe_policy import OpenMoeForCausalLMPolicy
from torch.utils.data import Dataset
from tqdm import tqdm
Expand All @@ -19,7 +19,7 @@
from colossalai.cluster import DistCoordinator
from colossalai.moe.layers import apply_load_balance
from colossalai.moe.manager import MOE_MANAGER
from colossalai.moe.utils import set_moe_args, skip_init
from colossalai.moe.utils import skip_init
from colossalai.nn.optimizer import HybridAdam
from colossalai.utils import get_current_device

Expand Down Expand Up @@ -218,28 +218,12 @@ def main():
# Build OpenMoe model
repo_name = "hpcaitech/openmoe-" + args.model_name
config = LlamaConfig.from_pretrained(repo_name)
moe_args = {
"num_experts": config.num_experts,
"moe_layer_interval": config.moe_layer_interval,
"router_topk": 2,
"router_capacity_factor_train": 1.25,
"router_capacity_factor_eval": 2.0,
"router_min_capacity": 4,
"router_noisy_policy": None,
"router_drop_tks": True,
"router_aux_loss_factor": 0.01,
"router_z_loss_factor": 0.01,
"mlp_gated": True,
"label_smoothing": 0.001,
"z_loss_factor": 0.01,
"enable_load_balance": args.load_balance,
"load_balance_tolerance": 0.1,
"load_balance_beam_width": 8,
"load_balance_group_swap_factor": 0.4,
"enable_kernel": args.use_kernel,
"enable_comm_overlap": args.overlap_alltoall,
}
set_moe_args(config, moe_args)
set_openmoe_args(config,
num_experts=config.num_experts,
moe_layer_interval=config.moe_layer_interval,
enable_load_balance=args.load_balance,
enable_kernel=args.use_kernel,
enable_comm_overlap=args.overlap_alltoall)
with skip_init():
model = OpenMoeForCausalLM(config)
coordinator.print_on_master(f"Finish init model with config:\n{config}")
Expand Down
61 changes: 11 additions & 50 deletions examples/language/openmoe/infer.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,10 @@
from argparse import ArgumentParser

import torch
from model.modeling_openmoe import OpenMoeForCausalLM
from model.modeling_openmoe import OpenMoeForCausalLM, set_openmoe_args
from transformers import T5Tokenizer
from transformers.models.llama import LlamaConfig

from colossalai.moe.utils import set_moe_args


def parse_args():
parser = ArgumentParser()
Expand All @@ -15,59 +13,22 @@ def parse_args():


def inference(args):

tokenizer = T5Tokenizer.from_pretrained("google/umt5-small")
if args.model == "test":
config = LlamaConfig.from_pretrained("hpcaitech/openmoe-base")
moe_args = {
"num_experts": config.num_experts,
"moe_layer_interval": config.moe_layer_interval,
"router_topk": 2,
"router_capacity_factor_train": 1.25,
"router_capacity_factor_eval": 2.0,
"router_min_capacity": 4,
"router_noisy_policy": None,
"router_drop_tks": True,
"router_aux_loss_factor": 0.01,
"router_z_loss_factor": 0.01,
"mlp_gated": True,
"label_smoothing": 0.001,
"z_loss_factor": 0.01,
"enable_load_balance": False,
"load_balance_tolerance": 0.1,
"load_balance_beam_width": 8,
"load_balance_group_swap_factor": 0.4,
"enable_kernel": False,
"enable_comm_overlap": False,
}
set_moe_args(config, moe_args)
set_openmoe_args(config,
num_experts=config.num_experts,
moe_layer_interval=config.moe_layer_interval,
enable_kernel=True)
model = OpenMoeForCausalLM(config)
else:
config = LlamaConfig.from_pretrained(f"hpcaitech/openmoe-{args.model}")
moe_args = {
"num_experts": config.num_experts,
"moe_layer_interval": config.moe_layer_interval,
"router_topk": 2,
"router_capacity_factor_train": 1.25,
"router_capacity_factor_eval": 2.0,
"router_min_capacity": 4,
"router_noisy_policy": None,
"router_drop_tks": True,
"router_aux_loss_factor": 0.01,
"router_z_loss_factor": 0.01,
"mlp_gated": True,
"label_smoothing": 0.001,
"z_loss_factor": 0.01,
"enable_load_balance": False,
"load_balance_tolerance": 0.1,
"load_balance_beam_width": 8,
"load_balance_group_swap_factor": 0.4,
"enable_kernel": False,
"enable_comm_overlap": False,
}
set_moe_args(config, moe_args)
set_openmoe_args(config,
num_experts=config.num_experts,
moe_layer_interval=config.moe_layer_interval,
enable_kernel=False)
model = OpenMoeForCausalLM.from_pretrained(f"hpcaitech/openmoe-{args.model}", config=config)
model = model.eval().half()
model = model.eval().bfloat16()
model = model.to(torch.cuda.current_device())

input_str = """```
Expand All @@ -86,7 +47,7 @@ def inference(args):
# print("model config: ", model.config)
input_ids = tokenizer("<pad>" + input_str, return_tensors="pt", add_special_tokens=False)
input_ids = input_ids.input_ids.to(torch.cuda.current_device())
generation_output = model.generate(input_ids, use_cache=True, do_sample=True, max_new_tokens=16)
generation_output = model.generate(input_ids, use_cache=True, do_sample=True, max_new_tokens=64)
out = tokenizer.decode(generation_output[0], skip_special_tokens=False)
print(f"output: \n{out}\n")

Expand Down
78 changes: 75 additions & 3 deletions examples/language/openmoe/model/modeling_openmoe.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@
from colossalai.kernel.triton.llama_act_combine_kernel import HAS_TRITON
from colossalai.moe.layers import SparseMLP
from colossalai.moe.manager import MOE_MANAGER
from colossalai.moe.utils import get_activation
from colossalai.moe.utils import get_activation, set_moe_args

if HAS_TRITON:
from colossalai.kernel.triton.llama_act_combine_kernel import LlamaActCombine
Expand All @@ -49,6 +49,78 @@
_CONFIG_FOR_DOC = "LlamaConfig"


def set_openmoe_args(
config: LlamaConfig,
num_experts: int,
moe_layer_interval: int,
router_topk: int = 2,
router_capacity_factor_train: float = 1.25,
router_capacity_factor_eval: float = 2.0,
router_min_capacity: int = 4,
router_noisy_policy: str = None,
router_drop_tks: bool = True,
router_aux_loss_factor: float = 0.01,
router_z_loss_factor: float = 0.01,
mlp_gated: bool = True,
label_smoothing: float = 0.001,
z_loss_factor: float = 0.01,
enable_load_balance: bool = False,
load_balance_tolerance: float = 0.1,
load_balance_beam_width: int = 8,
load_balance_group_swap_factor: float = 0.4,
enable_kernel: bool = False,
enable_comm_overlap: bool = False,
) -> None:
"""
MoE related arguments.
It inserts the MoE arguments into the Llama config.

Args:
config (LlamaConfig): Transformers Llama config.
num_experts (int, optional): Number of experts.
moe_layer_interval (int, optional): The interval moe layer.
router_topk (int, optional): Moe router top k. Defaults to 2.
router_capacity_factor_train (float, optional): Moe router max capacity for train. Defaults to 1.25.
router_capacity_factor_eval (float, optional): Moe router max capacity for eval. Defaults to 2.0.
router_min_capacity (int, optional): Moe router min capacity. Defaults to 4.
router_noisy_policy (str, optional): Moe router noisy policy. You can choose [Jitter, Gaussian, None]. Defaults to None.
router_drop_tks (bool, optional): Whether moe router drop tokens which exceed max capacity. Defaults to True.
router_aux_loss_factor (float, optional): Moe router aux loss. You can refer to STMoE for details. Defaults to 0.01.
router_z_loss_factor (float, optional): Moe router z loss. You can refer to STMoE for details. Defaults to 0.01.
mlp_gated (bool, optional): Use gate in mlp. Defaults to True.
label_smoothing (float, optional): Label smoothing. Defaults to 0.001.
z_loss_factor (float, optional): The final outputs' classification z loss factor. Defaults to 0.01.
enable_load_balance (bool, optional): Expert load balance. Defaults to False.
load_balance_tolerance (float, optional): Expert load balance search's difference tolerance. Defaults to 0.1.
load_balance_beam_width (int, optional): Expert load balance search's beam width. Defaults to 8.
load_balance_group_swap_factor (float, optional): Expert load balance group swap factor. Longer value encourages less swap. Defaults to 0.4.
enable_kernel (bool, optional): Use kernel optimization. Defaults to False.
enable_comm_overlap (bool, optional): Use communication overlap for MoE. Recommended to enable for muiti-node training. Defaults to False.
"""
moe_args = dict(
num_experts=num_experts,
moe_layer_interval=moe_layer_interval,
router_topk=router_topk,
router_capacity_factor_train=router_capacity_factor_train,
router_capacity_factor_eval=router_capacity_factor_eval,
router_min_capacity=router_min_capacity,
router_noisy_policy=router_noisy_policy,
router_drop_tks=router_drop_tks,
router_aux_loss_factor=router_aux_loss_factor,
router_z_loss_factor=router_z_loss_factor,
mlp_gated=mlp_gated,
label_smoothing=label_smoothing,
z_loss_factor=z_loss_factor,
enable_load_balance=enable_load_balance,
load_balance_tolerance=load_balance_tolerance,
load_balance_beam_width=load_balance_beam_width,
load_balance_group_swap_factor=load_balance_group_swap_factor,
enable_kernel=enable_kernel,
enable_comm_overlap=enable_comm_overlap,
)
set_moe_args(config, moe_args)


# Copied from transformers.models.bart.modeling_bart._make_causal_mask
def _make_causal_mask(input_ids_shape: torch.Size,
dtype: torch.dtype,
Expand Down Expand Up @@ -96,7 +168,7 @@ def generate_fixed_pos_embedding(features, length, min_timescale=1.0, max_timesc
output_sin: a float32 Tensor with shape [length, features]
output_cos: a float32 Tensor with shape [length, features]
"""
fraction = torch.arange(0, features, 2, dtype=torch.float64).cuda() / features
fraction = torch.arange(0, features, 2, dtype=torch.float32).cuda() / features
timescale = min_timescale * (max_timescale / min_timescale)**fraction
rotational_frequency = 1. / timescale

Expand Down Expand Up @@ -231,7 +303,7 @@ def __init__(self, config: LlamaConfig):
self.k_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=False)
self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=False)
self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=False)
self.sin, self.cos = generate_fixed_pos_embedding(self.head_dim, self.max_position_embeddings, 1e4)
self.sin, self.cos = generate_fixed_pos_embedding(self.head_dim, self.max_position_embeddings, 1.0, 1e4)

def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int):
return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous()
Expand Down
8 changes: 7 additions & 1 deletion examples/language/openmoe/test_ci.sh
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,13 @@ python infer.py --model "test"
torchrun --standalone --nproc_per_node 4 train.py \
--num_epoch 1 \
--model_name "test" \
--plugin zero2_ep \
--plugin "ep" \
--batch_size 1

torchrun --standalone --nproc_per_node 4 train.py \
--num_epoch 1 \
--model_name "test" \
--plugin "ep_zero" \
--batch_size 1

torchrun --standalone --nproc_per_node 4 train.py \
Expand Down
Loading