diff --git a/tests/test_infer/test_bloom_infer.py b/tests/test_infer/test_bloom_infer.py index dad3f9cb295f..1f01460994d9 100644 --- a/tests/test_infer/test_bloom_infer.py +++ b/tests/test_infer/test_bloom_infer.py @@ -21,7 +21,7 @@ def run(): - model_path = "/data3/models/bloom-7b1" + model_path = "/home/lczyh/data3/models/bloom-7b1" if os.path.isdir(model_path) is False: return @@ -43,7 +43,7 @@ def run(): infer_engine.shard_model_by(shardformer) generate_kwargs = dict(do_sample=False) - outputs = infer_engine.generate(input_ids, generate_kwargs) + outputs = infer_engine.generate(input_ids, **generate_kwargs) if not dist.is_initialized() or dist.get_rank() == 0: output_text = tokenizer.decode(outputs[0]) diff --git a/tests/test_infer/test_llama_infer.py b/tests/test_infer/test_llama_infer.py index 1d043ba59338..986f70633289 100644 --- a/tests/test_infer/test_llama_infer.py +++ b/tests/test_infer/test_llama_infer.py @@ -15,7 +15,7 @@ from colossalai.testing import clear_cache_before_run, parameterize, rerun_if_address_is_in_use, spawn os.environ['TRANSFORMERS_NO_ADVISORY_WARNINGS'] = 'true' -TPSIZE = 1 +TPSIZE = 2 BATCH_SIZE = 8 MAX_INPUT_LEN = 12 MAX_OUTPUT_LEN = 100 @@ -46,10 +46,7 @@ def init_to_get_rotary(self, base=10000): return -@parameterize('test_config', [{ - 'tp_size': TPSIZE, -}]) -def run_llama_test(test_config): +def run_llama_test(): llama_model_path = "/data/scratch/llama-7b-hf" if os.path.isdir(llama_model_path) is False: @@ -73,14 +70,14 @@ def run_llama_test(test_config): infer_engine.shard_model_by(shardformer) generate_kwargs = dict(max_new_tokens=MAX_OUTPUT_LEN, do_sample=False) - outputs = infer_engine.generate(input_ids, generate_kwargs) + outputs = infer_engine.generate(input_ids, **generate_kwargs) #print("outputs.shape: ", outputs.shape) #print("outputs: ", outputs) if not dist.is_initialized() or dist.get_rank() == 0: for o in outputs: output_text = tokenizer.decode(o) - #print(output_text) + # print(output_text) def check_llama(rank, world_size, port):