diff --git a/colossalai/moe/_operation.py b/colossalai/moe/_operation.py index bde457947e3f..f483f2d85989 100644 --- a/colossalai/moe/_operation.py +++ b/colossalai/moe/_operation.py @@ -5,18 +5,22 @@ from torch import Tensor from torch.distributed import ProcessGroup -try: - from colossalai._C import moe -except: +from colossalai.moe.manager import MOE_MANAGER + +MOE_KERNEL = None + + +def load_moe(): + global MOE_KERNEL from colossalai.kernel.op_builder import MOEBuilder - moe = MOEBuilder().load() + + MOE_KERNEL = MOEBuilder().load() class AllGather(torch.autograd.Function): @staticmethod def forward(ctx: Any, inputs: Tensor, group: Optional[ProcessGroup] = None) -> Tensor: - if ctx is not None: ctx.comm_grp = group @@ -89,7 +93,10 @@ def forward(ctx, tokens, mask, dest_idx, ec): s = tokens.size(0) h = tokens.size(1) - expert_input = moe.dispatch_forward(s, ec, h, tokens, mask, dest_idx) + if MOE_KERNEL is None: + load_moe() + + expert_input = MOE_KERNEL.dispatch_forward(s, ec, h, tokens, mask, dest_idx) ctx.save_for_backward(mask, dest_idx) ctx.s = s @@ -101,7 +108,7 @@ def forward(ctx, tokens, mask, dest_idx, ec): @staticmethod def backward(ctx, output_grad): mask, dest_idx = ctx.saved_tensors - d_tokens = moe.dispatch_backward(ctx.s, ctx.ec, ctx.h, output_grad, mask, dest_idx) + d_tokens = MOE_KERNEL.dispatch_backward(ctx.s, ctx.ec, ctx.h, output_grad, mask, dest_idx) return d_tokens, None, None, None @@ -116,9 +123,11 @@ def forward(ctx, expert_tokens, logits, mask, dest_idx, ec): c = ec // e h = expert_tokens.size(-1) - fp16_flag = (expert_tokens.dtype == torch.float16) + fp16_flag = expert_tokens.dtype == torch.float16 cb_input = expert_tokens.to(torch.float32) if fp16_flag else expert_tokens - ctokens = moe.combine_forward(s, e, c, h, cb_input, logits, mask, dest_idx) + if MOE_KERNEL is None: + load_moe() + ctokens = MOE_KERNEL.combine_forward(s, e, c, h, cb_input, logits, mask, dest_idx) output = ctokens.to(torch.float16) if fp16_flag else ctokens ctx.save_for_backward(expert_tokens, logits, mask, dest_idx) @@ -134,10 +143,10 @@ def forward(ctx, expert_tokens, logits, mask, dest_idx, ec): def backward(ctx, tokens_grad): expert_tokens, logits, mask, dest_idx = ctx.saved_tensors - cb_grad = tokens_grad.to(torch.float32) if tokens_grad.dtype is torch.float16 \ - else tokens_grad + cb_grad = (tokens_grad.to(torch.float32) if tokens_grad.dtype is torch.float16 else tokens_grad) cb_input = expert_tokens.to(torch.float32) if ctx.fp16_flag else expert_tokens - d_expert, d_logits = moe.combine_backward(ctx.s, ctx.e, ctx.c, ctx.h, cb_grad, cb_input, logits, mask, dest_idx) + d_expert, d_logits = MOE_KERNEL.combine_backward(ctx.s, ctx.e, ctx.c, ctx.h, cb_grad, cb_input, logits, mask, + dest_idx) d_expert = d_expert.to(torch.float16) if ctx.fp16_flag else d_expert return d_expert, d_logits, None, None, None @@ -146,8 +155,10 @@ def backward(ctx, tokens_grad): def moe_cumsum(inputs: Tensor): dim0 = inputs.size(0) flag = (dim0 <= 1024) or (dim0 <= 2048 and dim0 % 2 == 0) or (dim0 % 4 == 0) - if flag: - return moe.cumsum_sub_one(inputs) + if flag and MOE_MANAGER.use_kernel_optim: + if MOE_KERNEL is None: + load_moe() + return MOE_KERNEL.cumsum_sub_one(inputs) else: return torch.cumsum(inputs, dim=0) - 1 diff --git a/colossalai/moe/layers.py b/colossalai/moe/layers.py index 1255a4816041..a78bfe0a3d74 100644 --- a/colossalai/moe/layers.py +++ b/colossalai/moe/layers.py @@ -58,7 +58,7 @@ def __init__(self, super().__init__() self.hidden_size = hidden_size self.num_experts = num_experts - self.use_kernel = True if MOE_MANAGER.use_kernel_optim else False + self.use_kernel = MOE_MANAGER.use_kernel_optim self.expert_parallel = expert_parallel assert expert_parallel in ["EP", "TP", None], f"Unsupported expert parallel type {expert_parallel}" diff --git a/examples/language/openmoe/benchmark/benchmark_fsdp.py b/examples/language/openmoe/benchmark/benchmark_fsdp.py index cb231687ef39..c7357c06e5c7 100644 --- a/examples/language/openmoe/benchmark/benchmark_fsdp.py +++ b/examples/language/openmoe/benchmark/benchmark_fsdp.py @@ -58,7 +58,6 @@ def fsdp_main(rank, world_size, args): setattr(config, "label_smoothing", 0.1) setattr(config, "z_loss_factor", 0.1) model = OpenMoeForCausalLM(config).to(rank) - # 使用FSDP将model warp起来 model = FSDP( model, mixed_precision=MixedPrecision( diff --git a/examples/language/openmoe/model/modeling_openmoe.py b/examples/language/openmoe/model/modeling_openmoe.py index 4d5ff19936b6..6933f108a09e 100644 --- a/examples/language/openmoe/model/modeling_openmoe.py +++ b/examples/language/openmoe/model/modeling_openmoe.py @@ -776,7 +776,7 @@ def forward( output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, - chunk_head: Optional[bool] = None, + chunk_head: Optional[bool] = True, ) -> Union[Tuple, CausalLMOutputWithPast]: r""" Args: diff --git a/examples/language/openmoe/test_ci.sh b/examples/language/openmoe/test_ci.sh index e69de29bb2d1..86742e088f71 100644 --- a/examples/language/openmoe/test_ci.sh +++ b/examples/language/openmoe/test_ci.sh @@ -0,0 +1,11 @@ +pip install -r requirements.txt + +# inference +python infer.py --model "test" + +# train +torchrun --standalone --nproc_per_node 4 train.py \ + --num_epoch 1 \ + --model_name "test" \ + --plugin zero2 \ + --batch_size 1 diff --git a/examples/language/openmoe/train.py b/examples/language/openmoe/train.py index e276759043a9..a7f46f2f693b 100644 --- a/examples/language/openmoe/train.py +++ b/examples/language/openmoe/train.py @@ -67,7 +67,7 @@ def parse_args(): "--model_name", type=str, default="base", - choices=["base", "8b"], + choices=["base", "8b", "test"], help="Path to pretrained model or model identifier from huggingface.co/models.", ) parser.add_argument( @@ -132,6 +132,7 @@ def main(): # Launch ColossalAI colossalai.launch_from_torch(config={}, seed=args.seed) coordinator = DistCoordinator() + test_mode = args.model_name == "test" # Manage loggers disable_existing_loggers() @@ -150,14 +151,14 @@ def main(): MOE_MANAGER.setup( seed=42, parallel="EP", - use_kernel_optim=args.use_kernel, + use_kernel_optim=args.use_kernel if not test_mode else False, ) elif args.plugin == "zero2": plugin = LowLevelZeroPlugin(initial_scale=2**5, stage=2) MOE_MANAGER.setup( seed=42, parallel="EP", - use_kernel_optim=args.use_kernel, + use_kernel_optim=args.use_kernel if not test_mode else False, ) elif args.plugin == "hybrid": plugin = MoeHybridParallelPlugin( @@ -166,8 +167,8 @@ def main(): zero_stage=args.zero_stage, microbatch_size=args.microbatch_size, custom_policy=OpenMoeForCausalLMPolicy(), - enable_fused_normalization=args.use_kernel, - enable_jit_fused=args.use_kernel, + enable_fused_normalization=args.use_kernel if not test_mode else False, + enable_jit_fused=args.use_kernel if not test_mode else False, ) MOE_MANAGER.setup( seed=42, @@ -183,15 +184,22 @@ def main(): logger.info(f"Set plugin as {plugin}", ranks=[0]) # Build OpenMoe model - repo_name = "hpcaitech/openmoe-" + args.model_name - config = LlamaConfig.from_pretrained(repo_name) + if test_mode: + config = LlamaConfig.from_pretrained("hpcaitech/openmoe-base") + config.hidden_size = 64 + config.intermediate_size = 128 + config.vocab_size = 32000 + else: + repo_name = "hpcaitech/openmoe-" + args.model_name + config = LlamaConfig.from_pretrained(repo_name) setattr(config, "router_aux_loss_factor", args.router_aux_loss_factor) setattr(config, "router_z_loss_factor", args.router_z_loss_factor) setattr(config, "label_smoothing", args.label_smoothing) setattr(config, "z_loss_factor", args.z_loss_factor) with skip_init(): model = OpenMoeForCausalLM(config) - load_ckpt(repo_name, model) + if not test_mode: + load_ckpt(repo_name, model) logger.info(f"Finish init model with config:\n{config}", ranks=[0]) # Enable gradient checkpointing @@ -199,7 +207,7 @@ def main(): # Prepare tokenizer and dataloader tokenizer = T5Tokenizer.from_pretrained("google/umt5-small") - dataset = RandomDataset(num_samples=1000) + dataset = RandomDataset(num_samples=1000 if not test_mode else 20) dataloader = plugin.prepare_dataloader(dataset, batch_size=args.batch_size, shuffle=True, drop_last=True) # Set optimizer