From 6e501f2d7b738f7f78545ab564d2e1deeddd5642 Mon Sep 17 00:00:00 2001 From: Xu Kai Date: Thu, 28 Sep 2023 16:36:25 +0800 Subject: [PATCH 1/3] fix test bug --- .../modeling/chatglm2_6b/modeling_chatglm.py | 2 +- examples/inference/bench_llama.py | 2 +- tests/test_infer/test_bloom_infer.py | 29 +++++---- tests/test_infer/test_chatglm2_infer.py | 63 +++++++++++++------ tests/test_infer/test_llama_infer.py | 37 ++++++----- 5 files changed, 82 insertions(+), 51 deletions(-) diff --git a/colossalai/shardformer/modeling/chatglm2_6b/modeling_chatglm.py b/colossalai/shardformer/modeling/chatglm2_6b/modeling_chatglm.py index cbb25b5b1f4c..fdd49ecfeae5 100644 --- a/colossalai/shardformer/modeling/chatglm2_6b/modeling_chatglm.py +++ b/colossalai/shardformer/modeling/chatglm2_6b/modeling_chatglm.py @@ -873,7 +873,7 @@ def __init__(self, config: ChatGLMConfig, device=None, empty_init=True): self.rotary_pos_emb = RotaryEmbedding( rotary_dim // 2, - original_impl=config.original_rope, + # original_impl=config.original_rope, # config has no attribute original_rope device=device, dtype=config.torch_dtype, ) diff --git a/examples/inference/bench_llama.py b/examples/inference/bench_llama.py index 6e49fa80c812..8eb22a605347 100644 --- a/examples/inference/bench_llama.py +++ b/examples/inference/bench_llama.py @@ -68,7 +68,7 @@ def run_llama_test(args): model = LlamaForCausalLM.from_pretrained(llama_model_path, pad_token_id=tokenizer.eos_token_id) init_to_get_rotary(model.model, base=10000) model = model.half() - + print("model config:", model.config) model_config = model.config shard_config = ShardConfig(enable_tensor_parallelism=True if args.tp_size > 1 else False, inference_only=True) diff --git a/tests/test_infer/test_bloom_infer.py b/tests/test_infer/test_bloom_infer.py index 5a5d341fc6ba..ba978ad9bf0d 100644 --- a/tests/test_infer/test_bloom_infer.py +++ b/tests/test_infer/test_bloom_infer.py @@ -1,13 +1,14 @@ import pytest import torch from packaging import version +from transformers import BloomForCausalLM +from transformers.models.bloom.configuration_bloom import BloomConfig import colossalai from colossalai.inference.tensor_parallel import TPInferEngine from colossalai.logging import disable_existing_loggers from colossalai.shardformer import ShardConfig from colossalai.testing import clear_cache_before_run, parameterize, rerun_if_address_is_in_use, spawn -from tests.kit.model_zoo import model_zoo TP_SIZE = 2 MAX_BATCH_SIZE = 4 @@ -26,21 +27,23 @@ ], ) def run(test_config): - sub_model_zoo = model_zoo.get_sub_registry("transformers_bloom_for_causal_lm") - for name, (model_fn, data_gen_fn, _, _, _) in sub_model_zoo.items(): - orig_model = model_fn() - orig_model = orig_model.half() - data = data_gen_fn() + bloom_config = BloomConfig(num_hidden_layers=2, bos_token_id=0, eos_token_id=1, vocab_size=1200, hidden_size=1024) + model = BloomForCausalLM(bloom_config) + model = model.half() - shard_config = ShardConfig( - enable_tensor_parallelism=True if test_config["tp_size"] > 1 else False, inference_only=True - ) - infer_engine = TPInferEngine(orig_model, shard_config, MAX_BATCH_SIZE, MAX_INPUT_LEN, MAX_OUTPUT_LEN) + shard_config = ShardConfig( + enable_tensor_parallelism=True if test_config["tp_size"] > 1 else False, inference_only=True + ) + infer_engine = TPInferEngine(model, shard_config, MAX_BATCH_SIZE, MAX_INPUT_LEN, MAX_OUTPUT_LEN) + generate_kwargs = dict(max_new_tokens=MAX_OUTPUT_LEN, do_sample=False) - generate_kwargs = dict(do_sample=False) - outputs = infer_engine.generate(data, **generate_kwargs) + input_tokens = { + "input_ids": torch.randint(1, 1000, (MAX_BATCH_SIZE, MAX_INPUT_LEN), device="cuda"), + "attention_mask": torch.ones((MAX_BATCH_SIZE, MAX_INPUT_LEN), device="cuda"), + } + outputs = infer_engine.generate(input_tokens, **generate_kwargs) - assert outputs is not None + assert outputs is not None def check_bloom(rank, world_size, port): diff --git a/tests/test_infer/test_chatglm2_infer.py b/tests/test_infer/test_chatglm2_infer.py index 699ba7b52fe0..afafe64427ad 100644 --- a/tests/test_infer/test_chatglm2_infer.py +++ b/tests/test_infer/test_chatglm2_infer.py @@ -2,17 +2,15 @@ import pytest import torch -import torch.distributed as dist from packaging import version -from transformers import AutoTokenizer import colossalai from colossalai.inference.tensor_parallel.engine import TPInferEngine from colossalai.logging import disable_existing_loggers from colossalai.shardformer import ShardConfig +from colossalai.shardformer.modeling.chatglm2_6b.configuration_chatglm import ChatGLMConfig from colossalai.shardformer.modeling.chatglm2_6b.modeling_chatglm import ChatGLMForConditionalGeneration from colossalai.testing import clear_cache_before_run, parameterize, rerun_if_address_is_in_use, spawn -from tests.kit.model_zoo.transformers.chatglm2 import infer_config os.environ["TRANSFORMERS_NO_ADVISORY_WARNINGS"] = "true" TPSIZE = 1 @@ -31,28 +29,55 @@ ], ) def run_chatglm2_test(test_config): - tokenizer = AutoTokenizer.from_pretrained("THUDM/chatglm2-6b", trust_remote_code=True) - # pad_token_id = 0 - model_fn = lambda: ChatGLMForConditionalGeneration(infer_config, empty_init=False) - orig_model = model_fn() - orig_model = orig_model.half() - text = ["how is the weather today?"] - input_ids = tokenizer.batch_encode_plus(text, return_tensors="pt", padding=True) + chagglm_config = ChatGLMConfig( + num_layers=2, + vocab_size=1200, + use_cache=True, + multi_query_attention=True, + multi_query_group_num=2, + num_attention_heads=8, + hidden_size=1024, + ) + model = ChatGLMForConditionalGeneration(chagglm_config) + model = model.half() + shard_config = ShardConfig( enable_tensor_parallelism=True if test_config["tp_size"] > 1 else False, inference_only=True ) - infer_engine = TPInferEngine(orig_model, shard_config, BATCH_SIZE, MAX_INPUT_LEN, MAX_OUTPUT_LEN) - + infer_engine = TPInferEngine(model, shard_config, BATCH_SIZE, MAX_INPUT_LEN, MAX_OUTPUT_LEN) generate_kwargs = dict(max_new_tokens=MAX_OUTPUT_LEN, do_sample=False) - outputs = infer_engine.generate(input_ids, **generate_kwargs) + + input_tokens = { + "input_ids": torch.randint(1, 1000, (BATCH_SIZE, MAX_INPUT_LEN), device="cuda"), + "attention_mask": torch.ones((BATCH_SIZE, MAX_INPUT_LEN), device="cuda"), + } + outputs = infer_engine.generate(input_tokens, **generate_kwargs) + assert outputs is not None - # print("outputs.shape: ", outputs[0].shape) - # print("outputs: ", outputs[0]) - if not dist.is_initialized() or dist.get_rank() == 0: - for o in outputs: - output_text = tokenizer.decode(o) - print(output_text) + # tokenizer = AutoTokenizer.from_pretrained("THUDM/chatglm2-6b", trust_remote_code=True) + # # pad_token_id = 0 + # model_fn = lambda: ChatGLMForConditionalGeneration(infer_config, empty_init=False) + # orig_model = model_fn() + # orig_model = orig_model.half() + + # text = ["how is the weather today?"] + # input_ids = tokenizer.batch_encode_plus(text, return_tensors="pt", padding=True) + # shard_config = ShardConfig( + # enable_tensor_parallelism=True if test_config["tp_size"] > 1 else False, inference_only=True + # ) + # infer_engine = TPInferEngine(orig_model, shard_config, BATCH_SIZE, MAX_INPUT_LEN, MAX_OUTPUT_LEN) + + # generate_kwargs = dict(max_new_tokens=MAX_OUTPUT_LEN, do_sample=False) + # outputs = infer_engine.generate(input_ids, **generate_kwargs) + # assert outputs is not None + + # # print("outputs.shape: ", outputs[0].shape) + # # print("outputs: ", outputs[0]) + # if not dist.is_initialized() or dist.get_rank() == 0: + # for o in outputs: + # output_text = tokenizer.decode(o) + # print(output_text) def check_chatglm2(rank, world_size, port): diff --git a/tests/test_infer/test_llama_infer.py b/tests/test_infer/test_llama_infer.py index 0e5efe68508a..81405fcee51d 100644 --- a/tests/test_infer/test_llama_infer.py +++ b/tests/test_infer/test_llama_infer.py @@ -3,13 +3,14 @@ import pytest import torch from packaging import version +from transformers import LlamaForCausalLM +from transformers.models.llama.configuration_llama import LlamaConfig import colossalai from colossalai.inference.tensor_parallel.engine import TPInferEngine from colossalai.logging import disable_existing_loggers from colossalai.shardformer import ShardConfig from colossalai.testing import clear_cache_before_run, parameterize, rerun_if_address_is_in_use, spawn -from tests.kit.model_zoo import model_zoo os.environ["TRANSFORMERS_NO_ADVISORY_WARNINGS"] = "true" TPSIZE = 2 @@ -53,22 +54,24 @@ def init_to_get_rotary(self, base=10000): ], ) def run_llama_test(test_config): - sub_model_zoo = model_zoo.get_sub_registry("transformers_llama_for_casual_lm") - for name, (model_fn, data_gen_fn, _, _, _) in sub_model_zoo.items(): - orig_model = model_fn() - init_to_get_rotary(orig_model.model, base=10000) - orig_model = orig_model.half() - data = data_gen_fn() - - shard_config = ShardConfig( - enable_tensor_parallelism=True if test_config["tp_size"] > 1 else False, inference_only=True - ) - infer_engine = TPInferEngine(orig_model, shard_config, BATCH_SIZE, MAX_INPUT_LEN, MAX_OUTPUT_LEN) - - generate_kwargs = dict(do_sample=False) - outputs = infer_engine.generate(data, **generate_kwargs) - - assert outputs is not None + llama_config = LlamaConfig(num_hidden_layers=2, bos_token_id=0, eos_token_id=1, vocab_size=1200, hidden_size=1024) + model = LlamaForCausalLM(llama_config) + model = model.half() + + shard_config = ShardConfig( + enable_tensor_parallelism=True if test_config["tp_size"] > 1 else False, inference_only=True + ) + infer_engine = TPInferEngine(model, shard_config, BATCH_SIZE, MAX_INPUT_LEN, MAX_OUTPUT_LEN) + init_to_get_rotary(model.model, base=10000) + generate_kwargs = dict(max_new_tokens=MAX_OUTPUT_LEN, do_sample=False) + + input_tokens = { + "input_ids": torch.randint(1, 1000, (BATCH_SIZE, MAX_INPUT_LEN), device="cuda"), + "attention_mask": torch.ones((BATCH_SIZE, MAX_INPUT_LEN), device="cuda"), + } + outputs = infer_engine.generate(input_tokens, **generate_kwargs) + + assert outputs is not None def check_llama(rank, world_size, port): From 74b048fff8dc83d2279dfd9c67b633a1ad2aa809 Mon Sep 17 00:00:00 2001 From: Xu Kai Date: Thu, 28 Sep 2023 16:41:41 +0800 Subject: [PATCH 2/3] delete useless code --- examples/inference/bench_llama.py | 1 - tests/test_infer/test_chatglm2_infer.py | 24 ------------------- .../triton/test_llama2_token_attn.py | 4 +--- 3 files changed, 1 insertion(+), 28 deletions(-) diff --git a/examples/inference/bench_llama.py b/examples/inference/bench_llama.py index 8eb22a605347..0d280bc71cb9 100644 --- a/examples/inference/bench_llama.py +++ b/examples/inference/bench_llama.py @@ -68,7 +68,6 @@ def run_llama_test(args): model = LlamaForCausalLM.from_pretrained(llama_model_path, pad_token_id=tokenizer.eos_token_id) init_to_get_rotary(model.model, base=10000) model = model.half() - print("model config:", model.config) model_config = model.config shard_config = ShardConfig(enable_tensor_parallelism=True if args.tp_size > 1 else False, inference_only=True) diff --git a/tests/test_infer/test_chatglm2_infer.py b/tests/test_infer/test_chatglm2_infer.py index afafe64427ad..28c029fde16c 100644 --- a/tests/test_infer/test_chatglm2_infer.py +++ b/tests/test_infer/test_chatglm2_infer.py @@ -55,30 +55,6 @@ def run_chatglm2_test(test_config): assert outputs is not None - # tokenizer = AutoTokenizer.from_pretrained("THUDM/chatglm2-6b", trust_remote_code=True) - # # pad_token_id = 0 - # model_fn = lambda: ChatGLMForConditionalGeneration(infer_config, empty_init=False) - # orig_model = model_fn() - # orig_model = orig_model.half() - - # text = ["how is the weather today?"] - # input_ids = tokenizer.batch_encode_plus(text, return_tensors="pt", padding=True) - # shard_config = ShardConfig( - # enable_tensor_parallelism=True if test_config["tp_size"] > 1 else False, inference_only=True - # ) - # infer_engine = TPInferEngine(orig_model, shard_config, BATCH_SIZE, MAX_INPUT_LEN, MAX_OUTPUT_LEN) - - # generate_kwargs = dict(max_new_tokens=MAX_OUTPUT_LEN, do_sample=False) - # outputs = infer_engine.generate(input_ids, **generate_kwargs) - # assert outputs is not None - - # # print("outputs.shape: ", outputs[0].shape) - # # print("outputs: ", outputs[0]) - # if not dist.is_initialized() or dist.get_rank() == 0: - # for o in outputs: - # output_text = tokenizer.decode(o) - # print(output_text) - def check_chatglm2(rank, world_size, port): disable_existing_loggers() diff --git a/tests/test_infer_ops/triton/test_llama2_token_attn.py b/tests/test_infer_ops/triton/test_llama2_token_attn.py index c22f70211d4f..0537a3d76129 100644 --- a/tests/test_infer_ops/triton/test_llama2_token_attn.py +++ b/tests/test_infer_ops/triton/test_llama2_token_attn.py @@ -38,9 +38,7 @@ def test(): q = torch.empty((Z, head_num, head_dim), dtype=dtype, device="cuda").normal_(mean=0.1, std=0.2) k = torch.empty((Z * seq_len, head_num, head_dim), dtype=dtype, device="cuda").normal_(mean=0.4, std=0.2) v = torch.empty((Z * seq_len, head_num, head_dim), dtype=dtype, device="cuda").normal_(mean=0.3, std=0.2) - o = torch.empty_like() - # o = torch.empty((Z, head_num, head_dim), dtype=dtype, device="cuda").normal_(mean=0.3, std=0.2) - + o = torch.empty((Z, head_num, head_dim), dtype=dtype, device="cuda") max_kv_cache_len = seq_len kv_cache_start_loc = torch.zeros((Z,), dtype=torch.int32, device="cuda") kv_cache_loc = torch.zeros((Z, seq_len), dtype=torch.int32, device="cuda") From 1f57aa8c50e6269a6dccd7e8357e738a63a848b1 Mon Sep 17 00:00:00 2001 From: Xu Kai Date: Thu, 28 Sep 2023 17:23:28 +0800 Subject: [PATCH 3/3] fix typo --- tests/test_infer/test_chatglm2_infer.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/test_infer/test_chatglm2_infer.py b/tests/test_infer/test_chatglm2_infer.py index 28c029fde16c..399b70e1460e 100644 --- a/tests/test_infer/test_chatglm2_infer.py +++ b/tests/test_infer/test_chatglm2_infer.py @@ -29,7 +29,7 @@ ], ) def run_chatglm2_test(test_config): - chagglm_config = ChatGLMConfig( + chatglm_config = ChatGLMConfig( num_layers=2, vocab_size=1200, use_cache=True, @@ -38,7 +38,7 @@ def run_chatglm2_test(test_config): num_attention_heads=8, hidden_size=1024, ) - model = ChatGLMForConditionalGeneration(chagglm_config) + model = ChatGLMForConditionalGeneration(chatglm_config) model = model.half() shard_config = ShardConfig(