Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 8 additions & 5 deletions colossalai/context/moe_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,8 @@ def __init__(self):
# When we have a maximum expert parallel size, we have a minimum data parallel size naturally
self.max_ep_size = None
self.min_dp_size = None
self.aux_loss = None
self.router_aux_loss = []
self.router_z_loss = []
self.parallel = None
self.use_kernel_optim = True

Expand Down Expand Up @@ -97,13 +98,15 @@ def set_kernel_not_use(self):
self.use_kernel_optim = False

def reset_loss(self):
self.aux_loss = 0
self.router_aux_loss, self.router_z_loss = [], []

def add_loss(self, loss):
self.aux_loss += loss
def add_loss(self, aux_loss: float = 0., z_loss: float = 0.):
self.router_aux_loss.append(aux_loss)
self.router_z_loss.append(z_loss)

def get_loss(self):
return self.aux_loss
cur_loss = self.router_aux_loss, self.router_z_loss
return cur_loss

def get_parallel(self):
return self.parallel
Expand Down
3 changes: 1 addition & 2 deletions colossalai/nn/layer/moe/layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,8 +135,7 @@ def forward(self, inputs: torch.Tensor) -> Tuple:
ans = torch.matmul(combine_weights, expert_output)

ans = ans.reshape(inputs.shape)
l_aux = self.router.pop_routing_loss()
return ans, l_aux
return ans

def _local_process(self, expert_in: torch.Tensor) -> torch.Tensor:
expert_in = expert_in.unsqueeze(0)
Expand Down
49 changes: 30 additions & 19 deletions colossalai/nn/layer/moe/routers.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,8 @@ def __init__(self,
self.min_capacity = min_capacity
self.noisy_func = noisy_func
self.drop_tks = drop_tks
self._routing_loss = None
self._aux_loss = None
self._z_loss = None

def get_capacity(self, logits_shape):
capacity_factor = self.capacity_factor_train if self.training else self.capacity_factor_eval
Expand All @@ -48,15 +49,26 @@ def get_capacity(self, logits_shape):
assert capacity > 0
return capacity

def set_routing_loss(self, aux_loss: torch.Tensor) -> None:
assert self._routing_loss is None
self._routing_loss = aux_loss
def set_aux_loss(self, logits: torch.Tensor, cmask: torch.Tensor, num_experts: int) -> None:
assert self._aux_loss is None
me = torch.mean(logits, dim=0)
ce = torch.mean(cmask.float(), dim=0)
aux_loss = num_experts * torch.sum(me * ce)
self._aux_loss = aux_loss

def set_z_loss(self, router_logits: torch.Tensor):
assert self._z_loss is None
n, _ = router_logits.shape
log_z = torch.logsumexp(router_logits, axis=-1)
z_loss = log_z**2
z_loss = torch.sum(z_loss, dtype=torch.float32) / n
self._z_loss = z_loss

def pop_routing_loss(self) -> torch.Tensor:
assert self._routing_loss is not None
reservation = self._routing_loss
self._routing_loss = None
return reservation
def pop_router_loss(self) -> torch.Tensor:
assert self._aux_loss is not None
MOE_CONTEXT.add_loss(self._aux_loss, self._z_loss)
self._aux_loss = None
self._z_loss = None


class Top1Router(MoeRouter):
Expand Down Expand Up @@ -105,11 +117,10 @@ def forward(self, inputs: torch.Tensor, use_kernel: bool = False, ep_group: Opti
top1_idx = torch.argmax(inputs, dim=-1)
mask = F.one_hot(top1_idx, num_classes=num_experts).to(torch.int32)

# caculate the auxiliary loss
me = torch.mean(logits, dim=0)
ce = torch.mean(mask.float(), dim=0)
l_aux = num_experts * torch.sum(me * ce)
self.set_routing_loss(l_aux)
# caculate router loss
self.set_aux_loss(logits, mask, num_experts)
self.set_z_loss(inputs)
self.pop_router_loss()

if not self.training and not self.drop_tks and ep_group is not None:
max_num = torch.max(torch.sum(mask, dim=0))
Expand Down Expand Up @@ -183,12 +194,12 @@ def forward(self, inputs: torch.Tensor, use_kernel: bool = False, ep_group: Opti
mask2 = F.one_hot(top2_idx, num_classes=num_experts).to(torch.int32)

cmask = (mask1 + mask2) # loss: [s, e]
cmask = cmask.float() / 2.0 # div 2 to normalize it to 1

# caculate the auxiliary loss
me = torch.mean(logits, dim=0)
ce = torch.mean(cmask.float(), dim=0)
l_aux = num_experts * torch.sum(me * ce) / 2.0 # div 2 to normalize it to 1
self.set_routing_loss(l_aux)
# caculate loss
self.set_aux_loss(logits, cmask, num_experts)
self.set_z_loss(inputs)
self.pop_router_loss()

if not self.training and not self.drop_tks and ep_group is not None:
max_num = torch.max(torch.sum(cmask, dim=0))
Expand Down
98 changes: 77 additions & 21 deletions examples/language/openmoe/model/modeling_openmoe.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,14 +18,13 @@
# See the License for the specific language governing permissions and
# limitations under the License.
""" PyTorch OpenMoE model."""
import math
from typing import List, Optional, Tuple, Union

import torch
import torch.nn.functional as F
import torch.utils.checkpoint
from torch import nn
from torch.nn import CrossEntropyLoss
from transformers.activations import ACT2FN
from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast
from transformers.modeling_utils import PreTrainedModel
from transformers.models.llama import LlamaConfig
Expand Down Expand Up @@ -508,8 +507,6 @@ def forward(
residual = hidden_states
hidden_states = self.post_attention_layernorm(hidden_states)
hidden_states = self.mlp(hidden_states)
if self.moe:
hidden_states = hidden_states[0]
hidden_states = residual + hidden_states

if self.moe:
Expand Down Expand Up @@ -742,7 +739,6 @@ def forward(

if inputs_embeds is None:
inputs_embeds = self.embed_tokens(input_ids)
# import pdb; pdb.set_trace()
# embed positions
if attention_mask is None:
attention_mask = torch.ones((batch_size, seq_length_with_past),
Expand Down Expand Up @@ -894,6 +890,8 @@ def forward(
>>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
"Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you."
```"""
# reset moe loss
MOE_CONTEXT.reset_loss()

output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (output_hidden_states
Expand Down Expand Up @@ -939,24 +937,19 @@ def custom_forward(*inputs):
shift_logits = logits[..., :-1, :].contiguous().float()
shift_labels = inputs[1][..., 1:].contiguous()
# Flatten the tokens
loss_fct = CrossEntropyLoss()
shift_logits = shift_logits.view(-1, self.config.vocab_size)
shift_labels = shift_labels.view(-1)
# Enable model parallelism
shift_labels = shift_labels.to(shift_logits.device)
loss = loss_fct(shift_logits, shift_labels)
loss = self._calculate_loss(shift_logits, shift_labels)
return loss

return custom_forward

loss = 0.
aux_loss, z_loss = self._calculate_router_loss()
loss = aux_loss + z_loss
for batch_idx in range(hidden_states.shape[0]):
loss = loss + torch.utils.checkpoint.checkpoint(
create_custom_forward(self.lm_head),
hidden_states[batch_idx, :],
labels[batch_idx, :],
hidden_states[batch_idx:batch_idx + 1, :],
labels[batch_idx:batch_idx + 1, :],
)
loss = loss / hidden_states.shape[0]
logits = None
else:
logits = self.lm_head(hidden_states)
Expand All @@ -965,12 +958,9 @@ def custom_forward(*inputs):
shift_logits = logits[..., :-1, :].contiguous()
shift_labels = labels[..., 1:].contiguous()
# Flatten the tokens
loss_fct = CrossEntropyLoss()
shift_logits = shift_logits.view(-1, self.config.vocab_size)
shift_labels = shift_labels.view(-1)
# Enable model parallelism
shift_labels = shift_labels.to(shift_logits.device)
loss = loss_fct(shift_logits, shift_labels)
aux_loss, z_loss = self._calculate_router_loss()
loss = aux_loss + z_loss
loss = loss + self._calculate_loss(shift_logits, shift_labels)

if not return_dict:
output = (logits,) + outputs[1:]
Expand Down Expand Up @@ -1022,3 +1012,69 @@ def _reorder_cache(past_key_values, beam_idx):
reordered_past += (tuple(
past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past),)
return reordered_past

def _calculate_router_loss(self):
aux_loss, z_loss = MOE_CONTEXT.get_loss()
assert len(aux_loss) == len(z_loss) == self.config.num_hidden_layers // self.config.moe_layer_interval
aux_loss = self.config.router_aux_loss_factor * sum(aux_loss) / len(aux_loss)
z_loss = self.config.router_z_loss_factor * sum(z_loss) / len(z_loss)
return aux_loss, z_loss

def _calculate_loss(self, logits, targets):
if len(logits.shape) != len(targets.shape) + 1:
raise ValueError('Incorrect shapes. Got shape %s logits and %s targets' %
(str(logits.shape), str(targets.shape)))
vocab_size = logits.shape[-1]
confidence = 1.0 - self.config.label_smoothing
low_confidence = (1.0 - confidence) / (vocab_size - 1)
normalizing_constant = -(confidence * math.log(confidence) +
(vocab_size - 1) * low_confidence * math.log(low_confidence + 1e-20))

# one hot
soft_targets = targets[..., None] == \
torch.arange(vocab_size, device=targets.device).reshape((1,) * len(targets.shape) + (-1,))
soft_targets = torch.where(soft_targets, torch.full_like(soft_targets, confidence),
torch.full_like(soft_targets, low_confidence))
soft_targets = soft_targets.to(torch.float32)

# cross entropy
total_loss = ZLossCrossEntropy.apply(logits, soft_targets, self.config.z_loss_factor)
total_loss = total_loss - normalizing_constant
total_loss = torch.mean(torch.sum(total_loss, dim=-1), dim=0)
return total_loss


class ZLossCrossEntropy(torch.autograd.Function):

@staticmethod
def forward(ctx, logits, targets, z_loss):
max_logit = torch.max(logits, dim=-1, keepdim=True)[0]
shifted = logits - max_logit
exp_shifted = torch.exp(shifted)
sum_exp = torch.sum(exp_shifted, axis=-1, keepdims=True)
log_softmax = shifted - torch.log(sum_exp)
loss = -torch.sum(targets * log_softmax, axis=-1)
# Add auxilliary z-loss term.
log_z = torch.squeeze(torch.log(sum_exp) + max_logit, axis=-1)
total_z_loss = z_loss * torch.square(log_z)
loss += total_z_loss
ctx.z_loss = z_loss
ctx.save_for_backward(logits, targets, exp_shifted, sum_exp, log_softmax, log_z)
return loss

@staticmethod
def backward(ctx, *grad_outputs):
assert len(grad_outputs) == 1
g = grad_outputs[0]
z_loss = ctx.z_loss
logits, targets, exp_shifted, sum_exp, log_softmax, log_z = ctx.saved_tensors
# z-loss term adds the (2 * z_loss * log_z) factor.
deriv = ((1 + 2 * z_loss * log_z).unsqueeze(-1) * exp_shifted / sum_exp - targets)
g_logits = g.unsqueeze(-1) * deriv
g_targets = -g.unsqueeze(-1) * log_softmax

return (
g_logits.to(logits.dtype),
g_targets.to(targets.dtype),
None,
)
1 change: 1 addition & 0 deletions examples/language/openmoe/requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -2,3 +2,4 @@ colossalai >= 0.1.12
torch >= 1.8.1
transformers >= 4.20.0
sentencepiece
datasets
1 change: 1 addition & 0 deletions examples/language/openmoe/test_ci.sh
Original file line number Diff line number Diff line change
Expand Up @@ -2,3 +2,4 @@ set -xe
pip install -r requirements.txt

python infer.py --model "test"
torchrun --standalone --nproc_per_node 2 train.py --model_name "test" --batch_size 1 --num_epoch 20
Loading