diff --git a/colossalai/inference/tensor_parallel/engine.py b/colossalai/inference/tensor_parallel/engine.py index 2fb76d3e5e58..c6abb74f080b 100644 --- a/colossalai/inference/tensor_parallel/engine.py +++ b/colossalai/inference/tensor_parallel/engine.py @@ -36,11 +36,11 @@ def __init__(self, self.max_output_len = max_output_len self.max_total_token_num = self.max_batch_size * (self.max_input_len + self.max_output_len) - # Constraints relatable with specs of devices and model - # This may change into an optional arg in the future + # Constraints relatable with specs of devices assert self.max_batch_size <= 64, "Max batch size exceeds the constraint" - assert self.max_input_len + self.max_output_len <= 4096, "Max length exceeds the constraint" + assert self.max_input_len + self.max_output_len <= 2048, "Max length exceeds the constraint" + torch.device(device=device) self.dtype = dtype self.head_dim = self.model.config.hidden_size // self.model.config.num_attention_heads @@ -94,7 +94,7 @@ def shard_model_by(self, shardformer: ShardFormer) -> None: def _supported_models() -> List[str]: return _supported_models - def generate(self, input_tokens: Union[BatchEncoding, dict, list, torch.Tensor], **generate_kwargs) -> torch.Tensor: + def generate(self, input_tokens, generate_kwargs) -> torch.Tensor: if isinstance(input_tokens, torch.Tensor): input_tokens = dict(input_ids=input_tokens, attention_mask=torch.ones_like(input_tokens, dtype=torch.bool)) for t in input_tokens: @@ -102,12 +102,12 @@ def generate(self, input_tokens: Union[BatchEncoding, dict, list, torch.Tensor], input_tokens[t] = input_tokens[t].cuda() if self.sharded_model is not None: - return self.generate_by_set_infer_state(input_tokens, **generate_kwargs) + return self.generate_by_set_infer_state(input_tokens, generate_kwargs) return self.model.generate(**input_tokens, **generate_kwargs) @torch.no_grad() - def generate_by_set_infer_state(self, input_tokens, **generate_kwargs) -> torch.Tensor: + def generate_by_set_infer_state(self, input_tokens, generate_kwargs) -> torch.Tensor: """ Generate output tokens by setting BatchInferState as an attribute to the model and calling model.generate @@ -191,9 +191,8 @@ def prepare_batch_state(self, inputs) -> BatchInferState: start_index += curr_seq_len max_len_in_batch = curr_seq_len if curr_seq_len > max_len_in_batch else max_len_in_batch else: - length = max(len(input_id) for input_id in input_ids_list) for i, input_ids in enumerate(input_ids_list): - curr_seq_len = length + curr_seq_len = len(input_ids) seq_lengths[i] = curr_seq_len seq_start_indexes[i] = start_index start_index += curr_seq_len diff --git a/colossalai/inference/tensor_parallel/kvcache_manager.py b/colossalai/inference/tensor_parallel/kvcache_manager.py index 274c01841279..2ddb6c5cdb35 100644 --- a/colossalai/inference/tensor_parallel/kvcache_manager.py +++ b/colossalai/inference/tensor_parallel/kvcache_manager.py @@ -3,7 +3,8 @@ # https://github.com/ModelTC/lightllm/blob/050af3ce65edca617e2f30ec2479397d5bb248c9/lightllm/common/mem_manager.py import torch -from transformers.utils import logging + +from colossalai.logging import get_dist_logger class MemoryManager: @@ -26,7 +27,7 @@ def __init__(self, head_dim: int, layer_num: int, device: torch.device = torch.device('cuda')): - self.logger = logging.get_logger(__name__) + self.logger = get_dist_logger(__name__) self.available_size = size self.past_key_values_length = 0 self._init_mem_states(size, device) diff --git a/colossalai/shardformer/modeling/llama.py b/colossalai/shardformer/modeling/llama.py index 2a5a5b4cf64c..3730e2c1e1ec 100644 --- a/colossalai/shardformer/modeling/llama.py +++ b/colossalai/shardformer/modeling/llama.py @@ -7,9 +7,16 @@ CausalLMOutputWithPast, SequenceClassifierOutputWithPast, ) -from transformers.models.llama.modeling_llama import LlamaForCausalLM, LlamaForSequenceClassification, LlamaModel +from transformers.models.llama.modeling_llama import ( + LlamaAttention, + LlamaForCausalLM, + LlamaForSequenceClassification, + LlamaModel, + apply_rotary_pos_emb, +) from transformers.utils import logging +from colossalai.kernel.cuda_native import AttnMaskType, ColoAttention from colossalai.pipeline.stage_manager import PipelineStageManager @@ -395,10 +402,6 @@ def get_llama_flash_attention_forward(): from colossalai.kernel.cuda_native import AttnMaskType, ColoAttention - from transformers.models.llama.modeling_llama import LlamaAttention, apply_rotary_pos_emb - - from colossalai.kernel.cuda_native import AttnMaskType, ColoAttention - def forward( self: LlamaAttention, hidden_states: torch.Tensor, diff --git a/colossalai/shardformer/shard/shard_config.py b/colossalai/shardformer/shard/shard_config.py index 412f78fd21fa..35b526456b10 100644 --- a/colossalai/shardformer/shard/shard_config.py +++ b/colossalai/shardformer/shard/shard_config.py @@ -29,8 +29,6 @@ class ShardConfig: enable_flash_attention: bool = False enable_jit_fused: bool = False inference_only: bool = False - enable_sequence_parallelism: bool = False - enable_sequence_overlap: bool = False # pipeline_parallel_size: int # data_parallel_size: int @@ -60,8 +58,6 @@ def _turn_on_all_optimization(self): self.enable_fused_normalization = True self.enable_flash_attention = True self.enable_jit_fused = True - self.enable_sequence_parallelism = True - self.enable_sequence_overlap = True def _infer(self): """ diff --git a/tests/test_infer/test_infer_engine.py b/tests/test_infer/test_infer_engine.py index bc96ee137353..c9432509d941 100644 --- a/tests/test_infer/test_infer_engine.py +++ b/tests/test_infer/test_infer_engine.py @@ -1,9 +1,9 @@ +import pytest from itertools import accumulate +from packaging import version -import pytest import torch import torch.nn as nn -from packaging import version from transformers import BloomConfig, BloomForCausalLM, LlamaConfig, LlamaForCausalLM from transformers.tokenization_utils_base import BatchEncoding @@ -21,7 +21,6 @@ CUDA_SUPPORT = version.parse(torch.version.cuda) > version.parse('11.5') - @pytest.mark.skipif(not CUDA_SUPPORT, reason="kv-cache manager engine requires cuda version to be higher than 11.5") def test_prepare_data(): # dummy module used for testing @@ -55,24 +54,22 @@ def __init__(self): attention_mask[i][max_seq_len - len(li):] = [1 for _ in range(len(li))] data = dict(input_ids=input_ids_list, attention_mask=attention_mask) inputs_batch_encoding = BatchEncoding(data=data) + seq_lengths = [len(li) for li in input_ids_list] start_loc = list(accumulate([0] + seq_lengths[:-1])) seq_lengths = torch.tensor(seq_lengths, dtype=torch.int32) start_loc = torch.tensor(start_loc, dtype=torch.int32) + # input token id list as inputs batch_state_out1 = infer_engine.prepare_batch_state(inputs_batch_encoding) # BatchEncoding as inputs batch_state_out2 = infer_engine.prepare_batch_state(input_ids_list) assert batch_state_out1.batch_size == batch_state_out2.batch_size == batch_size - assert torch.equal(batch_state_out1.seq_len, batch_state_out2.seq_len) - - # The following tests are discarded for now, and will be reused after all features are added - # assert torch.equal(batch_state_out1.seq_len.to(seq_lengths.device), seq_lengths) - # assert torch.equal(batch_state_out2.seq_len.to(seq_lengths.device), seq_lengths) - # assert torch.equal(batch_state_out1.start_loc.to(start_loc.device), start_loc) - # assert torch.equal(batch_state_out2.start_loc.to(start_loc.device), start_loc) - + assert torch.equal(batch_state_out1.seq_len.to(seq_lengths.device), seq_lengths) + assert torch.equal(batch_state_out2.seq_len.to(seq_lengths.device), seq_lengths) + assert torch.equal(batch_state_out1.start_loc.to(start_loc.device), start_loc) + assert torch.equal(batch_state_out2.start_loc.to(start_loc.device), start_loc) @pytest.mark.skipif(not CUDA_SUPPORT, reason="kv-cache manager engine requires cuda version to be higher than 11.5") def test_orig_generate(): @@ -91,7 +88,7 @@ def test_orig_generate(): # original model generate generate_kwargs = dict(do_sample=False) - infer_engine.generate(input_ids, **generate_kwargs) + infer_engine.generate(input_ids, generate_kwargs) torch.cuda.empty_cache()