Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 6 additions & 1 deletion colossalai/context/moe_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ def __init__(self):
self.max_ep_size = None
self.min_dp_size = None
Comment thread
oahzxl marked this conversation as resolved.
self.aux_loss = None
self.parallel = None
self.use_kernel_optim = True

self.has_setup = False
Expand All @@ -34,13 +35,14 @@ def parallel_info_dict(self):
def is_initialized(self):
return self.has_setup

def setup(self, seed: int, use_kernel_optim: bool = True, max_ep_size: int = 8):
def setup(self, seed: int, use_kernel_optim: bool = True, max_ep_size: int = 8, parallel: bool = None):
assert not self.is_initialized, "MoE distributed context shouldn't be set up again"
assert torch.cuda.is_available(), "MoE requires to enable CUDA first"

self.world_size = dist.get_world_size()
self.max_ep_size = min(max_ep_size, dist.get_world_size())
self.min_dp_size = self.world_size // self.max_ep_size
self.parallel = parallel

# Enabling kernel optimization may raise error in some cases
# Users can close kernel optimization manually
Expand Down Expand Up @@ -103,5 +105,8 @@ def add_loss(self, loss):
def get_loss(self):
return self.aux_loss

def get_parallel(self):
return self.parallel


MOE_CONTEXT = MoeContext()
40 changes: 28 additions & 12 deletions colossalai/nn/layer/moe/checkpoint.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from copy import deepcopy
from pathlib import Path
from typing import Optional

Expand All @@ -6,32 +7,47 @@
import torch.nn as nn
from torch.optim import Optimizer

from colossalai.checkpoint_io import CheckpointIO
from colossalai.tensor.moe_tensor.api import get_ep_group
from colossalai.checkpoint_io import GeneralCheckpointIO
from colossalai.tensor.moe_tensor.api import get_dp_rank, get_ep_group, get_ep_rank, get_ep_size, is_moe_tensor


class MoeCheckpintIO(CheckpointIO):
class MoeCheckpintIO(GeneralCheckpointIO):

def __init__(self) -> None:
super().__init__()

def load_unsharded_model(self, model: nn.Module, checkpoint: str, strict: bool):
state_dict = torch.load(checkpoint)
for name, param in model.named_parameters():
for name, param in state_dict.items():
if '.experts.' in name:
ep_rank = dist.get_rank(get_ep_group(param))
ep_size = dist.get_world_size(get_ep_group(param))
for rank in range(ep_size):
new_name = name.replace('.experts.', f'.experts.{rank}.')
if rank == ep_rank:
state_dict[name] = state_dict.pop(new_name)
else:
state_dict.pop(new_name)
model_param = dict(model.named_parameters())[name]
if is_moe_tensor(model_param):
ep_rank = get_ep_rank(model_param)
ep_size = get_ep_size(model_param)
expert_num = param.shape[0] // ep_size
assert param.shape[0] % ep_size == 0
param = param[ep_rank * expert_num:(ep_rank + 1) * expert_num]
state_dict[name] = param

model.load_state_dict(state_dict, strict=strict)

def save_unsharded_model(self, model: nn.Module, checkpoint: str, gather_dtensor: bool, use_safetensors: bool):
state_dict = model.state_dict()
for name, param in model.named_parameters():
if '.experts.' in name and is_moe_tensor(param):
ep_group = get_ep_group(param)
ep_rank = get_ep_rank(param)
ep_size = get_ep_size(param)
dp_rank = get_dp_rank(param)
if dp_rank == 0:
param = param.data.cuda()
all_param = [deepcopy(param) for _ in range(ep_size)]
# gather param from every ep rank
dist.all_gather(all_param, param, group=ep_group)
if ep_rank == 0:
assert dist.get_rank() == 0
all_param = torch.cat(all_param, dim=0)
state_dict[name] = all_param.cpu()
if dist.get_rank() == 0:
torch.save(state_dict, checkpoint)
dist.barrier()
Expand Down
32 changes: 5 additions & 27 deletions colossalai/nn/layer/moe/experts.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,11 @@ def __init__(

if expert_parallel is not None:
with seed(ParallelMode.TENSOR):
nn.init.trunc_normal_(self.wi, std=math.sqrt(0.1 / hidden_size))
if gated:
nn.init.trunc_normal_(self.wi_gate, std=math.sqrt(0.1 / hidden_size))
nn.init.trunc_normal_(self.wi_up, std=math.sqrt(0.1 / hidden_size))
else:
nn.init.trunc_normal_(self.wi, std=math.sqrt(0.1 / hidden_size))
nn.init.trunc_normal_(self.wo, std=math.sqrt(0.1 / intermediate_size))

self.act = get_activation(activation)
Expand Down Expand Up @@ -110,32 +114,6 @@ def __init__(self,
gated: bool = False):
super().__init__(num_experts, hidden_size, intermediate_size, "EP", activation, drop_rate, gated)

def state_dict(self, destination=None, prefix='', keep_vars=False):
dp_rank = dist.get_rank(get_dp_group(self))
ep_rank = dist.get_rank(get_ep_group(self))
ep_size = get_ep_size(self)
# dp rank 0 will save the state dict
if dp_rank == 0:
for name, param in self.named_parameters():
if param is self:
continue
# create buffer
buffer_module = deepcopy(param)
# gather param from every ep rank
for source_rank in range(ep_size):
current_prefix = f"{prefix}{source_rank}."
if ep_rank == source_rank:
dist.broadcast(param.data, src=source_rank, group=self.moe_info.ep_group)
else:
dist.broadcast(buffer_module.data, src=source_rank, group=self.moe_info.ep_group)
if ep_rank == 0:
if keep_vars:
destination[current_prefix + name] = buffer_module.cpu()
else:
destination[current_prefix + name] = buffer_module.data.cpu()

dist.barrier()


class TPMLPExperts(BaseMLPExperts):
"""Use tensor parallelism to split each expert evenly, which can deploy experts in
Expand Down
28 changes: 28 additions & 0 deletions colossalai/nn/layer/moe/utils.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import contextlib
from typing import Callable

import torch
Expand Down Expand Up @@ -91,3 +92,30 @@ def SwiGLU(x):
assert size % 2 == 0, "axis size must be divisible by 2"
x1, x2 = torch.split(x, size // 2, -1)
return x1 * (x2 * torch.sigmoid(x2))


@contextlib.contextmanager
def skip_init():
"""
skip param random init
"""

def _skip_init(x, *args, **kwargs):
return x

# __enter__
fn_saved = []
init_fn_list = [
torch.nn.init.constant_, torch.nn.init.uniform_, torch.nn.init.normal_, torch.nn.init.xavier_uniform_,
torch.nn.init.xavier_normal_, torch.nn.init.kaiming_uniform_, torch.nn.init.kaiming_normal_
]
for fn in init_fn_list:
fn_saved.append(fn)
fn = _skip_init

yield

# __exit__
for fn, fn_saved in zip(init_fn_list, fn_saved):
fn = fn_saved
return
99 changes: 72 additions & 27 deletions examples/language/openmoe/model/modeling_openmoe.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@
replace_return_docstrings,
)

from colossalai.context import MOE_CONTEXT
from colossalai.nn.layer.moe.layers import SparseMLP

logger = logging.get_logger(__name__)
Expand Down Expand Up @@ -99,11 +100,14 @@ def generate_fixed_pos_embedding(features, length, min_timescale=1.0, max_timesc

sinusoid_inp = torch.cat([sinusoid_inp, sinusoid_inp], dim=-1)

return torch.sin(sinusoid_inp).to(torch.bfloat16), torch.cos(sinusoid_inp).to(torch.bfloat16)
return torch.sin(sinusoid_inp), torch.cos(sinusoid_inp)


def apply_rotary_embedding(q, k, cos, sin, decode=False, rotary_index=None):
"""Helper function to apply Rotary Embeddings."""
cos = cos.to(q.dtype)
sin = sin.to(q.dtype)

if len(k.shape) == 3:
# for multi query attention
k = k.unsqueeze(2)
Expand Down Expand Up @@ -405,6 +409,8 @@ def forward(
if attention_mask.size() != (bsz, 1, q_len, kv_seq_len):
raise ValueError(
f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}")
if self.training:
attention_mask = attention_mask.clone().detach()
attention_mask[:, :, :, 0] = 0
attn_weights = attn_weights + attention_mask

Expand Down Expand Up @@ -442,18 +448,19 @@ def __init__(self, config: LlamaConfig, moe: bool):
self.input_layernorm = T5LayerNorm(config.hidden_size, eps=config.rms_norm_eps)
self.post_attention_layernorm = T5LayerNorm(config.hidden_size, eps=config.rms_norm_eps)
if self.moe:
self.mlp = SparseMLP(num_experts=config.num_experts,
top_k=config.topk,
capacity_factor_train=config.capacity_factor_train,
capacity_factor_eval=config.capacity_factor_eval,
min_capacity=config.min_capacity,
noisy_policy=config.noisy_policy,
drop_tks=config.drop_tks,
expert_parallel=config.expert_parallel,
hidden_size=config.hidden_size,
intermediate_size=config.intermediate_size,
activation=config.hidden_act,
gated=config.gated)
self.mlp = SparseMLP(
num_experts=config.num_experts,
top_k=config.topk,
capacity_factor_train=config.capacity_factor_train,
capacity_factor_eval=config.capacity_factor_eval,
min_capacity=config.min_capacity,
noisy_policy=config.noisy_policy,
drop_tks=config.drop_tks,
expert_parallel=MOE_CONTEXT.get_parallel() if MOE_CONTEXT.is_initialized else config.expert_parallel,
hidden_size=config.hidden_size,
intermediate_size=config.intermediate_size,
activation=config.hidden_act,
gated=config.gated)
self.pre_extra_mlp_layernorm = T5LayerNorm(config.hidden_size, eps=config.rms_norm_eps)
self.extra_mlp = LlamaMLP(config)
else:
Expand Down Expand Up @@ -860,6 +867,7 @@ def forward(
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
chunk_head: Optional[bool] = None,
) -> Union[Tuple, CausalLMOutputWithPast]:
r"""
Args:
Expand Down Expand Up @@ -910,22 +918,59 @@ def forward(
lm_head_slices = self.lm_head.weight.split(self.vocab_size // self.pretraining_tp, dim=0)
logits = [F.linear(hidden_states, lm_head_slices[i]) for i in range(self.pretraining_tp)]
logits = torch.cat(logits, dim=-1)
else:
logits = self.lm_head(hidden_states)
logits = logits.float()

loss = None
if labels is not None:
# Shift so that tokens < n predict n
shift_logits = logits[..., :-1, :].contiguous()
shift_labels = labels[..., 1:].contiguous()
# Flatten the tokens
loss_fct = CrossEntropyLoss()
shift_logits = shift_logits.view(-1, self.config.vocab_size)
shift_labels = shift_labels.view(-1)
# Enable model parallelism
shift_labels = shift_labels.to(shift_logits.device)
loss = loss_fct(shift_logits, shift_labels)
# if no training, just do forward
if labels is None:
logits = self.lm_head(hidden_states)
logits = logits.float()
# the vocab size for openmoe is 30w+
# which causes great activation memory in training, up to 20G for one sequence
# so we use chunk and checkpoint to reduce memory
else:
if chunk_head == True:

def create_custom_forward(module):

def custom_forward(*inputs):
logits = module(inputs[0])
logits = logits.float()
# Shift so that tokens < n predict n
shift_logits = logits[..., :-1, :].contiguous().float()
shift_labels = inputs[1][..., 1:].contiguous()
# Flatten the tokens
loss_fct = CrossEntropyLoss()
shift_logits = shift_logits.view(-1, self.config.vocab_size)
shift_labels = shift_labels.view(-1)
# Enable model parallelism
shift_labels = shift_labels.to(shift_logits.device)
loss = loss_fct(shift_logits, shift_labels)
return loss

return custom_forward

loss = 0.
for batch_idx in range(hidden_states.shape[0]):
loss = loss + torch.utils.checkpoint.checkpoint(
create_custom_forward(self.lm_head),
hidden_states[batch_idx, :],
labels[batch_idx, :],
)
loss = loss / hidden_states.shape[0]
logits = None
else:
logits = self.lm_head(hidden_states)
logits = logits.float()
# Shift so that tokens < n predict n
shift_logits = logits[..., :-1, :].contiguous()
shift_labels = labels[..., 1:].contiguous()
# Flatten the tokens
loss_fct = CrossEntropyLoss()
shift_logits = shift_logits.view(-1, self.config.vocab_size)
shift_labels = shift_labels.view(-1)
# Enable model parallelism
shift_labels = shift_labels.to(shift_logits.device)
loss = loss_fct(shift_logits, shift_labels)

if not return_dict:
output = (logits,) + outputs[1:]
Expand Down
Loading