Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions colossalai/nn/layer/moe/__init__.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
from .checkpoint import MoeCheckpintIO
from .experts import EPMLPExperts, TPMLPExperts
from .experts import EPMLPExperts, TPMLPExperts, build_ffn_experts
from .layers import MoeLayer, MoeModule, SparseMLP
from .routers import MoeRouter, Top1Router, Top2Router
from .utils import NormalNoiseGenerator, UniformNoiseGenerator, build_ffn_experts
from .utils import NormalNoiseGenerator, UniformNoiseGenerator

__all__ = [
'EPMLPExperts', 'TPMLPExperts', 'Top1Router', 'Top2Router', 'MoeModule', 'MoeLayer', 'NormalNoiseGenerator',
'UniformNoiseGenerator', 'build_ffn_experts', 'SparseMLP', 'MoeRouter', 'MoeCheckpintIO'
'UniformNoiseGenerator', 'SparseMLP', 'MoeRouter', 'MoeCheckpintIO', 'build_ffn_experts'
]
51 changes: 38 additions & 13 deletions colossalai/nn/layer/moe/experts.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from colossalai.context import ParallelMode, seed
from colossalai.context.moe_context import MOE_CONTEXT
from colossalai.nn.layer.moe._operation import MoeInGradScaler, MoeOutGradScaler
from colossalai.nn.layer.moe.utils import get_activation
from colossalai.tensor.moe_tensor.api import get_dp_group, get_ep_group, get_ep_size, set_moe_tensor_info


Expand All @@ -24,11 +25,13 @@ def __init__(
expert_parallel: str = None,
activation: str = None,
drop_rate: float = 0,
gated: bool = False,
):
super().__init__()
assert expert_parallel in ["EP", "TP", None]
self.expert_parallel = expert_parallel
self.num_total_experts = num_experts
self.gated = gated

# get expert parallel info
if expert_parallel is not None:
Expand All @@ -47,14 +50,19 @@ def __init__(
self.num_local_experts = self.num_total_experts
self.ep_size = 1

self.wi = nn.Parameter(torch.empty(num_experts, hidden_size, intermediate_size))
if gated:
self.wi_gate = nn.Parameter(torch.empty(num_experts, hidden_size, intermediate_size * 2))
self.wi_up = nn.Parameter(torch.empty(num_experts, hidden_size, intermediate_size))
else:
self.wi = nn.Parameter(torch.empty(num_experts, hidden_size, intermediate_size))
self.wo = nn.Parameter(torch.empty(num_experts, intermediate_size, hidden_size))

with seed(ParallelMode.TENSOR):
nn.init.trunc_normal_(self.wi, std=math.sqrt(0.1 / hidden_size))
nn.init.trunc_normal_(self.wo, std=math.sqrt(0.1 / intermediate_size))
if expert_parallel is not None:
with seed(ParallelMode.TENSOR):
nn.init.trunc_normal_(self.wi, std=math.sqrt(0.1 / hidden_size))
nn.init.trunc_normal_(self.wo, std=math.sqrt(0.1 / intermediate_size))

self.act = nn.GELU() if activation is None else activation
self.act = get_activation(activation)
self.drop = nn.Dropout(p=drop_rate)

if expert_parallel is not None:
Expand All @@ -71,10 +79,15 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: # inputs [g, e, c, h]
inshape = x.shape
x = x.reshape(e, -1, h)

x = torch.bmm(x, self.wi)
x = self.act(x)
with seed(ParallelMode.TENSOR):
x = self.drop(x)
if self.gated:
x = self.act(torch.bmm(x, self.wi_gate)) * torch.bmm(x, self.wi_up)
else:
x = torch.bmm(x, self.wi)
x = self.act(x)

if self.expert_parallel is not None:
with seed(ParallelMode.TENSOR):
x = self.drop(x)
x = torch.bmm(x, self.wo)

x = x.reshape(inshape)
Expand All @@ -93,8 +106,9 @@ def __init__(self,
hidden_size: int,
intermediate_size: int,
activation=None,
drop_rate: float = 0):
super().__init__(num_experts, hidden_size, intermediate_size, "EP", activation, drop_rate)
drop_rate: float = 0,
gated: bool = False):
super().__init__(num_experts, hidden_size, intermediate_size, "EP", activation, drop_rate, gated)

def state_dict(self, destination=None, prefix='', keep_vars=False):
dp_rank = dist.get_rank(get_dp_group(self))
Expand Down Expand Up @@ -134,8 +148,9 @@ def __init__(self,
hidden_size: int,
intermediate_size: int,
activation: str = None,
drop_rate: float = 0):
super().__init__(num_experts, hidden_size, intermediate_size, "TP", activation, drop_rate)
drop_rate: float = 0,
gated: bool = False):
super().__init__(num_experts, hidden_size, intermediate_size, "TP", activation, drop_rate, gated)


def get_expert_class(name: str) -> BaseMLPExperts:
Expand All @@ -147,3 +162,13 @@ def get_expert_class(name: str) -> BaseMLPExperts:
return BaseMLPExperts
else:
raise ValueError(f"Unknown expert class name: {name}")


def build_ffn_experts(num_experts: int, d_model: int, d_ff: int, activation=None, drop_rate: float = 0):
mep_size = MOE_CONTEXT.max_ep_size
if num_experts % mep_size == 0 or mep_size % num_experts == 0:
return EPMLPExperts(num_experts, d_model, d_ff, activation, drop_rate)
elif d_ff % mep_size == 0:
return TPMLPExperts(num_experts, d_model, d_ff, activation, drop_rate)
else:
raise NotImplementedError(f"Can not build {num_experts} experts in {mep_size} GPUS.")
6 changes: 4 additions & 2 deletions colossalai/nn/layer/moe/layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,8 @@ def __init__(self,
expert_parallel: str = "EP",
hidden_size: int = 2048,
intermediate_size: int = 2048,
activation: str = None):
activation: str = None,
gated: bool = False):
super().__init__()
self.hidden_size = hidden_size
self.num_experts = num_experts
Expand All @@ -82,7 +83,8 @@ def __init__(self,
self.experts: BaseMLPExperts = expert_cls(num_experts=num_experts,
hidden_size=hidden_size,
intermediate_size=intermediate_size,
activation=activation)
activation=activation,
gated=gated)
if expert_parallel is not None:
self.ep_group = get_ep_group(self.experts)
self.ep_size = get_ep_size(self.experts)
Expand Down
35 changes: 23 additions & 12 deletions colossalai/nn/layer/moe/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,6 @@
from colossalai.context.moe_context import MOE_CONTEXT
from colossalai.utils import get_current_device

from .experts import EPMLPExperts, TPMLPExperts


class ForceFP32Parameter(torch.nn.Parameter):

Expand Down Expand Up @@ -60,16 +58,6 @@ def autocast_softmax(logit: torch.Tensor, dim: int):
return F.softmax(logit, dim=dim, detype=torch.float32)


def build_ffn_experts(num_experts: int, d_model: int, d_ff: int, activation=None, drop_rate: float = 0):
mep_size = MOE_CONTEXT.max_ep_size
if num_experts % mep_size == 0 or mep_size % num_experts == 0:
return EPMLPExperts(num_experts, d_model, d_ff, activation, drop_rate)
elif d_ff % mep_size == 0:
return TPMLPExperts(num_experts, d_model, d_ff, activation, drop_rate)
else:
raise NotImplementedError(f"Can not build {num_experts} experts in {mep_size} GPUS.")


def get_noise_generator(noise_type: str, num_experts: int) -> Callable:
if noise_type is None:
return None
Expand All @@ -80,3 +68,26 @@ def get_noise_generator(noise_type: str, num_experts: int) -> Callable:
else:
raise NotImplementedError("Unsupported input noisy policy")
return noisy_func


def get_activation(act: str) -> Callable:
if act is None or act == 'relu':
return torch.nn.ReLU()
elif act == 'gelu':
return torch.nn.GELU()
elif act == 'swiglu':
return SwiGLU
else:
raise NotImplementedError("Unsupported activation function")


def SwiGLU(x):
"""Gated linear unit activation function.
Args:
x : input array
axis: the axis along which the split should be computed (default: -1)
"""
size = x.shape[-1]
assert size % 2 == 0, "axis size must be divisible by 2"
x1, x2 = torch.split(x, size // 2, -1)
return x1 * (x2 * torch.sigmoid(x2))
Empty file.
17 changes: 17 additions & 0 deletions examples/language/openmoe/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
## OpenMoE
[OpenMoE](https://github.com/XueFuzhao/OpenMoE) is a project aimed at Igniting the Open-Source MoE Community!

The following example of [Colossal-AI](https://github.com/hpcaitech/ColossalAI) demonstrates finetune and inference methods.


## Our Modifications

We reimplement OpenMoE with PyTorch + GPU.

## Run Inference

By running the following script:
```bash
bash infer.sh
```
You will infer a [OpenMoE-8B/32E](https://github.com/XueFuzhao/OpenMoE) model.
49 changes: 49 additions & 0 deletions examples/language/openmoe/infer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
from argparse import ArgumentParser

import torch
from model.modeling_openmoe import OpenMoeForCausalLM
from transformers import T5Tokenizer
from transformers.models.llama import LlamaConfig


def parse_args():
parser = ArgumentParser()
parser.add_argument("--model", default="base", type=str, help="model path", choices=["base", "8b", "test"])
return parser.parse_args()


def inference(args):

tokenizer = T5Tokenizer.from_pretrained("google/umt5-small")
if args.model == "test":
config = LlamaConfig.from_pretrained("hpcaitech/openmoe-base")
model = OpenMoeForCausalLM(config)
else:
model = OpenMoeForCausalLM.from_pretrained(f"hpcaitech/openmoe-{args.model}")
model = model.eval().bfloat16()
model = model.to(torch.cuda.current_device())

input_str = """```
y = list(map(int, ['1', 'hello', '2']))
```
What error does this program produce?
ValueError: invalid literal for int() with base 10: 'hello'

```
sum = 0
for i in range(100):
sum += i
```
What is the value of sum immediately after the 10th time line 3 is executed?"""

# print("model config: ", model.config)
input_ids = tokenizer("<pad>" + input_str, return_tensors="pt", add_special_tokens=True)
input_ids = input_ids.input_ids.to(torch.cuda.current_device())
generation_output = model.generate(input_ids, use_cache=True, do_sample=True, max_new_tokens=128)
out = tokenizer.decode(generation_output[0], skip_special_tokens=False)
print(f"output: \n{out}\n")


if __name__ == "__main__":
args = parse_args()
inference(args)
1 change: 1 addition & 0 deletions examples/language/openmoe/infer.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
python infer.py --model "base"
Loading