diff --git a/colossalai/inference/tensor_parallel/modeling/llama.py b/colossalai/inference/tensor_parallel/modeling/llama.py index df1b99769d3e..d55634a6f00b 100644 --- a/colossalai/inference/tensor_parallel/modeling/llama.py +++ b/colossalai/inference/tensor_parallel/modeling/llama.py @@ -81,7 +81,7 @@ def llama_model_forward( infer_state.context_mem_index = infer_state.cache_manager.alloc(infer_state.total_token_num) infer_state.init_block_loc(infer_state.block_loc, infer_state.seq_len, seq_length, infer_state.context_mem_index) else: - # TODO handle the condition that no contiguous memory presents + infer_state.is_context_stage = False alloc_mem = infer_state.cache_manager.alloc_contiguous(batch_size) if alloc_mem is not None: infer_state.decode_is_contiguous = True diff --git a/tests/test_infer/test_llama_infer.py b/tests/test_infer/test_llama_infer.py index 89646ca9f97f..55576e55fd2d 100644 --- a/tests/test_infer/test_llama_infer.py +++ b/tests/test_infer/test_llama_infer.py @@ -15,6 +15,9 @@ os.environ['TRANSFORMERS_NO_ADVISORY_WARNINGS'] = 'true' TPSIZE = 2 +BATCH_SIZE = 8 +MAX_INPUT_LEN = 12 +MAX_OUTPUT_LEN = 100 def init_to_get_rotary(self, base=10000): self.config.head_dim_ = self.config.hidden_size // self.config.num_attention_heads @@ -48,21 +51,20 @@ def run_llama_test(test_config): model = LlamaForCausalLM.from_pretrained(llama_model_path, pad_token_id=tokenizer.eos_token_id) init_to_get_rotary(model.model, base=10000) model = model.half() - model.to(torch.cuda.current_device()) - text = "Introduce some landmarks in Beijing" - input_ids = tokenizer.encode(text, return_tensors='pt') - # pg_mesh = ProcessGroupMesh(1, 1, test_config["tp_size"]) + text = "how is weather today?" + input_ids = tokenizer.encode(text, return_tensors='pt', device='cuda') - infer_engine = TPInferEngine(model.half(), 4, 12, 8) + infer_engine = TPInferEngine(model.half(), BATCH_SIZE, MAX_INPUT_LEN, MAX_OUTPUT_LEN) shard_config = ShardConfig(enable_tensor_parallelism=True, inference_only=True) shardformer = ShardFormer(shard_config=shard_config) infer_engine.prepare_with_shard_config(shard_config) infer_engine.shard_model_by(shardformer) - generate_kwargs = dict(do_sample=False) + generate_kwargs = dict(max_new_tokens=MAX_OUTPUT_LEN, do_sample=False) outputs = infer_engine.generate(input_ids, generate_kwargs) + print("outputs.shape: ", outputs.shape) print("outputs: ", outputs)