diff --git a/colossalai/moe/experts.py b/colossalai/moe/experts.py index 81a7b21544e4..47dceeae8edb 100644 --- a/colossalai/moe/experts.py +++ b/colossalai/moe/experts.py @@ -116,19 +116,19 @@ def forward(self, x: torch.Tensor, param_slice: Tuple[slice] = (slice(None),)) - x = x.reshape(e, -1, h) if self.gated: + x_gate = [torch.mm(x[i], self.wi_gate[param_slice][i]) for i in range(e)] + x_up = [torch.mm(x[i], self.wi_up[param_slice][i]) for i in range(e)] if self.use_kernel and HAS_TRITON and self.act_name == "swiglu": - x = LlamaActCombine.apply( - torch.bmm(x, self.wi_gate[param_slice]), - torch.bmm(x, self.wi_up[param_slice]), - ) + x = [LlamaActCombine.apply(x_gate[i], x_up[i]) for i in range(e)] else: - x = self.act(torch.bmm(x, self.wi_gate[param_slice])) * torch.bmm(x, self.wi_up[param_slice]) + x = [self.act(x_gate[i]) * x_up[i] for i in range(e)] else: - x = torch.bmm(x, self.wi[param_slice]) - x = self.act(x) - x = self.drop(x) - x = torch.bmm(x, self.wo[param_slice]) + x = [torch.mm(x[i], self.wi[param_slice][i]) for i in range(e)] + x = [self.act(x[i]) for i in range(e)] + x = [self.drop(x[i]) for i in range(e)] + x = [torch.mm(x[i], self.wo[param_slice][i]) for i in range(e)] + x = torch.cat([x[i].unsqueeze(0) for i in range(e)], dim=0) x = x.reshape(inshape) x = x.transpose(0, 1).contiguous() x = MoeOutGradScaler.apply(x, self.ep_size)