From 65b0760c042b32f3e3e31b91d290ec26c2ac5851 Mon Sep 17 00:00:00 2001 From: oahzxl Date: Wed, 13 Sep 2023 11:12:52 +0800 Subject: [PATCH 1/9] update triton --- .../kernel/triton/llama_act_combine_kernel.py | 199 ++++++++++++++++++ .../openmoe/model/modeling_openmoe.py | 7 +- tests/test_kernels/test_llama_act_combine.py | 52 +++++ 3 files changed, 255 insertions(+), 3 deletions(-) create mode 100644 colossalai/kernel/triton/llama_act_combine_kernel.py create mode 100644 tests/test_kernels/test_llama_act_combine.py diff --git a/colossalai/kernel/triton/llama_act_combine_kernel.py b/colossalai/kernel/triton/llama_act_combine_kernel.py new file mode 100644 index 000000000000..44627554eb17 --- /dev/null +++ b/colossalai/kernel/triton/llama_act_combine_kernel.py @@ -0,0 +1,199 @@ +from functools import reduce +from typing import Any, Tuple + +import torch +from torch import Tensor +from torch.cuda.amp import custom_bwd, custom_fwd + +try: + import triton + import triton.language as tl + HAS_TRITON = True +except ImportError: + HAS_TRITON = False + print("please install triton from https://github.com/openai/triton") + +if HAS_TRITON: + PRECISION_MAP = { + "fp32": (0, torch.float32), + "fp16": (1, torch.float16), + "bf16": (2, torch.bfloat16), + } + + @triton.jit + def _llama_act_combine_forward( + X_GATE1, + X_GATE2, + X_UP, + Y, + stride, # how much to increase the pointer when moving by 1 row + N, # number of columns in X + BLOCK_SIZE: tl.constexpr, + ): + # Map the program id to the row of X and Y it should compute. + row = tl.program_id(0) + X_GATE1 += row * stride + X_GATE2 += row * stride + X_UP += row * stride + Y += row * stride + + # do activation and combine, and store in y + for off in range(0, N, BLOCK_SIZE): + cols = off + tl.arange(0, BLOCK_SIZE) + mask = cols < N + x_gate1 = tl.load(X_GATE1 + cols, mask=mask, other=0.).to(tl.float32) + x_gate2 = tl.load(X_GATE2 + cols, mask=mask, other=0.).to(tl.float32) + x_up = tl.load(X_UP + cols, mask=mask, other=0.).to(tl.float32) + y = x_gate1 * x_gate2 * tl.sigmoid(x_gate2) * x_up + + # if PRECISION == 0: + # pass + # elif PRECISION == 1: + # y = y.to(tl.float16) + # elif PRECISION == 2: + # y = y.to(tl.bfloat16) + + # Write output + tl.store(Y + cols, y, mask=mask) + + @triton.jit + def _llama_act_combine_backward( + X_GATE1, + X_GATE2, + X_UP, + X_GATE1_GRAD, + X_GATE2_GRAD, + X_UP_GRAD, + Y_GRAD, + stride, # how much to increase the pointer when moving by 1 row + N, # number of columns in X + BLOCK_SIZE: tl.constexpr, + ): + # Map the program id to the row of X and Y it should compute. + row = tl.program_id(0) + X_GATE1 += row * stride + X_GATE2 += row * stride + X_UP += row * stride + X_GATE1_GRAD += row * stride + X_GATE2_GRAD += row * stride + X_UP_GRAD += row * stride + Y_GRAD += row * stride + + # do activation and combine, and store in y + for off in range(0, N, BLOCK_SIZE): + cols = off + tl.arange(0, BLOCK_SIZE) + mask = cols < N + x_gate1 = tl.load(X_GATE1 + cols, mask=mask, other=0.).to(tl.float32) + x_gate2 = tl.load(X_GATE2 + cols, mask=mask, other=0.).to(tl.float32) + x_up = tl.load(X_UP + cols, mask=mask, other=0.).to(tl.float32) + y_grad = tl.load(Y_GRAD + cols, mask=mask, other=0.).to(tl.float32) + + # forward: y = x_gate1 * x_gate2 * tl.sigmoid(x_gate2) * x_up + x_gate2_sigmoid = tl.sigmoid(x_gate2) + x_gate2_act = y_grad * x_gate2 * x_gate2_sigmoid + x_up_grad = x_gate2_act * x_gate1 + x_gate1_grad = x_gate2_act * x_up + # grad(x*sigmoid(x)) = sigmoid(x) + x * sigmoid(x) * [1 − sigmoid(x)] + # = sigmoid(x) * {1 + x * [(1 − sigmoid(x)]} + x_gate2_grad = (y_grad * x_gate1 * x_up) * x_gate2_sigmoid * (1 + x_gate2 * (1 - x_gate2_sigmoid)) + + # Write output + tl.store(X_GATE1_GRAD + cols, x_gate1_grad, mask=mask) + tl.store(X_GATE2_GRAD + cols, x_gate2_grad, mask=mask) + tl.store(X_UP_GRAD + cols, x_up_grad, mask=mask) + + class LlamaActCombine(torch.autograd.Function): + """ + act(x_gate) * x_up + + Args: + x_gate (torch.Tensor): (b, l, 2d) x gate + x_up (torch.Tensor): (b, l, d) x up + activation (str): only support swiglu + precision (str): fp32, fp16, bf16 + """ + + @staticmethod + @custom_fwd(cast_inputs=torch.float32) + def forward(ctx: Any, + x_gate: torch.Tensor, + x_up: torch.Tensor, + activation: str = "swiglu", + precision: str = "fp32"): + """ + act(x_gate) * x_up + + Args: + x_gate (torch.Tensor): (b, l, 2d) x gate + x_up (torch.Tensor): (b, l, d) x up + activation (str): only support swiglu + precision (str): fp32, fp16, bf16 + """ + assert activation == "swiglu", "Only swiglu is supported" + assert precision in PRECISION_MAP + precision, dtype = PRECISION_MAP[precision] + + # split x gate + assert x_gate.shape[-1] % 2 == 0, "axis size must be divisible by 2" + x_gate1, x_gate2 = torch.split(x_gate, x_gate.shape[-1] // 2, -1) + x_gate1 = x_gate1.contiguous() + x_gate2 = x_gate2.contiguous() + if not x_up.is_contiguous(): + x_up = x_up.contiguous() + # assert shape + assert x_gate1.shape == x_gate2.shape == x_up.shape + + # add ctx for backward + if x_gate.requires_grad: + ctx.save_for_backward(x_gate1, x_gate2, x_up) + + # allocate output + y = torch.empty_like(x_up, dtype=dtype) + M, N = reduce(lambda x, y: x * y, x_up.shape[:-1]), x_up.shape[-1] + + # Less than 64KB per feature: enqueue fused kernel + MAX_FUSED_SIZE = 65536 // x_gate.element_size() + BLOCK_SIZE = min(MAX_FUSED_SIZE, triton.next_power_of_2(N)) + if N > BLOCK_SIZE: + raise RuntimeError("This layer norm doesn't support feature dim >= 64KB.") + # heuristics for number of warps + num_warps = min(max(BLOCK_SIZE // 256, 1), 8) + # restore setting + ctx.M, ctx.N, ctx.BLOCK_SIZE, ctx.num_warps = M, N, BLOCK_SIZE, num_warps + # enqueue kernel + _llama_act_combine_forward[(M,)](x_gate1, + x_gate2, + x_up, + y, + x_up.stride(-2), + N, + BLOCK_SIZE=BLOCK_SIZE, + num_warps=num_warps) + return y + + @staticmethod + @custom_bwd + def backward(ctx: Any, *grad_outputs: Tensor) -> Tuple[Tensor, None]: + # restore from ctx + (x_gate1, x_gate2, x_up) = ctx.saved_tensors + M, N, BLOCK_SIZE, num_warps = ctx.M, ctx.N, ctx.BLOCK_SIZE, ctx.num_warps + + # init grad + y_grad = grad_outputs[0] + x_gate1_grad, x_gate2_grad, x_up_grad = torch.empty_like(x_gate1), torch.empty_like( + x_gate2), torch.empty_like(x_up) + + # enqueue kernel + _llama_act_combine_backward[(M,)](x_gate1, + x_gate2, + x_up, + x_gate1_grad, + x_gate2_grad, + x_up_grad, + y_grad, + x_up.stride(-2), + N, + BLOCK_SIZE=BLOCK_SIZE, + num_warps=num_warps) + x_gate_grad = torch.cat([x_gate1_grad, x_gate2_grad], dim=-1) + return x_gate_grad, x_up_grad, None, None diff --git a/examples/language/openmoe/model/modeling_openmoe.py b/examples/language/openmoe/model/modeling_openmoe.py index ec7e1e8941f7..84d0d205d31e 100644 --- a/examples/language/openmoe/model/modeling_openmoe.py +++ b/examples/language/openmoe/model/modeling_openmoe.py @@ -95,7 +95,7 @@ def generate_fixed_pos_embedding(features, length, min_timescale=1.0, max_timesc timescale = min_timescale * (max_timescale / min_timescale)**fraction rotational_frequency = 1. / timescale - sinusoid_inp = torch.einsum('i,j->ij', torch.arange(length, dtype=torch.float64).cuda(), rotational_frequency) + sinusoid_inp = torch.einsum('i,j->ij', torch.arange(length, dtype=torch.float32).cuda(), rotational_frequency) sinusoid_inp = torch.cat([sinusoid_inp, sinusoid_inp], dim=-1) @@ -313,6 +313,7 @@ def __init__(self, config: LlamaConfig): self.k_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=False) self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=False) self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=False) + self.sin, self.cos = generate_fixed_pos_embedding(self.head_dim, self.max_position_embeddings, 1e4) self._init_rope() def _init_rope(self): @@ -382,9 +383,9 @@ def forward( query_states = query_states.transpose(1, 2) key_states = key_states.transpose(1, 2) - dim = query_states.shape[-1] max_length = max(query_states.shape[1], key_states.shape[1]) - sin, cos = generate_fixed_pos_embedding(dim, max_length, max_timescale=1e4) + assert max_length <= self.sin.shape[0] + sin, cos = self.sin[:max_length], self.cos[:max_length] query_states, key_states = apply_rotary_embedding(query_states, key_states, cos, diff --git a/tests/test_kernels/test_llama_act_combine.py b/tests/test_kernels/test_llama_act_combine.py new file mode 100644 index 000000000000..0e09d3199ccd --- /dev/null +++ b/tests/test_kernels/test_llama_act_combine.py @@ -0,0 +1,52 @@ +import pytest +import torch +from torch import nn + +from colossalai.kernel.triton.llama_act_combine_kernel import LlamaActCombine + +try: + import triton + HAS_TRITON = True +except ImportError: + HAS_TRITON = False + print("please install triton from https://github.com/openai/triton") + +BATCH_SIZE = 4 +SEQ_LEN = 16 +HIDDEN_SIZE = 32 + + +def SwiGLU(x): + """Gated linear unit activation function. + Args: + x : input array + axis: the axis along which the split should be computed (default: -1) + """ + size = x.shape[-1] + assert size % 2 == 0, "axis size must be divisible by 2" + x1, x2 = torch.split(x, size // 2, -1) + return x1 * (x2 * torch.sigmoid(x2)) + + +@pytest.mark.skipif(not HAS_TRITON, reason="requires triton") +def test_llama_act_combine(): + x_gate = torch.randn(BATCH_SIZE, SEQ_LEN, HIDDEN_SIZE * 2).cuda() + x_gate_torch = nn.Parameter(x_gate.detach().clone()) + x_gate_kernel = nn.Parameter(x_gate.detach().clone()) + x_up = torch.randn(BATCH_SIZE, SEQ_LEN, HIDDEN_SIZE).cuda() + x_up_torch = nn.Parameter(x_up.detach().clone()) + x_up_kernel = nn.Parameter(x_up.detach().clone()) + + torch_out = SwiGLU(x_gate_torch) * x_up_torch + kernel_out = LlamaActCombine.apply(x_gate_kernel, x_up_kernel) + assert torch.allclose(torch_out, kernel_out, atol=1e-5) + + torch_out.mean().backward() + kernel_out.mean().backward() + assert all(grad is not None for grad in [x_gate_torch.grad, x_up_torch.grad, x_gate_kernel.grad, x_up_kernel.grad]) + assert torch.allclose(x_gate_torch.grad, x_gate_kernel.grad, atol=1e-5) + assert torch.allclose(x_up_torch.grad, x_up_kernel.grad, atol=1e-5) + + +if __name__ == '__main__': + test_llama_act_combine() From 5783416755e570d3951b0b540b29242e1a8a1b4e Mon Sep 17 00:00:00 2001 From: oahzxl Date: Wed, 13 Sep 2023 11:16:57 +0800 Subject: [PATCH 2/9] update kernel --- colossalai/kernel/triton/llama_act_combine_kernel.py | 4 ++-- examples/language/openmoe/model/modeling_openmoe.py | 9 ++++++++- 2 files changed, 10 insertions(+), 3 deletions(-) diff --git a/colossalai/kernel/triton/llama_act_combine_kernel.py b/colossalai/kernel/triton/llama_act_combine_kernel.py index 44627554eb17..60ba44fb631e 100644 --- a/colossalai/kernel/triton/llama_act_combine_kernel.py +++ b/colossalai/kernel/triton/llama_act_combine_kernel.py @@ -107,8 +107,8 @@ class LlamaActCombine(torch.autograd.Function): act(x_gate) * x_up Args: - x_gate (torch.Tensor): (b, l, 2d) x gate - x_up (torch.Tensor): (b, l, d) x up + x_gate (torch.Tensor): (b, l, 2d) x_gate + x_up (torch.Tensor): (b, l, d) x_up activation (str): only support swiglu precision (str): fp32, fp16, bf16 """ diff --git a/examples/language/openmoe/model/modeling_openmoe.py b/examples/language/openmoe/model/modeling_openmoe.py index 84d0d205d31e..d203a67d6caa 100644 --- a/examples/language/openmoe/model/modeling_openmoe.py +++ b/examples/language/openmoe/model/modeling_openmoe.py @@ -36,9 +36,13 @@ replace_return_docstrings, ) +from colossalai.kernel.triton.llama_act_combine_kernel import HAS_TRITON from colossalai.moe.layers import SparseMLP from colossalai.moe.manager import MOE_MANAGER +if HAS_TRITON: + from colossalai.kernel.triton.llama_act_combine_kernel import LlamaActCombine + logger = logging.get_logger(__name__) _CONFIG_FOR_DOC = "LlamaConfig" @@ -278,7 +282,10 @@ def forward(self, x): down_proj = [F.linear(intermediate_states[i], down_proj_slices[i]) for i in range(self.pretraining_tp)] down_proj = sum(down_proj) else: - down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x)) + if HAS_TRITON: + down_proj = LlamaActCombine.apply(self.gate_proj(x), self.up_proj(x)) + else: + down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x)) return down_proj From dd37c08af29dc145809c386ac71d6529f0105f86 Mon Sep 17 00:00:00 2001 From: oahzxl Date: Wed, 13 Sep 2023 11:32:52 +0800 Subject: [PATCH 3/9] add init --- colossalai/kernel/triton/__init__.py | 0 1 file changed, 0 insertions(+), 0 deletions(-) create mode 100644 colossalai/kernel/triton/__init__.py diff --git a/colossalai/kernel/triton/__init__.py b/colossalai/kernel/triton/__init__.py new file mode 100644 index 000000000000..e69de29bb2d1 From 7a3e523eacbc87b52c582999c3606b17e2b9c6ce Mon Sep 17 00:00:00 2001 From: oahzxl Date: Wed, 13 Sep 2023 11:54:42 +0800 Subject: [PATCH 4/9] add version check --- colossalai/kernel/triton/llama_act_combine_kernel.py | 4 ++-- tests/test_kernels/test_llama_act_combine.py | 4 +++- 2 files changed, 5 insertions(+), 3 deletions(-) diff --git a/colossalai/kernel/triton/llama_act_combine_kernel.py b/colossalai/kernel/triton/llama_act_combine_kernel.py index 60ba44fb631e..c419d1b61e09 100644 --- a/colossalai/kernel/triton/llama_act_combine_kernel.py +++ b/colossalai/kernel/triton/llama_act_combine_kernel.py @@ -119,7 +119,7 @@ def forward(ctx: Any, x_gate: torch.Tensor, x_up: torch.Tensor, activation: str = "swiglu", - precision: str = "fp32"): + precision: str = "fp32") -> torch.Tensor: """ act(x_gate) * x_up @@ -173,7 +173,7 @@ def forward(ctx: Any, @staticmethod @custom_bwd - def backward(ctx: Any, *grad_outputs: Tensor) -> Tuple[Tensor, None]: + def backward(ctx: Any, *grad_outputs: Tensor) -> Tuple[Tensor, Tensor, None, None]: # restore from ctx (x_gate1, x_gate2, x_up) = ctx.saved_tensors M, N, BLOCK_SIZE, num_warps = ctx.M, ctx.N, ctx.BLOCK_SIZE, ctx.num_warps diff --git a/tests/test_kernels/test_llama_act_combine.py b/tests/test_kernels/test_llama_act_combine.py index 0e09d3199ccd..a1d887cc22c4 100644 --- a/tests/test_kernels/test_llama_act_combine.py +++ b/tests/test_kernels/test_llama_act_combine.py @@ -1,5 +1,6 @@ import pytest import torch +from packaging import version from torch import nn from colossalai.kernel.triton.llama_act_combine_kernel import LlamaActCombine @@ -10,6 +11,7 @@ except ImportError: HAS_TRITON = False print("please install triton from https://github.com/openai/triton") +TRITON_CUDA_SUPPORT = version.parse(torch.version.cuda) > version.parse('11.4') BATCH_SIZE = 4 SEQ_LEN = 16 @@ -28,7 +30,7 @@ def SwiGLU(x): return x1 * (x2 * torch.sigmoid(x2)) -@pytest.mark.skipif(not HAS_TRITON, reason="requires triton") +@pytest.mark.skipif(not (HAS_TRITON and TRITON_CUDA_SUPPORT), reason="requires triton") def test_llama_act_combine(): x_gate = torch.randn(BATCH_SIZE, SEQ_LEN, HIDDEN_SIZE * 2).cuda() x_gate_torch = nn.Parameter(x_gate.detach().clone()) From def965596d0e8dcd19496f3a0d54dc841f89da86 Mon Sep 17 00:00:00 2001 From: oahzxl Date: Wed, 13 Sep 2023 14:07:33 +0800 Subject: [PATCH 5/9] update precision --- .../kernel/triton/llama_act_combine_kernel.py | 40 ++++++------------- .../openmoe/model/modeling_openmoe.py | 2 +- 2 files changed, 14 insertions(+), 28 deletions(-) diff --git a/colossalai/kernel/triton/llama_act_combine_kernel.py b/colossalai/kernel/triton/llama_act_combine_kernel.py index c419d1b61e09..45996c0dca53 100644 --- a/colossalai/kernel/triton/llama_act_combine_kernel.py +++ b/colossalai/kernel/triton/llama_act_combine_kernel.py @@ -41,18 +41,11 @@ def _llama_act_combine_forward( for off in range(0, N, BLOCK_SIZE): cols = off + tl.arange(0, BLOCK_SIZE) mask = cols < N - x_gate1 = tl.load(X_GATE1 + cols, mask=mask, other=0.).to(tl.float32) - x_gate2 = tl.load(X_GATE2 + cols, mask=mask, other=0.).to(tl.float32) - x_up = tl.load(X_UP + cols, mask=mask, other=0.).to(tl.float32) - y = x_gate1 * x_gate2 * tl.sigmoid(x_gate2) * x_up - - # if PRECISION == 0: - # pass - # elif PRECISION == 1: - # y = y.to(tl.float16) - # elif PRECISION == 2: - # y = y.to(tl.bfloat16) - + x_gate1 = tl.load(X_GATE1 + cols, mask=mask, other=0.) + x_gate2 = tl.load(X_GATE2 + cols, mask=mask, other=0.) + x_up = tl.load(X_UP + cols, mask=mask, other=0.) + x_gate2_sigmoid = tl.sigmoid(x_gate2.to(tl.float32)).to(x_gate2.dtype) + y = x_gate1 * x_gate2 * x_gate2_sigmoid * x_up # Write output tl.store(Y + cols, y, mask=mask) @@ -83,13 +76,13 @@ def _llama_act_combine_backward( for off in range(0, N, BLOCK_SIZE): cols = off + tl.arange(0, BLOCK_SIZE) mask = cols < N - x_gate1 = tl.load(X_GATE1 + cols, mask=mask, other=0.).to(tl.float32) - x_gate2 = tl.load(X_GATE2 + cols, mask=mask, other=0.).to(tl.float32) - x_up = tl.load(X_UP + cols, mask=mask, other=0.).to(tl.float32) - y_grad = tl.load(Y_GRAD + cols, mask=mask, other=0.).to(tl.float32) + x_gate1 = tl.load(X_GATE1 + cols, mask=mask, other=0.) + x_gate2 = tl.load(X_GATE2 + cols, mask=mask, other=0.) + x_up = tl.load(X_UP + cols, mask=mask, other=0.) + y_grad = tl.load(Y_GRAD + cols, mask=mask, other=0.) # forward: y = x_gate1 * x_gate2 * tl.sigmoid(x_gate2) * x_up - x_gate2_sigmoid = tl.sigmoid(x_gate2) + x_gate2_sigmoid = tl.sigmoid(x_gate2.to(tl.float32)).to(x_gate2.dtype) x_gate2_act = y_grad * x_gate2 * x_gate2_sigmoid x_up_grad = x_gate2_act * x_gate1 x_gate1_grad = x_gate2_act * x_up @@ -114,12 +107,8 @@ class LlamaActCombine(torch.autograd.Function): """ @staticmethod - @custom_fwd(cast_inputs=torch.float32) - def forward(ctx: Any, - x_gate: torch.Tensor, - x_up: torch.Tensor, - activation: str = "swiglu", - precision: str = "fp32") -> torch.Tensor: + @custom_fwd + def forward(ctx: Any, x_gate: torch.Tensor, x_up: torch.Tensor, activation: str = "swiglu") -> torch.Tensor: """ act(x_gate) * x_up @@ -127,11 +116,8 @@ def forward(ctx: Any, x_gate (torch.Tensor): (b, l, 2d) x gate x_up (torch.Tensor): (b, l, d) x up activation (str): only support swiglu - precision (str): fp32, fp16, bf16 """ assert activation == "swiglu", "Only swiglu is supported" - assert precision in PRECISION_MAP - precision, dtype = PRECISION_MAP[precision] # split x gate assert x_gate.shape[-1] % 2 == 0, "axis size must be divisible by 2" @@ -148,7 +134,7 @@ def forward(ctx: Any, ctx.save_for_backward(x_gate1, x_gate2, x_up) # allocate output - y = torch.empty_like(x_up, dtype=dtype) + y = torch.empty_like(x_up) M, N = reduce(lambda x, y: x * y, x_up.shape[:-1]), x_up.shape[-1] # Less than 64KB per feature: enqueue fused kernel diff --git a/examples/language/openmoe/model/modeling_openmoe.py b/examples/language/openmoe/model/modeling_openmoe.py index d203a67d6caa..07f572a0d67a 100644 --- a/examples/language/openmoe/model/modeling_openmoe.py +++ b/examples/language/openmoe/model/modeling_openmoe.py @@ -283,7 +283,7 @@ def forward(self, x): down_proj = sum(down_proj) else: if HAS_TRITON: - down_proj = LlamaActCombine.apply(self.gate_proj(x), self.up_proj(x)) + down_proj = self.down_proj(LlamaActCombine.apply(self.gate_proj(x), self.up_proj(x))) else: down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x)) From d21fac9da68a52f3f1ab135f85f29e0ce85748bf Mon Sep 17 00:00:00 2001 From: oahzxl Date: Wed, 13 Sep 2023 14:20:08 +0800 Subject: [PATCH 6/9] update precision --- tests/test_kernels/test_llama_act_combine.py | 16 +++++++++------- 1 file changed, 9 insertions(+), 7 deletions(-) diff --git a/tests/test_kernels/test_llama_act_combine.py b/tests/test_kernels/test_llama_act_combine.py index a1d887cc22c4..db7ef4305896 100644 --- a/tests/test_kernels/test_llama_act_combine.py +++ b/tests/test_kernels/test_llama_act_combine.py @@ -27,27 +27,29 @@ def SwiGLU(x): size = x.shape[-1] assert size % 2 == 0, "axis size must be divisible by 2" x1, x2 = torch.split(x, size // 2, -1) - return x1 * (x2 * torch.sigmoid(x2)) + return x1 * (x2 * torch.sigmoid(x2.to(torch.float32)).to(x.dtype)) @pytest.mark.skipif(not (HAS_TRITON and TRITON_CUDA_SUPPORT), reason="requires triton") -def test_llama_act_combine(): - x_gate = torch.randn(BATCH_SIZE, SEQ_LEN, HIDDEN_SIZE * 2).cuda() +@pytest.mark.parametrize("dtype", [torch.float32, torch.float16]) +def test_llama_act_combine(dtype: str = torch.float16): + x_gate = torch.randn(BATCH_SIZE, SEQ_LEN, HIDDEN_SIZE * 2, dtype=dtype).cuda() x_gate_torch = nn.Parameter(x_gate.detach().clone()) x_gate_kernel = nn.Parameter(x_gate.detach().clone()) - x_up = torch.randn(BATCH_SIZE, SEQ_LEN, HIDDEN_SIZE).cuda() + x_up = torch.randn(BATCH_SIZE, SEQ_LEN, HIDDEN_SIZE, dtype=dtype).cuda() x_up_torch = nn.Parameter(x_up.detach().clone()) x_up_kernel = nn.Parameter(x_up.detach().clone()) torch_out = SwiGLU(x_gate_torch) * x_up_torch kernel_out = LlamaActCombine.apply(x_gate_kernel, x_up_kernel) - assert torch.allclose(torch_out, kernel_out, atol=1e-5) + atol = 1e-5 if dtype == torch.float32 else 5e-2 + assert torch.allclose(torch_out, kernel_out, atol=atol) torch_out.mean().backward() kernel_out.mean().backward() assert all(grad is not None for grad in [x_gate_torch.grad, x_up_torch.grad, x_gate_kernel.grad, x_up_kernel.grad]) - assert torch.allclose(x_gate_torch.grad, x_gate_kernel.grad, atol=1e-5) - assert torch.allclose(x_up_torch.grad, x_up_kernel.grad, atol=1e-5) + assert torch.allclose(x_gate_torch.grad, x_gate_kernel.grad, atol=atol) + assert torch.allclose(x_up_torch.grad, x_up_kernel.grad, atol=atol) if __name__ == '__main__': From a95d72116fefaca7a901d73aee74de023a933a4f Mon Sep 17 00:00:00 2001 From: oahzxl Date: Wed, 13 Sep 2023 15:48:33 +0800 Subject: [PATCH 7/9] update kernel in experts --- colossalai/moe/experts.py | 11 ++++++++++- .../language/openmoe/model/modeling_openmoe.py | 17 +++++++++-------- 2 files changed, 19 insertions(+), 9 deletions(-) diff --git a/colossalai/moe/experts.py b/colossalai/moe/experts.py index 9715f4dc37b3..4535d8ab9a85 100644 --- a/colossalai/moe/experts.py +++ b/colossalai/moe/experts.py @@ -4,12 +4,17 @@ import torch import torch.nn as nn + +from colossalai.kernel.triton.llama_act_combine_kernel import HAS_TRITON from colossalai.moe._operation import MoeInGradScaler, MoeOutGradScaler from colossalai.moe.manager import MOE_MANAGER from colossalai.moe.utils import get_activation from colossalai.shardformer.layer.utils import Randomizer from colossalai.tensor.moe_tensor.api import get_ep_size, set_moe_tensor_info +if HAS_TRITON: + from colossalai.kernel.triton.llama_act_combine_kernel import LlamaActCombine + class BaseMLPExperts(nn.Module): """ @@ -78,6 +83,7 @@ def __init__( nn.init.trunc_normal_(self.wi, std=math.sqrt(0.1 / hidden_size)) nn.init.trunc_normal_(self.wo, std=math.sqrt(0.1 / intermediate_size)) + self.act_name = activation self.act = get_activation(activation) self.drop = nn.Dropout(p=drop_rate) @@ -103,7 +109,10 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: x = x.reshape(e, -1, h) if self.gated: - x = self.act(torch.bmm(x, self.wi_gate)) * torch.bmm(x, self.wi_up) + if HAS_TRITON and self.act_name == "swiglu": + x = LlamaActCombine.apply(torch.bmm(x, self.wi_gate), torch.bmm(x, self.wi_up)) + else: + x = self.act(torch.bmm(x, self.wi_gate)) * torch.bmm(x, self.wi_up) else: x = torch.bmm(x, self.wi) x = self.act(x) diff --git a/examples/language/openmoe/model/modeling_openmoe.py b/examples/language/openmoe/model/modeling_openmoe.py index 07f572a0d67a..cb5f441d6c90 100644 --- a/examples/language/openmoe/model/modeling_openmoe.py +++ b/examples/language/openmoe/model/modeling_openmoe.py @@ -125,16 +125,16 @@ def apply_rotary_embedding(q, k, cos, sin, decode=False, rotary_index=None): if decode and qlen == 1 and rotary_index is not None: qcos = cos[rotary_index + 1, :] qsin = sin[rotary_index + 1, :] - qcos = qcos.unsqueeze(2).expand(batch, qlen, qheads, d) - qsin = qsin.unsqueeze(2).expand(batch, qlen, qheads, d) + qcos = qcos.unsqueeze(2) + qsin = qsin.unsqueeze(2) + kcos, ksin = cos[:klen, :], sin[:klen, :] + kcos = kcos.unsqueeze(0).unsqueeze(2) + ksin = ksin.unsqueeze(0).unsqueeze(2) else: qcos, qsin = cos[:qlen, :], sin[:qlen, :] - qcos = qcos.unsqueeze(0).unsqueeze(2).expand(batch, qlen, qheads, d) - qsin = qsin.unsqueeze(0).unsqueeze(2).expand(batch, qlen, qheads, d) - - kcos, ksin = cos[:klen, :], sin[:klen, :] - kcos = kcos.unsqueeze(0).unsqueeze(2).expand(batch, klen, kheads, d) - ksin = ksin.unsqueeze(0).unsqueeze(2).expand(batch, klen, kheads, d) + qcos = qcos.unsqueeze(0).unsqueeze(2) + qsin = qsin.unsqueeze(0).unsqueeze(2) + kcos, ksin = qcos, qsin out_q = (q * qcos) + (rotate_half(q) * qsin) out_k = (k * kcos) + (rotate_half(k) * ksin) @@ -393,6 +393,7 @@ def forward( max_length = max(query_states.shape[1], key_states.shape[1]) assert max_length <= self.sin.shape[0] sin, cos = self.sin[:max_length], self.cos[:max_length] + # TODO: for inference, we can add emb kv into cache to avoid computation query_states, key_states = apply_rotary_embedding(query_states, key_states, cos, From 49416b1f2f591d94a78925a88483c08068f50381 Mon Sep 17 00:00:00 2001 From: oahzxl Date: Wed, 13 Sep 2023 20:29:35 +0800 Subject: [PATCH 8/9] update test arg --- tests/test_kernels/test_llama_act_combine.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/test_kernels/test_llama_act_combine.py b/tests/test_kernels/test_llama_act_combine.py index db7ef4305896..5341aa35ab90 100644 --- a/tests/test_kernels/test_llama_act_combine.py +++ b/tests/test_kernels/test_llama_act_combine.py @@ -32,7 +32,7 @@ def SwiGLU(x): @pytest.mark.skipif(not (HAS_TRITON and TRITON_CUDA_SUPPORT), reason="requires triton") @pytest.mark.parametrize("dtype", [torch.float32, torch.float16]) -def test_llama_act_combine(dtype: str = torch.float16): +def test_llama_act_combine(dtype: str): x_gate = torch.randn(BATCH_SIZE, SEQ_LEN, HIDDEN_SIZE * 2, dtype=dtype).cuda() x_gate_torch = nn.Parameter(x_gate.detach().clone()) x_gate_kernel = nn.Parameter(x_gate.detach().clone()) @@ -53,4 +53,4 @@ def test_llama_act_combine(dtype: str = torch.float16): if __name__ == '__main__': - test_llama_act_combine() + test_llama_act_combine(torch.float16) From 0f20765652fe93640885e679fe669a8b53d31106 Mon Sep 17 00:00:00 2001 From: oahzxl Date: Thu, 14 Sep 2023 09:57:41 +0800 Subject: [PATCH 9/9] update settings --- colossalai/moe/_operation.py | 29 +++-------------------------- colossalai/moe/layers.py | 4 ++-- examples/language/openmoe/infer.py | 6 +++--- 3 files changed, 8 insertions(+), 31 deletions(-) diff --git a/colossalai/moe/_operation.py b/colossalai/moe/_operation.py index a0753d8581b4..bde457947e3f 100644 --- a/colossalai/moe/_operation.py +++ b/colossalai/moe/_operation.py @@ -5,20 +5,11 @@ from torch import Tensor from torch.distributed import ProcessGroup -COL_MOE_KERNEL_FLAG = False - try: from colossalai._C import moe except: - moe = None - - -def build_moe_if_not_prebuilt(): - # load moe kernel during runtime if not pre-built - global moe - if moe is None: - from colossalai.kernel.op_builder import MOEBuilder - moe = MOEBuilder().load() + from colossalai.kernel.op_builder import MOEBuilder + moe = MOEBuilder().load() class AllGather(torch.autograd.Function): @@ -26,12 +17,6 @@ class AllGather(torch.autograd.Function): @staticmethod def forward(ctx: Any, inputs: Tensor, group: Optional[ProcessGroup] = None) -> Tensor: - global moe - - if moe is None: - from colossalai.kernel.op_builder import MOEBuilder - moe = MOEBuilder().load() - if ctx is not None: ctx.comm_grp = group @@ -104,9 +89,6 @@ def forward(ctx, tokens, mask, dest_idx, ec): s = tokens.size(0) h = tokens.size(1) - # load moe kernel during runtime if not pre-built - build_moe_if_not_prebuilt() - expert_input = moe.dispatch_forward(s, ec, h, tokens, mask, dest_idx) ctx.save_for_backward(mask, dest_idx) @@ -134,9 +116,6 @@ def forward(ctx, expert_tokens, logits, mask, dest_idx, ec): c = ec // e h = expert_tokens.size(-1) - # load moe kernel during runtime if not pre-built - build_moe_if_not_prebuilt() - fp16_flag = (expert_tokens.dtype == torch.float16) cb_input = expert_tokens.to(torch.float32) if fp16_flag else expert_tokens ctokens = moe.combine_forward(s, e, c, h, cb_input, logits, mask, dest_idx) @@ -167,9 +146,7 @@ def backward(ctx, tokens_grad): def moe_cumsum(inputs: Tensor): dim0 = inputs.size(0) flag = (dim0 <= 1024) or (dim0 <= 2048 and dim0 % 2 == 0) or (dim0 % 4 == 0) - if flag and COL_MOE_KERNEL_FLAG: - # load moe kernel during runtime if not pre-built - build_moe_if_not_prebuilt() + if flag: return moe.cumsum_sub_one(inputs) else: return torch.cumsum(inputs, dim=0) - 1 diff --git a/colossalai/moe/layers.py b/colossalai/moe/layers.py index a3f68cf7a6f1..1255a4816041 100644 --- a/colossalai/moe/layers.py +++ b/colossalai/moe/layers.py @@ -5,7 +5,7 @@ import torch.nn as nn import torch.nn.functional as F -from colossalai.moe._operation import COL_MOE_KERNEL_FLAG, AllGather, AllToAll, MoeCombine, MoeDispatch, ReduceScatter +from colossalai.moe._operation import AllGather, AllToAll, MoeCombine, MoeDispatch, ReduceScatter from colossalai.moe.experts import BaseMLPExperts, get_expert_class from colossalai.moe.manager import MOE_MANAGER from colossalai.moe.routers import MoeRouter, get_router_cls @@ -58,7 +58,7 @@ def __init__(self, super().__init__() self.hidden_size = hidden_size self.num_experts = num_experts - self.use_kernel = True if COL_MOE_KERNEL_FLAG and MOE_MANAGER.use_kernel_optim else False + self.use_kernel = True if MOE_MANAGER.use_kernel_optim else False self.expert_parallel = expert_parallel assert expert_parallel in ["EP", "TP", None], f"Unsupported expert parallel type {expert_parallel}" diff --git a/examples/language/openmoe/infer.py b/examples/language/openmoe/infer.py index b41fa2f2e4f1..f59772189827 100644 --- a/examples/language/openmoe/infer.py +++ b/examples/language/openmoe/infer.py @@ -20,7 +20,7 @@ def inference(args): model = OpenMoeForCausalLM(config) else: model = OpenMoeForCausalLM.from_pretrained(f"hpcaitech/openmoe-{args.model}") - model = model.eval().bfloat16() + model = model.eval().half() model = model.to(torch.cuda.current_device()) input_str = """``` @@ -37,9 +37,9 @@ def inference(args): What is the value of sum immediately after the 10th time line 3 is executed?""" # print("model config: ", model.config) - input_ids = tokenizer("" + input_str, return_tensors="pt", add_special_tokens=True) + input_ids = tokenizer("" + input_str, return_tensors="pt", add_special_tokens=False) input_ids = input_ids.input_ids.to(torch.cuda.current_device()) - generation_output = model.generate(input_ids, use_cache=True, do_sample=True, max_new_tokens=128) + generation_output = model.generate(input_ids, use_cache=True, do_sample=True, max_new_tokens=16) out = tokenizer.decode(generation_output[0], skip_special_tokens=False) print(f"output: \n{out}\n")