Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
73 changes: 56 additions & 17 deletions colossalai/inference/tensor_parallel/modeling/llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@
LlamaDecoderLayer,
LlamaModel,
LlamaRMSNorm,
apply_rotary_pos_emb,
)

from colossalai.inference.tensor_parallel.batch_infer_state import BatchInferState
Expand All @@ -29,6 +28,29 @@
)
HAS_VLLM_KERNERL = False

def rotate_half(x):
"""Rotates half the hidden dims of the input."""
x1 = x[..., : x.shape[-1] // 2]
x2 = x[..., x.shape[-1] // 2 :]
return torch.cat((-x2, x1), dim=-1)


def apply_rotary_pos_emb(q, k, cos, sin, position_ids):
# The first two dimensions of cos and sin are always 1, so we can `squeeze` them.
cos = cos.squeeze(1).squeeze(0) # [seq_len, dim]
sin = sin.squeeze(1).squeeze(0) # [seq_len, dim]
cos = cos[position_ids].unsqueeze(1) # [bs, 1, seq_len, dim]
sin = sin[position_ids].unsqueeze(1) # [bs, 1, seq_len, dim]

q_embed = (q * cos) + (rotate_half(q) * sin)
k_embed = (k * cos) + (rotate_half(k) * sin)
return q_embed, k_embed

def _copy_kv_to_mem_cache(layer_id, key_buffer, value_buffer, context_mem_index, mem_manager):
copy_kv_cache_to_dest(key_buffer, context_mem_index, mem_manager.key_buffer[layer_id])
copy_kv_cache_to_dest(value_buffer, context_mem_index, mem_manager.value_buffer[layer_id])
return


class LlamaInferenceForwards:
"""
Expand Down Expand Up @@ -251,8 +273,9 @@ def llama_flash_attn_kvcache_forward(

query_states = self.q_proj(hidden_states).view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
key_states = self.k_proj(hidden_states).view(bsz, q_len, self.num_heads, self.head_dim)
value_states = self.v_proj(hidden_states).view(bsz, q_len, self.num_heads, self.head_dim)
key_states_transposed = key_states.transpose(1, 2)
value_states = self.v_proj(hidden_states).view(bsz, q_len, self.num_heads, self.head_dim)


# NOTE might want to revise
# need some way to record the length of past key values cache
Expand All @@ -261,26 +284,42 @@ def llama_flash_attn_kvcache_forward(
infer_state.cache_manager.past_key_values_length += q_len # seq_len

if HAS_VLLM_KERNERL:
# NOTE: fix rotatry embedding precision problem
cos, sin = infer_state.position_cos, infer_state.position_sin

# value_states_transposed = value_states.transpose(1, 2)

# cos, sin = self.rotary_emb(value_states_transposed,
# seq_len=infer_state.cache_manager.past_key_values_length)

cos_sin_cache = torch.cat((cos, sin), dim=-1)
rotary_embedding_neox(position_ids, query_states, key_states_transposed, self.head_dim, cos_sin_cache)
key_states = key_states_transposed.transpose(1, 2)

key_states = key_states.view(-1, self.num_heads * self.head_dim)
query_states = query_states.transpose(1, 2).reshape(-1, self.num_heads * self.head_dim)
rotary_embedding_neox(position_ids.squeeze(1), query_states, key_states, self.head_dim, cos_sin_cache)


query_states = query_states.reshape(-1, self.num_heads, self.head_dim)
key_states = key_states.reshape(-1, self.num_heads, self.head_dim)
value_states = value_states.reshape(-1, self.num_heads, self.head_dim)

else:
# NOTE: there are some issues for original rotary_embedding_neox of huggingface

value_states_transposed = value_states.transpose(1, 2)
cos, sin = self.rotary_emb(value_states_transposed,
seq_len=infer_state.cache_manager.past_key_values_length)
query_states, key_states = apply_rotary_pos_emb(query_states, key_states_transposed, cos, sin, position_ids)
key_states = key_states_transposed.transpose(1, 2)

def _copy_kv_to_mem_cache(layer_id, key_buffer, value_buffer, context_mem_index, mem_manager):
copy_kv_cache_to_dest(key_buffer, context_mem_index, mem_manager.key_buffer[layer_id])
copy_kv_cache_to_dest(value_buffer, context_mem_index, mem_manager.value_buffer[layer_id])
return

key_states = key_states.reshape(-1, self.num_heads, self.head_dim)
value_states = value_states.reshape(-1, self.num_heads, self.head_dim)
query_states = query_states.transpose(1, 2).reshape(-1, self.num_heads, self.head_dim)
seq_len=infer_state.cache_manager.past_key_values_length)

rotary_positions_ids = position_ids
idx = position_ids.shape[0] - 1
if idx >= 1:
rotary_positions_ids = [[idx]]
query_states, key_states = apply_rotary_pos_emb(query_states, key_states_transposed, cos, sin, rotary_positions_ids)
query_states = query_states.transpose(1, 2).reshape(-1, self.num_heads, self.head_dim)
key_states = key_states.transpose(1, 2).reshape(-1, self.num_heads, self.head_dim)
value_states = value_states.reshape(-1, self.num_heads, self.head_dim)

if infer_state.is_context_stage:
# first token generation
Expand Down Expand Up @@ -330,7 +369,6 @@ def _copy_kv_to_mem_cache(layer_id, key_buffer, value_buffer, context_mem_index,
def get_llama_vllm_rmsnorm_forward():

if HAS_VLLM_KERNERL:

def _vllm_rmsnorm_forward(self: LlamaRMSNorm, hidden_states: torch.Tensor):
x = hidden_states
out = torch.empty_like(x)
Expand All @@ -346,3 +384,4 @@ def _vllm_rmsnorm_forward(self: LlamaRMSNorm, hidden_states: torch.Tensor):
return _vllm_rmsnorm_forward
else:
return None

53 changes: 37 additions & 16 deletions colossalai/inference/tensor_parallel/policies/llama.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,33 @@
from functools import partial
import torch
from transformers.models.llama.modeling_llama import (
LlamaAttention,
LlamaDecoderLayer,
LlamaModel,
LlamaRMSNorm
)

from transformers.models.llama.modeling_llama import LlamaAttention, LlamaDecoderLayer, LlamaModel, LlamaRMSNorm

# import colossalai
from colossalai.shardformer.policies.llama import LlamaForCausalLMPolicy

from ..modeling.llama import LlamaInferenceForwards, get_llama_vllm_rmsnorm_forward

try:
from colossalai.kernel.triton.rms_norm import rmsnorm_forward
HAS_TRITON_RMSNORM = True
except:
print("you should install triton from https://github.com/openai/triton")
HAS_TRITON_RMSNORM = False


def get_triton_rmsnorm_forward():
if HAS_TRITON_RMSNORM:
def _triton_rmsnorm_forward(self: LlamaRMSNorm, hidden_states: torch.Tensor):
return rmsnorm_forward(hidden_states, self.weight.data, self.variance_epsilon)

return _triton_rmsnorm_forward
else:
return None

class LlamaModelInferPolicy(LlamaForCausalLMPolicy):

def __init__(self) -> None:
Expand All @@ -16,12 +37,6 @@ def module_policy(self):
policy = super().module_policy()
self.shard_config._infer()

# example for replace layer or decoder
# if self.shard_config.enable_flash_attention:
# policy[LlamaAttention] = ModulePolicyDescription(method_replacement={
# 'forward': get_llama_flash_attention_forward(),
# })

infer_forward = LlamaInferenceForwards.llama_model_forward
method_replacement = {'forward': partial(infer_forward)}
self.append_or_create_method_replacement(description=method_replacement, policy=policy, target_key=LlamaModel)
Expand All @@ -38,12 +53,18 @@ def module_policy(self):
policy=policy,
target_key=LlamaAttention)

# NOTE: adding rms_norm caused precision issue, fix @tiandiao123
# infer_forward = get_llama_vllm_rmsnorm_forward()
# if infer_forward is not None:
# method_replacement = {'forward': partial(infer_forward)}
# self.append_or_create_method_replacement(description=method_replacement,
# policy=policy,
# target_key=LlamaRMSNorm)
infer_forward = None
if HAS_TRITON_RMSNORM:
infer_forward = get_triton_rmsnorm_forward()
else:
# NOTE: adding rms_norm from cuda kernels caused precision issue, fix @tiandiao123
infer_forward = get_llama_vllm_rmsnorm_forward()

if infer_forward is not None:
method_replacement = {'forward': partial(infer_forward)}
self.append_or_create_method_replacement(description=method_replacement,
policy=policy,
target_key=LlamaRMSNorm)

return policy

1 change: 1 addition & 0 deletions colossalai/kernel/triton/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from .context_attention import llama_context_attn_fwd, bloom_context_attn_fwd
from .softmax import softmax
from .copy_kv_cache_dest import copy_kv_cache_to_dest
from .rms_norm import rmsnorm_forward
72 changes: 72 additions & 0 deletions colossalai/kernel/triton/rms_norm.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
import torch

try:
import triton
import triton.language as tl
HAS_TRITON = True
except ImportError:
HAS_TRITON = False
print("please install triton from https://github.com/openai/triton")


if HAS_TRITON:
'''
this kernel function is modified from
https://github.com/ModelTC/lightllm/blob/main/lightllm/models/llama/triton_kernel/rmsnorm.py
'''
@triton.jit
def _rms_norm_fwd_fused(
X, # pointer to the input
Y, # pointer to the output
W, # pointer to the weights
stride, # how much to increase the pointer when moving by 1 row
N, # number of columns in X
eps, # epsilon to avoid division by zero
BLOCK_SIZE: tl.constexpr,
):
# Map the program id to the row of X and Y it should compute.
row = tl.program_id(0)
Y += row * stride
X += row * stride
# Compute variance
_var = tl.zeros([BLOCK_SIZE], dtype=tl.float32)
for off in range(0, N, BLOCK_SIZE):
cols = off + tl.arange(0, BLOCK_SIZE)
x = tl.load(X + cols, mask=cols < N, other=0.).to(tl.float32)
_var += x * x
var = tl.sum(_var, axis=0) / N
rstd = 1 / tl.sqrt(var + eps)
# Normalize and apply linear transformation
for off in range(0, N, BLOCK_SIZE):
cols = off + tl.arange(0, BLOCK_SIZE)
mask = cols < N
w = tl.load(W + cols, mask=mask).to(tl.float32)
x = tl.load(X + cols, mask=mask, other=0.).to(tl.float32)
x_hat = x * rstd
y = x_hat * w
# Write output
tl.store(Y + cols, y.to(tl.float16), mask=mask)


def rmsnorm_forward(x, weight, eps):
# allocate output
y = torch.empty_like(x)
# reshape input data into 2D tensor
x_arg = x.view(-1, x.shape[-1])
M, N = x_arg.shape
# Less than 64KB per feature: enqueue fused kernel
MAX_FUSED_SIZE = 65536 // x.element_size()
BLOCK_SIZE = min(MAX_FUSED_SIZE, triton.next_power_of_2(N))
# print("BLOCK_SIZE:", BLOCK_SIZE)
if N > BLOCK_SIZE:
raise RuntimeError("This layer norm doesn't support feature dim >= 64KB.")
# heuristics for number of warps
num_warps = min(max(BLOCK_SIZE // 256, 1), 8)
# print(BLOCK_SIZE, num_warps, "block_size, numwarps")
BLOCK_SIZE = 128 * 2 * 2 * 2 * 2 * 2 * 2 * 2
num_warps = 8
# enqueue kernel
_rms_norm_fwd_fused[(M,)](x_arg, y, weight,
x_arg.stride(0), N, eps,
BLOCK_SIZE=BLOCK_SIZE, num_warps=num_warps)
return y
10 changes: 8 additions & 2 deletions tests/test_infer/test_bloom_infer.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
import os
import pytest
from packaging import version
import torch
import torch.distributed as dist
from transformers import AutoModelForCausalLM, AutoTokenizer, BloomForCausalLM
Expand All @@ -14,10 +16,14 @@
MAX_INPUT_LEN = 16
MAX_OUTPUT_LEN = 32

CUDA_SUPPORT = version.parse(torch.version.cuda) > version.parse('11.5')

def run():

def run():
model_path = "/data3/data/model_eval_for_commerical_use/phoenix-inst-chat-7b"
if os.path.isdir(model_path) is False:
return

tokenizer = AutoTokenizer.from_pretrained(model_path)
tokenizer.pad_token = tokenizer.eos_token

Expand Down Expand Up @@ -48,7 +54,7 @@ def check_engine(rank, world_size, port):
colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
run()


@pytest.mark.skipif(not CUDA_SUPPORT, reason="kv-cache manager engine requires cuda version to be higher than 11.5")
@pytest.mark.dist
@rerun_if_address_is_in_use()
@clear_cache_before_run()
Expand Down
8 changes: 6 additions & 2 deletions tests/test_infer/test_infer_engine.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import pytest
from itertools import accumulate
from packaging import version

import pytest
import torch
import torch.nn as nn
from transformers import BloomConfig, BloomForCausalLM, LlamaConfig, LlamaForCausalLM
Expand All @@ -18,7 +19,9 @@
MAX_INPUT_LEN = 16
MAX_OUTPUT_LEN = 8

CUDA_SUPPORT = version.parse(torch.version.cuda) > version.parse('11.5')

@pytest.mark.skipif(not CUDA_SUPPORT, reason="kv-cache manager engine requires cuda version to be higher than 11.5")
def test_prepare_data():
# dummy module used for testing
class DummyModule(nn.Module):
Expand Down Expand Up @@ -68,7 +71,7 @@ def __init__(self):
assert torch.equal(batch_state_out1.start_loc.to(start_loc.device), start_loc)
assert torch.equal(batch_state_out2.start_loc.to(start_loc.device), start_loc)


@pytest.mark.skipif(not CUDA_SUPPORT, reason="kv-cache manager engine requires cuda version to be higher than 11.5")
def test_orig_generate():
input_ids = torch.randint(low=10, high=1000, size=(MAX_BATCH_SIZE, MAX_INPUT_LEN))

Expand Down Expand Up @@ -116,6 +119,7 @@ def check_engine(rank, world_size, port):
run()


@pytest.mark.skipif(not CUDA_SUPPORT, reason="kv-cache manager engine requires cuda version to be higher than 11.5")
@pytest.mark.dist
@rerun_if_address_is_in_use()
@clear_cache_before_run()
Expand Down
5 changes: 3 additions & 2 deletions tests/test_infer/test_kvcache_manager.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import os

from packaging import version
import pytest
import torch

Expand All @@ -14,6 +14,7 @@
HEAD_NUM = 32
HEAD_DIM = 128

CUDA_SUPPORT = version.parse(torch.version.cuda) > version.parse('11.5')

def create_cache_manager(rank, world_size, port, batch_size, input_len, output_len, layer_num, head_num, head_dim):
os.environ['RANK'] = str(rank)
Expand Down Expand Up @@ -42,7 +43,7 @@ def create_cache_manager(rank, world_size, port, batch_size, input_len, output_l
kvcache_manager.alloc_contiguous(batch_size)
assert torch.all(kvcache_manager.mem_state[:total_token_prefill + batch_size] == False)


@pytest.mark.skipif(not CUDA_SUPPORT, reason="kv-cache manager engine requires cuda version to be higher than 11.5")
@pytest.mark.dist
@rerun_if_address_is_in_use()
def test_cache_manager_dist():
Expand Down
Loading