Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
79 changes: 47 additions & 32 deletions colossalai/moe/_operation.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,6 @@
from colossalai.moe.manager import MOE_MANAGER

MOE_KERNEL = None
WORLD_HANDLE_ALLGATHER = None
WORLD_HANDLE_REDUCESCATTER = None


def load_moe():
Expand All @@ -28,14 +26,20 @@ def forward(
inputs: Tensor,
group: Optional[ProcessGroup] = None,
overlap: bool = False,
) -> Tensor:
) -> Tuple[Tensor, Any]:
"""
Returns:
outputs: Tensor
handle: Optional[Work], if overlap is True
"""
assert ctx is not None or not overlap

if ctx is not None:
ctx.comm_grp = group
ctx.overlap = overlap

comm_size = dist.get_world_size(group)
if comm_size == 1:
return inputs.unsqueeze(0)
return inputs.unsqueeze(0), None

buffer_shape = (comm_size,) + inputs.shape
outputs = torch.empty(buffer_shape, dtype=inputs.dtype, device=inputs.device)
Expand All @@ -45,19 +49,12 @@ def forward(
return outputs, None
else:
handle = dist.all_gather(buffer_list, inputs, group=group, async_op=True)
if ctx is None and overlap:
global WORLD_HANDLE_ALLGATHER
WORLD_HANDLE_ALLGATHER = handle
return outputs, handle

@staticmethod
def backward(ctx: Any, *grad_outputs) -> Tuple[Tensor, None]:
global WORLD_HANDLE_REDUCESCATTER
if WORLD_HANDLE_REDUCESCATTER is not None:
WORLD_HANDLE_REDUCESCATTER.wait()
WORLD_HANDLE_REDUCESCATTER = None
def backward(ctx: Any, *grad_outputs) -> Tuple[Tensor, None, None]:
return (
ReduceScatter.forward(None, grad_outputs[0], ctx.comm_grp, ctx.overlap)[0],
ReduceScatter.forward(None, grad_outputs[0], ctx.comm_grp, False)[0],
None,
None,
)
Expand All @@ -71,14 +68,20 @@ def forward(
inputs: Tensor,
group: Optional[ProcessGroup] = None,
overlap: bool = False,
) -> Tensor:
) -> Tuple[Tensor, Any]:
"""
Returns:
outputs: Tensor
handle: Optional[Work], if overlap is True
"""
assert ctx is not None or not overlap

if ctx is not None:
ctx.comm_grp = group
ctx.overlap = overlap

comm_size = dist.get_world_size(group)
if comm_size == 1:
return inputs.squeeze(0)
return inputs.squeeze(0), None

if not inputs.is_contiguous():
inputs = inputs.contiguous()
Expand All @@ -91,19 +94,13 @@ def forward(
return outputs, None
else:
handle = dist.reduce_scatter(outputs, buffer_list, group=group, async_op=True)
if ctx is None and overlap:
global WORLD_HANDLE_REDUCESCATTER
WORLD_HANDLE_REDUCESCATTER = handle
return outputs, handle

@staticmethod
def backward(ctx: Any, *grad_outputs) -> Tuple[Tensor, None]:
global WORLD_HANDLE_ALLGATHER
if WORLD_HANDLE_ALLGATHER is not None:
WORLD_HANDLE_ALLGATHER.wait()
WORLD_HANDLE_ALLGATHER = None
def backward(ctx: Any, *grad_outputs) -> Tuple[Tensor, None, None]:
# TODO: support async backward
return (
AllGather.forward(None, grad_outputs[0], ctx.comm_grp, ctx.overlap)[0],
AllGather.forward(None, grad_outputs[0], ctx.comm_grp, False)[0],
None,
None,
)
Expand All @@ -115,20 +112,38 @@ class AllToAll(torch.autograd.Function):
"""

@staticmethod
def forward(ctx: Any, inputs: Tensor, group: Optional[ProcessGroup] = None) -> Tensor:
def forward(
ctx: Any,
inputs: Tensor,
group: Optional[ProcessGroup] = None,
overlap: bool = False,
) -> Tuple[Tensor, Any]:
"""
Returns:
outputs: Tensor
handle: Optional[Work], if overlap is True
"""
if ctx is not None:
ctx.comm_grp = group
if not inputs.is_contiguous():
inputs = inputs.contiguous()
if dist.get_world_size(group) == 1:
return inputs
return inputs, None
output = torch.empty_like(inputs)
dist.all_to_all_single(output, inputs, group=group)
return output
if not overlap:
dist.all_to_all_single(output, inputs, group=group)
return output, None
else:
handle = dist.all_to_all_single(output, inputs, group=group, async_op=True)
return output, handle

@staticmethod
def backward(ctx: Any, *grad_outputs: Tensor) -> Tuple[Tensor, None]:
return AllToAll.forward(None, *grad_outputs, ctx.comm_grp), None
def backward(ctx: Any, *grad_outputs) -> Tuple[Tensor, None, None]:
return (
AllToAll.forward(None, grad_outputs[0], ctx.comm_grp)[0],
None,
None,
)


class MoeDispatch(torch.autograd.Function):
Expand Down
205 changes: 133 additions & 72 deletions colossalai/moe/layers.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import dataclasses
import math
from typing import Any, Optional, Tuple

Expand Down Expand Up @@ -188,7 +189,10 @@ def _local_process(self, expert_in: torch.Tensor) -> torch.Tensor:
expert_out = self.experts(expert_in)
return expert_out

def _ep_process(self, dispatch_data: torch.Tensor) -> torch.Tensor:
def _ep_process(self,
dispatch_data: torch.Tensor,
overlap: bool = True
) -> torch.Tensor:
"""
Expert Parallel

Expand All @@ -198,93 +202,150 @@ def _ep_process(self, dispatch_data: torch.Tensor) -> torch.Tensor:
Returns:
torch.Tensor: (num_experts, capacity, hidden_size)
"""
expert_input = AllToAll.apply(dispatch_data, self.ep_group)
input_shape = expert_input.shape
expert_input = expert_input.reshape(self.ep_size, self.num_local_experts, -1, self.hidden_size)
expert_output = self.experts(expert_input)
expert_output = expert_output.reshape(input_shape)
expert_output = AllToAll.apply(expert_output, self.ep_group)
return expert_output

def _tp_process(self, dispatch_data: torch.Tensor, use_overlap: bool = False) -> torch.Tensor:
"""
TP with overlap.
if not overlap or dist.get_world_size(self.ep_group) == 1:
expert_input = AllToAll.apply(dispatch_data, self.ep_group, False)[0]
expert_input = expert_input.reshape(self.ep_size, self.num_local_experts, -1, self.hidden_size)
expert_output = self.experts(expert_input)
expert_output = AllToAll.apply(expert_output, self.ep_group, False)[0]
return expert_output

origin:
else:
@dataclasses.dataclass
class Capsule():
data: torch.Tensor
handle: Any = None

NUM_CHUNK = 2
NUM_STAGES = 4

assert dispatch_data.shape[1] % NUM_CHUNK == 0, \
"arbitrary chunk num is not supported yet"
chunk_size = dispatch_data.shape[1] // NUM_CHUNK
input_shape = (self.ep_size, self.num_local_experts, -1, self.hidden_size)
dispatch_data = dispatch_data.reshape(*input_shape)
chunk_data = torch.split(dispatch_data, chunk_size, dim=2)
output = torch.empty_like(dispatch_data)

offset = 0
_expert_in, expert_in, _expert_out, expert_out = None, None, None, None

for i in range(NUM_CHUNK + NUM_STAGES - 1):
if expert_out is not None:
expert_out.handle.wait()
output[:, :, offset:offset + chunk_size, :] = expert_out.data
offset += chunk_size
expert_out = None

# all2all last output
if _expert_out is not None:
expert_out = Capsule(
*AllToAll.apply(_expert_out.data, self.ep_group, True),
)
_expert_out = None

# all2all next input
if 0 <= i < NUM_CHUNK:
_expert_in = Capsule(
*AllToAll.apply(chunk_data[i].contiguous(), self.ep_group, True)
)

# compute
if expert_in is not None:
expert_in.handle.wait()
_expert_out = Capsule(
data=self.experts(expert_in.data),
handle=None
)
expert_in = None

if _expert_in is not None:
expert_in = _expert_in
_expert_in = None

return output

def _tp_process(self,
dispatch_data: torch.Tensor,
overlap: bool = True
) -> torch.Tensor:
"""
without overlap:
| C |
| A | | R |

overlap:
with overlap:
| C1 || C2 || C3 || C4 |
| A1 || A2 | | R1 | A3 || R2 | A4 || R3 | | R4 |

C is computation, A is all gather, R is reduce scatter.
where C is computation, A is all gather, R is reduce scatter.

Args:
dispatch_data (torch.Tensor): (num_experts, capacity, hidden_size)

Returns:
torch.Tensor: (num_experts, capacity, hidden_size)
"""
if use_overlap == False:
expert_in, _ = AllGather.apply(dispatch_data, self.ep_group)
if not overlap or dist.get_world_size(self.ep_group) == 1:
expert_in = AllGather.apply(dispatch_data, self.ep_group, False)[0]
expert_out = self.experts(expert_in)
expert_out, _ = ReduceScatter.apply(expert_out, self.ep_group)
expert_out = ReduceScatter.apply(expert_out, self.ep_group, False)[0]
return expert_out

# TODO: there is accuracy problem in overlap
chunk_num = 4
chunk_size = dispatch_data.shape[0] // chunk_num
out = torch.empty_like(dispatch_data)
in_data = None
in_handle = None
out_data = None
out_handle = None

# backward compatibility for async op
torch.cuda.synchronize()

def get_chunk_slice(idx: int, gap: int) -> Tuple[slice]:
return (slice(idx * gap, (idx + 1) * gap),)

for i in range(chunk_num):
cur_chunk_slice = get_chunk_slice(i, chunk_size)

# if first, all gather
if i == 0:
d = dispatch_data[cur_chunk_slice].contiguous()
expert_in, _ = AllGather.apply(d, self.ep_group)
else:
expert_in = in_data

# async communication while compute
if i != 0:
# reduce scatter last out
out_data, out_handle = ReduceScatter.apply(out_data, self.ep_group, True)
if i != chunk_num - 1:
# all gather next in
next_d = dispatch_data[get_chunk_slice(i + 1, chunk_size)].contiguous()
in_data, in_handle = AllGather.apply(next_d, self.ep_group, True)

# compute
expert_out = self.experts(expert_in, cur_chunk_slice)

# sync handle
if i != 0:
out_handle.wait()
out[get_chunk_slice(i - 1, chunk_size)] = out_data
if i != chunk_num - 1:
in_handle.wait()
out_data = expert_out

# store out for last loop
if i == chunk_num - 1:
out_data, _ = ReduceScatter.apply(out_data, self.ep_group)
out[cur_chunk_slice] = out_data

# sync for async op
torch.cuda.synchronize()
return out
else:
@dataclasses.dataclass
class Capsule():
data: torch.Tensor
handle: Any
indices: Tuple

NUM_CHUNK = 2
NUM_STAGES = 4

assert dispatch_data.shape[0] % NUM_CHUNK == 0, \
"arbitrary chunk num is not supported yet, please use chunk num that can divide num_experts"
chunk_size = dispatch_data.shape[0] // NUM_CHUNK
chunk_data = torch.split(dispatch_data, chunk_size, dim=0)
output = torch.empty_like(dispatch_data)

def get_chunk_slice(idx: int, chunk_size: int) -> Tuple[slice]:
return (slice(idx * chunk_size, (idx + 1) * chunk_size), )

_expert_in, expert_in, _expert_out, expert_out = None, None, None, None

for i in range(NUM_CHUNK + NUM_STAGES - 1):
if expert_out is not None:
expert_out.handle.wait()
output[expert_out.indices] = expert_out.data
expert_out = None

# reduce scatter last output
if _expert_out is not None:
expert_out = Capsule(
*ReduceScatter.apply(_expert_out.data, self.ep_group, True),
indices=_expert_out.indices
)
_expert_out = None

# all gather next input
if 0 <= i < NUM_CHUNK:
_expert_in = Capsule(
*AllGather.apply(chunk_data[i].contiguous(), self.ep_group, True),
indices=get_chunk_slice(i, chunk_size)
)

# compute
if expert_in is not None:
expert_in.handle.wait()
_expert_out = Capsule(
self.experts(expert_in.data, expert_in.indices),
handle=None, indices=expert_in.indices
)
expert_in = None

if _expert_in is not None:
expert_in = _expert_in
_expert_in = None

return output


def apply_load_balance(model: nn.Module, optim: Any) -> None:
Expand Down
Loading