Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
32 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
39 changes: 25 additions & 14 deletions colossalai/moe/_operation.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,18 +5,22 @@
from torch import Tensor
from torch.distributed import ProcessGroup

try:
from colossalai._C import moe
except:
from colossalai.moe.manager import MOE_MANAGER

MOE_KERNEL = None


def load_moe():
global MOE_KERNEL
from colossalai.kernel.op_builder import MOEBuilder
moe = MOEBuilder().load()

MOE_KERNEL = MOEBuilder().load()


class AllGather(torch.autograd.Function):

@staticmethod
def forward(ctx: Any, inputs: Tensor, group: Optional[ProcessGroup] = None) -> Tensor:

if ctx is not None:
ctx.comm_grp = group

Expand Down Expand Up @@ -89,7 +93,10 @@ def forward(ctx, tokens, mask, dest_idx, ec):
s = tokens.size(0)
h = tokens.size(1)

expert_input = moe.dispatch_forward(s, ec, h, tokens, mask, dest_idx)
if MOE_KERNEL is None:
load_moe()

expert_input = MOE_KERNEL.dispatch_forward(s, ec, h, tokens, mask, dest_idx)

ctx.save_for_backward(mask, dest_idx)
ctx.s = s
Expand All @@ -101,7 +108,7 @@ def forward(ctx, tokens, mask, dest_idx, ec):
@staticmethod
def backward(ctx, output_grad):
mask, dest_idx = ctx.saved_tensors
d_tokens = moe.dispatch_backward(ctx.s, ctx.ec, ctx.h, output_grad, mask, dest_idx)
d_tokens = MOE_KERNEL.dispatch_backward(ctx.s, ctx.ec, ctx.h, output_grad, mask, dest_idx)
return d_tokens, None, None, None


Expand All @@ -116,9 +123,11 @@ def forward(ctx, expert_tokens, logits, mask, dest_idx, ec):
c = ec // e
h = expert_tokens.size(-1)

fp16_flag = (expert_tokens.dtype == torch.float16)
fp16_flag = expert_tokens.dtype == torch.float16
cb_input = expert_tokens.to(torch.float32) if fp16_flag else expert_tokens
ctokens = moe.combine_forward(s, e, c, h, cb_input, logits, mask, dest_idx)
if MOE_KERNEL is None:
load_moe()
ctokens = MOE_KERNEL.combine_forward(s, e, c, h, cb_input, logits, mask, dest_idx)
output = ctokens.to(torch.float16) if fp16_flag else ctokens

ctx.save_for_backward(expert_tokens, logits, mask, dest_idx)
Expand All @@ -134,10 +143,10 @@ def forward(ctx, expert_tokens, logits, mask, dest_idx, ec):
def backward(ctx, tokens_grad):
expert_tokens, logits, mask, dest_idx = ctx.saved_tensors

cb_grad = tokens_grad.to(torch.float32) if tokens_grad.dtype is torch.float16 \
else tokens_grad
cb_grad = (tokens_grad.to(torch.float32) if tokens_grad.dtype is torch.float16 else tokens_grad)
cb_input = expert_tokens.to(torch.float32) if ctx.fp16_flag else expert_tokens
d_expert, d_logits = moe.combine_backward(ctx.s, ctx.e, ctx.c, ctx.h, cb_grad, cb_input, logits, mask, dest_idx)
d_expert, d_logits = MOE_KERNEL.combine_backward(ctx.s, ctx.e, ctx.c, ctx.h, cb_grad, cb_input, logits, mask,
dest_idx)
d_expert = d_expert.to(torch.float16) if ctx.fp16_flag else d_expert

return d_expert, d_logits, None, None, None
Expand All @@ -146,8 +155,10 @@ def backward(ctx, tokens_grad):
def moe_cumsum(inputs: Tensor):
dim0 = inputs.size(0)
flag = (dim0 <= 1024) or (dim0 <= 2048 and dim0 % 2 == 0) or (dim0 % 4 == 0)
if flag:
return moe.cumsum_sub_one(inputs)
if flag and MOE_MANAGER.use_kernel_optim:
if MOE_KERNEL is None:
load_moe()
return MOE_KERNEL.cumsum_sub_one(inputs)
else:
return torch.cumsum(inputs, dim=0) - 1

Expand Down
2 changes: 1 addition & 1 deletion colossalai/moe/layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ def __init__(self,
super().__init__()
self.hidden_size = hidden_size
self.num_experts = num_experts
self.use_kernel = True if MOE_MANAGER.use_kernel_optim else False
self.use_kernel = MOE_MANAGER.use_kernel_optim
self.expert_parallel = expert_parallel
assert expert_parallel in ["EP", "TP", None], f"Unsupported expert parallel type {expert_parallel}"

Expand Down
1 change: 0 additions & 1 deletion examples/language/openmoe/benchmark/benchmark_fsdp.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,6 @@ def fsdp_main(rank, world_size, args):
setattr(config, "label_smoothing", 0.1)
setattr(config, "z_loss_factor", 0.1)
model = OpenMoeForCausalLM(config).to(rank)
# 使用FSDP将model warp起来
model = FSDP(
model,
mixed_precision=MixedPrecision(
Expand Down
2 changes: 1 addition & 1 deletion examples/language/openmoe/model/modeling_openmoe.py
Original file line number Diff line number Diff line change
Expand Up @@ -776,7 +776,7 @@ def forward(
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
chunk_head: Optional[bool] = None,
chunk_head: Optional[bool] = True,
) -> Union[Tuple, CausalLMOutputWithPast]:
r"""
Args:
Expand Down
11 changes: 11 additions & 0 deletions examples/language/openmoe/test_ci.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
pip install -r requirements.txt

# inference
python infer.py --model "test"

# train
torchrun --standalone --nproc_per_node 4 train.py \
--num_epoch 1 \
--model_name "test" \
--plugin zero2 \
--batch_size 1
26 changes: 17 additions & 9 deletions examples/language/openmoe/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ def parse_args():
"--model_name",
type=str,
default="base",
choices=["base", "8b"],
choices=["base", "8b", "test"],
help="Path to pretrained model or model identifier from huggingface.co/models.",
)
parser.add_argument(
Expand Down Expand Up @@ -132,6 +132,7 @@ def main():
# Launch ColossalAI
colossalai.launch_from_torch(config={}, seed=args.seed)
coordinator = DistCoordinator()
test_mode = args.model_name == "test"

# Manage loggers
disable_existing_loggers()
Expand All @@ -150,14 +151,14 @@ def main():
MOE_MANAGER.setup(
seed=42,
parallel="EP",
use_kernel_optim=args.use_kernel,
use_kernel_optim=args.use_kernel if not test_mode else False,
)
elif args.plugin == "zero2":
plugin = LowLevelZeroPlugin(initial_scale=2**5, stage=2)
MOE_MANAGER.setup(
seed=42,
parallel="EP",
use_kernel_optim=args.use_kernel,
use_kernel_optim=args.use_kernel if not test_mode else False,
)
elif args.plugin == "hybrid":
plugin = MoeHybridParallelPlugin(
Expand All @@ -166,8 +167,8 @@ def main():
zero_stage=args.zero_stage,
microbatch_size=args.microbatch_size,
custom_policy=OpenMoeForCausalLMPolicy(),
enable_fused_normalization=args.use_kernel,
enable_jit_fused=args.use_kernel,
enable_fused_normalization=args.use_kernel if not test_mode else False,
enable_jit_fused=args.use_kernel if not test_mode else False,
)
MOE_MANAGER.setup(
seed=42,
Expand All @@ -183,23 +184,30 @@ def main():
logger.info(f"Set plugin as {plugin}", ranks=[0])

# Build OpenMoe model
repo_name = "hpcaitech/openmoe-" + args.model_name
config = LlamaConfig.from_pretrained(repo_name)
if test_mode:
config = LlamaConfig.from_pretrained("hpcaitech/openmoe-base")
config.hidden_size = 64
config.intermediate_size = 128
config.vocab_size = 32000
else:
repo_name = "hpcaitech/openmoe-" + args.model_name
config = LlamaConfig.from_pretrained(repo_name)
setattr(config, "router_aux_loss_factor", args.router_aux_loss_factor)
setattr(config, "router_z_loss_factor", args.router_z_loss_factor)
setattr(config, "label_smoothing", args.label_smoothing)
setattr(config, "z_loss_factor", args.z_loss_factor)
with skip_init():
model = OpenMoeForCausalLM(config)
load_ckpt(repo_name, model)
if not test_mode:
load_ckpt(repo_name, model)
logger.info(f"Finish init model with config:\n{config}", ranks=[0])

# Enable gradient checkpointing
model.gradient_checkpointing_enable()

# Prepare tokenizer and dataloader
tokenizer = T5Tokenizer.from_pretrained("google/umt5-small")
dataset = RandomDataset(num_samples=1000)
dataset = RandomDataset(num_samples=1000 if not test_mode else 20)
dataloader = plugin.prepare_dataloader(dataset, batch_size=args.batch_size, shuffle=True, drop_last=True)

# Set optimizer
Expand Down