Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
32 commits
Select commit Hold shift + click to select a range
064a8f2
adding kernels
tiandiao123 Aug 10, 2023
68a4735
update cutlass
tiandiao123 Aug 10, 2023
c70dcbe
update
tiandiao123 Aug 10, 2023
a8938ba
adding kernels
tiandiao123 Aug 10, 2023
b9b4396
delete useless files
tiandiao123 Aug 10, 2023
7a236e3
clean codes
tiandiao123 Aug 10, 2023
54ac1e1
added cuda test
tiandiao123 Aug 10, 2023
e48a143
added tests
tiandiao123 Aug 14, 2023
49bb1e3
added tests
tiandiao123 Aug 14, 2023
0b3cffa
add comments
tiandiao123 Aug 14, 2023
c171f43
refactoring
tiandiao123 Aug 15, 2023
0e0594e
fix tests
tiandiao123 Aug 15, 2023
9393dd1
change flash-attention as thrid-party directly
tiandiao123 Aug 15, 2023
c524925
remove cutlass
tiandiao123 Aug 15, 2023
8f51549
delete cutlass
tiandiao123 Aug 15, 2023
9efe2e9
cleaned codes
tiandiao123 Aug 15, 2023
caefcba
add
tiandiao123 Aug 15, 2023
ad3fa46
add norm
tiandiao123 Aug 15, 2023
7ba2d61
delete useless files
tiandiao123 Aug 15, 2023
03e4149
chnage setup
tiandiao123 Aug 16, 2023
746c1c9
change intp build class
tiandiao123 Aug 16, 2023
b1a2c19
added info
tiandiao123 Aug 16, 2023
d09a81a
added lisense
tiandiao123 Aug 16, 2023
0b0b28c
added lisense
tiandiao123 Aug 16, 2023
8dd322f
added lisense
tiandiao123 Aug 16, 2023
1d2294e
change name
tiandiao123 Aug 16, 2023
1b53af9
change flash version
tiandiao123 Aug 16, 2023
9313374
change req
tiandiao123 Aug 16, 2023
90859c6
added it
tiandiao123 Aug 16, 2023
559f2d5
merged
tiandiao123 Aug 17, 2023
b921fa2
modify req
tiandiao123 Aug 17, 2023
e4341de
added _vllm_rmsnorm_forward and fix llama forward
tiandiao123 Aug 17, 2023
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 18 additions & 0 deletions LICENSE
Original file line number Diff line number Diff line change
Expand Up @@ -396,3 +396,21 @@ Copyright 2021- HPC-AI Technology Inc. All rights reserved.
CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
POSSIBILITY OF SUCH DAMAGE.

---------------- LICENSE FOR VLLM TEAM ----------------

from VLLM TEAM:

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

https://github.com/vllm-project/vllm/blob/main/LICENSE

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.


Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
/* Copyright 2021 The Colossal-AI Team
Copyright (c) 2023, The vLLM team.
This file is adapted from vllm TEAM: https://github.com/vllm-project/vllm/blob/main/csrc/layernorm.cpp
*/
#include <torch/extension.h>

void rms_norm(
torch::Tensor& out,
torch::Tensor& input,
torch::Tensor& weight,
float epsilon);

PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def(
"rms_norm",
&rms_norm,
"Apply Root Mean Square (RMS) Normalization to the input tensor.");
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
/* Copyright 2021 The Colossal-AI Team
Copyright (c) 2023, The vLLM team.
This file is adapted from vllm TEAM: https://github.com/vllm-project/vllm/blob/main/csrc/layernorm_kernels.cu
*/
#include <torch/extension.h>
#include <ATen/cuda/CUDAContext.h>

#include "reduction_utils.cuh"

template<typename scalar_t>
__global__ void rms_norm_kernel(
scalar_t* __restrict__ out, // [num_tokens, hidden_size]
const scalar_t* __restrict__ input, // [num_tokens, hidden_size]
const scalar_t* __restrict__ weight, // [hidden_size]
const float epsilon,
const int num_tokens,
const int hidden_size) {
__shared__ float s_variance;
float variance = 0.0f;

for (int idx = threadIdx.x; idx < hidden_size; idx += blockDim.x) {
const float x = (float) input[blockIdx.x * hidden_size + idx];
variance += x * x;
}
variance = blockReduceSum<float>(variance);
if (threadIdx.x == 0) {
s_variance = rsqrtf(variance / hidden_size + epsilon);
}
__syncthreads();

for (int idx = threadIdx.x; idx < hidden_size; idx += blockDim.x) {
float x = (float) input[blockIdx.x * hidden_size + idx];
out[blockIdx.x * hidden_size + idx] = ((scalar_t) (x * s_variance)) * weight[idx];
}
}


void rms_norm(
torch::Tensor& out, // [num_tokens, hidden_size]
torch::Tensor& input, // [num_tokens, hidden_size]
torch::Tensor& weight, // [hidden_size]
float epsilon) {
int num_tokens = input.size(0);
int hidden_size = input.size(1);

dim3 grid(num_tokens);
dim3 block(std::min(hidden_size, 1024));
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
AT_DISPATCH_FLOATING_TYPES_AND2(
at::ScalarType::Half,
at::ScalarType::BFloat16,
input.scalar_type(),
"rms_norm_kernel",
[&] {
rms_norm_kernel<scalar_t><<<grid, block, 0, stream>>>(
out.data_ptr<scalar_t>(),
input.data_ptr<scalar_t>(),
weight.data_ptr<scalar_t>(),
epsilon,
num_tokens,
hidden_size);
});
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
/*
* Adapted from https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/kernels/reduce_kernel_utils.cuh
* Copyright (c) 2023, The vLLM team.
* Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#pragma once


template<typename T>
__inline__ __device__ T warpReduceSum(T val) {
#pragma unroll
for (int mask = 16; mask > 0; mask >>= 1)
val += __shfl_xor_sync(0xffffffff, val, mask, 32);
return val;
}

/* Calculate the sum of all elements in a block */
template<typename T>
__inline__ __device__ T blockReduceSum(T val) {
static __shared__ T shared[32];
int lane = threadIdx.x & 0x1f;
int wid = threadIdx.x >> 5;

val = warpReduceSum<T>(val);

if (lane == 0)
shared[wid] = val;

__syncthreads();

// Modify from blockDim.x << 5 to blockDim.x / 32. to prevent
// blockDim.x is not divided by 32
val = (threadIdx.x < (blockDim.x / 32.f)) ? shared[lane] : (T)(0.0f);
val = warpReduceSum<T>(val);
return val;
}

Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
/* Copyright 2021 The Colossal-AI Team
Copyright (c) 2023, The vLLM team.
This file is adapted from vllm TEAM: https://github.com/vllm-project/vllm/blob/main/csrc/pos_encoding.cpp
*/
#include <torch/extension.h>

void rotary_embedding_neox(
torch::Tensor& positions,
torch::Tensor& query,
torch::Tensor& key,
int head_size,
torch::Tensor& cos_sin_cache);

PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def(
"rotary_embedding_neox",
&rotary_embedding_neox,
"Apply GPT-NeoX style rotary embedding to query and key");
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,90 @@
/* Copyright 2021 The Colossal-AI Team
Copyright (c) 2023, The vLLM team.
This file is adapted from vllm TEAM: https://github.com/vllm-project/vllm/blob/main/csrc/pos_encoding_kernels.cu
*/
#include <torch/extension.h>
#include <ATen/cuda/CUDAContext.h>


template<typename scalar_t>
__global__ void rotary_embedding_neox_kernel(
const int64_t* __restrict__ positions, // [num_tokens]
scalar_t* __restrict__ query, // [num_tokens, num_heads, head_size]
scalar_t* __restrict__ key, // [num_tokens, num_kv_heads, head_size]
const scalar_t* __restrict__ cos_sin_cache, // [max_position, 2, rot_dim // 2]
const int rot_dim,
const int stride,
const int num_heads,
const int num_kv_heads,
const int head_size) {
// Each thread block is responsible for one token.
const int token_idx = blockIdx.x;
int64_t pos = positions[token_idx];
const scalar_t* cache_ptr = cos_sin_cache + pos * rot_dim;

const int embed_dim = rot_dim / 2;
const int nq = num_heads * embed_dim;
for (int i = threadIdx.x; i < nq; i += blockDim.x) {
const int head_idx = i / embed_dim;
const int token_head = token_idx * stride + head_idx * head_size;

const int rot_offset = i % embed_dim;
const int x_index = rot_offset;
const int y_index = embed_dim + rot_offset;

const int out_x = token_idx * stride + head_idx * head_size + x_index;
const int out_y = token_idx * stride + head_idx * head_size + y_index;

const scalar_t cos = __ldg(cache_ptr + x_index);
const scalar_t sin = __ldg(cache_ptr + y_index);

const scalar_t q_x = query[token_head + x_index];
const scalar_t q_y = query[token_head + y_index];
query[out_x] = q_x * cos - q_y * sin;
query[out_y] = q_y * cos + q_x * sin;

if (head_idx < num_kv_heads) {
const scalar_t k_x = key[token_head + x_index];
const scalar_t k_y = key[token_head + y_index];
key[out_x] = k_x * cos - k_y * sin;
key[out_y] = k_y * cos + k_x * sin;
}
}
}


void rotary_embedding_neox(
torch::Tensor& positions, // [num_tokens]
torch::Tensor& query, // [num_tokens, num_heads * head_size]
torch::Tensor& key, // [num_tokens, num_kv_heads * head_size]
int head_size,
torch::Tensor& cos_sin_cache) // [max_position, rot_dim]
{
int num_tokens = query.size(0);
int rot_dim = cos_sin_cache.size(1);
int num_heads = query.size(1) / head_size;
int num_kv_heads = key.size(1) / head_size;
int stride = query.stride(0);
TORCH_CHECK(stride == key.stride(0));

dim3 grid(num_tokens);
dim3 block(std::min(num_heads * rot_dim / 2, 512));
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
AT_DISPATCH_FLOATING_TYPES_AND2(
at::ScalarType::Half,
at::ScalarType::BFloat16,
query.scalar_type(),
"rotary_embedding_neox",
[&] {
rotary_embedding_neox_kernel<scalar_t><<<grid, block, 0, stream>>>(
positions.data_ptr<int64_t>(),
query.data_ptr<scalar_t>(),
key.data_ptr<scalar_t>(),
cos_sin_cache.data_ptr<scalar_t>(),
rot_dim,
stride,
num_heads,
num_kv_heads,
head_size);
});
}
44 changes: 41 additions & 3 deletions colossalai/shardformer/modeling/llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
CausalLMOutputWithPast,
SequenceClassifierOutputWithPast,
)
from transformers.models.llama.modeling_llama import LlamaForCausalLM, LlamaForSequenceClassification, LlamaModel
from transformers.models.llama.modeling_llama import LlamaForCausalLM, LlamaForSequenceClassification, LlamaModel, LlamaRMSNorm
from transformers.utils import logging

from colossalai.pipeline.stage_manager import PipelineStageManager
Expand Down Expand Up @@ -391,9 +391,17 @@ def llama_for_sequence_classification_forward(
def get_llama_flash_attention_forward():

from transformers.models.llama.modeling_llama import LlamaAttention, apply_rotary_pos_emb

from colossalai.kernel.cuda_native import AttnMaskType, ColoAttention

try:
from vllm import pos_encoding_ops
rotary_embedding_neox = pos_encoding_ops.rotary_embedding_neox
HAS_VLLM_KERNERL = True
except:
print("fall back to original rotary_embedding_neox of huggingface")
print("install vllm from https://github.com/vllm-project/vllm to accelerate your inference")
HAS_VLLM_KERNERL = False

def forward(
self: LlamaAttention,
hidden_states: torch.Tensor,
Expand All @@ -415,7 +423,12 @@ def forward(
kv_seq_len += past_key_value[0].shape[-2]

cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)

if HAS_VLLM_KERNERL:
cos_sin_cache = torch.cat((cos, sin), dim=-1)
rotary_embedding_neox(position_ids, query_states, key_states, self.head_dim, cos_sin_cache)
else:
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)

if past_key_value is not None:
# reuse k, v, self_attention
Expand Down Expand Up @@ -450,3 +463,28 @@ def forward(
return attn_output, None, past_key_value

return forward


def get_llama_vllm_rmsnorm_forward()
try:
from vllm import layernorm_ops
rms_norm = layernorm_ops.rms_norm
HAS_VLLM_KERNERL = True
except:
print("please install vllm kernels to install rmsnorm")
print("install vllm from https://github.com/vllm-project/vllm to accelerate your inference")
HAS_VLLM_KERNERL = False

def _vllm_rmsnorm_forward(self: LlamaRMSNorm, hidden_states: torch.Tensor):
x = hidden_states
out = torch.empty_like(x)
rms_norm(
out,
x,
self.weight.data,
self.variance_epsilon,
)

return out

return _vllm_rmsnorm_forward
8 changes: 7 additions & 1 deletion op_builder/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@
from .multi_head_attn import MultiHeadAttnBuilder
from .scaled_masked_softmax import ScaledMaskedSoftmaxBuilder
from .scaled_upper_triangle_masked_softmax import ScaledUpperTrainglemaskedSoftmaxBuilder
from .rmsnorm import RMSNORMBuilder
from .rotary_embedding import ROTARYEMBEDDINGBuilder

ALL_OPS = {
'cpu_adam': CPUAdamBuilder,
Expand All @@ -14,10 +16,14 @@
'scaled_masked_softmax': ScaledMaskedSoftmaxBuilder,
'scaled_upper_triangle_masked_softmax': ScaledUpperTrainglemaskedSoftmaxBuilder,
'layernorm': LayerNormBuilder,
'rmsnorm': RMSNORMBuilder,
'rotary_embedding': ROTARYEMBEDDINGBuilder,
}

__all__ = [
'ALL_OPS', 'CPUAdamBuilder', 'FusedOptimBuilder', 'MultiHeadAttnBuilder', 'ScaledMaskedSoftmaxBuilder',
'ScaledUpperTrainglemaskedSoftmaxBuilder', 'MOEBuilder', 'MultiTensorSGDBuilder', 'MultiTensorAdamBuilder',
'MultiTensorLambBuilder', 'MultiTensorScaleBuilder', 'MultiTensorL2NormBuilder'
'MultiTensorLambBuilder', 'MultiTensorScaleBuilder', 'MultiTensorL2NormBuilder',
'RMSNORMBuilder',
'ROTARYEMBEDDINGBuilder',
]
Loading