From 81a7c7b95fd2ac224c3ab4e5efd7c96b1dfe6a82 Mon Sep 17 00:00:00 2001 From: CWHer Date: Tue, 12 Sep 2023 11:46:46 +0800 Subject: [PATCH 1/5] feat: check z_loss and add doc --- colossalai/moe/routers.py | 23 +++++++++++++++++++---- 1 file changed, 19 insertions(+), 4 deletions(-) diff --git a/colossalai/moe/routers.py b/colossalai/moe/routers.py index 688471530758..63a90bdf31be 100644 --- a/colossalai/moe/routers.py +++ b/colossalai/moe/routers.py @@ -57,11 +57,26 @@ def set_aux_loss(self, logits: torch.Tensor, cmask: torch.Tensor, num_experts: i self._aux_loss = aux_loss def set_z_loss(self, router_logits: torch.Tensor): + """Compute router z-loss. + + The router z-loss was introduced in Designing Effective Sparse Expert Models + (https://arxiv.org/abs/2202.08906). It encourages router logits to remain + small in an effort to improve stability. + + Args: + router_logits: [num_groups, tokens_per_group, num_experts] router logits. + + Returns: + Scalar router z-loss. + """ assert self._z_loss is None - n, _ = router_logits.shape - log_z = torch.logsumexp(router_logits, axis=-1) - z_loss = log_z**2 - z_loss = torch.sum(z_loss, dtype=torch.float32) / n + if router_logits.dim() == 2: + router_logits = router_logits.unsqueeze(0) + assert router_logits.dim() == 3, "router_logits must be 3D tensor" + num_groups, tokens_per_group, _ = router_logits.shape + log_z = torch.logsumexp(router_logits, dim=-1) + z_loss = torch.sum(log_z**2, dtype=torch.float32 + ) / (num_groups * tokens_per_group) self._z_loss = z_loss def pop_router_loss(self) -> torch.Tensor: From af94e2f6cf5109376562334691f938bde414f4f9 Mon Sep 17 00:00:00 2001 From: CWHer Date: Tue, 12 Sep 2023 12:08:42 +0800 Subject: [PATCH 2/5] style: rename misleading variable --- colossalai/moe/routers.py | 34 ++++++++++++++++++---------------- 1 file changed, 18 insertions(+), 16 deletions(-) diff --git a/colossalai/moe/routers.py b/colossalai/moe/routers.py index 63a90bdf31be..82cf4509e627 100644 --- a/colossalai/moe/routers.py +++ b/colossalai/moe/routers.py @@ -141,15 +141,15 @@ def forward(self, inputs = self.noisy_func(inputs) assert inputs.dtype == torch.float - logits = F.softmax(inputs, dim=-1) - num_experts = logits.size(-1) - capacity = self.get_capacity(logits.shape) + probs = F.softmax(inputs, dim=-1) + num_experts = probs.size(-1) + capacity = self.get_capacity(inputs.shape) top1_idx = torch.argmax(inputs, dim=-1) mask = F.one_hot(top1_idx, num_classes=num_experts).to(torch.int32) # caculate router loss - self.set_aux_loss(logits, mask, num_experts) + self.set_aux_loss(probs, mask, num_experts) self.set_z_loss(inputs) self.pop_router_loss() @@ -175,10 +175,10 @@ def forward(self, mask = torch.sum(mask, dim=-1) mask = torch.stack([mask], dim=0).to(torch.int32) dest_idx = torch.stack([top1_idx * capacity + ranks], dim=0).to(torch.int32) - return logits, mask, dest_idx, num_experts * capacity + return probs, mask, dest_idx, num_experts * capacity else: ranks = F.one_hot(ranks, num_classes=capacity) - weight = mask * logits.type_as(inputs) + weight = mask * probs.type_as(inputs) combine_weights = weight.unsqueeze(2) * ranks.unsqueeze(1) sec_mask = combine_weights.bool() return combine_weights, sec_mask @@ -230,13 +230,13 @@ def forward(self, inputs = self.noisy_func(inputs) assert inputs.dtype == torch.float - logits = F.softmax(inputs, dim=-1) # logits: [s, e] - num_experts = logits.size(-1) - capacity = self.get_capacity(logits.shape) + probs = F.softmax(inputs, dim=-1) + num_experts = probs.size(-1) + capacity = self.get_capacity(inputs.shape) - top1_idx = torch.argmax(logits, dim=-1) + top1_idx = torch.argmax(probs, dim=-1) mask1 = F.one_hot(top1_idx, num_classes=num_experts).to(torch.int32) - logits_except1 = logits.masked_fill(mask1.bool(), float("-inf")) + logits_except1 = probs.masked_fill(mask1.bool(), float("-inf")) top2_idx = torch.argmax(logits_except1, dim=-1) mask2 = F.one_hot(top2_idx, num_classes=num_experts).to(torch.int32) @@ -244,7 +244,7 @@ def forward(self, cmask = cmask.float() / 2.0 # div 2 to normalize it to 1 # caculate loss - self.set_aux_loss(logits, cmask, num_experts) + self.set_aux_loss(probs, cmask, num_experts) self.set_z_loss(inputs) self.pop_router_loss() @@ -270,10 +270,10 @@ def forward(self, mask = torch.stack([mask1, mask2], dim=0).to(torch.int32) dest_idx = torch.stack([top1_idx * capacity + rank1, top2_idx * capacity + rank2], dim=0).to(torch.int32) - return logits, mask, dest_idx, num_experts * capacity + return probs, mask, dest_idx, num_experts * capacity else: - weight1 = mask1 * logits.type_as(inputs) - weight2 = mask2 * logits.type_as(inputs) + weight1 = mask1 * probs.type_as(inputs) + weight2 = mask2 * probs.type_as(inputs) rank1_sc = F.one_hot(rank1, num_classes=capacity) rank2_sc = F.one_hot(rank2, num_classes=capacity) @@ -335,7 +335,9 @@ def forward(self, # Shape: [num_groups, tokens_per_group, num_selected_experts]. expert_gate, expert_index = torch.topk(router_probs, self.k_value) - # TODO + # TODO: + # 1. add router loss + # 2. add parallel group # auxiliary_loss = _load_balancing_loss(router_probs, expert_index) # Make num_selected_experts the leading axis to ensure that top-1 choices From 46bbe13b68fd980216a4482646f7645cc5393718 Mon Sep 17 00:00:00 2001 From: CWHer Date: Tue, 12 Sep 2023 14:47:33 +0800 Subject: [PATCH 3/5] feat: modify auxiliary loss --- colossalai/moe/routers.py | 42 +++++++++++++++++++++++++++++++++------ 1 file changed, 36 insertions(+), 6 deletions(-) diff --git a/colossalai/moe/routers.py b/colossalai/moe/routers.py index 82cf4509e627..4ed7cb08214b 100644 --- a/colossalai/moe/routers.py +++ b/colossalai/moe/routers.py @@ -49,11 +49,40 @@ def get_capacity(self, logits_shape): assert capacity > 0 return capacity - def set_aux_loss(self, logits: torch.Tensor, cmask: torch.Tensor, num_experts: int) -> None: + def set_aux_loss(self, + router_probs: torch.Tensor, + expert_indices: torch.Tensor, + num_experts: int + ) -> None: + """Computes auxiliary load balancing loss as in Switch Transformer. + + See Switch Transformer (https://arxiv.org/abs/2101.03961). This function + implements the loss function presented in equations (4) - (6). It aims to + penalize those cases where the routing between experts is unbalanced. + + Args: + router_probs: Probability assigned to each expert per token. Shape: + [num_groups, tokens_per_group, num_experts]. + expert_indices: [num_groups, tokens_per_group, num_selected_experts] + indices identifying the top num_selected_experts for a given token. + """ assert self._aux_loss is None - me = torch.mean(logits, dim=0) - ce = torch.mean(cmask.float(), dim=0) - aux_loss = num_experts * torch.sum(me * ce) + if router_probs.dim() == expert_indices.dim() == 2: + router_probs = router_probs.unsqueeze(0) + expert_indices = expert_indices.unsqueeze(0) + assert router_probs.dim() == expert_indices.dim() == 3, \ + "router_probs must be 3D tensor and expert_indices must be 4D tensor" + + # Shape: [num_groups, tokens_per_group, num_selected_experts, num_experts]. + expert_mask = F.one_hot(expert_indices, num_experts) + # For a given token, determine if it was routed to a given expert. + # Shape: [num_groups, tokens_per_group, num_experts] + expert_mask = expert_mask.max(dim=-2)[0] + + tokens_per_group_and_expert = torch.mean(expert_mask.float(), dim=-2) + router_prob_per_group_and_expert = torch.mean(router_probs.float(), dim=-2) + aux_loss = num_experts**2 * torch.mean( + tokens_per_group_and_expert * router_prob_per_group_and_expert) self._aux_loss = aux_loss def set_z_loss(self, router_logits: torch.Tensor): @@ -149,7 +178,7 @@ def forward(self, mask = F.one_hot(top1_idx, num_classes=num_experts).to(torch.int32) # caculate router loss - self.set_aux_loss(probs, mask, num_experts) + self.set_aux_loss(probs, top1_idx.unsqueeze(-1), num_experts) self.set_z_loss(inputs) self.pop_router_loss() @@ -244,7 +273,8 @@ def forward(self, cmask = cmask.float() / 2.0 # div 2 to normalize it to 1 # caculate loss - self.set_aux_loss(probs, cmask, num_experts) + expert_indices = torch.stack([top1_idx, top2_idx], dim=-1) + self.set_aux_loss(probs, expert_indices, num_experts) self.set_z_loss(inputs) self.pop_router_loss() From 37bf69291b639caa10374059bd10b18f8cadd16a Mon Sep 17 00:00:00 2001 From: CWHer Date: Tue, 12 Sep 2023 15:10:41 +0800 Subject: [PATCH 4/5] feat: add aux_loss in topk router and modify doc --- colossalai/moe/routers.py | 16 ++++++---------- 1 file changed, 6 insertions(+), 10 deletions(-) diff --git a/colossalai/moe/routers.py b/colossalai/moe/routers.py index 4ed7cb08214b..6fa89a416203 100644 --- a/colossalai/moe/routers.py +++ b/colossalai/moe/routers.py @@ -94,9 +94,6 @@ def set_z_loss(self, router_logits: torch.Tensor): Args: router_logits: [num_groups, tokens_per_group, num_experts] router logits. - - Returns: - Scalar router z-loss. """ assert self._z_loss is None if router_logits.dim() == 2: @@ -327,9 +324,9 @@ class TopKRouter(MoeRouter): processed by an expert, or that each expert receives at least one token. Attributes: - num_selected_experts: Maximum number of experts to which each token is - routed. Tokens may be routed to fewer experts if particular experts are - oversubscribed / reach capacity. + num_selected_experts: Maximum number of experts to which each token is + routed. Tokens may be routed to fewer experts if particular experts are + oversubscribed / reach capacity. """ def __init__(self, @@ -359,16 +356,15 @@ def forward(self, Returns: Dispatch and combine arrays for routing with masked matmuls. """ + # TODO: add parallel group num_groups, _, num_experts = router_probs.shape # Top-k router probability and corresponding expert indices for each token. # Shape: [num_groups, tokens_per_group, num_selected_experts]. expert_gate, expert_index = torch.topk(router_probs, self.k_value) - # TODO: - # 1. add router loss - # 2. add parallel group - # auxiliary_loss = _load_balancing_loss(router_probs, expert_index) + self.set_aux_loss(router_probs, expert_index, num_experts) + self.pop_router_loss() # Make num_selected_experts the leading axis to ensure that top-1 choices # have priority over top-2 choices, which have priority over top-3 choices, From f03af279fc0124c8300d761daa909e316f8a8197 Mon Sep 17 00:00:00 2001 From: CWHer Date: Tue, 12 Sep 2023 18:19:51 +0800 Subject: [PATCH 5/5] docs: add fn doc --- .../openmoe/model/modeling_openmoe.py | 36 ++++++++++++++++++- 1 file changed, 35 insertions(+), 1 deletion(-) diff --git a/examples/language/openmoe/model/modeling_openmoe.py b/examples/language/openmoe/model/modeling_openmoe.py index ec7e1e8941f7..cf9c5013cc29 100644 --- a/examples/language/openmoe/model/modeling_openmoe.py +++ b/examples/language/openmoe/model/modeling_openmoe.py @@ -1020,7 +1020,19 @@ def _calculate_router_loss(self): z_loss = self.config.router_z_loss_factor * sum(z_loss) / len(z_loss) return aux_loss, z_loss - def _calculate_loss(self, logits, targets): + def _calculate_loss(self, + logits: torch.Tensor, + targets: torch.Tensor + ) -> torch.Tensor: + """Compute cross entropy and entropy for log probs and targets. + + Args: + logits: [batch, length, num_classes] float array. + targets: categorical targets [batch, length] int array. + + Returns: + Tuple of scalar loss. + """ if len(logits.shape) != len(targets.shape) + 1: raise ValueError('Incorrect shapes. Got shape %s logits and %s targets' % (str(logits.shape), str(targets.shape))) @@ -1045,6 +1057,28 @@ def _calculate_loss(self, logits, targets): class ZLossCrossEntropy(torch.autograd.Function): + """Computes cross entropy loss with stable custom gradient. + + Computes a stabilized-gradient version of: + -jnp.sum(targets * nn.log_softmax(logits), axis=-1) + + If z_loss > 0, then an auxiliary loss equal to z_loss*log(z)^2 + will be added to the cross entropy loss (z = softmax normalization constant). + The two uses of z_loss are: + 1. To keep the logits from drifting too far from zero, which can cause + unacceptable roundoff errors in bfloat16. + 2. To encourage the logits to be normalized log-probabilities. + + Args: + logits: [batch, length, num_classes] float array. + targets: categorical one-hot targets [batch, length, num_classes] float + array. + z_loss: coefficient for auxilliary z-loss loss term. + + Returns: + tuple with the total loss and the z_loss, both + float arrays with shape [batch, length]. + """ @staticmethod def forward(ctx, logits, targets, z_loss):