diff --git a/colossalai/inference/tensor_parallel/modeling/bloom.py b/colossalai/inference/tensor_parallel/modeling/bloom.py index 0fd08d3721e6..a6ee58f1e00d 100644 --- a/colossalai/inference/tensor_parallel/modeling/bloom.py +++ b/colossalai/inference/tensor_parallel/modeling/bloom.py @@ -197,7 +197,7 @@ def bloom_model_forward( all_hidden_states = all_hidden_states + (hidden_states,) if self.gradient_checkpointing and self.training: - # FIXME: currently our KV cache manager does not handle this condition + # NOTE: currently our KV cache manager does not handle this condition def create_custom_forward(module): def custom_forward(*inputs): @@ -240,7 +240,8 @@ def custom_forward(*inputs): all_hidden_states = all_hidden_states + (hidden_states,) # update indices of kv cache block - # TODO: might want to remove this part, instead, better to pass the BatchInferState from model forward, + # NOT READY FOR PRIME TIME + # might want to remove this part, instead, better to pass the BatchInferState from model forward, # and update these information in engine.generate after model foward called infer_state.start_loc = infer_state.start_loc + torch.arange(0, batch_size, dtype=torch.int32, device="cuda") infer_state.seq_len += 1 diff --git a/colossalai/inference/tensor_parallel/modeling/llama.py b/colossalai/inference/tensor_parallel/modeling/llama.py index 1d9e366f99f3..94a13b968d0d 100644 --- a/colossalai/inference/tensor_parallel/modeling/llama.py +++ b/colossalai/inference/tensor_parallel/modeling/llama.py @@ -77,12 +77,13 @@ def llama_model_forward( past_key_values_length = 0 if past_key_values is not None: - # TODO dummy but work, revise it + # NOT READY FOR PRIME TIME + # dummy but work, revise it past_key_values_length = infer_state.cache_manager.past_key_values_length # past_key_values_length = past_key_values[0][0].shape[2] seq_length_with_past = seq_length_with_past + past_key_values_length - # FIXME: differentiate with prefill stage + # NOTE: differentiate with prefill stage # block_loc require different value-assigning method for two different stage if use_cache and seq_length != 1: # NOTE assuem prefill stage diff --git a/colossalai/inference/tensor_parallel/policies/llama.py b/colossalai/inference/tensor_parallel/policies/llama.py index c569a0e3163a..bbd2156b8523 100644 --- a/colossalai/inference/tensor_parallel/policies/llama.py +++ b/colossalai/inference/tensor_parallel/policies/llama.py @@ -1,10 +1,10 @@ from functools import partial + from transformers.models.llama.modeling_llama import LlamaAttention, LlamaDecoderLayer, LlamaModel, LlamaRMSNorm from colossalai.shardformer.policies.llama import LlamaForCausalLMPolicy -from ..modeling.llama import LlamaInferenceForwards -from ..modeling.llama import get_llama_vllm_rmsnorm_forward +from ..modeling.llama import LlamaInferenceForwards, get_llama_vllm_rmsnorm_forward class LlamaModelInferPolicy(LlamaForCausalLMPolicy): @@ -37,8 +37,8 @@ def module_policy(self): self.append_or_create_method_replacement(description=method_replacement, policy=policy, target_key=LlamaAttention) - - # TODO: adding rms_norm caused precision issue, fix @tiandiao123 + + # NOTE: adding rms_norm caused precision issue, fix @tiandiao123 # infer_forward = get_llama_vllm_rmsnorm_forward() # if infer_forward is not None: # method_replacement = {'forward': partial(infer_forward)} diff --git a/colossalai/kernel/triton/self_attention_nofusion.py b/colossalai/kernel/triton/self_attention_nofusion.py index a6c9bdfbdff6..6ae54dcb0b38 100644 --- a/colossalai/kernel/triton/self_attention_nofusion.py +++ b/colossalai/kernel/triton/self_attention_nofusion.py @@ -13,8 +13,9 @@ from .qkv_matmul_kernel import qkv_gemm_4d_kernel from .softmax import softmax_kernel - def self_attention_forward_without_fusion(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, input_mask: torch.Tensor, scale: float): - r""" A function to do QKV Attention calculation by calling GEMM and softmax triton kernels + def self_attention_forward_without_fusion(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, + input_mask: torch.Tensor, scale: float): + r""" A function to do QKV Attention calculation by calling GEMM and softmax triton kernels Args: q (torch.Tensor): Q embedding in attention layer, shape should be (batch, seq_len, num_heads, head_size) k (torch.Tensor): K embedding in attention layer, shape should be (batch, seq_len, num_heads, head_size) @@ -36,39 +37,49 @@ def self_attention_forward_without_fusion(q: torch.Tensor, k: torch.Tensor, v: t # head_size * num_of_head d_model = q.shape[-1] * q.shape[-2] - score_output = torch.empty( - (batches, H, M, N), device=q.device, dtype=q.dtype) + score_output = torch.empty((batches, H, M, N), device=q.device, dtype=q.dtype) grid = lambda meta: ( batches, H, - triton.cdiv(M, meta["BLOCK_SIZE_M"]) * - triton.cdiv(N, meta["BLOCK_SIZE_N"]), + triton.cdiv(M, meta["BLOCK_SIZE_M"]) * triton.cdiv(N, meta["BLOCK_SIZE_N"]), ) qkv_gemm_4d_kernel[grid]( - q, k, score_output, - M, N, K, - q.stride(0), q.stride(2), q.stride(1), q.stride(3), - k.stride(0), k.stride(2), k.stride(3), k.stride(1), - score_output.stride(0), score_output.stride(1), score_output.stride(2), score_output.stride(3), + q, + k, + score_output, + M, + N, + K, + q.stride(0), + q.stride(2), + q.stride(1), + q.stride(3), + k.stride(0), + k.stride(2), + k.stride(3), + k.stride(1), + score_output.stride(0), + score_output.stride(1), + score_output.stride(2), + score_output.stride(3), scale=scale, - # currently manually setting, later on we can use auto-tune config to match best setting + # currently manually setting, later on we can use auto-tune config to match best setting BLOCK_SIZE_M=64, BLOCK_SIZE_N=32, BLOCK_SIZE_K=32, GROUP_SIZE_M=8, ) - - softmax_output = torch.empty( - score_output.shape, device=score_output.device, dtype=score_output.dtype) + + softmax_output = torch.empty(score_output.shape, device=score_output.device, dtype=score_output.dtype) score_output_shape = score_output.shape score_output = score_output.view(-1, score_output.shape[-1]) n_rows, n_cols = score_output.shape if n_rows <= 350000: - + block_size = max(triton.next_power_of_2(n_cols), 2) num_warps = 4 if block_size >= 4096: @@ -78,37 +89,39 @@ def self_attention_forward_without_fusion(q: torch.Tensor, k: torch.Tensor, v: t else: num_warps = 4 - softmax_kernel[(n_rows, )]( + softmax_kernel[(n_rows,)]( softmax_output, score_output, score_output.stride(0), n_cols, - mask_ptr = input_mask, + mask_ptr=input_mask, num_warps=num_warps, BLOCK_SIZE=block_size, ) else: - #TODO: change softmax kernel functions to make it suitable for large size dimension + # NOTE: change softmax kernel functions to make it suitable for large size dimension softmax_output = torch.nn.functional.softmax(score_output, dim=-1) softmax_output = softmax_output.view(*score_output_shape) batches, H, M, K = softmax_output.shape N = v.shape[-1] - output = torch.empty( - (batches, M, H, N), device=softmax_output.device, dtype=softmax_output.dtype) + output = torch.empty((batches, M, H, N), device=softmax_output.device, dtype=softmax_output.dtype) grid = lambda meta: ( batches, H, - triton.cdiv(M, meta["BLOCK_SIZE_M"]) * - triton.cdiv(N, meta["BLOCK_SIZE_N"]), + triton.cdiv(M, meta["BLOCK_SIZE_M"]) * triton.cdiv(N, meta["BLOCK_SIZE_N"]), ) qkv_gemm_4d_kernel[grid]( - softmax_output, v, output, - M, N, K, + softmax_output, + v, + output, + M, + N, + K, softmax_output.stride(0), softmax_output.stride(1), softmax_output.stride(2), @@ -129,7 +142,6 @@ def self_attention_forward_without_fusion(q: torch.Tensor, k: torch.Tensor, v: t ) return output.view(batches, -1, d_model) - def self_attention_compute_using_triton(qkv, input_mask, layer_past, @@ -152,7 +164,6 @@ def self_attention_compute_using_triton(qkv, k = k.view(batches, -1, num_of_heads, head_size) v = v.view(batches, -1, num_of_heads, head_size) - data_output_triton = self_attention_forward_without_fusion( - q, k, v, input_mask, scale) + data_output_triton = self_attention_forward_without_fusion(q, k, v, input_mask, scale) - return data_output_triton \ No newline at end of file + return data_output_triton diff --git a/tests/test_infer/test_llama_infer.py b/tests/test_infer/test_llama_infer.py index 55576e55fd2d..c8f852aef420 100644 --- a/tests/test_infer/test_llama_infer.py +++ b/tests/test_infer/test_llama_infer.py @@ -1,17 +1,17 @@ import os +import numpy as np import pytest import torch -import numpy as np +import torch.distributed as dist +from transformers import LlamaForCausalLM, LlamaTokenizer import colossalai -from colossalai.logging import disable_existing_loggers -from colossalai.testing import clear_cache_before_run, parameterize, rerun_if_address_is_in_use, spawn -from transformers import LlamaForCausalLM, LlamaTokenizer from colossalai.cluster import ProcessGroupMesh -from colossalai.shardformer import ShardConfig, ShardFormer from colossalai.inference.tensor_parallel.engine import TPInferEngine -import torch.distributed as dist +from colossalai.logging import disable_existing_loggers +from colossalai.shardformer import ShardConfig, ShardFormer +from colossalai.testing import clear_cache_before_run, parameterize, rerun_if_address_is_in_use, spawn os.environ['TRANSFORMERS_NO_ADVISORY_WARNINGS'] = 'true' TPSIZE = 2 @@ -19,20 +19,22 @@ MAX_INPUT_LEN = 12 MAX_OUTPUT_LEN = 100 + def init_to_get_rotary(self, base=10000): self.config.head_dim_ = self.config.hidden_size // self.config.num_attention_heads if not hasattr(self.config, "rope_scaling"): rope_scaling_factor = 1.0 else: rope_scaling_factor = self.config.rope_scaling.factor if self.config.rope_scaling is not None else 1.0 - if hasattr(self.config,"max_sequence_length"): + if hasattr(self.config, "max_sequence_length"): max_seq_len = self.config.max_sequence_length - elif hasattr(self.config,"max_position_embeddings"): + elif hasattr(self.config, "max_position_embeddings"): max_seq_len = self.config.max_position_embeddings * rope_scaling_factor else: - max_seq_len = 2048 * rope_scaling_factor + max_seq_len = 2048 * rope_scaling_factor base = float(base) - inv_freq = 1.0 / (base ** (torch.arange(0, self.config.head_dim_, 2, device="cpu", dtype=torch.float32) / self.config.head_dim_)) + inv_freq = 1.0 / (base**(torch.arange(0, self.config.head_dim_, 2, device="cpu", dtype=torch.float32) / + self.config.head_dim_)) t = torch.arange(max_seq_len + 1024 * 64, device="cpu", dtype=torch.float32) / rope_scaling_factor freqs = torch.outer(t, inv_freq) @@ -40,21 +42,22 @@ def init_to_get_rotary(self, base=10000): self._sin_cached = torch.sin(freqs).to(torch.float16).cuda() return + @parameterize('test_config', [{ 'tp_size': TPSIZE, }]) def run_llama_test(test_config): - + llama_model_path = "/data/scratch/llama-7b-hf" tokenizer = LlamaTokenizer.from_pretrained(llama_model_path) tokenizer.pad_token_id = tokenizer.unk_token_id model = LlamaForCausalLM.from_pretrained(llama_model_path, pad_token_id=tokenizer.eos_token_id) init_to_get_rotary(model.model, base=10000) model = model.half() - + text = "how is weather today?" input_ids = tokenizer.encode(text, return_tensors='pt', device='cuda') - + infer_engine = TPInferEngine(model.half(), BATCH_SIZE, MAX_INPUT_LEN, MAX_OUTPUT_LEN) shard_config = ShardConfig(enable_tensor_parallelism=True, inference_only=True) shardformer = ShardFormer(shard_config=shard_config) @@ -65,7 +68,7 @@ def run_llama_test(test_config): generate_kwargs = dict(max_new_tokens=MAX_OUTPUT_LEN, do_sample=False) outputs = infer_engine.generate(input_ids, generate_kwargs) print("outputs.shape: ", outputs.shape) - + print("outputs: ", outputs) output_text = tokenizer.decode(outputs[0])