Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions colossalai/booster/plugin/moe_hybrid_parallel_plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -214,6 +214,7 @@ def __init__(
moe_dp_outside: bool = True,
overlap_p2p: bool = True,
overlap_allgather: bool = False,
fp8_communication: bool = False,
) -> None:
if overlap_communication or zero_stage == 2:
overlap_communication = False
Expand Down Expand Up @@ -341,6 +342,7 @@ def __init__(
parallel_output=parallel_output,
make_vocab_size_divisible_by=make_vocab_size_divisible_by,
gradient_checkpoint_config=gradient_checkpoint_config,
fp8_communication=fp8_communication,
)
self.amp_config = dict(
initial_scale=initial_scale,
Expand Down
23 changes: 18 additions & 5 deletions colossalai/moe/_operation.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@
from torch.cuda.amp import custom_bwd, custom_fwd
from torch.distributed import ProcessGroup

from colossalai.quantization.fp8 import all_to_all_single_fp8

MOE_KERNEL = None


Expand Down Expand Up @@ -380,6 +382,7 @@ def _all_to_all(
output_split_sizes: Optional[List[int]] = None,
group=None,
async_op: bool = False,
fp8_communication: bool = False,
):
"""
Returns:
Expand All @@ -392,9 +395,14 @@ def _all_to_all(
outputs = torch.empty(outputs_shape, dtype=inputs.dtype, device=inputs.device)
inputs = inputs.contiguous()
outputs = outputs.contiguous()
handle = dist.all_to_all_single(
outputs, inputs, output_split_sizes, input_split_sizes, group=group, async_op=async_op
)
if fp8_communication:
handle = all_to_all_single_fp8(
outputs, inputs, output_split_sizes, input_split_sizes, group=group, async_op=False
)
else:
handle = dist.all_to_all_single(
outputs, inputs, output_split_sizes, input_split_sizes, group=group, async_op=async_op
)
return outputs, handle


Expand All @@ -407,6 +415,7 @@ def forward(
output_split_sizes=None,
group=None,
overlap: bool = False,
fp8_communication: bool = False,
):
"""
Returns:
Expand All @@ -416,7 +425,9 @@ def forward(
ctx.input_split_sizes = input_split_sizes
ctx.output_split_sizes = output_split_sizes
ctx.group = group
return _all_to_all(inputs, input_split_sizes, output_split_sizes, group, overlap)
return _all_to_all(
inputs, input_split_sizes, output_split_sizes, group, overlap, fp8_communication=fp8_communication
)

@staticmethod
def backward(ctx: Any, *grad_outputs):
Expand All @@ -426,6 +437,7 @@ def backward(ctx: Any, *grad_outputs):
None,
None,
None,
None,
)


Expand All @@ -435,8 +447,9 @@ def all_to_all_uneven(
output_split_sizes: Optional[List[int]] = None,
group=None,
overlap: bool = False,
fp8_communication: bool = False,
):
assert (
inputs.requires_grad
), "Input must require grad to assure that backward is executed, otherwise it might hang the program."
return AllToAllUneven.apply(inputs, input_split_sizes, output_split_sizes, group, overlap)
return AllToAllUneven.apply(inputs, input_split_sizes, output_split_sizes, group, overlap, fp8_communication)
23 changes: 13 additions & 10 deletions colossalai/quantization/fp8.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,16 +27,19 @@ def cast_to_fp8(inp: torch.Tensor, fp8_format="e4m3", per_channel_scale=False) -
fp8_type = torch.float8_e4m3fn if fp8_format == "e4m3" else torch.float8_e5m2
fp8_max = torch.finfo(fp8_type).max

if per_channel_scale:
per_channel_max = inp.abs().max(dim=-1).values.float()
per_channel_max = torch.where(per_channel_max > 0, per_channel_max, 1.0)
scale = fp8_max / per_channel_max[:, None]
scale_inv = per_channel_max / fp8_max
if inp.numel() == 0:
return inp.to(fp8_type), torch.tensor([1.0], device=inp.device)
else:
per_tensor_max = inp.abs().max().float()
per_tensor_max = torch.where(per_tensor_max > 0, per_tensor_max, 1.0)
scale = fp8_max / per_tensor_max
scale_inv = 1.0 / scale
if per_channel_scale:
per_channel_max = inp.abs().max(dim=-1).values.float()
per_channel_max = torch.where(per_channel_max > 0, per_channel_max, 1.0)
scale = fp8_max / per_channel_max[:, None]
scale_inv = per_channel_max / fp8_max
else:
per_tensor_max = inp.abs().max().float()
per_tensor_max = torch.where(per_tensor_max > 0, per_tensor_max, 1.0)
scale = fp8_max / per_tensor_max
scale_inv = 1.0 / scale

ret = (scale * inp.float()).to(fp8_type)
return ret, scale_inv
Expand Down Expand Up @@ -113,7 +116,7 @@ def all_reduce_fp8(tensor: torch.Tensor, fp8_format="e4m3", op=ReduceOp.SUM, gro
tensor_list = [torch.empty_like(summed_out_fp8.view(torch.uint8)) for _ in range(world_size)]
dist.all_gather(tensor_list, summed_out_fp8.view(torch.uint8), group=group)
for i in range(world_size):
tensor_list[i] = tensor_list[i].view(fp8_type).to(input_type) * scale_list[i]
tensor_list[i] = tensor_list[i].view(fp8_type).to(input_type) * scale_list[i].to(input_device)
out = torch.cat(tensor_list, dim=0)
tensor.copy_(out[:input_size].view(input_shape).to(input_type))

Expand Down
4 changes: 3 additions & 1 deletion colossalai/shardformer/layer/embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -274,6 +274,7 @@ def __init__(
weight: Optional[nn.Parameter] = None,
weight_initializer: Callable = init.normal_(),
make_vocab_size_divisible_by: int = 64,
fp8_communication: bool = False,
*args,
**kwargs,
):
Expand All @@ -282,6 +283,7 @@ def __init__(
self.embed_args = args
self.embed_kwargs = kwargs
self.process_group = process_group
self.fp8_communication = fp8_communication

tensor_parallel_size = dist.get_world_size(group=process_group)
tensor_parallel_rank = dist.get_rank(group=process_group)
Expand Down Expand Up @@ -390,5 +392,5 @@ def forward(self, input_: Tensor) -> Tensor:
embedding_output = output_parallel.clone()
embedding_output[input_mask, :] = 0.0
# Reduce across all the model parallel GPUs.
output = reduce_forward(embedding_output, self.process_group)
output = reduce_forward(embedding_output, self.process_group, fp8_communication=self.fp8_communication)
return output
69 changes: 53 additions & 16 deletions colossalai/shardformer/modeling/deepseek.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
all_to_all_uneven,
)
from colossalai.pipeline.stage_manager import PipelineStageManager
from colossalai.quantization.fp8 import all_reduce_fp8
from colossalai.shardformer.layer._operation import (
all_to_all_comm,
gather_forward_split_backward,
Expand Down Expand Up @@ -61,7 +62,13 @@ class EPDeepseekMoE(nn.Module):
def __init__(self):
raise RuntimeError(f"Please use `from_native_module` to create an instance of {self.__class__.__name__}")

def setup_process_groups(self, tp_group: ProcessGroup, moe_dp_group: ProcessGroup, ep_group: ProcessGroup):
def setup_process_groups(
self,
tp_group: ProcessGroup,
moe_dp_group: ProcessGroup,
ep_group: ProcessGroup,
fp8_communication: bool = False,
):
assert tp_group is not None
assert moe_dp_group is not None
assert ep_group is not None
Expand All @@ -70,6 +77,7 @@ def setup_process_groups(self, tp_group: ProcessGroup, moe_dp_group: ProcessGrou
self.ep_rank = dist.get_rank(ep_group)
self.num_experts = self.config.n_routed_experts
assert self.num_experts % self.ep_size == 0
self.fp8_communication = fp8_communication

self.ep_group = ep_group
self.num_experts_per_ep = self.num_experts // self.ep_size
Expand All @@ -86,9 +94,15 @@ def setup_process_groups(self, tp_group: ProcessGroup, moe_dp_group: ProcessGrou
self.tp_group = tp_group
if self.tp_group.size() > 1:
for expert in held_experts:
expert.gate_proj = Linear1D_Col.from_native_module(expert.gate_proj, self.tp_group)
expert.up_proj = Linear1D_Col.from_native_module(expert.up_proj, self.tp_group)
expert.down_proj = Linear1D_Row.from_native_module(expert.down_proj, self.tp_group)
expert.gate_proj = Linear1D_Col.from_native_module(
expert.gate_proj, self.tp_group, fp8_communication=self.fp8_communication
)
expert.up_proj = Linear1D_Col.from_native_module(
expert.up_proj, self.tp_group, fp8_communication=self.fp8_communication
)
expert.down_proj = Linear1D_Row.from_native_module(
expert.down_proj, self.tp_group, fp8_communication=self.fp8_communication
)

for p in self.experts.parameters():
set_moe_tensor_ep_group(p, ep_group)
Expand All @@ -106,7 +120,8 @@ def from_native_module(
if module.__class__.__name__ == "DeepseekMLP":
return module
module.__class__ = EPDeepseekMoE
module.setup_process_groups(tp_group, moe_dp_group, ep_group)
fp8_communication = kwargs.get("fp8_communication", False)
module.setup_process_groups(tp_group, moe_dp_group, ep_group, fp8_communication=fp8_communication)
return module

def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
Expand Down Expand Up @@ -137,11 +152,21 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
for i in range(1, self.ep_size):
activate_experts += output_split_sizes[i * self.num_experts_per_ep : (i + 1) * self.num_experts_per_ep]
activate_experts = (activate_experts > 0).float()
dist.all_reduce(activate_experts, group=self.moe_dp_group)

if self.fp8_communication:
all_reduce_fp8(activate_experts, group=self.moe_dp_group)
else:
dist.all_reduce(activate_experts, group=self.moe_dp_group)

input_split_list = input_split_sizes.view(self.ep_size, self.num_experts_per_ep).sum(dim=-1).tolist()
output_split_list = output_split_sizes.view(self.ep_size, self.num_experts_per_ep).sum(dim=-1).tolist()
output_states, _ = all_to_all_uneven(dispatch_states, input_split_list, output_split_list, self.ep_group)
output_states, _ = all_to_all_uneven(
dispatch_states,
input_split_list,
output_split_list,
self.ep_group,
fp8_communication=self.fp8_communication,
)
output_states = EPGradScalerIn.apply(output_states, self.ep_size)

if output_states.size(0) > 0:
Expand All @@ -167,7 +192,9 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
output_states_list.append(split_states)
output_states = torch.cat(output_states_list)
output_states = EPGradScalerOut.apply(output_states, self.ep_size)
dispatch_states, _ = all_to_all_uneven(output_states, output_split_list, input_split_list, self.ep_group)
dispatch_states, _ = all_to_all_uneven(
output_states, output_split_list, input_split_list, self.ep_group, fp8_communication=self.fp8_communication
)
recover_token_idx = torch.empty_like(flat_topk_token_idx)
recover_token_idx[flat_topk_token_idx] = torch.arange(
flat_topk_token_idx.size(0), device=flat_topk_token_idx.device
Expand Down Expand Up @@ -534,9 +561,9 @@ def forward(

# sp: all-to-all comminucation when introducing sequence parallel
if sp_mode == "all_to_all":
query_states = all_to_all_comm(query_states, sp_group)
key_states = all_to_all_comm(key_states, sp_group)
value_states = all_to_all_comm(value_states, sp_group)
query_states = all_to_all_comm(query_states, sp_group, fp8_communication=shard_config.fp8_communication)
key_states = all_to_all_comm(key_states, sp_group, fp8_communication=shard_config.fp8_communication)
value_states = all_to_all_comm(value_states, sp_group, fp8_communication=shard_config.fp8_communication)
bsz, q_len, _ = query_states.size()
# Flash attention requires the input to have the shape
# batch_size x seq_length x head_dim x hidden_dim
Expand Down Expand Up @@ -595,7 +622,9 @@ def forward(
# sp: all-to-all comminucation when introducing sequence parallel
if sp_mode == "all_to_all":
attn_output = attn_output.reshape(bsz, q_len, self.num_heads * self.head_dim).contiguous() # (1, 8, 128)
attn_output = all_to_all_comm(attn_output, sp_group, scatter_dim=1, gather_dim=2) # (1, 4, 256)
attn_output = all_to_all_comm(
attn_output, sp_group, scatter_dim=1, gather_dim=2, fp8_communication=shard_config.fp8_communication
) # (1, 4, 256)
else:
attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)

Expand Down Expand Up @@ -685,9 +714,13 @@ def forward(
)

if sp_mode in ["ring", "split_gather"]:
inputs_embeds = split_forward_gather_backward(inputs_embeds, 1, sp_group)
inputs_embeds = split_forward_gather_backward(
inputs_embeds, 1, sp_group, fp8_communication=shard_config.fp8_communication
)
elif sp_mode == "all_to_all":
inputs_embeds = split_forward_gather_backward(inputs_embeds, 1, sp_group, 1 / sp_size)
inputs_embeds = split_forward_gather_backward(
inputs_embeds, 1, sp_group, 1 / sp_size, fp8_communication=shard_config.fp8_communication
)
# embed positions
hidden_states = inputs_embeds

Expand Down Expand Up @@ -731,9 +764,13 @@ def forward(
hidden_states = self.norm(hidden_states)

if sp_mode == "ring" or sp_mode == "split_gather":
hidden_states = gather_forward_split_backward(hidden_states, 1, sp_group)
hidden_states = gather_forward_split_backward(
hidden_states, 1, sp_group, fp8_communication=shard_config.fp8_communication
)
elif sp_mode == "all_to_all":
hidden_states = gather_forward_split_backward(hidden_states, 1, sp_group, grad_scale=sp_size)
hidden_states = gather_forward_split_backward(
hidden_states, 1, sp_group, grad_scale=sp_size, fp8_communication=shard_config.fp8_communication
)

# add hidden states from the last decoder layer
if output_hidden_states:
Expand Down
Loading