From 264bbf6624b385a055b64ee2eba099726765eef5 Mon Sep 17 00:00:00 2001 From: yuanheng-zhao Date: Thu, 31 Aug 2023 15:24:00 +0800 Subject: [PATCH 1/5] fix engine prepare data --- .../inference/tensor_parallel/engine.py | 26 ++++++++++++------- 1 file changed, 17 insertions(+), 9 deletions(-) diff --git a/colossalai/inference/tensor_parallel/engine.py b/colossalai/inference/tensor_parallel/engine.py index 52d2fc05ffbb..01763c850381 100644 --- a/colossalai/inference/tensor_parallel/engine.py +++ b/colossalai/inference/tensor_parallel/engine.py @@ -163,14 +163,19 @@ def prepare_batch_state(self, inputs) -> BatchInferState: if not isinstance(inputs, (BatchEncoding, dict, list, torch.Tensor)): raise TypeError(f"inputs type {type(inputs)} is not supported in prepare_batch_state") + input_ids_list = None + attention_mask = None + if isinstance(inputs, (BatchEncoding, dict)): - attn_masks = inputs['attention_mask'] - batch_size = attn_masks.shape[0] - max_len_in_batch = attn_masks.shape[1] - elif isinstance(inputs, list): - batch_size = len(inputs) + input_ids_list = inputs['input_ids'] + attention_mask = inputs['attention_mask'] else: - batch_size = inputs.shape[0] + input_ids_list = inputs + if isinstance(input_ids_list[0], int): # for a single input + input_ids_list = [input_ids_list] + attention_mask = [attention_mask] if attention_mask is not None else attention_mask + + batch_size = len(input_ids_list) seq_start_indexes = torch.zeros(batch_size, dtype=torch.int32, device='cuda') seq_lengths = torch.zeros(batch_size, dtype=torch.int32, device='cuda') @@ -178,14 +183,17 @@ def prepare_batch_state(self, inputs) -> BatchInferState: max_len_in_batch = -1 if isinstance(inputs, (BatchEncoding, dict)): - for i, attn_mask in enumerate(attn_masks): - curr_seq_len = int(torch.sum(attn_mask)) + for i, attn_mask in enumerate(attention_mask): + if isinstance(attn_mask, torch.Tensor): + curr_seq_len = int(torch.sum(attn_mask)) + else: + curr_seq_len = int(sum(attn_mask)) seq_lengths[i] = curr_seq_len seq_start_indexes[i] = start_index start_index += curr_seq_len max_len_in_batch = curr_seq_len if curr_seq_len > max_len_in_batch else max_len_in_batch else: - for i, input_ids in enumerate(inputs): + for i, input_ids in enumerate(input_ids_list): curr_seq_len = len(input_ids) seq_lengths[i] = curr_seq_len seq_start_indexes[i] = start_index From 917590bb524ee0136af9f6a4bb8713b20b894bf4 Mon Sep 17 00:00:00 2001 From: yuanheng-zhao Date: Thu, 31 Aug 2023 15:24:32 +0800 Subject: [PATCH 2/5] add engine test --- tests/test_infer/test_infer_engine.py | 72 ++++++++++++++++++++++++--- 1 file changed, 65 insertions(+), 7 deletions(-) diff --git a/tests/test_infer/test_infer_engine.py b/tests/test_infer/test_infer_engine.py index 7fcb36554b90..a6b69f1cf0a9 100644 --- a/tests/test_infer/test_infer_engine.py +++ b/tests/test_infer/test_infer_engine.py @@ -1,28 +1,83 @@ +from itertools import accumulate + import pytest import torch -from transformers import AutoTokenizer, LlamaConfig, LlamaForCausalLM +import torch.nn as nn +from transformers import AutoTokenizer, LlamaConfig, LlamaForCausalLM, LlamaTokenizer +from transformers.tokenization_utils_base import BatchEncoding import colossalai from colossalai.inference.tensor_parallel import TPInferEngine +from colossalai.inference.tensor_parallel.batch_infer_state import BatchInferState from colossalai.logging import disable_existing_loggers from colossalai.shardformer import ShardConfig, ShardFormer from colossalai.testing import clear_cache_before_run, parameterize, rerun_if_address_is_in_use, spawn TP_SIZE = 2 -BATCH_SIZE = 4 +MAX_BATCH_SIZE = 4 MAX_INPUT_LEN = 16 MAX_OUTPUT_LEN = 8 +def test_prepare_data(): + # dummy module used for testing + class DummyModule(nn.Module): + + def __init__(self, config): + super(DummyModule, self).__init__() + self.config = config + + def forward(self, x): + return x + + # dummy config used for testing + class DummyModelConfig: + + def __init__(self): + self.hidden_size = 4096 + self.num_attention_heads = 32 + self.num_hidden_layers = 8 + + dummy_config = DummyModelConfig() + model = DummyModule(dummy_config) + infer_engine = TPInferEngine(model, MAX_BATCH_SIZE, MAX_INPUT_LEN, MAX_OUTPUT_LEN) + + input_ids_list = [[80540, 15473, 3331, 11970, 90472, 361, 61335], [80540, 15473, 3331, 11970], + [80540, 15473, 3331, 11970], [80540, 15473]] + batch_size = len(input_ids_list) + max_seq_len = max(len(li) for li in input_ids_list) + attention_mask = [[0] * max_seq_len for _ in range(batch_size)] + for i, li in enumerate(input_ids_list): + attention_mask[i][max_seq_len - len(li):] = [1 for _ in range(len(li))] + data = dict(input_ids=input_ids_list, attention_mask=attention_mask) + inputs_batch_encoding = BatchEncoding(data=data) + + seq_lengths = [len(li) for li in input_ids_list] + start_loc = list(accumulate([0] + seq_lengths[:-1])) + seq_lengths = torch.tensor(seq_lengths, dtype=torch.int32) + start_loc = torch.tensor(start_loc, dtype=torch.int32) + + # input token id list as inputs + batch_state_out1 = infer_engine.prepare_batch_state(inputs_batch_encoding) + # BatchEncoding as inputs + batch_state_out2 = infer_engine.prepare_batch_state(input_ids_list) + + assert batch_state_out1.batch_size == batch_state_out2.batch_size == batch_size + assert torch.equal(batch_state_out1.seq_len.to(seq_lengths.device), seq_lengths) + assert torch.equal(batch_state_out2.seq_len.to(seq_lengths.device), seq_lengths) + assert torch.equal(batch_state_out1.start_loc.to(start_loc.device), start_loc) + assert torch.equal(batch_state_out2.start_loc.to(start_loc.device), start_loc) + + def test_orig_generate(): - input_ids = torch.randint(low=10, high=1000, size=(BATCH_SIZE, MAX_INPUT_LEN)) + input_ids = torch.randint(low=10, high=1000, size=(MAX_BATCH_SIZE, MAX_INPUT_LEN)) model_config = LlamaConfig() model = LlamaForCausalLM(model_config) shard_config = ShardConfig(enable_tensor_parallelism=False) # init TPInferEngine and - infer_engine = TPInferEngine(model, BATCH_SIZE, MAX_INPUT_LEN, MAX_OUTPUT_LEN) + infer_engine = TPInferEngine(model, MAX_BATCH_SIZE, MAX_INPUT_LEN, MAX_OUTPUT_LEN) infer_engine.prepare_with_shard_config(shard_config) # original model generate @@ -31,12 +86,14 @@ def test_orig_generate(): def run(): + input_ids = torch.tensor([[80540, 15473, 3331, 11970, 90472, 361, 61335]], dtype=torch.int64) + model_config = LlamaConfig() model = LlamaForCausalLM(model_config) shard_config = ShardConfig(enable_tensor_parallelism=True, inference_only=True) shardformer = ShardFormer(shard_config=shard_config) - infer_engine = TPInferEngine(model, BATCH_SIZE, MAX_INPUT_LEN, MAX_OUTPUT_LEN) + infer_engine = TPInferEngine(model, MAX_BATCH_SIZE, MAX_INPUT_LEN, MAX_OUTPUT_LEN) infer_engine.prepare_with_shard_config(shard_config=shard_config) infer_engine.shard_model_by(shardformer) @@ -46,8 +103,8 @@ def run(): # TODO After adding forward replacement for CausalLM, # uncomment these lines to test sharded model generate - # generate_kwargs = dict(do_sample=False) - # infer_engine.generate(input_ids, generate_kwargs) + generate_kwargs = dict(do_sample=False) + infer_engine.generate(input_ids, generate_kwargs) torch.cuda.empty_cache() @@ -66,5 +123,6 @@ def test_engine_infer(): if __name__ == '__main__': + test_prepare_data() test_orig_generate() test_engine_infer() From 065c2b06b3681eb27e6f31f7097217dbe6977b41 Mon Sep 17 00:00:00 2001 From: yuanheng-zhao Date: Thu, 31 Aug 2023 15:37:26 +0800 Subject: [PATCH 3/5] use bloom for testing --- tests/test_infer/test_infer_engine.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/tests/test_infer/test_infer_engine.py b/tests/test_infer/test_infer_engine.py index a6b69f1cf0a9..43ef43beadc6 100644 --- a/tests/test_infer/test_infer_engine.py +++ b/tests/test_infer/test_infer_engine.py @@ -3,7 +3,7 @@ import pytest import torch import torch.nn as nn -from transformers import AutoTokenizer, LlamaConfig, LlamaForCausalLM, LlamaTokenizer +from transformers import BloomConfig, BloomForCausalLM, LlamaConfig, LlamaForCausalLM from transformers.tokenization_utils_base import BatchEncoding import colossalai @@ -72,8 +72,8 @@ def __init__(self): def test_orig_generate(): input_ids = torch.randint(low=10, high=1000, size=(MAX_BATCH_SIZE, MAX_INPUT_LEN)) - model_config = LlamaConfig() - model = LlamaForCausalLM(model_config) + model_config = BloomConfig() + model = BloomForCausalLM(model_config) shard_config = ShardConfig(enable_tensor_parallelism=False) # init TPInferEngine and @@ -88,8 +88,8 @@ def test_orig_generate(): def run(): input_ids = torch.tensor([[80540, 15473, 3331, 11970, 90472, 361, 61335]], dtype=torch.int64) - model_config = LlamaConfig() - model = LlamaForCausalLM(model_config) + model_config = BloomConfig() + model = BloomForCausalLM(model_config) shard_config = ShardConfig(enable_tensor_parallelism=True, inference_only=True) shardformer = ShardFormer(shard_config=shard_config) From 116404781bf8a5ea0021bebf024985e68eb7941d Mon Sep 17 00:00:00 2001 From: yuanheng-zhao Date: Thu, 31 Aug 2023 16:13:21 +0800 Subject: [PATCH 4/5] revise on test --- .../inference/tensor_parallel/modeling/bloom.py | 3 +++ tests/test_infer/test_infer_engine.py | 12 ++++++++++-- 2 files changed, 13 insertions(+), 2 deletions(-) diff --git a/colossalai/inference/tensor_parallel/modeling/bloom.py b/colossalai/inference/tensor_parallel/modeling/bloom.py index 0fd08d3721e6..7ebd68d22b93 100644 --- a/colossalai/inference/tensor_parallel/modeling/bloom.py +++ b/colossalai/inference/tensor_parallel/modeling/bloom.py @@ -457,6 +457,9 @@ def bloom_attention_forward( # output = self.output[:batch_size*q_length, :, :] output = torch.empty_like(q) + print(self.num_heads) + print(k.shape) + bloom_context_attn_fwd(q, k, v, output, b_start_loc, b_seq_len, max_input_len, alibi) context_layer = output.view(batch_size, q_length, H * D_HEAD) diff --git a/tests/test_infer/test_infer_engine.py b/tests/test_infer/test_infer_engine.py index 43ef43beadc6..d6a6d3465e03 100644 --- a/tests/test_infer/test_infer_engine.py +++ b/tests/test_infer/test_infer_engine.py @@ -72,8 +72,11 @@ def __init__(self): def test_orig_generate(): input_ids = torch.randint(low=10, high=1000, size=(MAX_BATCH_SIZE, MAX_INPUT_LEN)) - model_config = BloomConfig() - model = BloomForCausalLM(model_config) + model_config = LlamaConfig() + model = LlamaForCausalLM(model_config) + model = model.half() + model.to(torch.cuda.current_device()) + shard_config = ShardConfig(enable_tensor_parallelism=False) # init TPInferEngine and @@ -84,12 +87,17 @@ def test_orig_generate(): generate_kwargs = dict(do_sample=False) infer_engine.generate(input_ids, generate_kwargs) + torch.cuda.empty_cache() + def run(): input_ids = torch.tensor([[80540, 15473, 3331, 11970, 90472, 361, 61335]], dtype=torch.int64) model_config = BloomConfig() model = BloomForCausalLM(model_config) + model = model.half() + model.to(torch.cuda.current_device()) + shard_config = ShardConfig(enable_tensor_parallelism=True, inference_only=True) shardformer = ShardFormer(shard_config=shard_config) From 586ee081f257d94dc239fd19ceaae09f5725ab32 Mon Sep 17 00:00:00 2001 From: yuanheng-zhao Date: Thu, 31 Aug 2023 16:27:59 +0800 Subject: [PATCH 5/5] revise on test --- colossalai/inference/tensor_parallel/modeling/bloom.py | 3 --- tests/test_infer/test_infer_engine.py | 7 ------- 2 files changed, 10 deletions(-) diff --git a/colossalai/inference/tensor_parallel/modeling/bloom.py b/colossalai/inference/tensor_parallel/modeling/bloom.py index 7ebd68d22b93..0fd08d3721e6 100644 --- a/colossalai/inference/tensor_parallel/modeling/bloom.py +++ b/colossalai/inference/tensor_parallel/modeling/bloom.py @@ -457,9 +457,6 @@ def bloom_attention_forward( # output = self.output[:batch_size*q_length, :, :] output = torch.empty_like(q) - print(self.num_heads) - print(k.shape) - bloom_context_attn_fwd(q, k, v, output, b_start_loc, b_seq_len, max_input_len, alibi) context_layer = output.view(batch_size, q_length, H * D_HEAD) diff --git a/tests/test_infer/test_infer_engine.py b/tests/test_infer/test_infer_engine.py index d6a6d3465e03..6fcf9fe0f387 100644 --- a/tests/test_infer/test_infer_engine.py +++ b/tests/test_infer/test_infer_engine.py @@ -91,8 +91,6 @@ def test_orig_generate(): def run(): - input_ids = torch.tensor([[80540, 15473, 3331, 11970, 90472, 361, 61335]], dtype=torch.int64) - model_config = BloomConfig() model = BloomForCausalLM(model_config) model = model.half() @@ -109,11 +107,6 @@ def run(): assert infer_engine.tp_size == TP_SIZE assert infer_engine.head_num == model_config.num_attention_heads // TP_SIZE - # TODO After adding forward replacement for CausalLM, - # uncomment these lines to test sharded model generate - generate_kwargs = dict(do_sample=False) - infer_engine.generate(input_ids, generate_kwargs) - torch.cuda.empty_cache()