From c3cdda42513bfb6fb0ecd4fe43a755d69c62a812 Mon Sep 17 00:00:00 2001 From: Xuanlei Zhao Date: Tue, 10 Oct 2023 15:32:18 +0800 Subject: [PATCH] update mm --- colossalai/moe/experts.py | 18 +++++++++--------- 1 file changed, 9 insertions(+), 9 deletions(-) diff --git a/colossalai/moe/experts.py b/colossalai/moe/experts.py index e05ea59b3d28..904549b643aa 100644 --- a/colossalai/moe/experts.py +++ b/colossalai/moe/experts.py @@ -110,19 +110,19 @@ def forward(self, x: torch.Tensor, param_slice: Tuple[slice] = (slice(None),)) - x = x.reshape(e, -1, h) if self.gated: + x_gate = [torch.mm(x[i], self.wi_gate[param_slice][i]) for i in range(e)] + x_up = [torch.mm(x[i], self.wi_up[param_slice][i]) for i in range(e)] if HAS_TRITON and self.act_name == "swiglu": - x = LlamaActCombine.apply( - torch.bmm(x, self.wi_gate[param_slice]), - torch.bmm(x, self.wi_up[param_slice]), - ) + x = [LlamaActCombine.apply(x_gate[i], x_up[i]) for i in range(e)] else: - x = self.act(torch.bmm(x, self.wi_gate[param_slice])) * torch.bmm(x, self.wi_up[param_slice]) + x = [self.act(x_gate[i]) * x_up[i] for i in range(e)] else: - x = torch.bmm(x, self.wi[param_slice]) - x = self.act(x) - x = self.drop(x) - x = torch.bmm(x, self.wo[param_slice]) + x = [torch.mm(x[i], self.wi[param_slice][i]) for i in range(e)] + x = [self.act(x[i]) for i in range(e)] + x = [self.drop(x[i]) for i in range(e)] + x = [torch.mm(x[i], self.wo[param_slice][i]) for i in range(e)] + x = torch.cat([x[i].unsqueeze(0) for i in range(e)], dim=0) x = x.reshape(inshape) x = x.transpose(0, 1).contiguous() x = MoeOutGradScaler.apply(x, self.ep_size)