diff --git a/colossalai/kernel/triton/__init__.py b/colossalai/kernel/triton/__init__.py new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/colossalai/kernel/triton/llama_act_combine_kernel.py b/colossalai/kernel/triton/llama_act_combine_kernel.py new file mode 100644 index 000000000000..45996c0dca53 --- /dev/null +++ b/colossalai/kernel/triton/llama_act_combine_kernel.py @@ -0,0 +1,185 @@ +from functools import reduce +from typing import Any, Tuple + +import torch +from torch import Tensor +from torch.cuda.amp import custom_bwd, custom_fwd + +try: + import triton + import triton.language as tl + HAS_TRITON = True +except ImportError: + HAS_TRITON = False + print("please install triton from https://github.com/openai/triton") + +if HAS_TRITON: + PRECISION_MAP = { + "fp32": (0, torch.float32), + "fp16": (1, torch.float16), + "bf16": (2, torch.bfloat16), + } + + @triton.jit + def _llama_act_combine_forward( + X_GATE1, + X_GATE2, + X_UP, + Y, + stride, # how much to increase the pointer when moving by 1 row + N, # number of columns in X + BLOCK_SIZE: tl.constexpr, + ): + # Map the program id to the row of X and Y it should compute. + row = tl.program_id(0) + X_GATE1 += row * stride + X_GATE2 += row * stride + X_UP += row * stride + Y += row * stride + + # do activation and combine, and store in y + for off in range(0, N, BLOCK_SIZE): + cols = off + tl.arange(0, BLOCK_SIZE) + mask = cols < N + x_gate1 = tl.load(X_GATE1 + cols, mask=mask, other=0.) + x_gate2 = tl.load(X_GATE2 + cols, mask=mask, other=0.) + x_up = tl.load(X_UP + cols, mask=mask, other=0.) + x_gate2_sigmoid = tl.sigmoid(x_gate2.to(tl.float32)).to(x_gate2.dtype) + y = x_gate1 * x_gate2 * x_gate2_sigmoid * x_up + # Write output + tl.store(Y + cols, y, mask=mask) + + @triton.jit + def _llama_act_combine_backward( + X_GATE1, + X_GATE2, + X_UP, + X_GATE1_GRAD, + X_GATE2_GRAD, + X_UP_GRAD, + Y_GRAD, + stride, # how much to increase the pointer when moving by 1 row + N, # number of columns in X + BLOCK_SIZE: tl.constexpr, + ): + # Map the program id to the row of X and Y it should compute. + row = tl.program_id(0) + X_GATE1 += row * stride + X_GATE2 += row * stride + X_UP += row * stride + X_GATE1_GRAD += row * stride + X_GATE2_GRAD += row * stride + X_UP_GRAD += row * stride + Y_GRAD += row * stride + + # do activation and combine, and store in y + for off in range(0, N, BLOCK_SIZE): + cols = off + tl.arange(0, BLOCK_SIZE) + mask = cols < N + x_gate1 = tl.load(X_GATE1 + cols, mask=mask, other=0.) + x_gate2 = tl.load(X_GATE2 + cols, mask=mask, other=0.) + x_up = tl.load(X_UP + cols, mask=mask, other=0.) + y_grad = tl.load(Y_GRAD + cols, mask=mask, other=0.) + + # forward: y = x_gate1 * x_gate2 * tl.sigmoid(x_gate2) * x_up + x_gate2_sigmoid = tl.sigmoid(x_gate2.to(tl.float32)).to(x_gate2.dtype) + x_gate2_act = y_grad * x_gate2 * x_gate2_sigmoid + x_up_grad = x_gate2_act * x_gate1 + x_gate1_grad = x_gate2_act * x_up + # grad(x*sigmoid(x)) = sigmoid(x) + x * sigmoid(x) * [1 − sigmoid(x)] + # = sigmoid(x) * {1 + x * [(1 − sigmoid(x)]} + x_gate2_grad = (y_grad * x_gate1 * x_up) * x_gate2_sigmoid * (1 + x_gate2 * (1 - x_gate2_sigmoid)) + + # Write output + tl.store(X_GATE1_GRAD + cols, x_gate1_grad, mask=mask) + tl.store(X_GATE2_GRAD + cols, x_gate2_grad, mask=mask) + tl.store(X_UP_GRAD + cols, x_up_grad, mask=mask) + + class LlamaActCombine(torch.autograd.Function): + """ + act(x_gate) * x_up + + Args: + x_gate (torch.Tensor): (b, l, 2d) x_gate + x_up (torch.Tensor): (b, l, d) x_up + activation (str): only support swiglu + precision (str): fp32, fp16, bf16 + """ + + @staticmethod + @custom_fwd + def forward(ctx: Any, x_gate: torch.Tensor, x_up: torch.Tensor, activation: str = "swiglu") -> torch.Tensor: + """ + act(x_gate) * x_up + + Args: + x_gate (torch.Tensor): (b, l, 2d) x gate + x_up (torch.Tensor): (b, l, d) x up + activation (str): only support swiglu + """ + assert activation == "swiglu", "Only swiglu is supported" + + # split x gate + assert x_gate.shape[-1] % 2 == 0, "axis size must be divisible by 2" + x_gate1, x_gate2 = torch.split(x_gate, x_gate.shape[-1] // 2, -1) + x_gate1 = x_gate1.contiguous() + x_gate2 = x_gate2.contiguous() + if not x_up.is_contiguous(): + x_up = x_up.contiguous() + # assert shape + assert x_gate1.shape == x_gate2.shape == x_up.shape + + # add ctx for backward + if x_gate.requires_grad: + ctx.save_for_backward(x_gate1, x_gate2, x_up) + + # allocate output + y = torch.empty_like(x_up) + M, N = reduce(lambda x, y: x * y, x_up.shape[:-1]), x_up.shape[-1] + + # Less than 64KB per feature: enqueue fused kernel + MAX_FUSED_SIZE = 65536 // x_gate.element_size() + BLOCK_SIZE = min(MAX_FUSED_SIZE, triton.next_power_of_2(N)) + if N > BLOCK_SIZE: + raise RuntimeError("This layer norm doesn't support feature dim >= 64KB.") + # heuristics for number of warps + num_warps = min(max(BLOCK_SIZE // 256, 1), 8) + # restore setting + ctx.M, ctx.N, ctx.BLOCK_SIZE, ctx.num_warps = M, N, BLOCK_SIZE, num_warps + # enqueue kernel + _llama_act_combine_forward[(M,)](x_gate1, + x_gate2, + x_up, + y, + x_up.stride(-2), + N, + BLOCK_SIZE=BLOCK_SIZE, + num_warps=num_warps) + return y + + @staticmethod + @custom_bwd + def backward(ctx: Any, *grad_outputs: Tensor) -> Tuple[Tensor, Tensor, None, None]: + # restore from ctx + (x_gate1, x_gate2, x_up) = ctx.saved_tensors + M, N, BLOCK_SIZE, num_warps = ctx.M, ctx.N, ctx.BLOCK_SIZE, ctx.num_warps + + # init grad + y_grad = grad_outputs[0] + x_gate1_grad, x_gate2_grad, x_up_grad = torch.empty_like(x_gate1), torch.empty_like( + x_gate2), torch.empty_like(x_up) + + # enqueue kernel + _llama_act_combine_backward[(M,)](x_gate1, + x_gate2, + x_up, + x_gate1_grad, + x_gate2_grad, + x_up_grad, + y_grad, + x_up.stride(-2), + N, + BLOCK_SIZE=BLOCK_SIZE, + num_warps=num_warps) + x_gate_grad = torch.cat([x_gate1_grad, x_gate2_grad], dim=-1) + return x_gate_grad, x_up_grad, None, None diff --git a/colossalai/moe/_operation.py b/colossalai/moe/_operation.py index a0753d8581b4..bde457947e3f 100644 --- a/colossalai/moe/_operation.py +++ b/colossalai/moe/_operation.py @@ -5,20 +5,11 @@ from torch import Tensor from torch.distributed import ProcessGroup -COL_MOE_KERNEL_FLAG = False - try: from colossalai._C import moe except: - moe = None - - -def build_moe_if_not_prebuilt(): - # load moe kernel during runtime if not pre-built - global moe - if moe is None: - from colossalai.kernel.op_builder import MOEBuilder - moe = MOEBuilder().load() + from colossalai.kernel.op_builder import MOEBuilder + moe = MOEBuilder().load() class AllGather(torch.autograd.Function): @@ -26,12 +17,6 @@ class AllGather(torch.autograd.Function): @staticmethod def forward(ctx: Any, inputs: Tensor, group: Optional[ProcessGroup] = None) -> Tensor: - global moe - - if moe is None: - from colossalai.kernel.op_builder import MOEBuilder - moe = MOEBuilder().load() - if ctx is not None: ctx.comm_grp = group @@ -104,9 +89,6 @@ def forward(ctx, tokens, mask, dest_idx, ec): s = tokens.size(0) h = tokens.size(1) - # load moe kernel during runtime if not pre-built - build_moe_if_not_prebuilt() - expert_input = moe.dispatch_forward(s, ec, h, tokens, mask, dest_idx) ctx.save_for_backward(mask, dest_idx) @@ -134,9 +116,6 @@ def forward(ctx, expert_tokens, logits, mask, dest_idx, ec): c = ec // e h = expert_tokens.size(-1) - # load moe kernel during runtime if not pre-built - build_moe_if_not_prebuilt() - fp16_flag = (expert_tokens.dtype == torch.float16) cb_input = expert_tokens.to(torch.float32) if fp16_flag else expert_tokens ctokens = moe.combine_forward(s, e, c, h, cb_input, logits, mask, dest_idx) @@ -167,9 +146,7 @@ def backward(ctx, tokens_grad): def moe_cumsum(inputs: Tensor): dim0 = inputs.size(0) flag = (dim0 <= 1024) or (dim0 <= 2048 and dim0 % 2 == 0) or (dim0 % 4 == 0) - if flag and COL_MOE_KERNEL_FLAG: - # load moe kernel during runtime if not pre-built - build_moe_if_not_prebuilt() + if flag: return moe.cumsum_sub_one(inputs) else: return torch.cumsum(inputs, dim=0) - 1 diff --git a/colossalai/moe/experts.py b/colossalai/moe/experts.py index 9715f4dc37b3..4535d8ab9a85 100644 --- a/colossalai/moe/experts.py +++ b/colossalai/moe/experts.py @@ -4,12 +4,17 @@ import torch import torch.nn as nn + +from colossalai.kernel.triton.llama_act_combine_kernel import HAS_TRITON from colossalai.moe._operation import MoeInGradScaler, MoeOutGradScaler from colossalai.moe.manager import MOE_MANAGER from colossalai.moe.utils import get_activation from colossalai.shardformer.layer.utils import Randomizer from colossalai.tensor.moe_tensor.api import get_ep_size, set_moe_tensor_info +if HAS_TRITON: + from colossalai.kernel.triton.llama_act_combine_kernel import LlamaActCombine + class BaseMLPExperts(nn.Module): """ @@ -78,6 +83,7 @@ def __init__( nn.init.trunc_normal_(self.wi, std=math.sqrt(0.1 / hidden_size)) nn.init.trunc_normal_(self.wo, std=math.sqrt(0.1 / intermediate_size)) + self.act_name = activation self.act = get_activation(activation) self.drop = nn.Dropout(p=drop_rate) @@ -103,7 +109,10 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: x = x.reshape(e, -1, h) if self.gated: - x = self.act(torch.bmm(x, self.wi_gate)) * torch.bmm(x, self.wi_up) + if HAS_TRITON and self.act_name == "swiglu": + x = LlamaActCombine.apply(torch.bmm(x, self.wi_gate), torch.bmm(x, self.wi_up)) + else: + x = self.act(torch.bmm(x, self.wi_gate)) * torch.bmm(x, self.wi_up) else: x = torch.bmm(x, self.wi) x = self.act(x) diff --git a/colossalai/moe/layers.py b/colossalai/moe/layers.py index a3f68cf7a6f1..1255a4816041 100644 --- a/colossalai/moe/layers.py +++ b/colossalai/moe/layers.py @@ -5,7 +5,7 @@ import torch.nn as nn import torch.nn.functional as F -from colossalai.moe._operation import COL_MOE_KERNEL_FLAG, AllGather, AllToAll, MoeCombine, MoeDispatch, ReduceScatter +from colossalai.moe._operation import AllGather, AllToAll, MoeCombine, MoeDispatch, ReduceScatter from colossalai.moe.experts import BaseMLPExperts, get_expert_class from colossalai.moe.manager import MOE_MANAGER from colossalai.moe.routers import MoeRouter, get_router_cls @@ -58,7 +58,7 @@ def __init__(self, super().__init__() self.hidden_size = hidden_size self.num_experts = num_experts - self.use_kernel = True if COL_MOE_KERNEL_FLAG and MOE_MANAGER.use_kernel_optim else False + self.use_kernel = True if MOE_MANAGER.use_kernel_optim else False self.expert_parallel = expert_parallel assert expert_parallel in ["EP", "TP", None], f"Unsupported expert parallel type {expert_parallel}" diff --git a/examples/language/openmoe/infer.py b/examples/language/openmoe/infer.py index b41fa2f2e4f1..f59772189827 100644 --- a/examples/language/openmoe/infer.py +++ b/examples/language/openmoe/infer.py @@ -20,7 +20,7 @@ def inference(args): model = OpenMoeForCausalLM(config) else: model = OpenMoeForCausalLM.from_pretrained(f"hpcaitech/openmoe-{args.model}") - model = model.eval().bfloat16() + model = model.eval().half() model = model.to(torch.cuda.current_device()) input_str = """``` @@ -37,9 +37,9 @@ def inference(args): What is the value of sum immediately after the 10th time line 3 is executed?""" # print("model config: ", model.config) - input_ids = tokenizer("" + input_str, return_tensors="pt", add_special_tokens=True) + input_ids = tokenizer("" + input_str, return_tensors="pt", add_special_tokens=False) input_ids = input_ids.input_ids.to(torch.cuda.current_device()) - generation_output = model.generate(input_ids, use_cache=True, do_sample=True, max_new_tokens=128) + generation_output = model.generate(input_ids, use_cache=True, do_sample=True, max_new_tokens=16) out = tokenizer.decode(generation_output[0], skip_special_tokens=False) print(f"output: \n{out}\n") diff --git a/examples/language/openmoe/model/modeling_openmoe.py b/examples/language/openmoe/model/modeling_openmoe.py index ec7e1e8941f7..cb5f441d6c90 100644 --- a/examples/language/openmoe/model/modeling_openmoe.py +++ b/examples/language/openmoe/model/modeling_openmoe.py @@ -36,9 +36,13 @@ replace_return_docstrings, ) +from colossalai.kernel.triton.llama_act_combine_kernel import HAS_TRITON from colossalai.moe.layers import SparseMLP from colossalai.moe.manager import MOE_MANAGER +if HAS_TRITON: + from colossalai.kernel.triton.llama_act_combine_kernel import LlamaActCombine + logger = logging.get_logger(__name__) _CONFIG_FOR_DOC = "LlamaConfig" @@ -95,7 +99,7 @@ def generate_fixed_pos_embedding(features, length, min_timescale=1.0, max_timesc timescale = min_timescale * (max_timescale / min_timescale)**fraction rotational_frequency = 1. / timescale - sinusoid_inp = torch.einsum('i,j->ij', torch.arange(length, dtype=torch.float64).cuda(), rotational_frequency) + sinusoid_inp = torch.einsum('i,j->ij', torch.arange(length, dtype=torch.float32).cuda(), rotational_frequency) sinusoid_inp = torch.cat([sinusoid_inp, sinusoid_inp], dim=-1) @@ -121,16 +125,16 @@ def apply_rotary_embedding(q, k, cos, sin, decode=False, rotary_index=None): if decode and qlen == 1 and rotary_index is not None: qcos = cos[rotary_index + 1, :] qsin = sin[rotary_index + 1, :] - qcos = qcos.unsqueeze(2).expand(batch, qlen, qheads, d) - qsin = qsin.unsqueeze(2).expand(batch, qlen, qheads, d) + qcos = qcos.unsqueeze(2) + qsin = qsin.unsqueeze(2) + kcos, ksin = cos[:klen, :], sin[:klen, :] + kcos = kcos.unsqueeze(0).unsqueeze(2) + ksin = ksin.unsqueeze(0).unsqueeze(2) else: qcos, qsin = cos[:qlen, :], sin[:qlen, :] - qcos = qcos.unsqueeze(0).unsqueeze(2).expand(batch, qlen, qheads, d) - qsin = qsin.unsqueeze(0).unsqueeze(2).expand(batch, qlen, qheads, d) - - kcos, ksin = cos[:klen, :], sin[:klen, :] - kcos = kcos.unsqueeze(0).unsqueeze(2).expand(batch, klen, kheads, d) - ksin = ksin.unsqueeze(0).unsqueeze(2).expand(batch, klen, kheads, d) + qcos = qcos.unsqueeze(0).unsqueeze(2) + qsin = qsin.unsqueeze(0).unsqueeze(2) + kcos, ksin = qcos, qsin out_q = (q * qcos) + (rotate_half(q) * qsin) out_k = (k * kcos) + (rotate_half(k) * ksin) @@ -278,7 +282,10 @@ def forward(self, x): down_proj = [F.linear(intermediate_states[i], down_proj_slices[i]) for i in range(self.pretraining_tp)] down_proj = sum(down_proj) else: - down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x)) + if HAS_TRITON: + down_proj = self.down_proj(LlamaActCombine.apply(self.gate_proj(x), self.up_proj(x))) + else: + down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x)) return down_proj @@ -313,6 +320,7 @@ def __init__(self, config: LlamaConfig): self.k_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=False) self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=False) self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=False) + self.sin, self.cos = generate_fixed_pos_embedding(self.head_dim, self.max_position_embeddings, 1e4) self._init_rope() def _init_rope(self): @@ -382,9 +390,10 @@ def forward( query_states = query_states.transpose(1, 2) key_states = key_states.transpose(1, 2) - dim = query_states.shape[-1] max_length = max(query_states.shape[1], key_states.shape[1]) - sin, cos = generate_fixed_pos_embedding(dim, max_length, max_timescale=1e4) + assert max_length <= self.sin.shape[0] + sin, cos = self.sin[:max_length], self.cos[:max_length] + # TODO: for inference, we can add emb kv into cache to avoid computation query_states, key_states = apply_rotary_embedding(query_states, key_states, cos, diff --git a/tests/test_kernels/test_llama_act_combine.py b/tests/test_kernels/test_llama_act_combine.py new file mode 100644 index 000000000000..5341aa35ab90 --- /dev/null +++ b/tests/test_kernels/test_llama_act_combine.py @@ -0,0 +1,56 @@ +import pytest +import torch +from packaging import version +from torch import nn + +from colossalai.kernel.triton.llama_act_combine_kernel import LlamaActCombine + +try: + import triton + HAS_TRITON = True +except ImportError: + HAS_TRITON = False + print("please install triton from https://github.com/openai/triton") +TRITON_CUDA_SUPPORT = version.parse(torch.version.cuda) > version.parse('11.4') + +BATCH_SIZE = 4 +SEQ_LEN = 16 +HIDDEN_SIZE = 32 + + +def SwiGLU(x): + """Gated linear unit activation function. + Args: + x : input array + axis: the axis along which the split should be computed (default: -1) + """ + size = x.shape[-1] + assert size % 2 == 0, "axis size must be divisible by 2" + x1, x2 = torch.split(x, size // 2, -1) + return x1 * (x2 * torch.sigmoid(x2.to(torch.float32)).to(x.dtype)) + + +@pytest.mark.skipif(not (HAS_TRITON and TRITON_CUDA_SUPPORT), reason="requires triton") +@pytest.mark.parametrize("dtype", [torch.float32, torch.float16]) +def test_llama_act_combine(dtype: str): + x_gate = torch.randn(BATCH_SIZE, SEQ_LEN, HIDDEN_SIZE * 2, dtype=dtype).cuda() + x_gate_torch = nn.Parameter(x_gate.detach().clone()) + x_gate_kernel = nn.Parameter(x_gate.detach().clone()) + x_up = torch.randn(BATCH_SIZE, SEQ_LEN, HIDDEN_SIZE, dtype=dtype).cuda() + x_up_torch = nn.Parameter(x_up.detach().clone()) + x_up_kernel = nn.Parameter(x_up.detach().clone()) + + torch_out = SwiGLU(x_gate_torch) * x_up_torch + kernel_out = LlamaActCombine.apply(x_gate_kernel, x_up_kernel) + atol = 1e-5 if dtype == torch.float32 else 5e-2 + assert torch.allclose(torch_out, kernel_out, atol=atol) + + torch_out.mean().backward() + kernel_out.mean().backward() + assert all(grad is not None for grad in [x_gate_torch.grad, x_up_torch.grad, x_gate_kernel.grad, x_up_kernel.grad]) + assert torch.allclose(x_gate_torch.grad, x_gate_kernel.grad, atol=atol) + assert torch.allclose(x_up_torch.grad, x_up_kernel.grad, atol=atol) + + +if __name__ == '__main__': + test_llama_act_combine(torch.float16)