Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
69 changes: 64 additions & 5 deletions colossalai/moe/_operation.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,6 @@
from torch.cuda.amp import custom_bwd, custom_fwd
from torch.distributed import ProcessGroup

from colossalai.moe.manager import MOE_MANAGER

MOE_KERNEL = None


Expand Down Expand Up @@ -64,7 +62,7 @@ class ReduceScatter(torch.autograd.Function):
def forward(
ctx: Any,
inputs: Tensor,
group: Optional[ProcessGroup] = None,
group: ProcessGroup,
overlap: bool = False,
) -> Tuple[Tensor, Any]:
"""
Expand Down Expand Up @@ -113,14 +111,16 @@ class AllToAll(torch.autograd.Function):
def forward(
ctx: Any,
inputs: Tensor,
group: Optional[ProcessGroup] = None,
group: ProcessGroup,
overlap: bool = False,
) -> Tuple[Tensor, Any]:
"""
Returns:
outputs: Tensor
handle: Optional[Work], if overlap is True
"""
assert ctx is not None or not overlap

if ctx is not None:
ctx.comm_grp = group
if not inputs.is_contiguous():
Expand All @@ -138,8 +138,67 @@ def forward(
@staticmethod
def backward(ctx: Any, *grad_outputs) -> Tuple[Tensor, None, None]:
return (
AllToAll.forward(None, grad_outputs[0], ctx.comm_grp)[0],
AllToAll.forward(None, grad_outputs[0], ctx.comm_grp, False)[0],
None,
None,
)


class HierarchicalAllToAll(torch.autograd.Function):

@staticmethod
def forward(
ctx: Any,
inputs: Tensor,
groups: Tuple[ProcessGroup],
) -> Tensor:
"""
Returns:
outputs: Tensor
"""
# TODO: we can reduce comm volume by removing empty capacity
if ctx is not None:
ctx.comm_grps = groups
intra_node_group, inter_node_group = groups

local_world_size = dist.get_world_size(intra_node_group)
num_group = dist.get_world_size(inter_node_group) if inter_node_group is not None else 1
world_size = local_world_size * num_group
src_rank = dist.get_process_group_ranks(intra_node_group)[0]
outputs = torch.empty_like(inputs)

if dist.get_rank() == src_rank:
# intra-node gather
intra_output = [torch.empty_like(inputs) for _ in range(local_world_size)]
dist.gather(inputs, intra_output, dst=src_rank, group=intra_node_group)

intra_output = [v.chunk(world_size, dim=0) for v in intra_output]
intra_output = torch.cat(sum(zip(*intra_output), ()))

# inter-node all-to-all
Comment thread
cwher marked this conversation as resolved.
if inter_node_group is not None:
inter_output = torch.empty_like(intra_output)
dist.all_to_all_single(inter_output, intra_output, group=inter_node_group)

# layout transform
inter_output = inter_output.chunk(num_group, dim=0)
inter_output = [v.chunk(local_world_size, dim=0) for v in inter_output]
intra_output = torch.cat(sum(zip(*inter_output), ()))

# intra-node scatter
intra_output = list(intra_output.chunk(local_world_size, dim=0))
dist.scatter(outputs, intra_output, src=src_rank, group=intra_node_group)

else:
dist.gather(inputs, dst=src_rank, group=intra_node_group)
dist.scatter(outputs, src=src_rank, group=intra_node_group)

return outputs

@staticmethod
def backward(ctx: Any, *grad_outputs) -> Tuple[Tensor, None]:
return (
HierarchicalAllToAll.forward(None, grad_outputs[0], ctx.comm_grps),
None,
)

Expand Down
82 changes: 55 additions & 27 deletions colossalai/moe/layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,12 +7,12 @@
import torch.nn as nn
import torch.nn.functional as F

from colossalai.moe._operation import AllGather, AllToAll, MoeCombine, MoeDispatch, ReduceScatter
from colossalai.moe._operation import AllGather, AllToAll, HierarchicalAllToAll, MoeCombine, MoeDispatch, ReduceScatter
from colossalai.moe.experts import MLPExperts
from colossalai.moe.load_balance import LoadBalancer
from colossalai.moe.manager import MOE_MANAGER
from colossalai.moe.routers import MoeRouter, get_router_cls
from colossalai.moe.utils import get_noise_generator
from colossalai.moe.utils import create_ep_hierarchical_group, get_noise_generator
from colossalai.tensor.moe_tensor.api import get_dp_group, get_ep_group, get_ep_size


Expand Down Expand Up @@ -51,19 +51,20 @@ def __init__(
hidden_size: int,
intermediate_size: int,
router_top_k: int = 1,
router_capacity_factor_train: Optional[float] = 1.25,
router_capacity_factor_eval: Optional[float] = 2.0,
router_min_capacity: Optional[int] = 4,
router_capacity_factor_train: float = 1.25,
router_capacity_factor_eval: float = 2.0,
router_min_capacity: int = 4,
router_noisy_policy: Optional[str] = None,
router_drop_tks: Optional[bool] = True,
router_drop_tks: bool = True,
mlp_activation: Optional[str] = None,
mlp_gated: Optional[bool] = False,
enable_load_balance: Optional[bool] = False,
load_balance_tolerance: Optional[float] = 0.1,
load_balance_beam_width: Optional[int] = 8,
load_balance_group_swap_factor: Optional[float] = 0.4,
enable_kernel: Optional[bool] = False,
enable_comm_overlap: Optional[bool] = False,
mlp_gated: bool = False,
enable_load_balance: bool = False,
load_balance_tolerance: float = 0.1,
load_balance_beam_width: int = 8,
load_balance_group_swap_factor: float = 0.4,
enable_kernel: bool = False,
enable_comm_overlap: bool = False,
enable_hierarchical_comm: bool = False,
):
super().__init__()
self.hidden_size = hidden_size
Expand Down Expand Up @@ -104,6 +105,8 @@ def __init__(
if self.expert_parallel is not None:
self.ep_group = get_ep_group(self.experts)
self.ep_size = get_ep_size(self.experts)
self.ep_hierarchical_group = create_ep_hierarchical_group(
self.ep_group) if enable_hierarchical_comm else None
self.dp_group = get_dp_group(self.experts)
else:
self.ep_group = None
Expand Down Expand Up @@ -132,7 +135,7 @@ def __init__(
def reset_parameters(self):
torch.nn.init.normal_(self.gate_weight, std=math.sqrt(0.1 / self.hidden_size))

def forward(self, inputs: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
def forward(self, inputs: torch.Tensor) -> torch.Tensor:
"""
Args:
inputs (torch.Tensor): The input tensor of shape (batch_size, seq_len, hidden_size)
Expand All @@ -158,7 +161,8 @@ def forward(self, inputs: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
self.load_balancer.update_load(expert_load)

# the result from the router
route_result_list = self.router(inputs=gate_output, use_kernel=self.enable_kernel, ep_group=self.ep_group)
used_capacity, *route_result_list = self.router(
inputs=gate_output, use_kernel=self.enable_kernel, ep_group=self.ep_group)

# dispatch_data: (num_experts, capacity, hidden_size)
if self.enable_kernel:
Expand All @@ -170,9 +174,17 @@ def forward(self, inputs: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:

# expert_output: (num_groups, num_experts, capacity, hidden_size)
if self.expert_parallel == "EP":
expert_output = self._ep_process(dispatch_data, overlap=self.enable_comm_overlap)
expert_output = self._ep_process(
dispatch_data,
used_capacity,
overlap=self.enable_comm_overlap
)
elif self.expert_parallel == "TP":
expert_output = self._tp_process(dispatch_data, overlap=self.enable_comm_overlap)
expert_output = self._tp_process(
dispatch_data,
used_capacity,
overlap=self.enable_comm_overlap
)
elif self.expert_parallel is None:
expert_output = self._local_process(dispatch_data)
else:
Expand All @@ -196,7 +208,12 @@ def _local_process(self, expert_in: torch.Tensor) -> torch.Tensor:
expert_out = self.experts(expert_in)
return expert_out

def _ep_process(self, dispatch_data: torch.Tensor, overlap: bool = False) -> torch.Tensor:
def _ep_process(
self,
dispatch_data: torch.Tensor,
used_capacity: torch.Tensor,
overlap: bool = False
) -> torch.Tensor:
"""
Expert Parallel

Expand All @@ -207,12 +224,18 @@ def _ep_process(self, dispatch_data: torch.Tensor, overlap: bool = False) -> tor
torch.Tensor: (num_experts, capacity, hidden_size)
"""
if not overlap or dist.get_world_size(self.ep_group) == 1:
expert_input = AllToAll.apply(dispatch_data, self.ep_group, False)[0]
expert_input = expert_input.reshape(self.ep_size, self.num_local_experts, -1, self.hidden_size)
expert_output = self.experts(expert_input)
expert_output = AllToAll.apply(expert_output, self.ep_group, False)[0]
return expert_output

if self.ep_hierarchical_group is not None:
expert_input = HierarchicalAllToAll.apply(dispatch_data, self.ep_hierarchical_group)
expert_input = expert_input.reshape(self.ep_size, self.num_local_experts, -1, self.hidden_size)
expert_output = self.experts(expert_input)
expert_output = HierarchicalAllToAll.apply(expert_output, self.ep_hierarchical_group)
return expert_output
else:
expert_input = AllToAll.apply(dispatch_data, self.ep_group, False)[0]
expert_input = expert_input.reshape(self.ep_size, self.num_local_experts, -1, self.hidden_size)
expert_output = self.experts(expert_input)
expert_output = AllToAll.apply(expert_output, self.ep_group, False)[0]
return expert_output
else:

@dataclasses.dataclass
Expand Down Expand Up @@ -261,7 +284,12 @@ class Capsule:

return output

def _tp_process(self, dispatch_data: torch.Tensor, overlap: bool = False) -> torch.Tensor:
def _tp_process(
self,
dispatch_data: torch.Tensor,
used_capacity: torch.Tensor,
overlap: bool = False
) -> torch.Tensor:
"""
without overlap:
| C |
Expand Down Expand Up @@ -295,8 +323,8 @@ class Capsule:
NUM_CHUNK = 4
NUM_STAGES = 4

assert (dispatch_data.shape[0] % NUM_CHUNK == 0
), "arbitrary chunk num is not supported yet, please use chunk num that can divide num_experts"
assert dispatch_data.shape[0] % NUM_CHUNK == 0, \
"arbitrary chunk num is not supported yet, please use chunk num that can divide num_experts"
chunk_size = dispatch_data.shape[0] // NUM_CHUNK
chunk_data = torch.split(dispatch_data, chunk_size, dim=0)
output = torch.empty_like(dispatch_data)
Expand Down
48 changes: 28 additions & 20 deletions colossalai/moe/routers.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,9 +138,10 @@ def __init__(self,
self.select_policy = select_policy
assert select_policy in {"first", "random"}
if select_policy == "random":
self.uniform = torch.distributions.uniform.Uniform(low=torch.tensor(0.0, device=get_current_device()),
high=torch.tensor(1.0,
device=get_current_device())).rsample
self.uniform = torch.distributions.uniform.Uniform(
low=torch.tensor(0.0, device=get_current_device()),
high=torch.tensor(1.0, device=get_current_device())
).rsample

def forward(self, inputs: torch.Tensor, use_kernel: bool = False, ep_group: Optional[ProcessGroup] = None) -> Tuple:
"""
Expand All @@ -165,7 +166,7 @@ def forward(self, inputs: torch.Tensor, use_kernel: bool = False, ep_group: Opti
top1_idx = torch.argmax(inputs, dim=-1)
mask = F.one_hot(top1_idx, num_classes=num_experts).to(torch.int32)

# caculate router loss
# calculate router loss
self.set_aux_loss(probs, top1_idx.unsqueeze(-1), num_experts)
self.set_z_loss(inputs)
self.pop_router_loss()
Expand All @@ -187,18 +188,19 @@ def forward(self, inputs: torch.Tensor, use_kernel: bool = False, ep_group: Opti
raise NotImplementedError("Not support such select policy yet.")

ranks = torch.sum(mask * ranks, dim=-1)
used_capacity = mask.sum(dim=0)

if use_kernel:
mask = torch.sum(mask, dim=-1)
mask = torch.stack([mask], dim=0).to(torch.int32)
dest_idx = torch.stack([top1_idx * capacity + ranks], dim=0).to(torch.int32)
return probs, mask, dest_idx, num_experts * capacity
return used_capacity, probs, mask, dest_idx, num_experts * capacity
else:
ranks = F.one_hot(ranks, num_classes=capacity)
weight = mask * probs.type_as(inputs)
combine_weights = weight.unsqueeze(2) * ranks.unsqueeze(1)
sec_mask = combine_weights.bool()
return combine_weights, sec_mask
return used_capacity, combine_weights, sec_mask


class Top2Router(MoeRouter):
Expand Down Expand Up @@ -256,7 +258,7 @@ def forward(self, inputs: torch.Tensor, use_kernel: bool = False, ep_group: Opti
cmask = (mask1 + mask2) # loss: [s, e]
cmask = cmask.float() / 2.0 # div 2 to normalize it to 1

# caculate loss
# calculate loss
expert_indices = torch.stack([top1_idx, top2_idx], dim=-1)
self.set_aux_loss(probs, expert_indices, num_experts)
self.set_z_loss(inputs)
Expand All @@ -273,6 +275,7 @@ def forward(self, inputs: torch.Tensor, use_kernel: bool = False, ep_group: Opti

mask1 *= torch.lt(rank1, capacity)
mask2 *= torch.lt(rank2, capacity)
used_capacity = mask1.sum(dim=0) + mask2.sum(dim=0)

rank1 = torch.sum(mask1 * rank1, dim=-1)
rank2 = torch.sum(mask2 * rank2, dim=-1)
Expand All @@ -284,18 +287,23 @@ def forward(self, inputs: torch.Tensor, use_kernel: bool = False, ep_group: Opti
mask = torch.stack([mask1, mask2], dim=0).to(torch.int32)
dest_idx = torch.stack([top1_idx * capacity + rank1, top2_idx * capacity + rank2], dim=0).to(torch.int32)

return probs, mask, dest_idx, num_experts * capacity
return used_capacity, probs, mask, dest_idx, num_experts * capacity
else:
# >>> original code
# weight1 = mask1 * probs.type_as(inputs)
# weight2 = mask2 * probs.type_as(inputs)
# rank1_sc = F.one_hot(rank1, num_classes=capacity)
# rank2_sc = F.one_hot(rank2, num_classes=capacity)

# cb_weight1 = weight1.unsqueeze(2) * rank1_sc.unsqueeze(1)
# cb_weight2 = weight2.unsqueeze(2) * rank2_sc.unsqueeze(1)
# cb_weight = cb_weight1 + cb_weight2
# sec_mask = cb_weight.bool()
"""
The following code is equivalent to:

```
weight1 = mask1 * probs.type_as(inputs)
weight2 = mask2 * probs.type_as(inputs)
rank1_sc = F.one_hot(rank1, num_classes=capacity)
rank2_sc = F.one_hot(rank2, num_classes=capacity)

cb_weight1 = weight1.unsqueeze(2) * rank1_sc.unsqueeze(1)
cb_weight2 = weight2.unsqueeze(2) * rank2_sc.unsqueeze(1)
cb_weight = cb_weight1 + cb_weight2
sec_mask = cb_weight.bool()
```
"""

weight1 = mask1 * probs.type_as(inputs)
weight2 = mask2 * probs.type_as(inputs)
Expand All @@ -308,7 +316,7 @@ def forward(self, inputs: torch.Tensor, use_kernel: bool = False, ep_group: Opti
sec_mask[indices, top1_idx[indices], rank1[indices]] |= mask1.bool()[indices, top1_idx[indices]]
sec_mask[indices, top2_idx[indices], rank2[indices]] |= mask2.bool()[indices, top2_idx[indices]]

return cb_weight, sec_mask
return used_capacity, cb_weight, sec_mask


class TopKRouter(MoeRouter):
Expand Down Expand Up @@ -352,7 +360,7 @@ def forward(
Returns:
Dispatch and combine arrays for routing with masked matmuls.
"""
# TODO: add parallel group
# TODO: FIXME: add parallel group
num_groups, _, num_experts = router_probs.shape

# Top-k router probability and corresponding expert indices for each token.
Expand Down
Loading