diff --git a/colossalai/inference/tensor_parallel/modeling/llama.py b/colossalai/inference/tensor_parallel/modeling/llama.py index 82f294163fd7..c68092ee9427 100644 --- a/colossalai/inference/tensor_parallel/modeling/llama.py +++ b/colossalai/inference/tensor_parallel/modeling/llama.py @@ -8,7 +8,6 @@ LlamaDecoderLayer, LlamaModel, LlamaRMSNorm, - apply_rotary_pos_emb, ) from colossalai.inference.tensor_parallel.batch_infer_state import BatchInferState @@ -29,6 +28,29 @@ ) HAS_VLLM_KERNERL = False +def rotate_half(x): + """Rotates half the hidden dims of the input.""" + x1 = x[..., : x.shape[-1] // 2] + x2 = x[..., x.shape[-1] // 2 :] + return torch.cat((-x2, x1), dim=-1) + + +def apply_rotary_pos_emb(q, k, cos, sin, position_ids): + # The first two dimensions of cos and sin are always 1, so we can `squeeze` them. + cos = cos.squeeze(1).squeeze(0) # [seq_len, dim] + sin = sin.squeeze(1).squeeze(0) # [seq_len, dim] + cos = cos[position_ids].unsqueeze(1) # [bs, 1, seq_len, dim] + sin = sin[position_ids].unsqueeze(1) # [bs, 1, seq_len, dim] + + q_embed = (q * cos) + (rotate_half(q) * sin) + k_embed = (k * cos) + (rotate_half(k) * sin) + return q_embed, k_embed + +def _copy_kv_to_mem_cache(layer_id, key_buffer, value_buffer, context_mem_index, mem_manager): + copy_kv_cache_to_dest(key_buffer, context_mem_index, mem_manager.key_buffer[layer_id]) + copy_kv_cache_to_dest(value_buffer, context_mem_index, mem_manager.value_buffer[layer_id]) + return + class LlamaInferenceForwards: """ @@ -251,8 +273,9 @@ def llama_flash_attn_kvcache_forward( query_states = self.q_proj(hidden_states).view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) key_states = self.k_proj(hidden_states).view(bsz, q_len, self.num_heads, self.head_dim) - value_states = self.v_proj(hidden_states).view(bsz, q_len, self.num_heads, self.head_dim) key_states_transposed = key_states.transpose(1, 2) + value_states = self.v_proj(hidden_states).view(bsz, q_len, self.num_heads, self.head_dim) + # NOTE might want to revise # need some way to record the length of past key values cache @@ -261,26 +284,42 @@ def llama_flash_attn_kvcache_forward( infer_state.cache_manager.past_key_values_length += q_len # seq_len if HAS_VLLM_KERNERL: + # NOTE: fix rotatry embedding precision problem cos, sin = infer_state.position_cos, infer_state.position_sin + + # value_states_transposed = value_states.transpose(1, 2) + + # cos, sin = self.rotary_emb(value_states_transposed, + # seq_len=infer_state.cache_manager.past_key_values_length) + cos_sin_cache = torch.cat((cos, sin), dim=-1) - rotary_embedding_neox(position_ids, query_states, key_states_transposed, self.head_dim, cos_sin_cache) - key_states = key_states_transposed.transpose(1, 2) + + key_states = key_states.view(-1, self.num_heads * self.head_dim) + query_states = query_states.transpose(1, 2).reshape(-1, self.num_heads * self.head_dim) + rotary_embedding_neox(position_ids.squeeze(1), query_states, key_states, self.head_dim, cos_sin_cache) + + + query_states = query_states.reshape(-1, self.num_heads, self.head_dim) + key_states = key_states.reshape(-1, self.num_heads, self.head_dim) + value_states = value_states.reshape(-1, self.num_heads, self.head_dim) + else: # NOTE: there are some issues for original rotary_embedding_neox of huggingface + value_states_transposed = value_states.transpose(1, 2) cos, sin = self.rotary_emb(value_states_transposed, - seq_len=infer_state.cache_manager.past_key_values_length) - query_states, key_states = apply_rotary_pos_emb(query_states, key_states_transposed, cos, sin, position_ids) - key_states = key_states_transposed.transpose(1, 2) - - def _copy_kv_to_mem_cache(layer_id, key_buffer, value_buffer, context_mem_index, mem_manager): - copy_kv_cache_to_dest(key_buffer, context_mem_index, mem_manager.key_buffer[layer_id]) - copy_kv_cache_to_dest(value_buffer, context_mem_index, mem_manager.value_buffer[layer_id]) - return - - key_states = key_states.reshape(-1, self.num_heads, self.head_dim) - value_states = value_states.reshape(-1, self.num_heads, self.head_dim) - query_states = query_states.transpose(1, 2).reshape(-1, self.num_heads, self.head_dim) + seq_len=infer_state.cache_manager.past_key_values_length) + + rotary_positions_ids = position_ids + idx = position_ids.shape[0] - 1 + if idx >= 1: + rotary_positions_ids = [[idx]] + + query_states, key_states = apply_rotary_pos_emb(query_states, key_states_transposed, cos, sin, rotary_positions_ids) + + query_states = query_states.transpose(1, 2).reshape(-1, self.num_heads, self.head_dim) + key_states = key_states.transpose(1, 2).reshape(-1, self.num_heads, self.head_dim) + value_states = value_states.reshape(-1, self.num_heads, self.head_dim) if infer_state.is_context_stage: # first token generation @@ -330,7 +369,6 @@ def _copy_kv_to_mem_cache(layer_id, key_buffer, value_buffer, context_mem_index, def get_llama_vllm_rmsnorm_forward(): if HAS_VLLM_KERNERL: - def _vllm_rmsnorm_forward(self: LlamaRMSNorm, hidden_states: torch.Tensor): x = hidden_states out = torch.empty_like(x) @@ -346,3 +384,4 @@ def _vllm_rmsnorm_forward(self: LlamaRMSNorm, hidden_states: torch.Tensor): return _vllm_rmsnorm_forward else: return None + diff --git a/colossalai/inference/tensor_parallel/policies/llama.py b/colossalai/inference/tensor_parallel/policies/llama.py index bbd2156b8523..e819f2a8810c 100644 --- a/colossalai/inference/tensor_parallel/policies/llama.py +++ b/colossalai/inference/tensor_parallel/policies/llama.py @@ -1,12 +1,33 @@ from functools import partial +import torch +from transformers.models.llama.modeling_llama import ( + LlamaAttention, + LlamaDecoderLayer, + LlamaModel, + LlamaRMSNorm +) -from transformers.models.llama.modeling_llama import LlamaAttention, LlamaDecoderLayer, LlamaModel, LlamaRMSNorm - +# import colossalai from colossalai.shardformer.policies.llama import LlamaForCausalLMPolicy - from ..modeling.llama import LlamaInferenceForwards, get_llama_vllm_rmsnorm_forward +try: + from colossalai.kernel.triton.rms_norm import rmsnorm_forward + HAS_TRITON_RMSNORM = True +except: + print("you should install triton from https://github.com/openai/triton") + HAS_TRITON_RMSNORM = False + +def get_triton_rmsnorm_forward(): + if HAS_TRITON_RMSNORM: + def _triton_rmsnorm_forward(self: LlamaRMSNorm, hidden_states: torch.Tensor): + return rmsnorm_forward(hidden_states, self.weight.data, self.variance_epsilon) + + return _triton_rmsnorm_forward + else: + return None + class LlamaModelInferPolicy(LlamaForCausalLMPolicy): def __init__(self) -> None: @@ -16,12 +37,6 @@ def module_policy(self): policy = super().module_policy() self.shard_config._infer() - # example for replace layer or decoder - # if self.shard_config.enable_flash_attention: - # policy[LlamaAttention] = ModulePolicyDescription(method_replacement={ - # 'forward': get_llama_flash_attention_forward(), - # }) - infer_forward = LlamaInferenceForwards.llama_model_forward method_replacement = {'forward': partial(infer_forward)} self.append_or_create_method_replacement(description=method_replacement, policy=policy, target_key=LlamaModel) @@ -38,12 +53,18 @@ def module_policy(self): policy=policy, target_key=LlamaAttention) - # NOTE: adding rms_norm caused precision issue, fix @tiandiao123 - # infer_forward = get_llama_vllm_rmsnorm_forward() - # if infer_forward is not None: - # method_replacement = {'forward': partial(infer_forward)} - # self.append_or_create_method_replacement(description=method_replacement, - # policy=policy, - # target_key=LlamaRMSNorm) + infer_forward = None + if HAS_TRITON_RMSNORM: + infer_forward = get_triton_rmsnorm_forward() + else: + # NOTE: adding rms_norm from cuda kernels caused precision issue, fix @tiandiao123 + infer_forward = get_llama_vllm_rmsnorm_forward() + + if infer_forward is not None: + method_replacement = {'forward': partial(infer_forward)} + self.append_or_create_method_replacement(description=method_replacement, + policy=policy, + target_key=LlamaRMSNorm) return policy + diff --git a/colossalai/kernel/triton/__init__.py b/colossalai/kernel/triton/__init__.py index 9655d720406a..75bd4ed80a72 100644 --- a/colossalai/kernel/triton/__init__.py +++ b/colossalai/kernel/triton/__init__.py @@ -1,3 +1,4 @@ from .context_attention import llama_context_attn_fwd, bloom_context_attn_fwd from .softmax import softmax from .copy_kv_cache_dest import copy_kv_cache_to_dest +from .rms_norm import rmsnorm_forward diff --git a/colossalai/kernel/triton/rms_norm.py b/colossalai/kernel/triton/rms_norm.py new file mode 100644 index 000000000000..1fb79115f8ce --- /dev/null +++ b/colossalai/kernel/triton/rms_norm.py @@ -0,0 +1,72 @@ +import torch + +try: + import triton + import triton.language as tl + HAS_TRITON = True +except ImportError: + HAS_TRITON = False + print("please install triton from https://github.com/openai/triton") + + +if HAS_TRITON: + ''' + this kernel function is modified from + https://github.com/ModelTC/lightllm/blob/main/lightllm/models/llama/triton_kernel/rmsnorm.py + ''' + @triton.jit + def _rms_norm_fwd_fused( + X, # pointer to the input + Y, # pointer to the output + W, # pointer to the weights + stride, # how much to increase the pointer when moving by 1 row + N, # number of columns in X + eps, # epsilon to avoid division by zero + BLOCK_SIZE: tl.constexpr, + ): + # Map the program id to the row of X and Y it should compute. + row = tl.program_id(0) + Y += row * stride + X += row * stride + # Compute variance + _var = tl.zeros([BLOCK_SIZE], dtype=tl.float32) + for off in range(0, N, BLOCK_SIZE): + cols = off + tl.arange(0, BLOCK_SIZE) + x = tl.load(X + cols, mask=cols < N, other=0.).to(tl.float32) + _var += x * x + var = tl.sum(_var, axis=0) / N + rstd = 1 / tl.sqrt(var + eps) + # Normalize and apply linear transformation + for off in range(0, N, BLOCK_SIZE): + cols = off + tl.arange(0, BLOCK_SIZE) + mask = cols < N + w = tl.load(W + cols, mask=mask).to(tl.float32) + x = tl.load(X + cols, mask=mask, other=0.).to(tl.float32) + x_hat = x * rstd + y = x_hat * w + # Write output + tl.store(Y + cols, y.to(tl.float16), mask=mask) + + + def rmsnorm_forward(x, weight, eps): + # allocate output + y = torch.empty_like(x) + # reshape input data into 2D tensor + x_arg = x.view(-1, x.shape[-1]) + M, N = x_arg.shape + # Less than 64KB per feature: enqueue fused kernel + MAX_FUSED_SIZE = 65536 // x.element_size() + BLOCK_SIZE = min(MAX_FUSED_SIZE, triton.next_power_of_2(N)) + # print("BLOCK_SIZE:", BLOCK_SIZE) + if N > BLOCK_SIZE: + raise RuntimeError("This layer norm doesn't support feature dim >= 64KB.") + # heuristics for number of warps + num_warps = min(max(BLOCK_SIZE // 256, 1), 8) + # print(BLOCK_SIZE, num_warps, "block_size, numwarps") + BLOCK_SIZE = 128 * 2 * 2 * 2 * 2 * 2 * 2 * 2 + num_warps = 8 + # enqueue kernel + _rms_norm_fwd_fused[(M,)](x_arg, y, weight, + x_arg.stride(0), N, eps, + BLOCK_SIZE=BLOCK_SIZE, num_warps=num_warps) + return y diff --git a/tests/test_infer/test_bloom_infer.py b/tests/test_infer/test_bloom_infer.py index 95ab7d5c451e..754c158e6279 100644 --- a/tests/test_infer/test_bloom_infer.py +++ b/tests/test_infer/test_bloom_infer.py @@ -1,4 +1,6 @@ +import os import pytest +from packaging import version import torch import torch.distributed as dist from transformers import AutoModelForCausalLM, AutoTokenizer, BloomForCausalLM @@ -14,10 +16,14 @@ MAX_INPUT_LEN = 16 MAX_OUTPUT_LEN = 32 +CUDA_SUPPORT = version.parse(torch.version.cuda) > version.parse('11.5') -def run(): +def run(): model_path = "/data3/data/model_eval_for_commerical_use/phoenix-inst-chat-7b" + if os.path.isdir(model_path) is False: + return + tokenizer = AutoTokenizer.from_pretrained(model_path) tokenizer.pad_token = tokenizer.eos_token @@ -48,7 +54,7 @@ def check_engine(rank, world_size, port): colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') run() - +@pytest.mark.skipif(not CUDA_SUPPORT, reason="kv-cache manager engine requires cuda version to be higher than 11.5") @pytest.mark.dist @rerun_if_address_is_in_use() @clear_cache_before_run() diff --git a/tests/test_infer/test_infer_engine.py b/tests/test_infer/test_infer_engine.py index 6fcf9fe0f387..c9432509d941 100644 --- a/tests/test_infer/test_infer_engine.py +++ b/tests/test_infer/test_infer_engine.py @@ -1,6 +1,7 @@ +import pytest from itertools import accumulate +from packaging import version -import pytest import torch import torch.nn as nn from transformers import BloomConfig, BloomForCausalLM, LlamaConfig, LlamaForCausalLM @@ -18,7 +19,9 @@ MAX_INPUT_LEN = 16 MAX_OUTPUT_LEN = 8 +CUDA_SUPPORT = version.parse(torch.version.cuda) > version.parse('11.5') +@pytest.mark.skipif(not CUDA_SUPPORT, reason="kv-cache manager engine requires cuda version to be higher than 11.5") def test_prepare_data(): # dummy module used for testing class DummyModule(nn.Module): @@ -68,7 +71,7 @@ def __init__(self): assert torch.equal(batch_state_out1.start_loc.to(start_loc.device), start_loc) assert torch.equal(batch_state_out2.start_loc.to(start_loc.device), start_loc) - +@pytest.mark.skipif(not CUDA_SUPPORT, reason="kv-cache manager engine requires cuda version to be higher than 11.5") def test_orig_generate(): input_ids = torch.randint(low=10, high=1000, size=(MAX_BATCH_SIZE, MAX_INPUT_LEN)) @@ -116,6 +119,7 @@ def check_engine(rank, world_size, port): run() +@pytest.mark.skipif(not CUDA_SUPPORT, reason="kv-cache manager engine requires cuda version to be higher than 11.5") @pytest.mark.dist @rerun_if_address_is_in_use() @clear_cache_before_run() diff --git a/tests/test_infer/test_kvcache_manager.py b/tests/test_infer/test_kvcache_manager.py index fb04d7800ea2..f57c6956f817 100644 --- a/tests/test_infer/test_kvcache_manager.py +++ b/tests/test_infer/test_kvcache_manager.py @@ -1,5 +1,5 @@ import os - +from packaging import version import pytest import torch @@ -14,6 +14,7 @@ HEAD_NUM = 32 HEAD_DIM = 128 +CUDA_SUPPORT = version.parse(torch.version.cuda) > version.parse('11.5') def create_cache_manager(rank, world_size, port, batch_size, input_len, output_len, layer_num, head_num, head_dim): os.environ['RANK'] = str(rank) @@ -42,7 +43,7 @@ def create_cache_manager(rank, world_size, port, batch_size, input_len, output_l kvcache_manager.alloc_contiguous(batch_size) assert torch.all(kvcache_manager.mem_state[:total_token_prefill + batch_size] == False) - +@pytest.mark.skipif(not CUDA_SUPPORT, reason="kv-cache manager engine requires cuda version to be higher than 11.5") @pytest.mark.dist @rerun_if_address_is_in_use() def test_cache_manager_dist(): diff --git a/tests/test_infer/test_llama_infer.py b/tests/test_infer/test_llama_infer.py index c8f852aef420..3caa73ac23d4 100644 --- a/tests/test_infer/test_llama_infer.py +++ b/tests/test_infer/test_llama_infer.py @@ -1,24 +1,26 @@ import os - -import numpy as np import pytest import torch +from packaging import version +import numpy as np import torch.distributed as dist from transformers import LlamaForCausalLM, LlamaTokenizer import colossalai -from colossalai.cluster import ProcessGroupMesh -from colossalai.inference.tensor_parallel.engine import TPInferEngine from colossalai.logging import disable_existing_loggers -from colossalai.shardformer import ShardConfig, ShardFormer from colossalai.testing import clear_cache_before_run, parameterize, rerun_if_address_is_in_use, spawn +from colossalai.cluster import ProcessGroupMesh +from colossalai.shardformer import ShardConfig, ShardFormer +from colossalai.inference.tensor_parallel.engine import TPInferEngine + os.environ['TRANSFORMERS_NO_ADVISORY_WARNINGS'] = 'true' -TPSIZE = 2 +TPSIZE = 1 BATCH_SIZE = 8 MAX_INPUT_LEN = 12 MAX_OUTPUT_LEN = 100 +CUDA_SUPPORT = version.parse(torch.version.cuda) > version.parse('11.5') def init_to_get_rotary(self, base=10000): self.config.head_dim_ = self.config.hidden_size // self.config.num_attention_heads @@ -33,8 +35,7 @@ def init_to_get_rotary(self, base=10000): else: max_seq_len = 2048 * rope_scaling_factor base = float(base) - inv_freq = 1.0 / (base**(torch.arange(0, self.config.head_dim_, 2, device="cpu", dtype=torch.float32) / - self.config.head_dim_)) + inv_freq = 1.0 / (base ** (torch.arange(0, self.config.head_dim_, 2, device="cpu", dtype=torch.float32) / self.config.head_dim_)) t = torch.arange(max_seq_len + 1024 * 64, device="cpu", dtype=torch.float32) / rope_scaling_factor freqs = torch.outer(t, inv_freq) @@ -42,33 +43,35 @@ def init_to_get_rotary(self, base=10000): self._sin_cached = torch.sin(freqs).to(torch.float16).cuda() return - @parameterize('test_config', [{ 'tp_size': TPSIZE, }]) def run_llama_test(test_config): - + llama_model_path = "/data/scratch/llama-7b-hf" + if os.path.isdir(llama_model_path) is False: + return + tokenizer = LlamaTokenizer.from_pretrained(llama_model_path) tokenizer.pad_token_id = tokenizer.unk_token_id model = LlamaForCausalLM.from_pretrained(llama_model_path, pad_token_id=tokenizer.eos_token_id) init_to_get_rotary(model.model, base=10000) model = model.half() - - text = "how is weather today?" + + text = "where is the location of the capital of france?" input_ids = tokenizer.encode(text, return_tensors='pt', device='cuda') - + infer_engine = TPInferEngine(model.half(), BATCH_SIZE, MAX_INPUT_LEN, MAX_OUTPUT_LEN) shard_config = ShardConfig(enable_tensor_parallelism=True, inference_only=True) shardformer = ShardFormer(shard_config=shard_config) - + infer_engine.prepare_with_shard_config(shard_config) infer_engine.shard_model_by(shardformer) generate_kwargs = dict(max_new_tokens=MAX_OUTPUT_LEN, do_sample=False) outputs = infer_engine.generate(input_ids, generate_kwargs) print("outputs.shape: ", outputs.shape) - + print("outputs: ", outputs) output_text = tokenizer.decode(outputs[0]) @@ -80,13 +83,12 @@ def check_llama(rank, world_size, port): colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') run_llama_test() - +@pytest.mark.skipif(not CUDA_SUPPORT, reason="kv-cache manager engine requires cuda version to be higher than 11.5") @pytest.mark.dist @rerun_if_address_is_in_use() @clear_cache_before_run() def test_llama(): spawn(check_llama, TPSIZE) - if __name__ == "__main__": test_llama() diff --git a/tests/test_infer_ops/triton/test_llama_context_attention.py b/tests/test_infer_ops/triton/test_llama_context_attention.py index e7446b289acd..b0fac1263047 100644 --- a/tests/test_infer_ops/triton/test_llama_context_attention.py +++ b/tests/test_infer_ops/triton/test_llama_context_attention.py @@ -44,7 +44,7 @@ def test_llama_context_attention(): torch_out = torch_context_attention(query.clone(), k.clone(), v.clone(), bs, seq_len, head_num, head_dim) - assert torch.allclose(torch_out.cpu(), o.cpu(), rtol=1e-3, atol=1e-2), "outputs from triton and torch are not matched" + assert torch.allclose(torch_out.cpu(), o.cpu(), rtol=1e-3, atol=1e-3), "outputs from triton and torch are not matched" latency_1 = benchmark(llama_context_attn_fwd, query, k, v, o, b_start, b_len, max_input_len) latency_2 = benchmark(torch_context_attention, query, k, v, bs, seq_len, head_num, head_dim)