diff --git a/colossalai/context/moe_context.py b/colossalai/context/moe_context.py index b413e12e3650..652b9c2382f4 100644 --- a/colossalai/context/moe_context.py +++ b/colossalai/context/moe_context.py @@ -20,7 +20,8 @@ def __init__(self): # When we have a maximum expert parallel size, we have a minimum data parallel size naturally self.max_ep_size = None self.min_dp_size = None - self.aux_loss = None + self.router_aux_loss = [] + self.router_z_loss = [] self.parallel = None self.use_kernel_optim = True @@ -97,13 +98,15 @@ def set_kernel_not_use(self): self.use_kernel_optim = False def reset_loss(self): - self.aux_loss = 0 + self.router_aux_loss, self.router_z_loss = [], [] - def add_loss(self, loss): - self.aux_loss += loss + def add_loss(self, aux_loss: float = 0., z_loss: float = 0.): + self.router_aux_loss.append(aux_loss) + self.router_z_loss.append(z_loss) def get_loss(self): - return self.aux_loss + cur_loss = self.router_aux_loss, self.router_z_loss + return cur_loss def get_parallel(self): return self.parallel diff --git a/colossalai/nn/layer/moe/layers.py b/colossalai/nn/layer/moe/layers.py index 6b7be9eb57c0..3f65bcde8b29 100644 --- a/colossalai/nn/layer/moe/layers.py +++ b/colossalai/nn/layer/moe/layers.py @@ -135,8 +135,7 @@ def forward(self, inputs: torch.Tensor) -> Tuple: ans = torch.matmul(combine_weights, expert_output) ans = ans.reshape(inputs.shape) - l_aux = self.router.pop_routing_loss() - return ans, l_aux + return ans def _local_process(self, expert_in: torch.Tensor) -> torch.Tensor: expert_in = expert_in.unsqueeze(0) diff --git a/colossalai/nn/layer/moe/routers.py b/colossalai/nn/layer/moe/routers.py index 962aec9cf1e7..9332302a096a 100644 --- a/colossalai/nn/layer/moe/routers.py +++ b/colossalai/nn/layer/moe/routers.py @@ -38,7 +38,8 @@ def __init__(self, self.min_capacity = min_capacity self.noisy_func = noisy_func self.drop_tks = drop_tks - self._routing_loss = None + self._aux_loss = None + self._z_loss = None def get_capacity(self, logits_shape): capacity_factor = self.capacity_factor_train if self.training else self.capacity_factor_eval @@ -48,15 +49,26 @@ def get_capacity(self, logits_shape): assert capacity > 0 return capacity - def set_routing_loss(self, aux_loss: torch.Tensor) -> None: - assert self._routing_loss is None - self._routing_loss = aux_loss + def set_aux_loss(self, logits: torch.Tensor, cmask: torch.Tensor, num_experts: int) -> None: + assert self._aux_loss is None + me = torch.mean(logits, dim=0) + ce = torch.mean(cmask.float(), dim=0) + aux_loss = num_experts * torch.sum(me * ce) + self._aux_loss = aux_loss + + def set_z_loss(self, router_logits: torch.Tensor): + assert self._z_loss is None + n, _ = router_logits.shape + log_z = torch.logsumexp(router_logits, axis=-1) + z_loss = log_z**2 + z_loss = torch.sum(z_loss, dtype=torch.float32) / n + self._z_loss = z_loss - def pop_routing_loss(self) -> torch.Tensor: - assert self._routing_loss is not None - reservation = self._routing_loss - self._routing_loss = None - return reservation + def pop_router_loss(self) -> torch.Tensor: + assert self._aux_loss is not None + MOE_CONTEXT.add_loss(self._aux_loss, self._z_loss) + self._aux_loss = None + self._z_loss = None class Top1Router(MoeRouter): @@ -105,11 +117,10 @@ def forward(self, inputs: torch.Tensor, use_kernel: bool = False, ep_group: Opti top1_idx = torch.argmax(inputs, dim=-1) mask = F.one_hot(top1_idx, num_classes=num_experts).to(torch.int32) - # caculate the auxiliary loss - me = torch.mean(logits, dim=0) - ce = torch.mean(mask.float(), dim=0) - l_aux = num_experts * torch.sum(me * ce) - self.set_routing_loss(l_aux) + # caculate router loss + self.set_aux_loss(logits, mask, num_experts) + self.set_z_loss(inputs) + self.pop_router_loss() if not self.training and not self.drop_tks and ep_group is not None: max_num = torch.max(torch.sum(mask, dim=0)) @@ -183,12 +194,12 @@ def forward(self, inputs: torch.Tensor, use_kernel: bool = False, ep_group: Opti mask2 = F.one_hot(top2_idx, num_classes=num_experts).to(torch.int32) cmask = (mask1 + mask2) # loss: [s, e] + cmask = cmask.float() / 2.0 # div 2 to normalize it to 1 - # caculate the auxiliary loss - me = torch.mean(logits, dim=0) - ce = torch.mean(cmask.float(), dim=0) - l_aux = num_experts * torch.sum(me * ce) / 2.0 # div 2 to normalize it to 1 - self.set_routing_loss(l_aux) + # caculate loss + self.set_aux_loss(logits, cmask, num_experts) + self.set_z_loss(inputs) + self.pop_router_loss() if not self.training and not self.drop_tks and ep_group is not None: max_num = torch.max(torch.sum(cmask, dim=0)) diff --git a/examples/language/openmoe/model/modeling_openmoe.py b/examples/language/openmoe/model/modeling_openmoe.py index a1e028ae6308..1ea9d48523c3 100644 --- a/examples/language/openmoe/model/modeling_openmoe.py +++ b/examples/language/openmoe/model/modeling_openmoe.py @@ -18,14 +18,13 @@ # See the License for the specific language governing permissions and # limitations under the License. """ PyTorch OpenMoE model.""" +import math from typing import List, Optional, Tuple, Union import torch import torch.nn.functional as F import torch.utils.checkpoint from torch import nn -from torch.nn import CrossEntropyLoss -from transformers.activations import ACT2FN from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast from transformers.modeling_utils import PreTrainedModel from transformers.models.llama import LlamaConfig @@ -508,8 +507,6 @@ def forward( residual = hidden_states hidden_states = self.post_attention_layernorm(hidden_states) hidden_states = self.mlp(hidden_states) - if self.moe: - hidden_states = hidden_states[0] hidden_states = residual + hidden_states if self.moe: @@ -742,7 +739,6 @@ def forward( if inputs_embeds is None: inputs_embeds = self.embed_tokens(input_ids) - # import pdb; pdb.set_trace() # embed positions if attention_mask is None: attention_mask = torch.ones((batch_size, seq_length_with_past), @@ -894,6 +890,8 @@ def forward( >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0] "Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you." ```""" + # reset moe loss + MOE_CONTEXT.reset_loss() output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_hidden_states = (output_hidden_states @@ -939,24 +937,19 @@ def custom_forward(*inputs): shift_logits = logits[..., :-1, :].contiguous().float() shift_labels = inputs[1][..., 1:].contiguous() # Flatten the tokens - loss_fct = CrossEntropyLoss() - shift_logits = shift_logits.view(-1, self.config.vocab_size) - shift_labels = shift_labels.view(-1) - # Enable model parallelism - shift_labels = shift_labels.to(shift_logits.device) - loss = loss_fct(shift_logits, shift_labels) + loss = self._calculate_loss(shift_logits, shift_labels) return loss return custom_forward - loss = 0. + aux_loss, z_loss = self._calculate_router_loss() + loss = aux_loss + z_loss for batch_idx in range(hidden_states.shape[0]): loss = loss + torch.utils.checkpoint.checkpoint( create_custom_forward(self.lm_head), - hidden_states[batch_idx, :], - labels[batch_idx, :], + hidden_states[batch_idx:batch_idx + 1, :], + labels[batch_idx:batch_idx + 1, :], ) - loss = loss / hidden_states.shape[0] logits = None else: logits = self.lm_head(hidden_states) @@ -965,12 +958,9 @@ def custom_forward(*inputs): shift_logits = logits[..., :-1, :].contiguous() shift_labels = labels[..., 1:].contiguous() # Flatten the tokens - loss_fct = CrossEntropyLoss() - shift_logits = shift_logits.view(-1, self.config.vocab_size) - shift_labels = shift_labels.view(-1) - # Enable model parallelism - shift_labels = shift_labels.to(shift_logits.device) - loss = loss_fct(shift_logits, shift_labels) + aux_loss, z_loss = self._calculate_router_loss() + loss = aux_loss + z_loss + loss = loss + self._calculate_loss(shift_logits, shift_labels) if not return_dict: output = (logits,) + outputs[1:] @@ -1022,3 +1012,69 @@ def _reorder_cache(past_key_values, beam_idx): reordered_past += (tuple( past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past),) return reordered_past + + def _calculate_router_loss(self): + aux_loss, z_loss = MOE_CONTEXT.get_loss() + assert len(aux_loss) == len(z_loss) == self.config.num_hidden_layers // self.config.moe_layer_interval + aux_loss = self.config.router_aux_loss_factor * sum(aux_loss) / len(aux_loss) + z_loss = self.config.router_z_loss_factor * sum(z_loss) / len(z_loss) + return aux_loss, z_loss + + def _calculate_loss(self, logits, targets): + if len(logits.shape) != len(targets.shape) + 1: + raise ValueError('Incorrect shapes. Got shape %s logits and %s targets' % + (str(logits.shape), str(targets.shape))) + vocab_size = logits.shape[-1] + confidence = 1.0 - self.config.label_smoothing + low_confidence = (1.0 - confidence) / (vocab_size - 1) + normalizing_constant = -(confidence * math.log(confidence) + + (vocab_size - 1) * low_confidence * math.log(low_confidence + 1e-20)) + + # one hot + soft_targets = targets[..., None] == \ + torch.arange(vocab_size, device=targets.device).reshape((1,) * len(targets.shape) + (-1,)) + soft_targets = torch.where(soft_targets, torch.full_like(soft_targets, confidence), + torch.full_like(soft_targets, low_confidence)) + soft_targets = soft_targets.to(torch.float32) + + # cross entropy + total_loss = ZLossCrossEntropy.apply(logits, soft_targets, self.config.z_loss_factor) + total_loss = total_loss - normalizing_constant + total_loss = torch.mean(torch.sum(total_loss, dim=-1), dim=0) + return total_loss + + +class ZLossCrossEntropy(torch.autograd.Function): + + @staticmethod + def forward(ctx, logits, targets, z_loss): + max_logit = torch.max(logits, dim=-1, keepdim=True)[0] + shifted = logits - max_logit + exp_shifted = torch.exp(shifted) + sum_exp = torch.sum(exp_shifted, axis=-1, keepdims=True) + log_softmax = shifted - torch.log(sum_exp) + loss = -torch.sum(targets * log_softmax, axis=-1) + # Add auxilliary z-loss term. + log_z = torch.squeeze(torch.log(sum_exp) + max_logit, axis=-1) + total_z_loss = z_loss * torch.square(log_z) + loss += total_z_loss + ctx.z_loss = z_loss + ctx.save_for_backward(logits, targets, exp_shifted, sum_exp, log_softmax, log_z) + return loss + + @staticmethod + def backward(ctx, *grad_outputs): + assert len(grad_outputs) == 1 + g = grad_outputs[0] + z_loss = ctx.z_loss + logits, targets, exp_shifted, sum_exp, log_softmax, log_z = ctx.saved_tensors + # z-loss term adds the (2 * z_loss * log_z) factor. + deriv = ((1 + 2 * z_loss * log_z).unsqueeze(-1) * exp_shifted / sum_exp - targets) + g_logits = g.unsqueeze(-1) * deriv + g_targets = -g.unsqueeze(-1) * log_softmax + + return ( + g_logits.to(logits.dtype), + g_targets.to(targets.dtype), + None, + ) diff --git a/examples/language/openmoe/requirements.txt b/examples/language/openmoe/requirements.txt index 2fb95d9c71d3..935a3f1e4ce0 100644 --- a/examples/language/openmoe/requirements.txt +++ b/examples/language/openmoe/requirements.txt @@ -2,3 +2,4 @@ colossalai >= 0.1.12 torch >= 1.8.1 transformers >= 4.20.0 sentencepiece +datasets diff --git a/examples/language/openmoe/test_ci.sh b/examples/language/openmoe/test_ci.sh index 349b2eaccd79..75eee902c747 100644 --- a/examples/language/openmoe/test_ci.sh +++ b/examples/language/openmoe/test_ci.sh @@ -2,3 +2,4 @@ set -xe pip install -r requirements.txt python infer.py --model "test" +torchrun --standalone --nproc_per_node 2 train.py --model_name "test" --batch_size 1 --num_epoch 20 diff --git a/examples/language/openmoe/train.py b/examples/language/openmoe/train.py index 407809702436..67dd387a3950 100644 --- a/examples/language/openmoe/train.py +++ b/examples/language/openmoe/train.py @@ -7,7 +7,7 @@ from model.modeling_openmoe import OpenMoeForCausalLM from torch.utils.data import Dataset from tqdm import tqdm -from transformers import T5Tokenizer, get_linear_schedule_with_warmup +from transformers import Adafactor, T5Tokenizer from transformers.models.llama import LlamaConfig import colossalai @@ -60,7 +60,7 @@ def __getitem__(self, idx): def parse_args(): parser = get_default_parser() - parser.add_argument("--model_name_or_path", + parser.add_argument("--model_name", type=str, default="base", help="Path to pretrained model or model identifier from huggingface.co/models.") @@ -73,16 +73,16 @@ def parse_args(): type=int, default=4, help="Batch size (per dp group) for the training dataloader.") - parser.add_argument("--learning_rate", - type=float, - default=5e-5, - help="Initial learning rate (after the potential warmup period) to use.") - parser.add_argument("--warmup_ratio", - type=float, - default=0.1, - help="Ratio of warmup steps against total training steps.") - parser.add_argument("--weight_decay", type=float, default=0.01, help="Weight decay to use.") parser.add_argument("--seed", type=int, default=42, help="A seed for reproducible training.") + # loss + parser.add_argument("--router_aux_loss_factor", type=float, default=0.01, help="router_aux_loss_factor.") + parser.add_argument("--router_z_loss_factor", type=float, default=0.0001, help="router_z_loss_factor.") + parser.add_argument("--label_smoothing", type=float, default=0.0, help="label_smoothing.") + parser.add_argument("--z_loss_factor", type=float, default=0.0001, help="z_loss_factor.") + # optim + parser.add_argument("--decay_rate", type=float, default=-0.8, help="adafactor optim decay rate.") + parser.add_argument("--weight_decay", type=float, default=0.01, help="Weight decay to use.") + args = parser.parse_args() return args @@ -93,7 +93,6 @@ def main(): # Launch ColossalAI colossalai.launch_from_torch(config={}, seed=args.seed) coordinator = DistCoordinator() - world_size = coordinator.world_size # Set up moe MOE_CONTEXT.setup(seed=42, parallel="EP") @@ -109,11 +108,20 @@ def main(): transformers.utils.logging.set_verbosity_error() # Build OpenMoe model - repo_name = "hpcaitech/openmoe-" + args.model_name_or_path - config = LlamaConfig.from_pretrained(repo_name) + repo_name = "hpcaitech/openmoe-" + args.model_name + if args.model_name == "test": + config = LlamaConfig.from_pretrained("hpcaitech/openmoe-base") + config.vocab_size = 32000 + else: + config = LlamaConfig.from_pretrained(repo_name) + setattr(config, "router_aux_loss_factor", args.router_aux_loss_factor) + setattr(config, "router_z_loss_factor", args.router_z_loss_factor) + setattr(config, "label_smoothing", args.label_smoothing) + setattr(config, "z_loss_factor", args.z_loss_factor) with skip_init(): model = OpenMoeForCausalLM(config) - load_ckpt(repo_name, model) + if args.model_name != "test": + load_ckpt(repo_name, model) logger.info(f"Finish init model with config:\n{config}", ranks=[0]) # Enable gradient checkpointing @@ -126,27 +134,15 @@ def main(): # Prepare tokenizer and dataloader tokenizer = T5Tokenizer.from_pretrained("google/umt5-small") - dataset = RandomDataset() + dataset = RandomDataset(num_samples=1000 if args.model_name != "test" else 1) dataloader = plugin.prepare_dataloader(dataset, batch_size=args.batch_size, shuffle=True, drop_last=True) # Set optimizer - optimizer = torch.optim.Adam(model.parameters(), - lr=(args.learning_rate * world_size), - weight_decay=args.weight_decay) - - # Set lr scheduler - total_steps = len(dataloader) * args.num_epoch - num_warmup_steps = int(args.warmup_ratio * total_steps) - lr_scheduler = get_linear_schedule_with_warmup(optimizer, - num_warmup_steps=num_warmup_steps, - num_training_steps=len(dataloader) * args.num_epoch) + optimizer = Adafactor(model.parameters(), decay_rate=args.decay_rate, weight_decay=args.weight_decay) # Set booster booster = Booster(plugin=plugin, **booster_kwargs) - model, optimizer, _, dataloader, lr_scheduler = booster.boost(model=model, - optimizer=optimizer, - dataloader=dataloader, - lr_scheduler=lr_scheduler) + model, optimizer, _, dataloader, _ = booster.boost(model=model, optimizer=optimizer, dataloader=dataloader) logger.info(f"Finish init booster", ranks=[0]) # Start finetuning @@ -165,7 +161,6 @@ def main(): # Backward booster.backward(loss, optimizer) optimizer.step() - lr_scheduler.step() # Print batch loss pbar.set_postfix({'loss': loss.item()}) diff --git a/examples/language/openmoe/train.sh b/examples/language/openmoe/train.sh new file mode 100644 index 000000000000..9a55779ca5ef --- /dev/null +++ b/examples/language/openmoe/train.sh @@ -0,0 +1,3 @@ +torchrun --standalone --nproc_per_node 2 train.py \ + --model_name "base" \ + --batch_size 4 diff --git a/tests/test_moe/moe_utils.py b/tests/test_moe/moe_utils.py index a3f6f86e6fe9..7f9bfe376632 100644 --- a/tests/test_moe/moe_utils.py +++ b/tests/test_moe/moe_utils.py @@ -30,9 +30,9 @@ def __init__(self): self.proj = nn.Linear(16, 4) def _forward(self, x): - x, y = self.moe(x) + x = self.moe(x) x = self.proj(x) - return x, y + return x super().__init__() self.test_embed = nn.Linear(4, 16) @@ -42,9 +42,8 @@ def forward(self, x): MOE_CONTEXT.reset_loss() x = self.test_embed(x) - x, y = self.test_transform(x) + x = self.test_transform(x) - MOE_CONTEXT.add_loss(y) return x diff --git a/tests/test_moe/test_grad_handler.py b/tests/test_moe/test_grad_handler.py index 9544aa0daf01..fd9f30ecb473 100644 --- a/tests/test_moe/test_grad_handler.py +++ b/tests/test_moe/test_grad_handler.py @@ -50,7 +50,7 @@ def run_test(rank, world_size, port): MOE_CONTEXT.reset_loss() for layer in layer_list: - data, _ = layer(data) + data = layer(data) data.backward(grad) grad_handler.handle_gradient() diff --git a/tests/test_moe/test_kernel.py b/tests/test_moe/test_kernel.py index 6e40e53311a6..46846206f7d1 100644 --- a/tests/test_moe/test_kernel.py +++ b/tests/test_moe/test_kernel.py @@ -43,7 +43,7 @@ def run_routing(rank, world_size, port, rs=2, hidden_size=128, data_type=torch.f # use matrix multiplication instead of COL_MOE_KERNEL in MOE dispatch and combine layer.use_kernel = False - old_out, _ = layer(tokens) + old_out = layer(tokens) ech = old_out.shape grad = torch.randn(ech, device=get_current_device()) old_out.backward(grad) # get gradient @@ -57,7 +57,7 @@ def run_routing(rank, world_size, port, rs=2, hidden_size=128, data_type=torch.f layer.gate_weight.grad.zero_() layer.use_kernel = True - new_out, _ = layer(tokens) # get outputs through colossal kernel + new_out = layer(tokens) # get outputs through colossal kernel if data_type == torch.float32: check_equal(old_out, new_out) diff --git a/tests/test_moe/test_moe_ep_tp.py b/tests/test_moe/test_moe_ep_tp.py index 13c66cf73e4d..cb261912e0f6 100644 --- a/tests/test_moe/test_moe_ep_tp.py +++ b/tests/test_moe/test_moe_ep_tp.py @@ -37,9 +37,9 @@ def run_test(rank, world_size, port): tp_data = torch.randn(BATCH_SIZE, DIM, device=get_current_device()) ep_data = tp_data.detach()[2 * rank:2 * (rank + 1)] - out_tp = tp_model(tp_data)[0] + out_tp = tp_model(tp_data) MOE_CONTEXT.reset_loss() - out_ep = ep_model(ep_data)[0] + out_ep = ep_model(ep_data) MOE_CONTEXT.reset_loss() assert torch.allclose(out_ep, out_tp[2 * rank:2 * (rank + 1)]) diff --git a/tests/test_moe/test_moe_local.py b/tests/test_moe/test_moe_local.py index d240ad46ce71..e41a0d821a10 100644 --- a/tests/test_moe/test_moe_local.py +++ b/tests/test_moe/test_moe_local.py @@ -37,9 +37,9 @@ def run_test(rank, world_size, port): tp_data = torch.randn(BATCH_SIZE, DIM, device=get_current_device()) ep_data = tp_data.detach()[2 * rank:2 * (rank + 1)] - out_tp = local_model(tp_data)[0] + out_tp = local_model(tp_data) MOE_CONTEXT.reset_loss() - out_ep = ep_model(ep_data)[0] + out_ep = ep_model(ep_data) MOE_CONTEXT.reset_loss() assert torch.allclose(out_ep, out_tp[2 * rank:2 * (rank + 1)]) diff --git a/tests/test_moe/test_moe_zero_fwd_bwd.py b/tests/test_moe/test_moe_zero_fwd_bwd.py index 9d19ee830f77..f1f888203746 100644 --- a/tests/test_moe/test_moe_zero_fwd_bwd.py +++ b/tests/test_moe/test_moe_zero_fwd_bwd.py @@ -40,7 +40,7 @@ def run_fwd_bwd(model, data, label, criterion, optimizer, enable_autocast=False) def run_zero_test(local_rank, world_size, stage=1): - criterion = MoeLoss(aux_weight=0.01, loss_fn=torch.nn.CrossEntropyLoss) + criterion = torch.nn.CrossEntropyLoss() zero_model = MoeModel(checkpoint=True) optimizer = torch.optim.Adam(zero_model.parameters()) diff --git a/tests/test_moe/test_moe_zero_optim.py b/tests/test_moe/test_moe_zero_optim.py index fcb6f95d1319..229ee528b4fc 100644 --- a/tests/test_moe/test_moe_zero_optim.py +++ b/tests/test_moe/test_moe_zero_optim.py @@ -39,7 +39,7 @@ def run_fwd_bwd(model, data, label, criterion, optimizer, enable_autocast=False) def run_zero_optim_test(local_rank, world_size, stage=1): - criterion = MoeLoss(aux_weight=0.01, loss_fn=torch.nn.CrossEntropyLoss) + criterion = torch.nn.CrossEntropyLoss() zero_model = MoeModel(checkpoint=True) zero_optimizer = torch.optim.Adam(zero_model.parameters())