Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 17 additions & 9 deletions colossalai/inference/tensor_parallel/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -163,29 +163,37 @@ def prepare_batch_state(self, inputs) -> BatchInferState:
if not isinstance(inputs, (BatchEncoding, dict, list, torch.Tensor)):
raise TypeError(f"inputs type {type(inputs)} is not supported in prepare_batch_state")

input_ids_list = None
attention_mask = None

if isinstance(inputs, (BatchEncoding, dict)):
attn_masks = inputs['attention_mask']
batch_size = attn_masks.shape[0]
max_len_in_batch = attn_masks.shape[1]
elif isinstance(inputs, list):
batch_size = len(inputs)
input_ids_list = inputs['input_ids']
attention_mask = inputs['attention_mask']
else:
batch_size = inputs.shape[0]
input_ids_list = inputs
if isinstance(input_ids_list[0], int): # for a single input
input_ids_list = [input_ids_list]
attention_mask = [attention_mask] if attention_mask is not None else attention_mask

batch_size = len(input_ids_list)

seq_start_indexes = torch.zeros(batch_size, dtype=torch.int32, device='cuda')
seq_lengths = torch.zeros(batch_size, dtype=torch.int32, device='cuda')
start_index = 0

max_len_in_batch = -1
if isinstance(inputs, (BatchEncoding, dict)):
for i, attn_mask in enumerate(attn_masks):
curr_seq_len = int(torch.sum(attn_mask))
for i, attn_mask in enumerate(attention_mask):
if isinstance(attn_mask, torch.Tensor):
curr_seq_len = int(torch.sum(attn_mask))
else:
curr_seq_len = int(sum(attn_mask))
seq_lengths[i] = curr_seq_len
seq_start_indexes[i] = start_index
start_index += curr_seq_len
max_len_in_batch = curr_seq_len if curr_seq_len > max_len_in_batch else max_len_in_batch
else:
for i, input_ids in enumerate(inputs):
for i, input_ids in enumerate(input_ids_list):
curr_seq_len = len(input_ids)
seq_lengths[i] = curr_seq_len
seq_start_indexes[i] = start_index
Expand Down
83 changes: 71 additions & 12 deletions tests/test_infer/test_infer_engine.py
Original file line number Diff line number Diff line change
@@ -1,54 +1,112 @@
from itertools import accumulate

import pytest
import torch
from transformers import AutoTokenizer, LlamaConfig, LlamaForCausalLM
import torch.nn as nn
from transformers import BloomConfig, BloomForCausalLM, LlamaConfig, LlamaForCausalLM
from transformers.tokenization_utils_base import BatchEncoding

import colossalai
from colossalai.inference.tensor_parallel import TPInferEngine
from colossalai.inference.tensor_parallel.batch_infer_state import BatchInferState
from colossalai.logging import disable_existing_loggers
from colossalai.shardformer import ShardConfig, ShardFormer
from colossalai.testing import clear_cache_before_run, parameterize, rerun_if_address_is_in_use, spawn

TP_SIZE = 2
BATCH_SIZE = 4
MAX_BATCH_SIZE = 4
MAX_INPUT_LEN = 16
MAX_OUTPUT_LEN = 8


def test_prepare_data():
# dummy module used for testing
class DummyModule(nn.Module):

def __init__(self, config):
super(DummyModule, self).__init__()
self.config = config

def forward(self, x):
return x

# dummy config used for testing
class DummyModelConfig:

def __init__(self):
self.hidden_size = 4096
self.num_attention_heads = 32
self.num_hidden_layers = 8

dummy_config = DummyModelConfig()
model = DummyModule(dummy_config)
infer_engine = TPInferEngine(model, MAX_BATCH_SIZE, MAX_INPUT_LEN, MAX_OUTPUT_LEN)

input_ids_list = [[80540, 15473, 3331, 11970, 90472, 361, 61335], [80540, 15473, 3331, 11970],
[80540, 15473, 3331, 11970], [80540, 15473]]
batch_size = len(input_ids_list)
max_seq_len = max(len(li) for li in input_ids_list)
attention_mask = [[0] * max_seq_len for _ in range(batch_size)]
for i, li in enumerate(input_ids_list):
attention_mask[i][max_seq_len - len(li):] = [1 for _ in range(len(li))]
data = dict(input_ids=input_ids_list, attention_mask=attention_mask)
inputs_batch_encoding = BatchEncoding(data=data)

seq_lengths = [len(li) for li in input_ids_list]
start_loc = list(accumulate([0] + seq_lengths[:-1]))
seq_lengths = torch.tensor(seq_lengths, dtype=torch.int32)
start_loc = torch.tensor(start_loc, dtype=torch.int32)

# input token id list as inputs
batch_state_out1 = infer_engine.prepare_batch_state(inputs_batch_encoding)
# BatchEncoding as inputs
batch_state_out2 = infer_engine.prepare_batch_state(input_ids_list)

assert batch_state_out1.batch_size == batch_state_out2.batch_size == batch_size
assert torch.equal(batch_state_out1.seq_len.to(seq_lengths.device), seq_lengths)
assert torch.equal(batch_state_out2.seq_len.to(seq_lengths.device), seq_lengths)
assert torch.equal(batch_state_out1.start_loc.to(start_loc.device), start_loc)
assert torch.equal(batch_state_out2.start_loc.to(start_loc.device), start_loc)


def test_orig_generate():
input_ids = torch.randint(low=10, high=1000, size=(BATCH_SIZE, MAX_INPUT_LEN))
input_ids = torch.randint(low=10, high=1000, size=(MAX_BATCH_SIZE, MAX_INPUT_LEN))

model_config = LlamaConfig()
model = LlamaForCausalLM(model_config)
model = model.half()
model.to(torch.cuda.current_device())

shard_config = ShardConfig(enable_tensor_parallelism=False)

# init TPInferEngine and
infer_engine = TPInferEngine(model, BATCH_SIZE, MAX_INPUT_LEN, MAX_OUTPUT_LEN)
infer_engine = TPInferEngine(model, MAX_BATCH_SIZE, MAX_INPUT_LEN, MAX_OUTPUT_LEN)
infer_engine.prepare_with_shard_config(shard_config)

# original model generate
generate_kwargs = dict(do_sample=False)
infer_engine.generate(input_ids, generate_kwargs)

torch.cuda.empty_cache()


def run():
model_config = LlamaConfig()
model = LlamaForCausalLM(model_config)
Comment thread
yuanheng-zhao marked this conversation as resolved.
model_config = BloomConfig()
model = BloomForCausalLM(model_config)
model = model.half()
model.to(torch.cuda.current_device())

shard_config = ShardConfig(enable_tensor_parallelism=True, inference_only=True)
shardformer = ShardFormer(shard_config=shard_config)

infer_engine = TPInferEngine(model, BATCH_SIZE, MAX_INPUT_LEN, MAX_OUTPUT_LEN)
infer_engine = TPInferEngine(model, MAX_BATCH_SIZE, MAX_INPUT_LEN, MAX_OUTPUT_LEN)
infer_engine.prepare_with_shard_config(shard_config=shard_config)
infer_engine.shard_model_by(shardformer)

assert infer_engine.cache_manager is not None
assert infer_engine.tp_size == TP_SIZE
assert infer_engine.head_num == model_config.num_attention_heads // TP_SIZE

# TODO After adding forward replacement for CausalLM,
# uncomment these lines to test sharded model generate
# generate_kwargs = dict(do_sample=False)
# infer_engine.generate(input_ids, generate_kwargs)

torch.cuda.empty_cache()


Expand All @@ -66,5 +124,6 @@ def test_engine_infer():


if __name__ == '__main__':
test_prepare_data()
test_orig_generate()
test_engine_infer()