diff --git a/LICENSE b/LICENSE index c7a5bb16880e..3f2167477589 100644 --- a/LICENSE +++ b/LICENSE @@ -396,3 +396,21 @@ Copyright 2021- HPC-AI Technology Inc. All rights reserved. CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + + ---------------- LICENSE FOR VLLM TEAM ---------------- + + from VLLM TEAM: + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + https://github.com/vllm-project/vllm/blob/main/LICENSE + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. + + diff --git a/colossalai/kernel/cuda_native/csrc/attention_infer_kernels/rmsnorm/layernorm.cpp b/colossalai/kernel/cuda_native/csrc/attention_infer_kernels/rmsnorm/layernorm.cpp new file mode 100644 index 000000000000..30d236dee895 --- /dev/null +++ b/colossalai/kernel/cuda_native/csrc/attention_infer_kernels/rmsnorm/layernorm.cpp @@ -0,0 +1,18 @@ +/* Copyright 2021 The Colossal-AI Team + Copyright (c) 2023, The vLLM team. + This file is adapted from vllm TEAM: https://github.com/vllm-project/vllm/blob/main/csrc/layernorm.cpp +*/ +#include + +void rms_norm( + torch::Tensor& out, + torch::Tensor& input, + torch::Tensor& weight, + float epsilon); + +PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { + m.def( + "rms_norm", + &rms_norm, + "Apply Root Mean Square (RMS) Normalization to the input tensor."); +} diff --git a/colossalai/kernel/cuda_native/csrc/attention_infer_kernels/rmsnorm/layernorm_kernels.cu b/colossalai/kernel/cuda_native/csrc/attention_infer_kernels/rmsnorm/layernorm_kernels.cu new file mode 100644 index 000000000000..1f6bff6a13d1 --- /dev/null +++ b/colossalai/kernel/cuda_native/csrc/attention_infer_kernels/rmsnorm/layernorm_kernels.cu @@ -0,0 +1,63 @@ +/* Copyright 2021 The Colossal-AI Team + Copyright (c) 2023, The vLLM team. + This file is adapted from vllm TEAM: https://github.com/vllm-project/vllm/blob/main/csrc/layernorm_kernels.cu +*/ +#include +#include + +#include "reduction_utils.cuh" + +template +__global__ void rms_norm_kernel( + scalar_t* __restrict__ out, // [num_tokens, hidden_size] + const scalar_t* __restrict__ input, // [num_tokens, hidden_size] + const scalar_t* __restrict__ weight, // [hidden_size] + const float epsilon, + const int num_tokens, + const int hidden_size) { + __shared__ float s_variance; + float variance = 0.0f; + + for (int idx = threadIdx.x; idx < hidden_size; idx += blockDim.x) { + const float x = (float) input[blockIdx.x * hidden_size + idx]; + variance += x * x; + } + variance = blockReduceSum(variance); + if (threadIdx.x == 0) { + s_variance = rsqrtf(variance / hidden_size + epsilon); + } + __syncthreads(); + + for (int idx = threadIdx.x; idx < hidden_size; idx += blockDim.x) { + float x = (float) input[blockIdx.x * hidden_size + idx]; + out[blockIdx.x * hidden_size + idx] = ((scalar_t) (x * s_variance)) * weight[idx]; + } +} + + +void rms_norm( + torch::Tensor& out, // [num_tokens, hidden_size] + torch::Tensor& input, // [num_tokens, hidden_size] + torch::Tensor& weight, // [hidden_size] + float epsilon) { + int num_tokens = input.size(0); + int hidden_size = input.size(1); + + dim3 grid(num_tokens); + dim3 block(std::min(hidden_size, 1024)); + const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + AT_DISPATCH_FLOATING_TYPES_AND2( + at::ScalarType::Half, + at::ScalarType::BFloat16, + input.scalar_type(), + "rms_norm_kernel", + [&] { + rms_norm_kernel<<>>( + out.data_ptr(), + input.data_ptr(), + weight.data_ptr(), + epsilon, + num_tokens, + hidden_size); + }); +} \ No newline at end of file diff --git a/colossalai/kernel/cuda_native/csrc/attention_infer_kernels/rmsnorm/reduction_utils.cuh b/colossalai/kernel/cuda_native/csrc/attention_infer_kernels/rmsnorm/reduction_utils.cuh new file mode 100644 index 000000000000..2d47c5222084 --- /dev/null +++ b/colossalai/kernel/cuda_native/csrc/attention_infer_kernels/rmsnorm/reduction_utils.cuh @@ -0,0 +1,50 @@ +/* + * Adapted from https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/kernels/reduce_kernel_utils.cuh + * Copyright (c) 2023, The vLLM team. + * Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + + +template +__inline__ __device__ T warpReduceSum(T val) { +#pragma unroll + for (int mask = 16; mask > 0; mask >>= 1) + val += __shfl_xor_sync(0xffffffff, val, mask, 32); + return val; +} + +/* Calculate the sum of all elements in a block */ +template +__inline__ __device__ T blockReduceSum(T val) { + static __shared__ T shared[32]; + int lane = threadIdx.x & 0x1f; + int wid = threadIdx.x >> 5; + + val = warpReduceSum(val); + + if (lane == 0) + shared[wid] = val; + + __syncthreads(); + + // Modify from blockDim.x << 5 to blockDim.x / 32. to prevent + // blockDim.x is not divided by 32 + val = (threadIdx.x < (blockDim.x / 32.f)) ? shared[lane] : (T)(0.0f); + val = warpReduceSum(val); + return val; +} + diff --git a/colossalai/kernel/cuda_native/csrc/attention_infer_kernels/rotary_embedding/pos_encoding.cpp b/colossalai/kernel/cuda_native/csrc/attention_infer_kernels/rotary_embedding/pos_encoding.cpp new file mode 100644 index 000000000000..21214d764d36 --- /dev/null +++ b/colossalai/kernel/cuda_native/csrc/attention_infer_kernels/rotary_embedding/pos_encoding.cpp @@ -0,0 +1,19 @@ +/* Copyright 2021 The Colossal-AI Team + Copyright (c) 2023, The vLLM team. + This file is adapted from vllm TEAM: https://github.com/vllm-project/vllm/blob/main/csrc/pos_encoding.cpp +*/ +#include + +void rotary_embedding_neox( + torch::Tensor& positions, + torch::Tensor& query, + torch::Tensor& key, + int head_size, + torch::Tensor& cos_sin_cache); + +PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { + m.def( + "rotary_embedding_neox", + &rotary_embedding_neox, + "Apply GPT-NeoX style rotary embedding to query and key"); +} diff --git a/colossalai/kernel/cuda_native/csrc/attention_infer_kernels/rotary_embedding/pos_encoding_kernels.cu b/colossalai/kernel/cuda_native/csrc/attention_infer_kernels/rotary_embedding/pos_encoding_kernels.cu new file mode 100644 index 000000000000..3b0767e23938 --- /dev/null +++ b/colossalai/kernel/cuda_native/csrc/attention_infer_kernels/rotary_embedding/pos_encoding_kernels.cu @@ -0,0 +1,90 @@ +/* Copyright 2021 The Colossal-AI Team + Copyright (c) 2023, The vLLM team. + This file is adapted from vllm TEAM: https://github.com/vllm-project/vllm/blob/main/csrc/pos_encoding_kernels.cu +*/ +#include +#include + + +template +__global__ void rotary_embedding_neox_kernel( + const int64_t* __restrict__ positions, // [num_tokens] + scalar_t* __restrict__ query, // [num_tokens, num_heads, head_size] + scalar_t* __restrict__ key, // [num_tokens, num_kv_heads, head_size] + const scalar_t* __restrict__ cos_sin_cache, // [max_position, 2, rot_dim // 2] + const int rot_dim, + const int stride, + const int num_heads, + const int num_kv_heads, + const int head_size) { + // Each thread block is responsible for one token. + const int token_idx = blockIdx.x; + int64_t pos = positions[token_idx]; + const scalar_t* cache_ptr = cos_sin_cache + pos * rot_dim; + + const int embed_dim = rot_dim / 2; + const int nq = num_heads * embed_dim; + for (int i = threadIdx.x; i < nq; i += blockDim.x) { + const int head_idx = i / embed_dim; + const int token_head = token_idx * stride + head_idx * head_size; + + const int rot_offset = i % embed_dim; + const int x_index = rot_offset; + const int y_index = embed_dim + rot_offset; + + const int out_x = token_idx * stride + head_idx * head_size + x_index; + const int out_y = token_idx * stride + head_idx * head_size + y_index; + + const scalar_t cos = __ldg(cache_ptr + x_index); + const scalar_t sin = __ldg(cache_ptr + y_index); + + const scalar_t q_x = query[token_head + x_index]; + const scalar_t q_y = query[token_head + y_index]; + query[out_x] = q_x * cos - q_y * sin; + query[out_y] = q_y * cos + q_x * sin; + + if (head_idx < num_kv_heads) { + const scalar_t k_x = key[token_head + x_index]; + const scalar_t k_y = key[token_head + y_index]; + key[out_x] = k_x * cos - k_y * sin; + key[out_y] = k_y * cos + k_x * sin; + } + } +} + + +void rotary_embedding_neox( + torch::Tensor& positions, // [num_tokens] + torch::Tensor& query, // [num_tokens, num_heads * head_size] + torch::Tensor& key, // [num_tokens, num_kv_heads * head_size] + int head_size, + torch::Tensor& cos_sin_cache) // [max_position, rot_dim] +{ + int num_tokens = query.size(0); + int rot_dim = cos_sin_cache.size(1); + int num_heads = query.size(1) / head_size; + int num_kv_heads = key.size(1) / head_size; + int stride = query.stride(0); + TORCH_CHECK(stride == key.stride(0)); + + dim3 grid(num_tokens); + dim3 block(std::min(num_heads * rot_dim / 2, 512)); + const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + AT_DISPATCH_FLOATING_TYPES_AND2( + at::ScalarType::Half, + at::ScalarType::BFloat16, + query.scalar_type(), + "rotary_embedding_neox", + [&] { + rotary_embedding_neox_kernel<<>>( + positions.data_ptr(), + query.data_ptr(), + key.data_ptr(), + cos_sin_cache.data_ptr(), + rot_dim, + stride, + num_heads, + num_kv_heads, + head_size); + }); +} \ No newline at end of file diff --git a/colossalai/shardformer/modeling/llama.py b/colossalai/shardformer/modeling/llama.py index f1d2998bbee4..ae3d8e207881 100644 --- a/colossalai/shardformer/modeling/llama.py +++ b/colossalai/shardformer/modeling/llama.py @@ -7,7 +7,7 @@ CausalLMOutputWithPast, SequenceClassifierOutputWithPast, ) -from transformers.models.llama.modeling_llama import LlamaForCausalLM, LlamaForSequenceClassification, LlamaModel +from transformers.models.llama.modeling_llama import LlamaForCausalLM, LlamaForSequenceClassification, LlamaModel, LlamaRMSNorm from transformers.utils import logging from colossalai.pipeline.stage_manager import PipelineStageManager @@ -391,9 +391,17 @@ def llama_for_sequence_classification_forward( def get_llama_flash_attention_forward(): from transformers.models.llama.modeling_llama import LlamaAttention, apply_rotary_pos_emb - from colossalai.kernel.cuda_native import AttnMaskType, ColoAttention + try: + from vllm import pos_encoding_ops + rotary_embedding_neox = pos_encoding_ops.rotary_embedding_neox + HAS_VLLM_KERNERL = True + except: + print("fall back to original rotary_embedding_neox of huggingface") + print("install vllm from https://github.com/vllm-project/vllm to accelerate your inference") + HAS_VLLM_KERNERL = False + def forward( self: LlamaAttention, hidden_states: torch.Tensor, @@ -415,7 +423,12 @@ def forward( kv_seq_len += past_key_value[0].shape[-2] cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) - query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids) + + if HAS_VLLM_KERNERL: + cos_sin_cache = torch.cat((cos, sin), dim=-1) + rotary_embedding_neox(position_ids, query_states, key_states, self.head_dim, cos_sin_cache) + else: + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids) if past_key_value is not None: # reuse k, v, self_attention @@ -450,3 +463,28 @@ def forward( return attn_output, None, past_key_value return forward + + +def get_llama_vllm_rmsnorm_forward() + try: + from vllm import layernorm_ops + rms_norm = layernorm_ops.rms_norm + HAS_VLLM_KERNERL = True + except: + print("please install vllm kernels to install rmsnorm") + print("install vllm from https://github.com/vllm-project/vllm to accelerate your inference") + HAS_VLLM_KERNERL = False + + def _vllm_rmsnorm_forward(self: LlamaRMSNorm, hidden_states: torch.Tensor): + x = hidden_states + out = torch.empty_like(x) + rms_norm( + out, + x, + self.weight.data, + self.variance_epsilon, + ) + + return out + + return _vllm_rmsnorm_forward diff --git a/op_builder/__init__.py b/op_builder/__init__.py index 5ae7223b8c69..85d26a258d23 100644 --- a/op_builder/__init__.py +++ b/op_builder/__init__.py @@ -5,6 +5,8 @@ from .multi_head_attn import MultiHeadAttnBuilder from .scaled_masked_softmax import ScaledMaskedSoftmaxBuilder from .scaled_upper_triangle_masked_softmax import ScaledUpperTrainglemaskedSoftmaxBuilder +from .rmsnorm import RMSNORMBuilder +from .rotary_embedding import ROTARYEMBEDDINGBuilder ALL_OPS = { 'cpu_adam': CPUAdamBuilder, @@ -14,10 +16,14 @@ 'scaled_masked_softmax': ScaledMaskedSoftmaxBuilder, 'scaled_upper_triangle_masked_softmax': ScaledUpperTrainglemaskedSoftmaxBuilder, 'layernorm': LayerNormBuilder, + 'rmsnorm': RMSNORMBuilder, + 'rotary_embedding': ROTARYEMBEDDINGBuilder, } __all__ = [ 'ALL_OPS', 'CPUAdamBuilder', 'FusedOptimBuilder', 'MultiHeadAttnBuilder', 'ScaledMaskedSoftmaxBuilder', 'ScaledUpperTrainglemaskedSoftmaxBuilder', 'MOEBuilder', 'MultiTensorSGDBuilder', 'MultiTensorAdamBuilder', - 'MultiTensorLambBuilder', 'MultiTensorScaleBuilder', 'MultiTensorL2NormBuilder' + 'MultiTensorLambBuilder', 'MultiTensorScaleBuilder', 'MultiTensorL2NormBuilder', + 'RMSNORMBuilder', + 'ROTARYEMBEDDINGBuilder', ] diff --git a/op_builder/rmsnorm.py b/op_builder/rmsnorm.py new file mode 100644 index 000000000000..27d1d879d6d6 --- /dev/null +++ b/op_builder/rmsnorm.py @@ -0,0 +1,39 @@ +import os + +from .builder import Builder +from .utils import append_nvcc_threads, get_cuda_cc_flag + +class RMSNORMBuilder(Builder): + + NAME = "rmsnorm" + PREBUILT_IMPORT_PATH = "colossalai._C.rmsnorm" + + def __init__(self): + super().__init__(name=RMSNORMBuilder.NAME, + prebuilt_import_path=RMSNORMBuilder.PREBUILT_IMPORT_PATH) + + + def include_dirs(self): + ret = [self.csrc_abs_path("attention_infer_kernels/rmsnorm"), self.get_cuda_home_include()] + return ret + + def sources_files(self): + ret = [ + self.csrc_abs_path(fname) for fname in [ + 'attention_infer_kernels/rmsnorm/layernorm_kernels.cu', + 'attention_infer_kernels/rmsnorm/layernorm.cpp' + ] + ] + return ret + + def cxx_flags(self): + return ['-O3'] + self.version_dependent_macros + + def nvcc_flags(self): + extra_cuda_flags = [ + '-std=c++14', '-U__CUDA_NO_HALF_OPERATORS__', '-U__CUDA_NO_HALF_CONVERSIONS__', + '-U__CUDA_NO_HALF2_OPERATORS__', '-DTHRUST_IGNORE_CUB_VERSION_CHECK' + ] + extra_cuda_flags.extend(get_cuda_cc_flag()) + ret = ['-O3', '--use_fast_math'] + self.version_dependent_macros + extra_cuda_flags + return append_nvcc_threads(ret) diff --git a/op_builder/rotary_embedding.py b/op_builder/rotary_embedding.py new file mode 100644 index 000000000000..3633becbc054 --- /dev/null +++ b/op_builder/rotary_embedding.py @@ -0,0 +1,68 @@ +import os +from packaging.version import parse, Version +from setuptools import setup, find_packages +import subprocess + +from .builder import Builder +from .utils import append_nvcc_threads, get_cuda_cc_flag + +from torch.utils.cpp_extension import BuildExtension, CUDAExtension, CUDA_HOME + + +def add_cc_flags(): + def get_cuda_bare_metal_version(cuda_dir): + raw_output = subprocess.check_output([cuda_dir + "/bin/nvcc", "-V"], universal_newlines=True) + output = raw_output.split() + release_idx = output.index("release") + 1 + bare_metal_version = parse(output[release_idx].split(",")[0]) + + return raw_output, bare_metal_version + + # Check, if CUDA11 is installed for compute capability 8.0 + cc_flag = [] + _, bare_metal_version = get_cuda_bare_metal_version(CUDA_HOME) + print(bare_metal_version) + if bare_metal_version < Version("11.0"): + raise RuntimeError("FlashAttention is only supported on CUDA 11 and above") + + cc_flag.append("-gencode") + cc_flag.append("arch=compute_80,code=sm_80") + if bare_metal_version >= Version("11.8"): + cc_flag.append("-gencode") + cc_flag.append("arch=compute_90,code=sm_90") + + return cc_flag + +class ROTARYEMBEDDINGBuilder(Builder): + + NAME = "rotary_embedding" + PREBUILT_IMPORT_PATH = "colossalai._C.rotary_embedding" + + def __init__(self): + super().__init__(name=ROTARYEMBEDDINGBuilder.NAME, + prebuilt_import_path=ROTARYEMBEDDINGBuilder.PREBUILT_IMPORT_PATH) + + + def include_dirs(self): + ret = [self.csrc_abs_path("attention_infer_kernels/rotary_embedding"), self.get_cuda_home_include()] + return ret + + def sources_files(self): + ret = [ + self.csrc_abs_path(fname) for fname in [ + 'attention_infer_kernels/rotary_embedding/pos_encoding_kernels.cu', + 'attention_infer_kernels/rotary_embedding/pos_encoding.cpp' + ] + ] + return ret + + def cxx_flags(self): + return ['-O3'] + self.version_dependent_macros + + def nvcc_flags(self): + extra_cuda_flags = [ + '-std=c++14', '-U__CUDA_NO_HALF_OPERATORS__', '-U__CUDA_NO_HALF_CONVERSIONS__', + '-U__CUDA_NO_HALF2_OPERATORS__', '-DTHRUST_IGNORE_CUB_VERSION_CHECK' + ] + ret = ['-O3', '--use_fast_math'] + self.version_dependent_macros + extra_cuda_flags + return append_nvcc_threads(ret) + add_cc_flags() diff --git a/setup.py b/setup.py index 5d8f831218d9..4e530909fea0 100644 --- a/setup.py +++ b/setup.py @@ -138,6 +138,7 @@ def get_version() -> str: op_name_list = ', '.join(op_names) print(f"[extension] loaded builders for {op_name_list}") + # always put not nightly branch as the if branch # otherwise github will treat colossalai-nightly as the project name # and it will mess up with the dependency graph insights diff --git a/tests/test_kernels/cuda/test_rmsnorm.py b/tests/test_kernels/cuda/test_rmsnorm.py new file mode 100644 index 000000000000..a09f96b56816 --- /dev/null +++ b/tests/test_kernels/cuda/test_rmsnorm.py @@ -0,0 +1,60 @@ +#!/usr/bin/env python +# -*- encoding: utf-8 -*- +import os +import pytest +import numpy as np +from packaging import version + +import torch +from torch import nn +from torch.nn import functional as F +try: + from colossalai.kernel.op_builder import RMSNORMBuilder + rmsnorm = RMSNORMBuilder().load() + rms_norm = rmsnorm.rms_norm + HAS_INFER_CUDA = True +except: + HAS_INFER_CUDA = False + print("please install your cuda ") + + +class LlamaRMSNorm(nn.Module): + def __init__(self, hidden_size, eps=1e-6): + """ + LlamaRMSNorm is equivalent to T5LayerNorm + """ + super().__init__() + self.weight = nn.Parameter(torch.ones(hidden_size)) + self.variance_epsilon = eps + + def forward(self, hidden_states): + input_dtype = hidden_states.dtype + hidden_states = hidden_states.to(torch.float32) + variance = hidden_states.pow(2).mean(-1, keepdim=True) + hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon) + return self.weight * hidden_states.to(input_dtype) + +def cuda_rmsnorm_forward(hidden_states, weight, variance_epsilon): + x = hidden_states + out = torch.empty_like(x) + rms_norm( + out, + x, + weight, + variance_epsilon, + ) + return out + +@pytest.mark.skipif(not HAS_INFER_CUDA, reason="You need to install llama supported cuda kernels to run this test") +def test_rmsnorm(): + data = torch.randn((1024, 64), dtype=torch.float16, device="cuda") + hg_rms = LlamaRMSNorm(64) + hg_rms = hg_rms.half().cuda() + out_torch = hg_rms(data) + out_cuda = cuda_rmsnorm_forward(data, hg_rms.weight.data, hg_rms.variance_epsilon) + + check = torch.allclose(out_torch.cpu(), out_cuda.cpu(), rtol=1e-3, atol=1e-5) + assert check is True, "cuda rmsnorm forward is not matched with torch rmsnorm forward" + +if __name__ == "__main__": + test_rmsnorm() \ No newline at end of file diff --git a/tests/test_kernels/cuda/test_rotary_embedding.py b/tests/test_kernels/cuda/test_rotary_embedding.py new file mode 100644 index 000000000000..e48d252d0add --- /dev/null +++ b/tests/test_kernels/cuda/test_rotary_embedding.py @@ -0,0 +1,156 @@ +#!/usr/bin/env python +# -*- encoding: utf-8 -*- +import pytest +from typing import Tuple + +import torch +import torch.nn as nn +import torch.nn.functional as F +from transformers.models.llama.modeling_llama import apply_rotary_pos_emb, rotate_half + +try: + from colossalai.kernel.op_builder import ROTARYEMBEDDINGBuilder + rotary_embedding = ROTARYEMBEDDINGBuilder().load() + rotary_embedding_neox = rotary_embedding.rotary_embedding_neox + HAS_INFER_CUDA = True +except: + HAS_INFER_CUDA = False + print("the cuda infer kernels for llama attention is not installed") + + +def rotate_half(x: torch.Tensor) -> torch.Tensor: + x1 = x[..., :x.shape[-1] // 2] + x2 = x[..., x.shape[-1] // 2:] + return torch.cat((-x2, x1), dim=-1) + + +def apply_rotary_pos_emb( + q: torch.Tensor, + k: torch.Tensor, + cos: torch.Tensor, + sin: torch.Tensor, +) -> Tuple[torch.Tensor, torch.Tensor]: + q_embed = (q * cos) + (rotate_half(q) * sin) + k_embed = (k * cos) + (rotate_half(k) * sin) + return q_embed, k_embed + + +class RefRotaryEmbeddingNeox(nn.Module): + """Reference implementation of the GPT-NeoX style rotary embedding.""" + + def __init__( + self, + dim: int, + max_position_embeddings: int = 2048, + base: int = 10000, + ) -> None: + super().__init__() + self.rotary_dim = dim + self.max_position_embeddings = max_position_embeddings + + # Create cos and sin embeddings. + inv_freq = 1.0 / (base**(torch.arange(0, dim, 2) / dim)) + t = torch.arange(max_position_embeddings).float() + freqs = torch.einsum("i,j->ij", t, inv_freq.float()) + emb = torch.cat((freqs, freqs), dim=-1) + cos = emb.cos().to(dtype=inv_freq.dtype) + sin = emb.sin().to(dtype=inv_freq.dtype) + self.register_buffer("cos_cached", cos, persistent=False) + self.register_buffer("sin_cached", sin, persistent=False) + + def forward( + self, + positions: torch.Tensor, # [num_tokens] + query: torch.Tensor, # [num_tokens, num_heads, head_size] + key: torch.Tensor, # [num_tokens, num_heads, head_size] + ) -> Tuple[torch.Tensor, torch.Tensor]: + + query_rot = query[..., :self.rotary_dim] + query_pass = query[..., self.rotary_dim:] + key_rot = key[..., :self.rotary_dim] + key_pass = key[..., self.rotary_dim:] + + query_rot = query_rot.transpose(0, 1) + key_rot = key_rot.transpose(0, 1) + cos = F.embedding(positions, self.cos_cached) + sin = F.embedding(positions, self.sin_cached) + query_rot, key_rot = apply_rotary_pos_emb(query_rot, key_rot, cos, sin) + query_rot = query_rot.transpose(0, 1).contiguous() + key_rot = key_rot.transpose(0, 1).contiguous() + + query = torch.cat((query_rot, query_pass), dim=-1) + key = torch.cat((key_rot, key_pass), dim=-1) + + # Output query/key shape: [num_tokens, num_tokens, head_size] + return query, key + +def run_rotary_embedding_neox( + num_tokens: int, + num_heads: int, + head_size: int, + max_position: int, + rotary_dim: int, + dtype: torch.dtype, + base: int = 10000, +) -> None: + positions = torch.randint(0, max_position, (num_tokens, ), device='cuda') + query = torch.randn(num_tokens, + num_heads * head_size, + dtype=dtype, + device='cuda') + key = torch.randn(num_tokens, + num_heads * head_size, + dtype=dtype, + device='cuda') + + # Create the rotary embedding. + inv_freq = 1.0 / (base**(torch.arange(0, rotary_dim, 2) / rotary_dim)) + t = torch.arange(max_position).float() + freqs = torch.einsum('i,j -> ij', t, inv_freq.float()) + cos = freqs.cos() + sin = freqs.sin() + cos_sin_cache = torch.cat((cos, sin), dim=-1) + cos_sin_cache = cos_sin_cache.to(dtype=dtype, device='cuda') + + # Run the kernel. The kernel is in-place, so we need to clone the inputs. + out_query = query.clone() + out_key = key.clone() + rotary_embedding_neox( + positions, + out_query, + out_key, + head_size, + cos_sin_cache, + ) + + # Run the reference implementation. + ref_rotary_embedding = RefRotaryEmbeddingNeox( + dim=rotary_dim, + max_position_embeddings=max_position, + base=base, + ).to(dtype=dtype, device='cuda') + ref_query, ref_key = ref_rotary_embedding( + positions, + query.view(num_tokens, num_heads, head_size), + key.view(num_tokens, num_heads, head_size), + ) + ref_query = ref_query.view(num_tokens, num_heads * head_size) + ref_key = ref_key.view(num_tokens, num_heads * head_size) + + # Compare the results. + assert torch.allclose(out_query, ref_query, atol=1e-3, rtol=1e-5) + assert torch.allclose(out_key, ref_key, atol=1e-3, rtol=1e-5) + +@pytest.mark.skipif(not HAS_INFER_CUDA, reason="You need to install llama supported cuda kernels to run this test") +def test_rotary_embedding(): + run_rotary_embedding_neox( + num_tokens=1024, + num_heads=8, + head_size=64, + max_position=8192, + rotary_dim=64, + dtype=torch.float16, + ) + +if __name__ == "__main__": + test_rotary_embedding() \ No newline at end of file diff --git a/tests/test_kernels/test_self_attention.py b/tests/test_kernels/triton/test_self_attention.py similarity index 100% rename from tests/test_kernels/test_self_attention.py rename to tests/test_kernels/triton/test_self_attention.py diff --git a/tests/test_kernels/test_softmax.py b/tests/test_kernels/triton/test_softmax.py similarity index 100% rename from tests/test_kernels/test_softmax.py rename to tests/test_kernels/triton/test_softmax.py