From c7d6d8d9a49633ea87f0146d18a330ff74d30d68 Mon Sep 17 00:00:00 2001 From: Xu Kai Date: Mon, 4 Sep 2023 17:06:25 +0800 Subject: [PATCH 1/7] fix rotary embedding --- .../tensor_parallel/modeling/llama.py | 49 +++++----- .../kernel/triton/rotary_embedding_kernel.py | 93 +++++++++++++++++++ tests/test_infer/test_llama_infer.py | 12 ++- .../triton/test_rotary_embedding.py | 51 ++++++++++ 4 files changed, 174 insertions(+), 31 deletions(-) create mode 100644 colossalai/kernel/triton/rotary_embedding_kernel.py create mode 100644 tests/test_infer_ops/triton/test_rotary_embedding.py diff --git a/colossalai/inference/tensor_parallel/modeling/llama.py b/colossalai/inference/tensor_parallel/modeling/llama.py index 82f294163fd7..ba72226f8941 100644 --- a/colossalai/inference/tensor_parallel/modeling/llama.py +++ b/colossalai/inference/tensor_parallel/modeling/llama.py @@ -14,6 +14,7 @@ from colossalai.inference.tensor_parallel.batch_infer_state import BatchInferState from colossalai.kernel.triton.context_attention import llama_context_attn_fwd from colossalai.kernel.triton.copy_kv_cache_dest import copy_kv_cache_to_dest +from colossalai.kernel.triton.rotary_embedding_kernel import rotary_embedding_fwd from colossalai.kernel.triton.token_attention_kernel import token_attention_fwd try: @@ -53,17 +54,6 @@ def llama_model_forward( batch_size = input_ids.shape[0] # input_ids.shape[0] infer_state = self.infer_state - b_seq_len_numpy = infer_state.seq_len.cpu().numpy() - - if HAS_VLLM_KERNERL: - position_ids = torch.from_numpy( - np.concatenate([np.arange(0, b_seq_len_numpy[i]) for i in range(len(b_seq_len_numpy))], axis=0)).cuda() - - # this equals - infer_state.position_cos = torch.index_select(self._cos_cached, 0, - position_ids).view(position_ids.shape[0], -1) - infer_state.position_sin = torch.index_select(self._sin_cached, 0, - position_ids).view(position_ids.shape[0], -1) return_dict = return_dict if return_dict is not None else self.config.use_return_dict @@ -117,6 +107,21 @@ def llama_model_forward( # infer_state.decode_value_buffer = torch.empty((batch_size, self.tp_head_num_, self.head_dim_), dtype=torch.float16, device="cuda") infer_state.block_loc[:, seq_length_with_past - 1] = infer_state.decode_mem_index + if infer_state.is_context_stage: + b_seq_len_numpy = infer_state.seq_len.cpu().numpy() + position_ids = torch.from_numpy( + np.concatenate([np.arange(0, b_seq_len_numpy[i]) for i in range(len(b_seq_len_numpy))], axis=0)).cuda() + + infer_state.position_cos = torch.index_select(self._cos_cached, 0, + position_ids).view(position_ids.shape[0], -1) + infer_state.position_sin = torch.index_select(self._sin_cached, 0, + position_ids).view(position_ids.shape[0], -1) + position_ids = None + else: + seq_len = infer_state.seq_len + infer_state.position_cos = torch.index_select(self._cos_cached, 0, seq_len - 1).view(seq_len.shape[0], -1) + infer_state.position_sin = torch.index_select(self._sin_cached, 0, seq_len - 1).view(seq_len.shape[0], -1) + if position_ids is None: device = input_ids.device if input_ids is not None else inputs_embeds.device position_ids = torch.arange(past_key_values_length, @@ -249,10 +254,9 @@ def llama_flash_attn_kvcache_forward( # key_states [bs, seq_len, num_heads, head_dim/embed_size_per_head] # key_states_transposed [bs, num_heads, seq_len, head_dim/embed_size_per_head] - query_states = self.q_proj(hidden_states).view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) + query_states = self.q_proj(hidden_states).view(bsz, q_len, self.num_heads, self.head_dim) key_states = self.k_proj(hidden_states).view(bsz, q_len, self.num_heads, self.head_dim) value_states = self.v_proj(hidden_states).view(bsz, q_len, self.num_heads, self.head_dim) - key_states_transposed = key_states.transpose(1, 2) # NOTE might want to revise # need some way to record the length of past key values cache @@ -260,27 +264,20 @@ def llama_flash_attn_kvcache_forward( if infer_state.decode_layer_id == 0: # once per model.forward infer_state.cache_manager.past_key_values_length += q_len # seq_len - if HAS_VLLM_KERNERL: - cos, sin = infer_state.position_cos, infer_state.position_sin - cos_sin_cache = torch.cat((cos, sin), dim=-1) - rotary_embedding_neox(position_ids, query_states, key_states_transposed, self.head_dim, cos_sin_cache) - key_states = key_states_transposed.transpose(1, 2) - else: - # NOTE: there are some issues for original rotary_embedding_neox of huggingface - value_states_transposed = value_states.transpose(1, 2) - cos, sin = self.rotary_emb(value_states_transposed, - seq_len=infer_state.cache_manager.past_key_values_length) - query_states, key_states = apply_rotary_pos_emb(query_states, key_states_transposed, cos, sin, position_ids) - key_states = key_states_transposed.transpose(1, 2) + cos, sin = infer_state.position_cos, infer_state.position_sin + # print("shape ", cos.shape, query_states.view(-1, self.num_heads, self.head_dim).shape, ) + + rotary_embedding_fwd(query_states.view(-1, self.num_heads, self.head_dim), cos, sin) + rotary_embedding_fwd(key_states.view(-1, self.num_heads, self.head_dim), cos, sin) def _copy_kv_to_mem_cache(layer_id, key_buffer, value_buffer, context_mem_index, mem_manager): copy_kv_cache_to_dest(key_buffer, context_mem_index, mem_manager.key_buffer[layer_id]) copy_kv_cache_to_dest(value_buffer, context_mem_index, mem_manager.value_buffer[layer_id]) return + query_states = query_states.reshape(-1, self.num_heads, self.head_dim) key_states = key_states.reshape(-1, self.num_heads, self.head_dim) value_states = value_states.reshape(-1, self.num_heads, self.head_dim) - query_states = query_states.transpose(1, 2).reshape(-1, self.num_heads, self.head_dim) if infer_state.is_context_stage: # first token generation diff --git a/colossalai/kernel/triton/rotary_embedding_kernel.py b/colossalai/kernel/triton/rotary_embedding_kernel.py new file mode 100644 index 000000000000..d9d1b2bcf026 --- /dev/null +++ b/colossalai/kernel/triton/rotary_embedding_kernel.py @@ -0,0 +1,93 @@ +# Adapted from ModelTC https://github.com/ModelTC/lightllm +import torch +import triton +import triton.language as tl + + +@triton.jit +def _rotary_kernel( + q, + Cos, + Sin, + q_bs_stride, + q_h_stride, + q_d_stride, + cos_bs_stride, + cos_d_stride, + total_len, + HEAD_NUM: tl.constexpr, + BLOCK_HEAD: tl.constexpr, + BLOCK_SEQ: tl.constexpr, + HEAD_DIM: tl.constexpr, +): + current_head_index = tl.program_id(0) + current_seq_index = tl.program_id(1) + + current_head_range = current_head_index * BLOCK_HEAD + tl.arange(0, BLOCK_HEAD) + current_seq_range = current_seq_index * BLOCK_SEQ + tl.arange(0, BLOCK_SEQ) + + dim_range0 = tl.arange(0, HEAD_DIM // 2) + dim_range1 = tl.arange(HEAD_DIM // 2, HEAD_DIM) + + off_q0 = current_seq_range[:, None, None] * q_bs_stride + current_head_range[ + None, :, None] * q_h_stride + dim_range0[None, None, :] * q_d_stride + off_q1 = current_seq_range[:, None, None] * q_bs_stride + current_head_range[ + None, :, None] * q_h_stride + dim_range1[None, None, :] * q_d_stride + + off_dimcos_sin = current_seq_range[:, None, None] * cos_bs_stride + dim_range0[None, None, :] * cos_d_stride + + q0 = tl.load(q + off_q0, + mask=(current_seq_range[:, None, None] < total_len) & (current_head_range[None, :, None] < HEAD_NUM), + other=0.0) + q1 = tl.load(q + off_q1, + mask=(current_seq_range[:, None, None] < total_len) & (current_head_range[None, :, None] < HEAD_NUM), + other=0.0) + + cos = tl.load(Cos + off_dimcos_sin, mask=current_seq_range[:, None, None] < total_len, other=0.0) + sin = tl.load(Sin + off_dimcos_sin, mask=current_seq_range[:, None, None] < total_len, other=0.0) + + out0 = q0 * cos - q1 * sin + out1 = q0 * sin + q1 * cos + + tl.store(q + off_q0, + out0, + mask=(current_seq_range[:, None, None] < total_len) & (current_head_range[None, :, None] < HEAD_NUM)) + tl.store(q + off_q1, + out1, + mask=(current_seq_range[:, None, None] < total_len) & (current_head_range[None, :, None] < HEAD_NUM)) + + return + + +@torch.no_grad() +def rotary_embedding_fwd(q, cos, sin): + total_len = q.shape[0] + head_num = q.shape[1] + head_dim = q.shape[2] + assert q.shape[0] == cos.shape[0] and q.shape[0] == sin.shape[0], f"q shape {q.shape} cos shape {cos.shape}" + BLOCK_HEAD = 4 + BLOCK_SEQ = 32 + grid = (triton.cdiv(head_num, BLOCK_HEAD), triton.cdiv(total_len, BLOCK_SEQ)) + if head_dim >= 128: + num_warps = 8 + else: + num_warps = 4 + + _rotary_kernel[grid]( + q, + cos, + sin, + q.stride(0), + q.stride(1), + q.stride(2), + cos.stride(0), + cos.stride(1), + total_len, + HEAD_NUM=head_num, + BLOCK_HEAD=BLOCK_HEAD, + BLOCK_SEQ=BLOCK_SEQ, + HEAD_DIM=head_dim, + num_warps=num_warps, + num_stages=1, + ) + return diff --git a/tests/test_infer/test_llama_infer.py b/tests/test_infer/test_llama_infer.py index c8f852aef420..101cb019b462 100644 --- a/tests/test_infer/test_llama_infer.py +++ b/tests/test_infer/test_llama_infer.py @@ -55,9 +55,10 @@ def run_llama_test(test_config): init_to_get_rotary(model.model, base=10000) model = model.half() - text = "how is weather today?" - input_ids = tokenizer.encode(text, return_tensors='pt', device='cuda') + text = ["how is weather today?", "i am "] + input_ids = tokenizer.batch_encode_plus(text, return_tensors='pt', padding=True, device='cuda') + print("input ids ", input_ids) infer_engine = TPInferEngine(model.half(), BATCH_SIZE, MAX_INPUT_LEN, MAX_OUTPUT_LEN) shard_config = ShardConfig(enable_tensor_parallelism=True, inference_only=True) shardformer = ShardFormer(shard_config=shard_config) @@ -70,9 +71,10 @@ def run_llama_test(test_config): print("outputs.shape: ", outputs.shape) print("outputs: ", outputs) - - output_text = tokenizer.decode(outputs[0]) - print(output_text) + if not dist.is_initialized() or dist.get_rank() == 0: + for o in outputs: + output_text = tokenizer.decode(o) + print(output_text) def check_llama(rank, world_size, port): diff --git a/tests/test_infer_ops/triton/test_rotary_embedding.py b/tests/test_infer_ops/triton/test_rotary_embedding.py new file mode 100644 index 000000000000..b8637b7965ef --- /dev/null +++ b/tests/test_infer_ops/triton/test_rotary_embedding.py @@ -0,0 +1,51 @@ +# Adapted from ModelTC https://github.com/ModelTC/lightllm + +import time + +import pytest +import torch +from packaging import version + +try: + import triton + import triton.language as tl + + from colossalai.kernel.triton.rotary_embedding_kernel import rotary_embedding_fwd + HAS_TRITON = True +except ImportError: + HAS_TRITON = False + print("please install triton from https://github.com/openai/triton") + +TRITON_CUDA_SUPPORT = version.parse(torch.version.cuda) > version.parse('11.4') + + +def torch_rotary_emb(x, cos, sin): + seq_len, h, dim = x.shape + x0 = x[:, :, 0:dim // 2] + x1 = x[:, :, dim // 2:dim] + cos = cos.view((seq_len, 1, dim // 2)) + sin = sin.view((seq_len, 1, dim // 2)) + o0 = x0 * cos - x1 * sin + o1 = x0 * sin + x1 * cos + return torch.cat((o0, o1), dim=-1) + + +def test_rotary_emb(): + SEQ_LEN = 512 + HEAD_NUM = 16 + HEAD_DIM = 64 + dtype = torch.half + # create data + x_shape = (SEQ_LEN, HEAD_NUM, HEAD_DIM) + x = -2.3 + 0.5 * torch.randn(x_shape, dtype=dtype, device='cuda') + cos_shape = (SEQ_LEN, HEAD_DIM // 2) + cos = -1.2 + 0.5 * torch.randn(cos_shape, dtype=dtype, device='cuda') + sin = -2.0 + 0.5 * torch.randn(cos_shape, dtype=dtype, device='cuda') + # forward pass + y_torch = torch_rotary_emb(x, cos, sin) + rotary_embedding_fwd(x, cos, sin) + y_triton = x + + # compare + print("max delta:", torch.max(torch.abs(y_torch - y_triton))) + assert torch.allclose(y_torch, y_triton, atol=1e-2, rtol=1e-2) From 64c0782bf7e918164d6873407bca4c6c428d8199 Mon Sep 17 00:00:00 2001 From: Xu Kai Date: Tue, 5 Sep 2023 13:41:07 +0800 Subject: [PATCH 2/7] delete print --- tests/test_infer/test_llama_infer.py | 8 ++++---- .../triton/test_rotary_embedding.py | 16 +++++++++++----- 2 files changed, 15 insertions(+), 9 deletions(-) diff --git a/tests/test_infer/test_llama_infer.py b/tests/test_infer/test_llama_infer.py index b098375768fa..1d043ba59338 100644 --- a/tests/test_infer/test_llama_infer.py +++ b/tests/test_infer/test_llama_infer.py @@ -64,7 +64,7 @@ def run_llama_test(test_config): text = ["how is weather today?", "i am "] input_ids = tokenizer.batch_encode_plus(text, return_tensors='pt', padding=True, device='cuda') - print("input ids ", input_ids) + #print("input ids ", input_ids) infer_engine = TPInferEngine(model.half(), BATCH_SIZE, MAX_INPUT_LEN, MAX_OUTPUT_LEN) shard_config = ShardConfig(enable_tensor_parallelism=True, inference_only=True) shardformer = ShardFormer(shard_config=shard_config) @@ -74,13 +74,13 @@ def run_llama_test(test_config): generate_kwargs = dict(max_new_tokens=MAX_OUTPUT_LEN, do_sample=False) outputs = infer_engine.generate(input_ids, generate_kwargs) - print("outputs.shape: ", outputs.shape) + #print("outputs.shape: ", outputs.shape) - print("outputs: ", outputs) + #print("outputs: ", outputs) if not dist.is_initialized() or dist.get_rank() == 0: for o in outputs: output_text = tokenizer.decode(o) - print(output_text) + #print(output_text) def check_llama(rank, world_size, port): diff --git a/tests/test_infer_ops/triton/test_rotary_embedding.py b/tests/test_infer_ops/triton/test_rotary_embedding.py index b8637b7965ef..3a82a6603318 100644 --- a/tests/test_infer_ops/triton/test_rotary_embedding.py +++ b/tests/test_infer_ops/triton/test_rotary_embedding.py @@ -11,6 +11,8 @@ import triton.language as tl from colossalai.kernel.triton.rotary_embedding_kernel import rotary_embedding_fwd + from tests.test_infer_ops.triton.utils import benchmark + HAS_TRITON = True except ImportError: HAS_TRITON = False @@ -31,9 +33,9 @@ def torch_rotary_emb(x, cos, sin): def test_rotary_emb(): - SEQ_LEN = 512 - HEAD_NUM = 16 - HEAD_DIM = 64 + SEQ_LEN = 1 + HEAD_NUM = 32 + HEAD_DIM = 128 dtype = torch.half # create data x_shape = (SEQ_LEN, HEAD_NUM, HEAD_DIM) @@ -45,7 +47,11 @@ def test_rotary_emb(): y_torch = torch_rotary_emb(x, cos, sin) rotary_embedding_fwd(x, cos, sin) y_triton = x - + # print("max delta:", torch.max(torch.abs(y_torch - y_triton))) # compare - print("max delta:", torch.max(torch.abs(y_torch - y_triton))) assert torch.allclose(y_torch, y_triton, atol=1e-2, rtol=1e-2) + + # triton_latency = benchmark(rotary_embedding_fwd, x, cos, sin) + # torch_latency = benchmark(torch_rotary_emb, x, cos, sin) + # print("triton kernel latency:{:.6f} ms".format(triton_latency)) + # print("torch kernel latency:{:.6f} ms".format(torch_latency)) From 64d8d557e4c9a99902816421747ec86b534f4d03 Mon Sep 17 00:00:00 2001 From: Xu Kai Date: Tue, 5 Sep 2023 13:49:11 +0800 Subject: [PATCH 3/7] fix init seq len bug --- colossalai/inference/tensor_parallel/engine.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/colossalai/inference/tensor_parallel/engine.py b/colossalai/inference/tensor_parallel/engine.py index 986177555b06..c6abb74f080b 100644 --- a/colossalai/inference/tensor_parallel/engine.py +++ b/colossalai/inference/tensor_parallel/engine.py @@ -181,10 +181,11 @@ def prepare_batch_state(self, inputs) -> BatchInferState: max_len_in_batch = -1 if isinstance(inputs, (BatchEncoding, dict)): for i, attn_mask in enumerate(attention_mask): - if isinstance(attn_mask, torch.Tensor): - curr_seq_len = int(torch.sum(attn_mask)) - else: - curr_seq_len = int(sum(attn_mask)) + curr_seq_len = len(attn_mask) + # if isinstance(attn_mask, torch.Tensor): + # curr_seq_len = int(torch.sum(attn_mask)) + # else: + # curr_seq_len = int(sum(attn_mask)) seq_lengths[i] = curr_seq_len seq_start_indexes[i] = start_index start_index += curr_seq_len @@ -196,7 +197,6 @@ def prepare_batch_state(self, inputs) -> BatchInferState: seq_start_indexes[i] = start_index start_index += curr_seq_len max_len_in_batch = curr_seq_len if curr_seq_len > max_len_in_batch else max_len_in_batch - block_loc = torch.empty((batch_size, self.max_input_len + self.max_output_len), dtype=torch.long, device='cuda') batch_infer_state = BatchInferState(batch_size, max_len_in_batch) batch_infer_state.seq_len = seq_lengths.to('cuda') # might want to assign specific device From a255925b7fa32bc3099ddec2cc54389bfe8e1990 Mon Sep 17 00:00:00 2001 From: Xu Kai Date: Tue, 5 Sep 2023 14:10:06 +0800 Subject: [PATCH 4/7] rename pytest --- .../triton/{test_layernorm.py => test_layernorm_triton.py} | 0 1 file changed, 0 insertions(+), 0 deletions(-) rename tests/test_infer_ops/triton/{test_layernorm.py => test_layernorm_triton.py} (100%) diff --git a/tests/test_infer_ops/triton/test_layernorm.py b/tests/test_infer_ops/triton/test_layernorm_triton.py similarity index 100% rename from tests/test_infer_ops/triton/test_layernorm.py rename to tests/test_infer_ops/triton/test_layernorm_triton.py From c525e78fbf58a29b2d16015155db16b41c095ad8 Mon Sep 17 00:00:00 2001 From: Xu Kai Date: Tue, 5 Sep 2023 14:10:34 +0800 Subject: [PATCH 5/7] add benchmark for llama --- examples/inference/bench_llama.py | 153 ++++++++++++++++++++++++++++++ 1 file changed, 153 insertions(+) create mode 100644 examples/inference/bench_llama.py diff --git a/examples/inference/bench_llama.py b/examples/inference/bench_llama.py new file mode 100644 index 000000000000..7bb1d11259d7 --- /dev/null +++ b/examples/inference/bench_llama.py @@ -0,0 +1,153 @@ +import os +import time + +import numpy as np +import pytest +import torch +import torch.distributed as dist +from torch.profiler import ProfilerActivity, profile, record_function +from transformers import LlamaForCausalLM, LlamaTokenizer + +import colossalai +from colossalai.cluster import ProcessGroupMesh +from colossalai.inference.tensor_parallel.engine import TPInferEngine +from colossalai.logging import disable_existing_loggers +from colossalai.shardformer import ShardConfig, ShardFormer +from colossalai.testing import clear_cache_before_run, parameterize, rerun_if_address_is_in_use, spawn + +os.environ['TRANSFORMERS_NO_ADVISORY_WARNINGS'] = 'true' +TPSIZE = 1 +BATCH_SIZE = 32 +MAX_INPUT_LEN = 1024 +MAX_OUTPUT_LEN = 256 + + +def init_to_get_rotary(self, base=10000): + self.config.head_dim_ = self.config.hidden_size // self.config.num_attention_heads + if not hasattr(self.config, "rope_scaling"): + rope_scaling_factor = 1.0 + else: + rope_scaling_factor = self.config.rope_scaling.factor if self.config.rope_scaling is not None else 1.0 + if hasattr(self.config, "max_sequence_length"): + max_seq_len = self.config.max_sequence_length + elif hasattr(self.config, "max_position_embeddings"): + max_seq_len = self.config.max_position_embeddings * rope_scaling_factor + else: + max_seq_len = 2048 * rope_scaling_factor + base = float(base) + inv_freq = 1.0 / (base**(torch.arange(0, self.config.head_dim_, 2, device="cpu", dtype=torch.float32) / + self.config.head_dim_)) + t = torch.arange(max_seq_len + 1024 * 64, device="cpu", dtype=torch.float32) / rope_scaling_factor + freqs = torch.outer(t, inv_freq) + + self._cos_cached = torch.cos(freqs).to(torch.float16).cuda() + self._sin_cached = torch.sin(freqs).to(torch.float16).cuda() + return + + +def print_perf_stats(latency_set, config, warmup=3): + # trim warmup queries + latency_set = list(latency_set) + latency_set = latency_set[warmup:] + count = len(latency_set) + + if count > 0: + latency_set.sort() + avg = sum(latency_set) / count + num_layers = getattr(config, "num_layers", config.num_hidden_layers) + num_parameters = num_layers * config.hidden_size * config.hidden_size * 12 + num_bytes = 2 + + print("Avg Per Token Latency: {0:8.2f} ms".format(avg * 1000)) + print("Avg BW: {0:8.2f} GB/s".format(1 / avg * num_parameters * num_bytes / 1e9)) + print("Avg flops: {0:8.2f} TFlops/s".format(1 / avg * num_parameters * num_bytes * BATCH_SIZE / 1e12)) + + +@parameterize('test_config', [{ + 'tp_size': TPSIZE, +}]) +def run_llama_test(test_config): + + llama_model_path = "/data/scratch/llama-7b-hf" + tokenizer = LlamaTokenizer.from_pretrained(llama_model_path) + tokenizer.pad_token_id = tokenizer.unk_token_id + model = LlamaForCausalLM.from_pretrained(llama_model_path, pad_token_id=tokenizer.eos_token_id) + init_to_get_rotary(model.model, base=10000) + model = model.half() + input = "" + for i in range(1022): + input += "a " + model_config = model.config + text = [input] + + batch_size = 32 + input_ids = tokenizer.batch_encode_plus(text, return_tensors='pt', padding=True, device='cuda') + + print("input ids ", input_ids) + infer_engine = TPInferEngine(model.half(), BATCH_SIZE, MAX_INPUT_LEN, MAX_OUTPUT_LEN) + shard_config = ShardConfig(enable_tensor_parallelism=False, inference_only=True) + shardformer = ShardFormer(shard_config=shard_config) + + infer_engine.prepare_with_shard_config(shard_config) + infer_engine.shard_model_by(shardformer) + print("input.shape: ", input_ids["input_ids"].shape) + + max_new_tokens = 128 + generate_kwargs = dict(max_new_tokens=max_new_tokens, do_sample=False) + outputs = infer_engine.generate(input_ids, generate_kwargs) + infer_engine.cache_manager.free_all() + print("outputs.shape: ", outputs.shape) + + input_len = 1024 + input_tokens = { + "input_ids": torch.randint(1, 1000, (batch_size, input_len), device='cuda'), + "attention_mask": torch.ones((batch_size, input_len), device='cuda') + } + input_tokens["input_ids"][:, 0] = 0 + # print("inputs ", input_tokens) + # print(input_tokens["input_ids"].shape) + input_len = input_tokens["input_ids"].shape[1] + + iters = 10 + + times = [] + + for i in range(iters): + torch.cuda.synchronize() + start = time.time() + outputs = infer_engine.generate(input_tokens, generate_kwargs) + torch.cuda.synchronize() + end = time.time() + out_len = outputs.shape[1] + print("generation time {} s".format(str(end - start))) + times.append((end - start) / (out_len - input_len)) + infer_engine.cache_manager.free_all() + + print("outputs, ", len(outputs)) + outputs = tokenizer.batch_decode(outputs) + + print_perf_stats(times, model_config) + + with profile(activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA], record_shapes=True) as prof: + with record_function("model_inference"): + torch.cuda.synchronize() + outputs = infer_engine.generate(input_tokens, generate_kwargs) + torch.cuda.synchronize() + print(prof.key_averages().table(sort_by="cuda_time_total", row_limit=10)) + + +def check_llama(rank, world_size, port): + disable_existing_loggers() + colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') + run_llama_test() + + +@pytest.mark.dist +@rerun_if_address_is_in_use() +@clear_cache_before_run() +def test_llama(): + spawn(check_llama, TPSIZE) + + +if __name__ == "__main__": + test_llama() From 117cdf17b32d1e47a62ce1fe81a501ea681411ae Mon Sep 17 00:00:00 2001 From: Xu Kai Date: Tue, 5 Sep 2023 16:40:13 +0800 Subject: [PATCH 6/7] refactor codes --- .../tensor_parallel/modeling/llama.py | 27 ++++++++----------- examples/inference/bench_llama.py | 23 +++------------- 2 files changed, 15 insertions(+), 35 deletions(-) diff --git a/colossalai/inference/tensor_parallel/modeling/llama.py b/colossalai/inference/tensor_parallel/modeling/llama.py index ab18d4a97ab7..219cd1ae0d0e 100644 --- a/colossalai/inference/tensor_parallel/modeling/llama.py +++ b/colossalai/inference/tensor_parallel/modeling/llama.py @@ -125,22 +125,6 @@ def llama_model_forward( # infer_state.decode_key_buffer = torch.empty((batch_size, self.tp_head_num_, self.head_dim_), dtype=torch.float16, device="cuda") # infer_state.decode_value_buffer = torch.empty((batch_size, self.tp_head_num_, self.head_dim_), dtype=torch.float16, device="cuda") infer_state.block_loc[:, seq_length_with_past - 1] = infer_state.decode_mem_index - - if infer_state.is_context_stage: - b_seq_len_numpy = infer_state.seq_len.cpu().numpy() - position_ids = torch.from_numpy( - np.concatenate([np.arange(0, b_seq_len_numpy[i]) for i in range(len(b_seq_len_numpy))], axis=0)).cuda() - - infer_state.position_cos = torch.index_select(self._cos_cached, 0, - position_ids).view(position_ids.shape[0], -1) - infer_state.position_sin = torch.index_select(self._sin_cached, 0, - position_ids).view(position_ids.shape[0], -1) - position_ids = None - else: - seq_len = infer_state.seq_len - infer_state.position_cos = torch.index_select(self._cos_cached, 0, seq_len - 1).view(seq_len.shape[0], -1) - infer_state.position_sin = torch.index_select(self._sin_cached, 0, seq_len - 1).view(seq_len.shape[0], -1) - if position_ids is None: device = input_ids.device if input_ids is not None else inputs_embeds.device position_ids = torch.arange(past_key_values_length, @@ -151,6 +135,17 @@ def llama_model_forward( else: position_ids = position_ids.view(-1, seq_length).long() + if infer_state.is_context_stage: + + infer_state.position_cos = torch.index_select(self._cos_cached, 0, position_ids.view(-1)).view( + position_ids.view(-1).shape[0], -1) + infer_state.position_sin = torch.index_select(self._sin_cached, 0, position_ids.view(-1)).view( + position_ids.view(-1).shape[0], -1) + else: + seq_len = infer_state.seq_len + infer_state.position_cos = torch.index_select(self._cos_cached, 0, seq_len - 1).view(seq_len.shape[0], -1) + infer_state.position_sin = torch.index_select(self._sin_cached, 0, seq_len - 1).view(seq_len.shape[0], -1) + if inputs_embeds is None: inputs_embeds = self.embed_tokens(input_ids) diff --git a/examples/inference/bench_llama.py b/examples/inference/bench_llama.py index 7bb1d11259d7..1aabd340aedd 100644 --- a/examples/inference/bench_llama.py +++ b/examples/inference/bench_llama.py @@ -74,42 +74,27 @@ def run_llama_test(test_config): model = LlamaForCausalLM.from_pretrained(llama_model_path, pad_token_id=tokenizer.eos_token_id) init_to_get_rotary(model.model, base=10000) model = model.half() - input = "" - for i in range(1022): - input += "a " - model_config = model.config - text = [input] - batch_size = 32 - input_ids = tokenizer.batch_encode_plus(text, return_tensors='pt', padding=True, device='cuda') + model_config = model.config - print("input ids ", input_ids) infer_engine = TPInferEngine(model.half(), BATCH_SIZE, MAX_INPUT_LEN, MAX_OUTPUT_LEN) shard_config = ShardConfig(enable_tensor_parallelism=False, inference_only=True) shardformer = ShardFormer(shard_config=shard_config) infer_engine.prepare_with_shard_config(shard_config) infer_engine.shard_model_by(shardformer) - print("input.shape: ", input_ids["input_ids"].shape) + batch_size = 2 max_new_tokens = 128 - generate_kwargs = dict(max_new_tokens=max_new_tokens, do_sample=False) - outputs = infer_engine.generate(input_ids, generate_kwargs) - infer_engine.cache_manager.free_all() - print("outputs.shape: ", outputs.shape) - input_len = 1024 + + generate_kwargs = dict(max_new_tokens=max_new_tokens, do_sample=False) input_tokens = { "input_ids": torch.randint(1, 1000, (batch_size, input_len), device='cuda'), "attention_mask": torch.ones((batch_size, input_len), device='cuda') } - input_tokens["input_ids"][:, 0] = 0 - # print("inputs ", input_tokens) - # print(input_tokens["input_ids"].shape) - input_len = input_tokens["input_ids"].shape[1] iters = 10 - times = [] for i in range(iters): From 2af8002dfa93696ed9eda8f793327b9e17513369 Mon Sep 17 00:00:00 2001 From: Xu Kai Date: Tue, 5 Sep 2023 16:57:51 +0800 Subject: [PATCH 7/7] delete useless code --- tests/test_infer_ops/triton/test_rotary_embedding.py | 5 ----- 1 file changed, 5 deletions(-) diff --git a/tests/test_infer_ops/triton/test_rotary_embedding.py b/tests/test_infer_ops/triton/test_rotary_embedding.py index 3a82a6603318..4413dba642b8 100644 --- a/tests/test_infer_ops/triton/test_rotary_embedding.py +++ b/tests/test_infer_ops/triton/test_rotary_embedding.py @@ -50,8 +50,3 @@ def test_rotary_emb(): # print("max delta:", torch.max(torch.abs(y_torch - y_triton))) # compare assert torch.allclose(y_torch, y_triton, atol=1e-2, rtol=1e-2) - - # triton_latency = benchmark(rotary_embedding_fwd, x, cos, sin) - # torch_latency = benchmark(torch_rotary_emb, x, cos, sin) - # print("triton kernel latency:{:.6f} ms".format(triton_latency)) - # print("torch kernel latency:{:.6f} ms".format(torch_latency))