Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
71 changes: 55 additions & 16 deletions colossalai/inference/tensor_parallel/modeling/llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@
LlamaDecoderLayer,
LlamaModel,
LlamaRMSNorm,
apply_rotary_pos_emb,
)

from colossalai.inference.tensor_parallel.batch_infer_state import BatchInferState
Expand All @@ -29,6 +28,29 @@
)
HAS_VLLM_KERNERL = False

def rotate_half(x):
"""Rotates half the hidden dims of the input."""
x1 = x[..., : x.shape[-1] // 2]
x2 = x[..., x.shape[-1] // 2 :]
return torch.cat((-x2, x1), dim=-1)


def apply_rotary_pos_emb(q, k, cos, sin, position_ids):
# The first two dimensions of cos and sin are always 1, so we can `squeeze` them.
cos = cos.squeeze(1).squeeze(0) # [seq_len, dim]
sin = sin.squeeze(1).squeeze(0) # [seq_len, dim]
cos = cos[position_ids].unsqueeze(1) # [bs, 1, seq_len, dim]
sin = sin[position_ids].unsqueeze(1) # [bs, 1, seq_len, dim]

q_embed = (q * cos) + (rotate_half(q) * sin)
k_embed = (k * cos) + (rotate_half(k) * sin)
return q_embed, k_embed

def _copy_kv_to_mem_cache(layer_id, key_buffer, value_buffer, context_mem_index, mem_manager):
copy_kv_cache_to_dest(key_buffer, context_mem_index, mem_manager.key_buffer[layer_id])
copy_kv_cache_to_dest(value_buffer, context_mem_index, mem_manager.value_buffer[layer_id])
return


class LlamaInferenceForwards:
"""
Expand Down Expand Up @@ -251,8 +273,9 @@ def llama_flash_attn_kvcache_forward(

query_states = self.q_proj(hidden_states).view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
key_states = self.k_proj(hidden_states).view(bsz, q_len, self.num_heads, self.head_dim)
value_states = self.v_proj(hidden_states).view(bsz, q_len, self.num_heads, self.head_dim)
key_states_transposed = key_states.transpose(1, 2)
value_states = self.v_proj(hidden_states).view(bsz, q_len, self.num_heads, self.head_dim)


# NOTE might want to revise
# need some way to record the length of past key values cache
Expand All @@ -261,26 +284,42 @@ def llama_flash_attn_kvcache_forward(
infer_state.cache_manager.past_key_values_length += q_len # seq_len

if HAS_VLLM_KERNERL:
# NOTE: fix rotatry embedding precision problem
cos, sin = infer_state.position_cos, infer_state.position_sin

# value_states_transposed = value_states.transpose(1, 2)

# cos, sin = self.rotary_emb(value_states_transposed,
# seq_len=infer_state.cache_manager.past_key_values_length)

cos_sin_cache = torch.cat((cos, sin), dim=-1)
rotary_embedding_neox(position_ids, query_states, key_states_transposed, self.head_dim, cos_sin_cache)
key_states = key_states_transposed.transpose(1, 2)

key_states = key_states.view(-1, self.num_heads * self.head_dim)
query_states = query_states.transpose(1, 2).reshape(-1, self.num_heads * self.head_dim)
rotary_embedding_neox(position_ids.squeeze(1), query_states, key_states, self.head_dim, cos_sin_cache)


query_states = query_states.reshape(-1, self.num_heads, self.head_dim)
key_states = key_states.reshape(-1, self.num_heads, self.head_dim)
value_states = value_states.reshape(-1, self.num_heads, self.head_dim)

else:
# NOTE: there are some issues for original rotary_embedding_neox of huggingface

value_states_transposed = value_states.transpose(1, 2)
cos, sin = self.rotary_emb(value_states_transposed,
seq_len=infer_state.cache_manager.past_key_values_length)
query_states, key_states = apply_rotary_pos_emb(query_states, key_states_transposed, cos, sin, position_ids)
key_states = key_states_transposed.transpose(1, 2)

def _copy_kv_to_mem_cache(layer_id, key_buffer, value_buffer, context_mem_index, mem_manager):
copy_kv_cache_to_dest(key_buffer, context_mem_index, mem_manager.key_buffer[layer_id])
copy_kv_cache_to_dest(value_buffer, context_mem_index, mem_manager.value_buffer[layer_id])
return

key_states = key_states.reshape(-1, self.num_heads, self.head_dim)
value_states = value_states.reshape(-1, self.num_heads, self.head_dim)
query_states = query_states.transpose(1, 2).reshape(-1, self.num_heads, self.head_dim)
seq_len=infer_state.cache_manager.past_key_values_length)

rotary_positions_ids = position_ids
idx = position_ids.shape[0] - 1
if idx >= 1:
rotary_positions_ids = [[idx]]
query_states, key_states = apply_rotary_pos_emb(query_states, key_states_transposed, cos, sin, rotary_positions_ids)
query_states = query_states.transpose(1, 2).reshape(-1, self.num_heads, self.head_dim)
key_states = key_states.transpose(1, 2).reshape(-1, self.num_heads, self.head_dim)
value_states = value_states.reshape(-1, self.num_heads, self.head_dim)

if infer_state.is_context_stage:
# first token generation
Expand Down
6 changes: 0 additions & 6 deletions colossalai/inference/tensor_parallel/policies/llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,12 +16,6 @@ def module_policy(self):
policy = super().module_policy()
self.shard_config._infer()

# example for replace layer or decoder
# if self.shard_config.enable_flash_attention:
# policy[LlamaAttention] = ModulePolicyDescription(method_replacement={
# 'forward': get_llama_flash_attention_forward(),
# })

infer_forward = LlamaInferenceForwards.llama_model_forward
method_replacement = {'forward': partial(infer_forward)}
self.append_or_create_method_replacement(description=method_replacement, policy=policy, target_key=LlamaModel)
Expand Down
10 changes: 8 additions & 2 deletions tests/test_infer/test_bloom_infer.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
import os
import pytest
from packaging import version
import torch
import torch.distributed as dist
from transformers import AutoModelForCausalLM, AutoTokenizer, BloomForCausalLM
Expand All @@ -14,10 +16,14 @@
MAX_INPUT_LEN = 16
MAX_OUTPUT_LEN = 32

CUDA_SUPPORT = version.parse(torch.version.cuda) > version.parse('11.5')

def run():

def run():
model_path = "/data3/data/model_eval_for_commerical_use/phoenix-inst-chat-7b"
if os.path.isdir(model_path) is False:
return

tokenizer = AutoTokenizer.from_pretrained(model_path)
tokenizer.pad_token = tokenizer.eos_token

Expand Down Expand Up @@ -48,7 +54,7 @@ def check_engine(rank, world_size, port):
colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
run()


@pytest.mark.skipif(not CUDA_SUPPORT, reason="kv-cache manager engine requires cuda version to be higher than 11.5")
@pytest.mark.dist
@rerun_if_address_is_in_use()
@clear_cache_before_run()
Expand Down
8 changes: 6 additions & 2 deletions tests/test_infer/test_infer_engine.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import pytest
from itertools import accumulate
from packaging import version

import pytest
import torch
import torch.nn as nn
from transformers import BloomConfig, BloomForCausalLM, LlamaConfig, LlamaForCausalLM
Expand All @@ -18,7 +19,9 @@
MAX_INPUT_LEN = 16
MAX_OUTPUT_LEN = 8

CUDA_SUPPORT = version.parse(torch.version.cuda) > version.parse('11.5')

@pytest.mark.skipif(not CUDA_SUPPORT, reason="kv-cache manager engine requires cuda version to be higher than 11.5")
def test_prepare_data():
# dummy module used for testing
class DummyModule(nn.Module):
Expand Down Expand Up @@ -68,7 +71,7 @@ def __init__(self):
assert torch.equal(batch_state_out1.start_loc.to(start_loc.device), start_loc)
assert torch.equal(batch_state_out2.start_loc.to(start_loc.device), start_loc)


@pytest.mark.skipif(not CUDA_SUPPORT, reason="kv-cache manager engine requires cuda version to be higher than 11.5")
def test_orig_generate():
input_ids = torch.randint(low=10, high=1000, size=(MAX_BATCH_SIZE, MAX_INPUT_LEN))

Expand Down Expand Up @@ -116,6 +119,7 @@ def check_engine(rank, world_size, port):
run()


@pytest.mark.skipif(not CUDA_SUPPORT, reason="kv-cache manager engine requires cuda version to be higher than 11.5")
@pytest.mark.dist
@rerun_if_address_is_in_use()
@clear_cache_before_run()
Expand Down
5 changes: 3 additions & 2 deletions tests/test_infer/test_kvcache_manager.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import os

from packaging import version
import pytest
import torch

Expand All @@ -14,6 +14,7 @@
HEAD_NUM = 32
HEAD_DIM = 128

CUDA_SUPPORT = version.parse(torch.version.cuda) > version.parse('11.5')

def create_cache_manager(rank, world_size, port, batch_size, input_len, output_len, layer_num, head_num, head_dim):
os.environ['RANK'] = str(rank)
Expand Down Expand Up @@ -42,7 +43,7 @@ def create_cache_manager(rank, world_size, port, batch_size, input_len, output_l
kvcache_manager.alloc_contiguous(batch_size)
assert torch.all(kvcache_manager.mem_state[:total_token_prefill + batch_size] == False)


@pytest.mark.skipif(not CUDA_SUPPORT, reason="kv-cache manager engine requires cuda version to be higher than 11.5")
@pytest.mark.dist
@rerun_if_address_is_in_use()
def test_cache_manager_dist():
Expand Down
36 changes: 19 additions & 17 deletions tests/test_infer/test_llama_infer.py
Original file line number Diff line number Diff line change
@@ -1,24 +1,26 @@
import os

import numpy as np
import pytest
import torch
from packaging import version
import numpy as np
import torch.distributed as dist
from transformers import LlamaForCausalLM, LlamaTokenizer

import colossalai
from colossalai.cluster import ProcessGroupMesh
from colossalai.inference.tensor_parallel.engine import TPInferEngine
from colossalai.logging import disable_existing_loggers
from colossalai.shardformer import ShardConfig, ShardFormer
from colossalai.testing import clear_cache_before_run, parameterize, rerun_if_address_is_in_use, spawn
from colossalai.cluster import ProcessGroupMesh
from colossalai.shardformer import ShardConfig, ShardFormer
from colossalai.inference.tensor_parallel.engine import TPInferEngine


os.environ['TRANSFORMERS_NO_ADVISORY_WARNINGS'] = 'true'
TPSIZE = 2
TPSIZE = 1
BATCH_SIZE = 8
MAX_INPUT_LEN = 12
MAX_OUTPUT_LEN = 100

CUDA_SUPPORT = version.parse(torch.version.cuda) > version.parse('11.5')

def init_to_get_rotary(self, base=10000):
self.config.head_dim_ = self.config.hidden_size // self.config.num_attention_heads
Expand All @@ -33,42 +35,43 @@ def init_to_get_rotary(self, base=10000):
else:
max_seq_len = 2048 * rope_scaling_factor
base = float(base)
inv_freq = 1.0 / (base**(torch.arange(0, self.config.head_dim_, 2, device="cpu", dtype=torch.float32) /
self.config.head_dim_))
inv_freq = 1.0 / (base ** (torch.arange(0, self.config.head_dim_, 2, device="cpu", dtype=torch.float32) / self.config.head_dim_))
t = torch.arange(max_seq_len + 1024 * 64, device="cpu", dtype=torch.float32) / rope_scaling_factor
freqs = torch.outer(t, inv_freq)

self._cos_cached = torch.cos(freqs).to(torch.float16).cuda()
self._sin_cached = torch.sin(freqs).to(torch.float16).cuda()
return


@parameterize('test_config', [{
'tp_size': TPSIZE,
}])
def run_llama_test(test_config):

llama_model_path = "/data/scratch/llama-7b-hf"
if os.path.isdir(llama_model_path) is False:
return

tokenizer = LlamaTokenizer.from_pretrained(llama_model_path)
tokenizer.pad_token_id = tokenizer.unk_token_id
model = LlamaForCausalLM.from_pretrained(llama_model_path, pad_token_id=tokenizer.eos_token_id)
init_to_get_rotary(model.model, base=10000)
model = model.half()

text = "how is weather today?"
text = "where is the location of the capital of france?"
input_ids = tokenizer.encode(text, return_tensors='pt', device='cuda')

infer_engine = TPInferEngine(model.half(), BATCH_SIZE, MAX_INPUT_LEN, MAX_OUTPUT_LEN)
shard_config = ShardConfig(enable_tensor_parallelism=True, inference_only=True)
shardformer = ShardFormer(shard_config=shard_config)

infer_engine.prepare_with_shard_config(shard_config)
infer_engine.shard_model_by(shardformer)

generate_kwargs = dict(max_new_tokens=MAX_OUTPUT_LEN, do_sample=False)
outputs = infer_engine.generate(input_ids, generate_kwargs)
print("outputs.shape: ", outputs.shape)

print("outputs: ", outputs)

output_text = tokenizer.decode(outputs[0])
Expand All @@ -80,13 +83,12 @@ def check_llama(rank, world_size, port):
colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
run_llama_test()


@pytest.mark.skipif(not CUDA_SUPPORT, reason="kv-cache manager engine requires cuda version to be higher than 11.5")
@pytest.mark.dist
@rerun_if_address_is_in_use()
@clear_cache_before_run()
def test_llama():
spawn(check_llama, TPSIZE)


if __name__ == "__main__":
test_llama()
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ def test_llama_context_attention():

torch_out = torch_context_attention(query.clone(), k.clone(), v.clone(), bs, seq_len, head_num, head_dim)

assert torch.allclose(torch_out.cpu(), o.cpu(), rtol=1e-3, atol=1e-2), "outputs from triton and torch are not matched"
assert torch.allclose(torch_out.cpu(), o.cpu(), rtol=1e-3, atol=1e-3), "outputs from triton and torch are not matched"

latency_1 = benchmark(llama_context_attn_fwd, query, k, v, o, b_start, b_len, max_input_len)
latency_2 = benchmark(torch_context_attention, query, k, v, bs, seq_len, head_num, head_dim)
Expand Down