Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Empty file.
185 changes: 185 additions & 0 deletions colossalai/kernel/triton/llama_act_combine_kernel.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,185 @@
from functools import reduce
from typing import Any, Tuple

import torch
from torch import Tensor
from torch.cuda.amp import custom_bwd, custom_fwd

try:
import triton
import triton.language as tl
HAS_TRITON = True
except ImportError:
HAS_TRITON = False
print("please install triton from https://github.com/openai/triton")

if HAS_TRITON:
PRECISION_MAP = {
"fp32": (0, torch.float32),
"fp16": (1, torch.float16),
"bf16": (2, torch.bfloat16),
}

@triton.jit
def _llama_act_combine_forward(
X_GATE1,
X_GATE2,
X_UP,
Y,
stride, # how much to increase the pointer when moving by 1 row
N, # number of columns in X
BLOCK_SIZE: tl.constexpr,
):
# Map the program id to the row of X and Y it should compute.
row = tl.program_id(0)
X_GATE1 += row * stride
X_GATE2 += row * stride
X_UP += row * stride
Y += row * stride

# do activation and combine, and store in y
for off in range(0, N, BLOCK_SIZE):
cols = off + tl.arange(0, BLOCK_SIZE)
mask = cols < N
x_gate1 = tl.load(X_GATE1 + cols, mask=mask, other=0.)
x_gate2 = tl.load(X_GATE2 + cols, mask=mask, other=0.)
x_up = tl.load(X_UP + cols, mask=mask, other=0.)
x_gate2_sigmoid = tl.sigmoid(x_gate2.to(tl.float32)).to(x_gate2.dtype)
y = x_gate1 * x_gate2 * x_gate2_sigmoid * x_up
# Write output
tl.store(Y + cols, y, mask=mask)

@triton.jit
def _llama_act_combine_backward(
X_GATE1,
X_GATE2,
X_UP,
X_GATE1_GRAD,
X_GATE2_GRAD,
X_UP_GRAD,
Y_GRAD,
stride, # how much to increase the pointer when moving by 1 row
N, # number of columns in X
BLOCK_SIZE: tl.constexpr,
):
# Map the program id to the row of X and Y it should compute.
row = tl.program_id(0)
X_GATE1 += row * stride
X_GATE2 += row * stride
X_UP += row * stride
X_GATE1_GRAD += row * stride
X_GATE2_GRAD += row * stride
X_UP_GRAD += row * stride
Y_GRAD += row * stride

# do activation and combine, and store in y
for off in range(0, N, BLOCK_SIZE):
cols = off + tl.arange(0, BLOCK_SIZE)
mask = cols < N
x_gate1 = tl.load(X_GATE1 + cols, mask=mask, other=0.)
x_gate2 = tl.load(X_GATE2 + cols, mask=mask, other=0.)
x_up = tl.load(X_UP + cols, mask=mask, other=0.)
y_grad = tl.load(Y_GRAD + cols, mask=mask, other=0.)

# forward: y = x_gate1 * x_gate2 * tl.sigmoid(x_gate2) * x_up
x_gate2_sigmoid = tl.sigmoid(x_gate2.to(tl.float32)).to(x_gate2.dtype)
x_gate2_act = y_grad * x_gate2 * x_gate2_sigmoid
x_up_grad = x_gate2_act * x_gate1
x_gate1_grad = x_gate2_act * x_up
# grad(x*sigmoid(x)) = sigmoid(x) + x * sigmoid(x) * [1 − sigmoid(x)]
# = sigmoid(x) * {1 + x * [(1 − sigmoid(x)]}
x_gate2_grad = (y_grad * x_gate1 * x_up) * x_gate2_sigmoid * (1 + x_gate2 * (1 - x_gate2_sigmoid))

# Write output
tl.store(X_GATE1_GRAD + cols, x_gate1_grad, mask=mask)
tl.store(X_GATE2_GRAD + cols, x_gate2_grad, mask=mask)
tl.store(X_UP_GRAD + cols, x_up_grad, mask=mask)

class LlamaActCombine(torch.autograd.Function):
"""
act(x_gate) * x_up

Args:
x_gate (torch.Tensor): (b, l, 2d) x_gate
x_up (torch.Tensor): (b, l, d) x_up
activation (str): only support swiglu
precision (str): fp32, fp16, bf16
"""

@staticmethod
@custom_fwd
def forward(ctx: Any, x_gate: torch.Tensor, x_up: torch.Tensor, activation: str = "swiglu") -> torch.Tensor:
"""
act(x_gate) * x_up

Args:
x_gate (torch.Tensor): (b, l, 2d) x gate
x_up (torch.Tensor): (b, l, d) x up
activation (str): only support swiglu
"""
assert activation == "swiglu", "Only swiglu is supported"

# split x gate
assert x_gate.shape[-1] % 2 == 0, "axis size must be divisible by 2"
x_gate1, x_gate2 = torch.split(x_gate, x_gate.shape[-1] // 2, -1)
x_gate1 = x_gate1.contiguous()
x_gate2 = x_gate2.contiguous()
if not x_up.is_contiguous():
x_up = x_up.contiguous()
# assert shape
assert x_gate1.shape == x_gate2.shape == x_up.shape

# add ctx for backward
if x_gate.requires_grad:
ctx.save_for_backward(x_gate1, x_gate2, x_up)

# allocate output
y = torch.empty_like(x_up)
M, N = reduce(lambda x, y: x * y, x_up.shape[:-1]), x_up.shape[-1]

# Less than 64KB per feature: enqueue fused kernel
MAX_FUSED_SIZE = 65536 // x_gate.element_size()
BLOCK_SIZE = min(MAX_FUSED_SIZE, triton.next_power_of_2(N))
if N > BLOCK_SIZE:
raise RuntimeError("This layer norm doesn't support feature dim >= 64KB.")
# heuristics for number of warps
num_warps = min(max(BLOCK_SIZE // 256, 1), 8)
# restore setting
ctx.M, ctx.N, ctx.BLOCK_SIZE, ctx.num_warps = M, N, BLOCK_SIZE, num_warps
# enqueue kernel
_llama_act_combine_forward[(M,)](x_gate1,
x_gate2,
x_up,
y,
x_up.stride(-2),
N,
BLOCK_SIZE=BLOCK_SIZE,
num_warps=num_warps)
return y

@staticmethod
@custom_bwd
def backward(ctx: Any, *grad_outputs: Tensor) -> Tuple[Tensor, Tensor, None, None]:
# restore from ctx
(x_gate1, x_gate2, x_up) = ctx.saved_tensors
M, N, BLOCK_SIZE, num_warps = ctx.M, ctx.N, ctx.BLOCK_SIZE, ctx.num_warps

# init grad
y_grad = grad_outputs[0]
x_gate1_grad, x_gate2_grad, x_up_grad = torch.empty_like(x_gate1), torch.empty_like(
x_gate2), torch.empty_like(x_up)

# enqueue kernel
_llama_act_combine_backward[(M,)](x_gate1,
x_gate2,
x_up,
x_gate1_grad,
x_gate2_grad,
x_up_grad,
y_grad,
x_up.stride(-2),
N,
BLOCK_SIZE=BLOCK_SIZE,
num_warps=num_warps)
x_gate_grad = torch.cat([x_gate1_grad, x_gate2_grad], dim=-1)
return x_gate_grad, x_up_grad, None, None
29 changes: 3 additions & 26 deletions colossalai/moe/_operation.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,33 +5,18 @@
from torch import Tensor
from torch.distributed import ProcessGroup

COL_MOE_KERNEL_FLAG = False

try:
from colossalai._C import moe
except:
moe = None


def build_moe_if_not_prebuilt():
# load moe kernel during runtime if not pre-built
global moe
if moe is None:
from colossalai.kernel.op_builder import MOEBuilder
moe = MOEBuilder().load()
from colossalai.kernel.op_builder import MOEBuilder
moe = MOEBuilder().load()


class AllGather(torch.autograd.Function):

@staticmethod
def forward(ctx: Any, inputs: Tensor, group: Optional[ProcessGroup] = None) -> Tensor:

global moe

if moe is None:
from colossalai.kernel.op_builder import MOEBuilder
moe = MOEBuilder().load()

if ctx is not None:
ctx.comm_grp = group

Expand Down Expand Up @@ -104,9 +89,6 @@ def forward(ctx, tokens, mask, dest_idx, ec):
s = tokens.size(0)
h = tokens.size(1)

# load moe kernel during runtime if not pre-built
build_moe_if_not_prebuilt()

expert_input = moe.dispatch_forward(s, ec, h, tokens, mask, dest_idx)

ctx.save_for_backward(mask, dest_idx)
Expand Down Expand Up @@ -134,9 +116,6 @@ def forward(ctx, expert_tokens, logits, mask, dest_idx, ec):
c = ec // e
h = expert_tokens.size(-1)

# load moe kernel during runtime if not pre-built
build_moe_if_not_prebuilt()

fp16_flag = (expert_tokens.dtype == torch.float16)
cb_input = expert_tokens.to(torch.float32) if fp16_flag else expert_tokens
ctokens = moe.combine_forward(s, e, c, h, cb_input, logits, mask, dest_idx)
Expand Down Expand Up @@ -167,9 +146,7 @@ def backward(ctx, tokens_grad):
def moe_cumsum(inputs: Tensor):
dim0 = inputs.size(0)
flag = (dim0 <= 1024) or (dim0 <= 2048 and dim0 % 2 == 0) or (dim0 % 4 == 0)
if flag and COL_MOE_KERNEL_FLAG:
# load moe kernel during runtime if not pre-built
build_moe_if_not_prebuilt()
if flag:
return moe.cumsum_sub_one(inputs)
else:
return torch.cumsum(inputs, dim=0) - 1
Expand Down
11 changes: 10 additions & 1 deletion colossalai/moe/experts.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,17 @@

import torch
import torch.nn as nn

from colossalai.kernel.triton.llama_act_combine_kernel import HAS_TRITON
from colossalai.moe._operation import MoeInGradScaler, MoeOutGradScaler
from colossalai.moe.manager import MOE_MANAGER
from colossalai.moe.utils import get_activation
from colossalai.shardformer.layer.utils import Randomizer
from colossalai.tensor.moe_tensor.api import get_ep_size, set_moe_tensor_info

if HAS_TRITON:
from colossalai.kernel.triton.llama_act_combine_kernel import LlamaActCombine


class BaseMLPExperts(nn.Module):
"""
Expand Down Expand Up @@ -78,6 +83,7 @@ def __init__(
nn.init.trunc_normal_(self.wi, std=math.sqrt(0.1 / hidden_size))
nn.init.trunc_normal_(self.wo, std=math.sqrt(0.1 / intermediate_size))

self.act_name = activation
self.act = get_activation(activation)
self.drop = nn.Dropout(p=drop_rate)

Expand All @@ -103,7 +109,10 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
x = x.reshape(e, -1, h)

if self.gated:
x = self.act(torch.bmm(x, self.wi_gate)) * torch.bmm(x, self.wi_up)
if HAS_TRITON and self.act_name == "swiglu":
x = LlamaActCombine.apply(torch.bmm(x, self.wi_gate), torch.bmm(x, self.wi_up))
else:
x = self.act(torch.bmm(x, self.wi_gate)) * torch.bmm(x, self.wi_up)
else:
x = torch.bmm(x, self.wi)
x = self.act(x)
Expand Down
4 changes: 2 additions & 2 deletions colossalai/moe/layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import torch.nn as nn
import torch.nn.functional as F

from colossalai.moe._operation import COL_MOE_KERNEL_FLAG, AllGather, AllToAll, MoeCombine, MoeDispatch, ReduceScatter
from colossalai.moe._operation import AllGather, AllToAll, MoeCombine, MoeDispatch, ReduceScatter
from colossalai.moe.experts import BaseMLPExperts, get_expert_class
from colossalai.moe.manager import MOE_MANAGER
from colossalai.moe.routers import MoeRouter, get_router_cls
Expand Down Expand Up @@ -58,7 +58,7 @@ def __init__(self,
super().__init__()
self.hidden_size = hidden_size
self.num_experts = num_experts
self.use_kernel = True if COL_MOE_KERNEL_FLAG and MOE_MANAGER.use_kernel_optim else False
self.use_kernel = True if MOE_MANAGER.use_kernel_optim else False
self.expert_parallel = expert_parallel
assert expert_parallel in ["EP", "TP", None], f"Unsupported expert parallel type {expert_parallel}"

Expand Down
6 changes: 3 additions & 3 deletions examples/language/openmoe/infer.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ def inference(args):
model = OpenMoeForCausalLM(config)
else:
model = OpenMoeForCausalLM.from_pretrained(f"hpcaitech/openmoe-{args.model}")
model = model.eval().bfloat16()
model = model.eval().half()
model = model.to(torch.cuda.current_device())

input_str = """```
Expand All @@ -37,9 +37,9 @@ def inference(args):
What is the value of sum immediately after the 10th time line 3 is executed?"""

# print("model config: ", model.config)
input_ids = tokenizer("<pad>" + input_str, return_tensors="pt", add_special_tokens=True)
input_ids = tokenizer("<pad>" + input_str, return_tensors="pt", add_special_tokens=False)
input_ids = input_ids.input_ids.to(torch.cuda.current_device())
generation_output = model.generate(input_ids, use_cache=True, do_sample=True, max_new_tokens=128)
generation_output = model.generate(input_ids, use_cache=True, do_sample=True, max_new_tokens=16)
out = tokenizer.decode(generation_output[0], skip_special_tokens=False)
print(f"output: \n{out}\n")

Expand Down
Loading