Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
99 changes: 71 additions & 28 deletions colossalai/moe/routers.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,19 +49,60 @@ def get_capacity(self, logits_shape):
assert capacity > 0
return capacity

def set_aux_loss(self, logits: torch.Tensor, cmask: torch.Tensor, num_experts: int) -> None:
def set_aux_loss(self,
router_probs: torch.Tensor,
expert_indices: torch.Tensor,
num_experts: int
) -> None:
"""Computes auxiliary load balancing loss as in Switch Transformer.

See Switch Transformer (https://arxiv.org/abs/2101.03961). This function
implements the loss function presented in equations (4) - (6). It aims to
penalize those cases where the routing between experts is unbalanced.

Args:
router_probs: Probability assigned to each expert per token. Shape:
<float32>[num_groups, tokens_per_group, num_experts].
expert_indices: <int>[num_groups, tokens_per_group, num_selected_experts]
indices identifying the top num_selected_experts for a given token.
"""
assert self._aux_loss is None
me = torch.mean(logits, dim=0)
ce = torch.mean(cmask.float(), dim=0)
aux_loss = num_experts * torch.sum(me * ce)
if router_probs.dim() == expert_indices.dim() == 2:
router_probs = router_probs.unsqueeze(0)
expert_indices = expert_indices.unsqueeze(0)
assert router_probs.dim() == expert_indices.dim() == 3, \
"router_probs must be 3D tensor and expert_indices must be 4D tensor"

# Shape: [num_groups, tokens_per_group, num_selected_experts, num_experts].
expert_mask = F.one_hot(expert_indices, num_experts)
# For a given token, determine if it was routed to a given expert.
# Shape: [num_groups, tokens_per_group, num_experts]
expert_mask = expert_mask.max(dim=-2)[0]

tokens_per_group_and_expert = torch.mean(expert_mask.float(), dim=-2)
router_prob_per_group_and_expert = torch.mean(router_probs.float(), dim=-2)
aux_loss = num_experts**2 * torch.mean(
tokens_per_group_and_expert * router_prob_per_group_and_expert)
self._aux_loss = aux_loss

def set_z_loss(self, router_logits: torch.Tensor):
"""Compute router z-loss.

The router z-loss was introduced in Designing Effective Sparse Expert Models
(https://arxiv.org/abs/2202.08906). It encourages router logits to remain
small in an effort to improve stability.

Args:
router_logits: <float>[num_groups, tokens_per_group, num_experts] router logits.
"""
assert self._z_loss is None
n, _ = router_logits.shape
log_z = torch.logsumexp(router_logits, axis=-1)
z_loss = log_z**2
z_loss = torch.sum(z_loss, dtype=torch.float32) / n
if router_logits.dim() == 2:
router_logits = router_logits.unsqueeze(0)
assert router_logits.dim() == 3, "router_logits must be 3D tensor"
num_groups, tokens_per_group, _ = router_logits.shape
log_z = torch.logsumexp(router_logits, dim=-1)
z_loss = torch.sum(log_z**2, dtype=torch.float32
) / (num_groups * tokens_per_group)
self._z_loss = z_loss

def pop_router_loss(self) -> torch.Tensor:
Expand Down Expand Up @@ -126,15 +167,15 @@ def forward(self,
inputs = self.noisy_func(inputs)

assert inputs.dtype == torch.float
logits = F.softmax(inputs, dim=-1)
num_experts = logits.size(-1)
capacity = self.get_capacity(logits.shape)
probs = F.softmax(inputs, dim=-1)
num_experts = probs.size(-1)
capacity = self.get_capacity(inputs.shape)

top1_idx = torch.argmax(inputs, dim=-1)
mask = F.one_hot(top1_idx, num_classes=num_experts).to(torch.int32)

# caculate router loss
self.set_aux_loss(logits, mask, num_experts)
self.set_aux_loss(probs, top1_idx.unsqueeze(-1), num_experts)
self.set_z_loss(inputs)
self.pop_router_loss()

Expand All @@ -160,10 +201,10 @@ def forward(self,
mask = torch.sum(mask, dim=-1)
mask = torch.stack([mask], dim=0).to(torch.int32)
dest_idx = torch.stack([top1_idx * capacity + ranks], dim=0).to(torch.int32)
return logits, mask, dest_idx, num_experts * capacity
return probs, mask, dest_idx, num_experts * capacity
else:
ranks = F.one_hot(ranks, num_classes=capacity)
weight = mask * logits.type_as(inputs)
weight = mask * probs.type_as(inputs)
combine_weights = weight.unsqueeze(2) * ranks.unsqueeze(1)
sec_mask = combine_weights.bool()
return combine_weights, sec_mask
Expand Down Expand Up @@ -215,21 +256,22 @@ def forward(self,
inputs = self.noisy_func(inputs)

assert inputs.dtype == torch.float
logits = F.softmax(inputs, dim=-1) # logits: [s, e]
num_experts = logits.size(-1)
capacity = self.get_capacity(logits.shape)
probs = F.softmax(inputs, dim=-1)
num_experts = probs.size(-1)
capacity = self.get_capacity(inputs.shape)

top1_idx = torch.argmax(logits, dim=-1)
top1_idx = torch.argmax(probs, dim=-1)
mask1 = F.one_hot(top1_idx, num_classes=num_experts).to(torch.int32)
logits_except1 = logits.masked_fill(mask1.bool(), float("-inf"))
logits_except1 = probs.masked_fill(mask1.bool(), float("-inf"))
top2_idx = torch.argmax(logits_except1, dim=-1)
mask2 = F.one_hot(top2_idx, num_classes=num_experts).to(torch.int32)

cmask = (mask1 + mask2) # loss: [s, e]
cmask = cmask.float() / 2.0 # div 2 to normalize it to 1

# caculate loss
self.set_aux_loss(logits, cmask, num_experts)
expert_indices = torch.stack([top1_idx, top2_idx], dim=-1)
self.set_aux_loss(probs, expert_indices, num_experts)
self.set_z_loss(inputs)
self.pop_router_loss()

Expand All @@ -255,10 +297,10 @@ def forward(self,
mask = torch.stack([mask1, mask2], dim=0).to(torch.int32)
dest_idx = torch.stack([top1_idx * capacity + rank1, top2_idx * capacity + rank2], dim=0).to(torch.int32)

return logits, mask, dest_idx, num_experts * capacity
return probs, mask, dest_idx, num_experts * capacity
else:
weight1 = mask1 * logits.type_as(inputs)
weight2 = mask2 * logits.type_as(inputs)
weight1 = mask1 * probs.type_as(inputs)
weight2 = mask2 * probs.type_as(inputs)
rank1_sc = F.one_hot(rank1, num_classes=capacity)
rank2_sc = F.one_hot(rank2, num_classes=capacity)

Expand All @@ -282,9 +324,9 @@ class TopKRouter(MoeRouter):
processed by an expert, or that each expert receives at least one token.

Attributes:
num_selected_experts: Maximum number of experts to which each token is
routed. Tokens may be routed to fewer experts if particular experts are
oversubscribed / reach capacity.
num_selected_experts: Maximum number of experts to which each token is
routed. Tokens may be routed to fewer experts if particular experts are
oversubscribed / reach capacity.
"""

def __init__(self,
Expand Down Expand Up @@ -314,14 +356,15 @@ def forward(self,
Returns:
Dispatch and combine arrays for routing with masked matmuls.
"""
# TODO: add parallel group
num_groups, _, num_experts = router_probs.shape

# Top-k router probability and corresponding expert indices for each token.
# Shape: [num_groups, tokens_per_group, num_selected_experts].
expert_gate, expert_index = torch.topk(router_probs, self.k_value)

# TODO
# auxiliary_loss = _load_balancing_loss(router_probs, expert_index)
self.set_aux_loss(router_probs, expert_index, num_experts)
self.pop_router_loss()

# Make num_selected_experts the leading axis to ensure that top-1 choices
# have priority over top-2 choices, which have priority over top-3 choices,
Expand Down
36 changes: 35 additions & 1 deletion examples/language/openmoe/model/modeling_openmoe.py
Original file line number Diff line number Diff line change
Expand Up @@ -1020,7 +1020,19 @@ def _calculate_router_loss(self):
z_loss = self.config.router_z_loss_factor * sum(z_loss) / len(z_loss)
return aux_loss, z_loss

def _calculate_loss(self, logits, targets):
def _calculate_loss(self,
logits: torch.Tensor,
targets: torch.Tensor
) -> torch.Tensor:
"""Compute cross entropy and entropy for log probs and targets.

Args:
logits: [batch, length, num_classes] float array.
targets: categorical targets [batch, length] int array.

Returns:
Tuple of scalar loss.
"""
if len(logits.shape) != len(targets.shape) + 1:
raise ValueError('Incorrect shapes. Got shape %s logits and %s targets' %
(str(logits.shape), str(targets.shape)))
Expand All @@ -1045,6 +1057,28 @@ def _calculate_loss(self, logits, targets):


class ZLossCrossEntropy(torch.autograd.Function):
"""Computes cross entropy loss with stable custom gradient.

Computes a stabilized-gradient version of:
-jnp.sum(targets * nn.log_softmax(logits), axis=-1)

If z_loss > 0, then an auxiliary loss equal to z_loss*log(z)^2
will be added to the cross entropy loss (z = softmax normalization constant).
The two uses of z_loss are:
1. To keep the logits from drifting too far from zero, which can cause
unacceptable roundoff errors in bfloat16.
2. To encourage the logits to be normalized log-probabilities.

Args:
logits: [batch, length, num_classes] float array.
targets: categorical one-hot targets [batch, length, num_classes] float
array.
z_loss: coefficient for auxilliary z-loss loss term.

Returns:
tuple with the total loss and the z_loss, both
float arrays with shape [batch, length].
"""

@staticmethod
def forward(ctx, logits, targets, z_loss):
Expand Down