Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
64 changes: 54 additions & 10 deletions colossalai/moe/_operation.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@
from colossalai.moe.manager import MOE_MANAGER

MOE_KERNEL = None
WORLD_HANDLE_ALLGATHER = None
WORLD_HANDLE_REDUCESCATTER = None


def load_moe():
Expand All @@ -20,9 +22,15 @@ def load_moe():
class AllGather(torch.autograd.Function):

@staticmethod
def forward(ctx: Any, inputs: Tensor, group: Optional[ProcessGroup] = None) -> Tensor:
def forward(
ctx: Any,
inputs: Tensor,
group: Optional[ProcessGroup] = None,
overlap: bool = False,
) -> Tensor:
if ctx is not None:
ctx.comm_grp = group
ctx.overlap = overlap

comm_size = dist.get_world_size(group)
if comm_size == 1:
Expand All @@ -31,20 +39,41 @@ def forward(ctx: Any, inputs: Tensor, group: Optional[ProcessGroup] = None) -> T
buffer_shape = (comm_size,) + inputs.shape
outputs = torch.empty(buffer_shape, dtype=inputs.dtype, device=inputs.device)
buffer_list = list(torch.chunk(outputs, comm_size, dim=0))
dist.all_gather(buffer_list, inputs, group=group)
return outputs
if not overlap:
dist.all_gather(buffer_list, inputs, group=group)
return outputs, None
else:
handle = dist.all_gather(buffer_list, inputs, group=group, async_op=True)
if ctx is None and overlap:
global WORLD_HANDLE_ALLGATHER
WORLD_HANDLE_ALLGATHER = handle
return outputs, handle

@staticmethod
def backward(ctx: Any, grad_outputs: Tensor) -> Tuple[Tensor, None]:
return ReduceScatter.forward(None, grad_outputs, ctx.comm_grp), None
def backward(ctx: Any, *grad_outputs) -> Tuple[Tensor, None]:
global WORLD_HANDLE_REDUCESCATTER
if WORLD_HANDLE_REDUCESCATTER is not None:
WORLD_HANDLE_REDUCESCATTER.wait()
WORLD_HANDLE_REDUCESCATTER = None
return (
ReduceScatter.forward(None, grad_outputs[0], ctx.comm_grp, ctx.overlap)[0],
None,
None,
)


class ReduceScatter(torch.autograd.Function):

@staticmethod
def forward(ctx: Any, inputs: Tensor, group: Optional[ProcessGroup] = None) -> Tensor:
def forward(
ctx: Any,
inputs: Tensor,
group: Optional[ProcessGroup] = None,
overlap: bool = False,
) -> Tensor:
if ctx is not None:
ctx.comm_grp = group
ctx.overlap = overlap

comm_size = dist.get_world_size(group)
if comm_size == 1:
Expand All @@ -56,12 +85,27 @@ def forward(ctx: Any, inputs: Tensor, group: Optional[ProcessGroup] = None) -> T
output_shape = inputs.shape[1:]
outputs = torch.empty(output_shape, dtype=inputs.dtype, device=inputs.device)
buffer_list = list(torch.chunk(inputs, comm_size, dim=0))
dist.reduce_scatter(outputs, buffer_list, group=group)
return outputs
if not overlap:
dist.reduce_scatter(outputs, buffer_list, group=group)
return outputs, None
else:
handle = dist.reduce_scatter(outputs, buffer_list, group=group, async_op=True)
if ctx is None and overlap:
global WORLD_HANDLE_REDUCESCATTER
WORLD_HANDLE_REDUCESCATTER = handle
return outputs, handle

@staticmethod
def backward(ctx: Any, grad_outputs: Tensor) -> Tuple[Tensor, None]:
return AllGather.forward(None, grad_outputs, ctx.comm_grp), None
def backward(ctx: Any, *grad_outputs) -> Tuple[Tensor, None]:
global WORLD_HANDLE_ALLGATHER
if WORLD_HANDLE_ALLGATHER is not None:
WORLD_HANDLE_ALLGATHER.wait()
WORLD_HANDLE_ALLGATHER = None
return (
AllGather.forward(None, grad_outputs[0], ctx.comm_grp, ctx.overlap)[0],
None,
None,
)


class AllToAll(torch.autograd.Function):
Expand Down
80 changes: 52 additions & 28 deletions colossalai/moe/experts.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import math
from contextlib import nullcontext
from typing import Callable, Optional
from typing import Callable, Optional, Tuple

import torch
import torch.nn as nn
Expand Down Expand Up @@ -52,8 +52,9 @@ def __init__(
num_experts, use_tp=True if expert_parallel == "TP" else False)
# get settings for different parallel
if expert_parallel == "TP":
assert intermediate_size % MOE_MANAGER.max_ep_size == 0, \
"intermediate_size should be divide by maximum expert parallel size"
assert (
intermediate_size %
MOE_MANAGER.max_ep_size == 0), "intermediate_size should be divide by maximum expert parallel size"
intermediate_size = intermediate_size // MOE_MANAGER.max_ep_size
num_experts = self.num_total_experts
else:
Expand All @@ -77,11 +78,11 @@ def __init__(
seed_ctx = nullcontext()
with seed_ctx:
if gated:
nn.init.trunc_normal_(self.wi_gate, std=math.sqrt(0.1 / hidden_size))
nn.init.trunc_normal_(self.wi_up, std=math.sqrt(0.1 / hidden_size))
torch.nn.init.trunc_normal_(self.wi_gate, std=math.sqrt(0.1 / hidden_size))
torch.nn.init.trunc_normal_(self.wi_up, std=math.sqrt(0.1 / hidden_size))
else:
nn.init.trunc_normal_(self.wi, std=math.sqrt(0.1 / hidden_size))
nn.init.trunc_normal_(self.wo, std=math.sqrt(0.1 / intermediate_size))
torch.nn.init.trunc_normal_(self.wi, std=math.sqrt(0.1 / hidden_size))
torch.nn.init.trunc_normal_(self.wo, std=math.sqrt(0.1 / intermediate_size))

self.act_name = activation
self.act = get_activation(activation)
Expand All @@ -91,7 +92,7 @@ def __init__(
for param in self.parameters():
set_moe_tensor_info(param, self.moe_info)

def forward(self, x: torch.Tensor) -> torch.Tensor:
def forward(self, x: torch.Tensor, param_slice: Tuple[slice] = (slice(None),)) -> torch.Tensor:
"""
Args:
x (torch.Tensor): The input tensor of shape (num_groups, num_experts, capacity, hidden_size)
Expand All @@ -110,14 +111,17 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:

if self.gated:
if HAS_TRITON and self.act_name == "swiglu":
x = LlamaActCombine.apply(torch.bmm(x, self.wi_gate), torch.bmm(x, self.wi_up))
x = LlamaActCombine.apply(
torch.bmm(x, self.wi_gate[param_slice]),
torch.bmm(x, self.wi_up[param_slice]),
)
else:
x = self.act(torch.bmm(x, self.wi_gate)) * torch.bmm(x, self.wi_up)
x = self.act(torch.bmm(x, self.wi_gate[param_slice])) * torch.bmm(x, self.wi_up[param_slice])
else:
x = torch.bmm(x, self.wi)
x = torch.bmm(x, self.wi[param_slice])
x = self.act(x)
x = self.drop(x)
x = torch.bmm(x, self.wo)
x = torch.bmm(x, self.wo[param_slice])

x = x.reshape(inshape)
x = x.transpose(0, 1).contiguous()
Expand All @@ -130,14 +134,24 @@ class EPMLPExperts(BaseMLPExperts):
Use expert parallelism to split each expert evenly, which can deploy experts in
"""

def __init__(self,
num_experts: int,
hidden_size: int,
intermediate_size: int,
activation=None,
drop_rate: float = 0,
gated: bool = False):
super().__init__(num_experts, hidden_size, intermediate_size, "EP", activation, drop_rate, gated)
def __init__(
self,
num_experts: int,
hidden_size: int,
intermediate_size: int,
activation=None,
drop_rate: float = 0,
gated: bool = False,
):
super().__init__(
num_experts,
hidden_size,
intermediate_size,
"EP",
activation,
drop_rate,
gated,
)


class TPMLPExperts(BaseMLPExperts):
Expand All @@ -146,14 +160,24 @@ class TPMLPExperts(BaseMLPExperts):
maximum expert parallel size can't be divide by the number of experts.
"""

def __init__(self,
num_experts: int,
hidden_size: int,
intermediate_size: int,
activation: str = None,
drop_rate: float = 0,
gated: bool = False):
super().__init__(num_experts, hidden_size, intermediate_size, "TP", activation, drop_rate, gated)
def __init__(
self,
num_experts: int,
hidden_size: int,
intermediate_size: int,
activation: str = None,
drop_rate: float = 0,
gated: bool = False,
):
super().__init__(
num_experts,
hidden_size,
intermediate_size,
"TP",
activation,
drop_rate,
gated,
)


def get_expert_class(name: str) -> BaseMLPExperts:
Expand Down
Loading