diff --git a/colossalai/context/moe_context.py b/colossalai/context/moe_context.py index 5622bd271735..b413e12e3650 100644 --- a/colossalai/context/moe_context.py +++ b/colossalai/context/moe_context.py @@ -21,6 +21,7 @@ def __init__(self): self.max_ep_size = None self.min_dp_size = None self.aux_loss = None + self.parallel = None self.use_kernel_optim = True self.has_setup = False @@ -34,13 +35,14 @@ def parallel_info_dict(self): def is_initialized(self): return self.has_setup - def setup(self, seed: int, use_kernel_optim: bool = True, max_ep_size: int = 8): + def setup(self, seed: int, use_kernel_optim: bool = True, max_ep_size: int = 8, parallel: bool = None): assert not self.is_initialized, "MoE distributed context shouldn't be set up again" assert torch.cuda.is_available(), "MoE requires to enable CUDA first" self.world_size = dist.get_world_size() self.max_ep_size = min(max_ep_size, dist.get_world_size()) self.min_dp_size = self.world_size // self.max_ep_size + self.parallel = parallel # Enabling kernel optimization may raise error in some cases # Users can close kernel optimization manually @@ -103,5 +105,8 @@ def add_loss(self, loss): def get_loss(self): return self.aux_loss + def get_parallel(self): + return self.parallel + MOE_CONTEXT = MoeContext() diff --git a/colossalai/nn/layer/moe/checkpoint.py b/colossalai/nn/layer/moe/checkpoint.py index 34af87bd9d47..3cda5a7f044c 100644 --- a/colossalai/nn/layer/moe/checkpoint.py +++ b/colossalai/nn/layer/moe/checkpoint.py @@ -1,3 +1,4 @@ +from copy import deepcopy from pathlib import Path from typing import Optional @@ -6,32 +7,47 @@ import torch.nn as nn from torch.optim import Optimizer -from colossalai.checkpoint_io import CheckpointIO -from colossalai.tensor.moe_tensor.api import get_ep_group +from colossalai.checkpoint_io import GeneralCheckpointIO +from colossalai.tensor.moe_tensor.api import get_dp_rank, get_ep_group, get_ep_rank, get_ep_size, is_moe_tensor -class MoeCheckpintIO(CheckpointIO): +class MoeCheckpintIO(GeneralCheckpointIO): def __init__(self) -> None: super().__init__() def load_unsharded_model(self, model: nn.Module, checkpoint: str, strict: bool): state_dict = torch.load(checkpoint) - for name, param in model.named_parameters(): + for name, param in state_dict.items(): if '.experts.' in name: - ep_rank = dist.get_rank(get_ep_group(param)) - ep_size = dist.get_world_size(get_ep_group(param)) - for rank in range(ep_size): - new_name = name.replace('.experts.', f'.experts.{rank}.') - if rank == ep_rank: - state_dict[name] = state_dict.pop(new_name) - else: - state_dict.pop(new_name) + model_param = dict(model.named_parameters())[name] + if is_moe_tensor(model_param): + ep_rank = get_ep_rank(model_param) + ep_size = get_ep_size(model_param) + expert_num = param.shape[0] // ep_size + assert param.shape[0] % ep_size == 0 + param = param[ep_rank * expert_num:(ep_rank + 1) * expert_num] + state_dict[name] = param model.load_state_dict(state_dict, strict=strict) def save_unsharded_model(self, model: nn.Module, checkpoint: str, gather_dtensor: bool, use_safetensors: bool): state_dict = model.state_dict() + for name, param in model.named_parameters(): + if '.experts.' in name and is_moe_tensor(param): + ep_group = get_ep_group(param) + ep_rank = get_ep_rank(param) + ep_size = get_ep_size(param) + dp_rank = get_dp_rank(param) + if dp_rank == 0: + param = param.data.cuda() + all_param = [deepcopy(param) for _ in range(ep_size)] + # gather param from every ep rank + dist.all_gather(all_param, param, group=ep_group) + if ep_rank == 0: + assert dist.get_rank() == 0 + all_param = torch.cat(all_param, dim=0) + state_dict[name] = all_param.cpu() if dist.get_rank() == 0: torch.save(state_dict, checkpoint) dist.barrier() diff --git a/colossalai/nn/layer/moe/experts.py b/colossalai/nn/layer/moe/experts.py index 9a51ec2a5c7e..fd93bed97992 100644 --- a/colossalai/nn/layer/moe/experts.py +++ b/colossalai/nn/layer/moe/experts.py @@ -59,7 +59,11 @@ def __init__( if expert_parallel is not None: with seed(ParallelMode.TENSOR): - nn.init.trunc_normal_(self.wi, std=math.sqrt(0.1 / hidden_size)) + if gated: + nn.init.trunc_normal_(self.wi_gate, std=math.sqrt(0.1 / hidden_size)) + nn.init.trunc_normal_(self.wi_up, std=math.sqrt(0.1 / hidden_size)) + else: + nn.init.trunc_normal_(self.wi, std=math.sqrt(0.1 / hidden_size)) nn.init.trunc_normal_(self.wo, std=math.sqrt(0.1 / intermediate_size)) self.act = get_activation(activation) @@ -110,32 +114,6 @@ def __init__(self, gated: bool = False): super().__init__(num_experts, hidden_size, intermediate_size, "EP", activation, drop_rate, gated) - def state_dict(self, destination=None, prefix='', keep_vars=False): - dp_rank = dist.get_rank(get_dp_group(self)) - ep_rank = dist.get_rank(get_ep_group(self)) - ep_size = get_ep_size(self) - # dp rank 0 will save the state dict - if dp_rank == 0: - for name, param in self.named_parameters(): - if param is self: - continue - # create buffer - buffer_module = deepcopy(param) - # gather param from every ep rank - for source_rank in range(ep_size): - current_prefix = f"{prefix}{source_rank}." - if ep_rank == source_rank: - dist.broadcast(param.data, src=source_rank, group=self.moe_info.ep_group) - else: - dist.broadcast(buffer_module.data, src=source_rank, group=self.moe_info.ep_group) - if ep_rank == 0: - if keep_vars: - destination[current_prefix + name] = buffer_module.cpu() - else: - destination[current_prefix + name] = buffer_module.data.cpu() - - dist.barrier() - class TPMLPExperts(BaseMLPExperts): """Use tensor parallelism to split each expert evenly, which can deploy experts in diff --git a/colossalai/nn/layer/moe/utils.py b/colossalai/nn/layer/moe/utils.py index 369f6c0752ac..5b3542c80595 100644 --- a/colossalai/nn/layer/moe/utils.py +++ b/colossalai/nn/layer/moe/utils.py @@ -1,3 +1,4 @@ +import contextlib from typing import Callable import torch @@ -91,3 +92,30 @@ def SwiGLU(x): assert size % 2 == 0, "axis size must be divisible by 2" x1, x2 = torch.split(x, size // 2, -1) return x1 * (x2 * torch.sigmoid(x2)) + + +@contextlib.contextmanager +def skip_init(): + """ + skip param random init + """ + + def _skip_init(x, *args, **kwargs): + return x + + # __enter__ + fn_saved = [] + init_fn_list = [ + torch.nn.init.constant_, torch.nn.init.uniform_, torch.nn.init.normal_, torch.nn.init.xavier_uniform_, + torch.nn.init.xavier_normal_, torch.nn.init.kaiming_uniform_, torch.nn.init.kaiming_normal_ + ] + for fn in init_fn_list: + fn_saved.append(fn) + fn = _skip_init + + yield + + # __exit__ + for fn, fn_saved in zip(init_fn_list, fn_saved): + fn = fn_saved + return diff --git a/examples/language/openmoe/model/modeling_openmoe.py b/examples/language/openmoe/model/modeling_openmoe.py index 7fdd4cc32c23..a1e028ae6308 100644 --- a/examples/language/openmoe/model/modeling_openmoe.py +++ b/examples/language/openmoe/model/modeling_openmoe.py @@ -37,6 +37,7 @@ replace_return_docstrings, ) +from colossalai.context import MOE_CONTEXT from colossalai.nn.layer.moe.layers import SparseMLP logger = logging.get_logger(__name__) @@ -99,11 +100,14 @@ def generate_fixed_pos_embedding(features, length, min_timescale=1.0, max_timesc sinusoid_inp = torch.cat([sinusoid_inp, sinusoid_inp], dim=-1) - return torch.sin(sinusoid_inp).to(torch.bfloat16), torch.cos(sinusoid_inp).to(torch.bfloat16) + return torch.sin(sinusoid_inp), torch.cos(sinusoid_inp) def apply_rotary_embedding(q, k, cos, sin, decode=False, rotary_index=None): """Helper function to apply Rotary Embeddings.""" + cos = cos.to(q.dtype) + sin = sin.to(q.dtype) + if len(k.shape) == 3: # for multi query attention k = k.unsqueeze(2) @@ -405,6 +409,8 @@ def forward( if attention_mask.size() != (bsz, 1, q_len, kv_seq_len): raise ValueError( f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}") + if self.training: + attention_mask = attention_mask.clone().detach() attention_mask[:, :, :, 0] = 0 attn_weights = attn_weights + attention_mask @@ -442,18 +448,19 @@ def __init__(self, config: LlamaConfig, moe: bool): self.input_layernorm = T5LayerNorm(config.hidden_size, eps=config.rms_norm_eps) self.post_attention_layernorm = T5LayerNorm(config.hidden_size, eps=config.rms_norm_eps) if self.moe: - self.mlp = SparseMLP(num_experts=config.num_experts, - top_k=config.topk, - capacity_factor_train=config.capacity_factor_train, - capacity_factor_eval=config.capacity_factor_eval, - min_capacity=config.min_capacity, - noisy_policy=config.noisy_policy, - drop_tks=config.drop_tks, - expert_parallel=config.expert_parallel, - hidden_size=config.hidden_size, - intermediate_size=config.intermediate_size, - activation=config.hidden_act, - gated=config.gated) + self.mlp = SparseMLP( + num_experts=config.num_experts, + top_k=config.topk, + capacity_factor_train=config.capacity_factor_train, + capacity_factor_eval=config.capacity_factor_eval, + min_capacity=config.min_capacity, + noisy_policy=config.noisy_policy, + drop_tks=config.drop_tks, + expert_parallel=MOE_CONTEXT.get_parallel() if MOE_CONTEXT.is_initialized else config.expert_parallel, + hidden_size=config.hidden_size, + intermediate_size=config.intermediate_size, + activation=config.hidden_act, + gated=config.gated) self.pre_extra_mlp_layernorm = T5LayerNorm(config.hidden_size, eps=config.rms_norm_eps) self.extra_mlp = LlamaMLP(config) else: @@ -860,6 +867,7 @@ def forward( output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, + chunk_head: Optional[bool] = None, ) -> Union[Tuple, CausalLMOutputWithPast]: r""" Args: @@ -910,22 +918,59 @@ def forward( lm_head_slices = self.lm_head.weight.split(self.vocab_size // self.pretraining_tp, dim=0) logits = [F.linear(hidden_states, lm_head_slices[i]) for i in range(self.pretraining_tp)] logits = torch.cat(logits, dim=-1) - else: - logits = self.lm_head(hidden_states) - logits = logits.float() loss = None - if labels is not None: - # Shift so that tokens < n predict n - shift_logits = logits[..., :-1, :].contiguous() - shift_labels = labels[..., 1:].contiguous() - # Flatten the tokens - loss_fct = CrossEntropyLoss() - shift_logits = shift_logits.view(-1, self.config.vocab_size) - shift_labels = shift_labels.view(-1) - # Enable model parallelism - shift_labels = shift_labels.to(shift_logits.device) - loss = loss_fct(shift_logits, shift_labels) + # if no training, just do forward + if labels is None: + logits = self.lm_head(hidden_states) + logits = logits.float() + # the vocab size for openmoe is 30w+ + # which causes great activation memory in training, up to 20G for one sequence + # so we use chunk and checkpoint to reduce memory + else: + if chunk_head == True: + + def create_custom_forward(module): + + def custom_forward(*inputs): + logits = module(inputs[0]) + logits = logits.float() + # Shift so that tokens < n predict n + shift_logits = logits[..., :-1, :].contiguous().float() + shift_labels = inputs[1][..., 1:].contiguous() + # Flatten the tokens + loss_fct = CrossEntropyLoss() + shift_logits = shift_logits.view(-1, self.config.vocab_size) + shift_labels = shift_labels.view(-1) + # Enable model parallelism + shift_labels = shift_labels.to(shift_logits.device) + loss = loss_fct(shift_logits, shift_labels) + return loss + + return custom_forward + + loss = 0. + for batch_idx in range(hidden_states.shape[0]): + loss = loss + torch.utils.checkpoint.checkpoint( + create_custom_forward(self.lm_head), + hidden_states[batch_idx, :], + labels[batch_idx, :], + ) + loss = loss / hidden_states.shape[0] + logits = None + else: + logits = self.lm_head(hidden_states) + logits = logits.float() + # Shift so that tokens < n predict n + shift_logits = logits[..., :-1, :].contiguous() + shift_labels = labels[..., 1:].contiguous() + # Flatten the tokens + loss_fct = CrossEntropyLoss() + shift_logits = shift_logits.view(-1, self.config.vocab_size) + shift_labels = shift_labels.view(-1) + # Enable model parallelism + shift_labels = shift_labels.to(shift_logits.device) + loss = loss_fct(shift_logits, shift_labels) if not return_dict: output = (logits,) + outputs[1:] diff --git a/examples/language/openmoe/train.py b/examples/language/openmoe/train.py new file mode 100644 index 000000000000..407809702436 --- /dev/null +++ b/examples/language/openmoe/train.py @@ -0,0 +1,180 @@ +import os + +import datasets +import torch +import transformers +from huggingface_hub import snapshot_download +from model.modeling_openmoe import OpenMoeForCausalLM +from torch.utils.data import Dataset +from tqdm import tqdm +from transformers import T5Tokenizer, get_linear_schedule_with_warmup +from transformers.models.llama import LlamaConfig + +import colossalai +from colossalai import get_default_parser +from colossalai.booster import Booster +from colossalai.booster.plugin import LowLevelZeroPlugin +from colossalai.cluster import DistCoordinator +from colossalai.context import MOE_CONTEXT +from colossalai.logging import disable_existing_loggers, get_dist_logger +from colossalai.nn.layer.moe import MoeCheckpintIO +from colossalai.nn.layer.moe.utils import skip_init +from colossalai.utils import get_current_device + + +def move_to_cuda(batch, device): + return {k: v.to(device) for k, v in batch.items()} + + +def load_ckpt(repo_name: str, model: OpenMoeForCausalLM): + ckpt_path = snapshot_download(repo_name) + # single ckpt + if os.path.exists(os.path.join(ckpt_path, "pytorch_model.bin")): + ckpt_path = os.path.join(ckpt_path, "pytorch_model.bin") + # shard ckpt + elif os.path.exists(os.path.join(ckpt_path, "pytorch_model.bin.index.json")): + ckpt_path = os.path.join(ckpt_path, "pytorch_model.bin.index.json") + else: + raise ValueError(f"Invalid checkpoint path: {ckpt_path}") + MoeCheckpintIO().load_model(model, ckpt_path) + + +class RandomDataset(Dataset): + + def __init__(self, num_samples: int = 1000, max_length: int = 2048, vocab_size: int = 32000): + self.num_samples = num_samples + self.max_length = max_length + self.input_ids = torch.randint(0, vocab_size, (num_samples, max_length), device=get_current_device()) + self.attention_mask = torch.ones_like(self.input_ids) + + def __len__(self): + return self.num_samples + + def __getitem__(self, idx): + return { + 'input_ids': self.input_ids[idx], + 'attention_mask': self.attention_mask[idx], + 'labels': self.input_ids[idx] + } + + +def parse_args(): + parser = get_default_parser() + parser.add_argument("--model_name_or_path", + type=str, + default="base", + help="Path to pretrained model or model identifier from huggingface.co/models.") + parser.add_argument("--output_path", + type=str, + default="./output_model.bin", + help="The path of your saved model after finetuning.") + parser.add_argument("--num_epoch", type=int, default=10, help="Number of epochs.") + parser.add_argument("--batch_size", + type=int, + default=4, + help="Batch size (per dp group) for the training dataloader.") + parser.add_argument("--learning_rate", + type=float, + default=5e-5, + help="Initial learning rate (after the potential warmup period) to use.") + parser.add_argument("--warmup_ratio", + type=float, + default=0.1, + help="Ratio of warmup steps against total training steps.") + parser.add_argument("--weight_decay", type=float, default=0.01, help="Weight decay to use.") + parser.add_argument("--seed", type=int, default=42, help="A seed for reproducible training.") + args = parser.parse_args() + return args + + +def main(): + args = parse_args() + + # Launch ColossalAI + colossalai.launch_from_torch(config={}, seed=args.seed) + coordinator = DistCoordinator() + world_size = coordinator.world_size + + # Set up moe + MOE_CONTEXT.setup(seed=42, parallel="EP") + + # Manage loggers + disable_existing_loggers() + logger = get_dist_logger() + if coordinator.is_master(): + datasets.utils.logging.set_verbosity_warning() + transformers.utils.logging.set_verbosity_info() + else: + datasets.utils.logging.set_verbosity_error() + transformers.utils.logging.set_verbosity_error() + + # Build OpenMoe model + repo_name = "hpcaitech/openmoe-" + args.model_name_or_path + config = LlamaConfig.from_pretrained(repo_name) + with skip_init(): + model = OpenMoeForCausalLM(config) + load_ckpt(repo_name, model) + logger.info(f"Finish init model with config:\n{config}", ranks=[0]) + + # Enable gradient checkpointing + model.gradient_checkpointing_enable() + + # Set plugin + booster_kwargs = {} + plugin = LowLevelZeroPlugin(initial_scale=2**5, stage=2) + logger.info(f"Set plugin as {plugin}", ranks=[0]) + + # Prepare tokenizer and dataloader + tokenizer = T5Tokenizer.from_pretrained("google/umt5-small") + dataset = RandomDataset() + dataloader = plugin.prepare_dataloader(dataset, batch_size=args.batch_size, shuffle=True, drop_last=True) + + # Set optimizer + optimizer = torch.optim.Adam(model.parameters(), + lr=(args.learning_rate * world_size), + weight_decay=args.weight_decay) + + # Set lr scheduler + total_steps = len(dataloader) * args.num_epoch + num_warmup_steps = int(args.warmup_ratio * total_steps) + lr_scheduler = get_linear_schedule_with_warmup(optimizer, + num_warmup_steps=num_warmup_steps, + num_training_steps=len(dataloader) * args.num_epoch) + + # Set booster + booster = Booster(plugin=plugin, **booster_kwargs) + model, optimizer, _, dataloader, lr_scheduler = booster.boost(model=model, + optimizer=optimizer, + dataloader=dataloader, + lr_scheduler=lr_scheduler) + logger.info(f"Finish init booster", ranks=[0]) + + # Start finetuning + logger.info(f"Start finetuning", ranks=[0]) + for epoch in range(args.num_epoch): + model.train() + with tqdm(dataloader, desc=f'Epoch [{epoch + 1}]', disable=not coordinator.is_master()) as pbar: + for batch in pbar: + # Forward + optimizer.zero_grad() + batch = move_to_cuda(batch, torch.cuda.current_device()) + + outputs = model(use_cache=False, chunk_head=True, **batch) + loss = outputs['loss'] + + # Backward + booster.backward(loss, optimizer) + optimizer.step() + lr_scheduler.step() + + # Print batch loss + pbar.set_postfix({'loss': loss.item()}) + + # Finish training and evaluate + logger.info(f"Finish finetuning", ranks=[0]) + booster.save_model(model, args.output_path) + logger.info(f"Saving model checkpoint to {args.output_path}", ranks=[0]) + + +if __name__ == "__main__": + main()