Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 22 additions & 6 deletions colossalai/inference/tensor_parallel/engine.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Any, Callable, Dict, List, Optional, Set, Union
from typing import Any, Callable, Dict, List, Optional, Union

import torch
import torch.nn as nn
Expand All @@ -7,7 +7,6 @@
from transformers.generation.stopping_criteria import StoppingCriteriaList
from transformers.tokenization_utils_base import BatchEncoding

from colossalai.cluster import ProcessGroupMesh
from colossalai.shardformer import ShardConfig, ShardFormer
from colossalai.shardformer.policies.auto_policy import get_autopolicy

Expand All @@ -29,6 +28,7 @@ def __init__(self,
dtype: torch.dtype = torch.float16,
device: str = 'cuda') -> None:
self.model = model
self.model = self.model.to(device)
self.sharded_model = None

self.max_batch_size = max_batch_size
Expand Down Expand Up @@ -57,7 +57,18 @@ def _init_manager(self) -> None:
self.cache_manager = MemoryManager(self.max_total_token_num, self.dtype, self.head_num, self.head_dim,
self.layer_num)

def prepare_with_shard_config(self, shard_config: Optional[ShardConfig] = None) -> ShardConfig:
def optimize_model(self, config: Optional[Dict[Any, Any]] = None) -> None:
""" Apply shardformer to optimize the model. In future generation, use sharded model instead of original model. """
tp_size = 1 if config is None else config.get('tp_size', 1)
# NOTE we will change to use an inference config later with additional attrs we want
# tp_size = getattr(config, 'tp_size', 1)
shard_config = ShardConfig(enable_tensor_parallelism=True if tp_size > 1 else False, inference_only=True)
shardformer = ShardFormer(shard_config=shard_config)
self._prepare_with_shard_config(shard_config=shard_config)
self._shard_model_by(shardformer)
self.model = None

def _prepare_with_shard_config(self, shard_config: Optional[ShardConfig] = None) -> ShardConfig:
""" Prepare the engine with a given ShardConfig, or create a default one with tp size 1 """
self.tp_size = 1
if shard_config is None:
Expand All @@ -80,7 +91,7 @@ def prepare_with_shard_config(self, shard_config: Optional[ShardConfig] = None)

return shard_config

def shard_model_by(self, shardformer: ShardFormer) -> None:
def _shard_model_by(self, shardformer: ShardFormer) -> None:
""" Shard the model and store the sharded model by given ShardFormer """
assert self.tp_size == shardformer.shard_config.tensor_parallel_size, \
"Discrepancy between the tp size of TPInferEngine and the tp size of shard config"
Expand All @@ -100,11 +111,13 @@ def generate(self, input_tokens: Union[BatchEncoding, dict, list, torch.Tensor],
for t in input_tokens:
if torch.is_tensor(input_tokens[t]):
input_tokens[t] = input_tokens[t].cuda()
if 'max_new_tokens' not in generate_kwargs:
generate_kwargs.update(max_new_tokens=self.max_output_len)

if self.sharded_model is not None:
return self.generate_by_set_infer_state(input_tokens, **generate_kwargs)

return self.model.generate(**input_tokens, **generate_kwargs)
return self.model.generate(input_tokens.get('input_ids'), **generate_kwargs)

@torch.no_grad()
def generate_by_set_infer_state(self, input_tokens, **generate_kwargs) -> torch.Tensor:
Expand Down Expand Up @@ -135,9 +148,12 @@ def generate_by_set_infer_state(self, input_tokens, **generate_kwargs) -> torch.
model = self.sharded_model.transformer
setattr(model, 'infer_state', batch_infer_state)

generate_kwargs.update(max_new_tokens=self.max_output_len)
outputs = self.sharded_model.generate(**input_tokens, **generate_kwargs, early_stopping=False)

# NOTE In future development, we're going to let the scheduler to handle the cache,
# instead of freeing space explicitly at the end of generation
self.cache_manager.free_all()

return outputs

def prepare_batch_state(self, inputs) -> BatchInferState:
Expand Down
15 changes: 4 additions & 11 deletions examples/inference/bench_bloom.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,17 +48,11 @@ def bench_bloom(test_config):
tokenizer.pad_token = tokenizer.eos_token
model = BloomForCausalLM.from_pretrained(model_path, pad_token_id=tokenizer.eos_token_id)
model = model.half()
# To benchmark torch original, uncommment the following line
# model.to(torch.cuda.current_device())

# init TPInferEngine and shard original model by shardformer
# To benchmark torch original, comment out lines of creating, preparing, and sharding by the shardformer
# init TPInferEngine and shard the original model
# To benchmark torch original, comment out lines of optimizing model
infer_engine = TPInferEngine(model, MAX_BATCH_SIZE, MAX_INPUT_LEN, MAX_OUTPUT_LEN)
shard_config = ShardConfig(enable_tensor_parallelism=True if test_config['tp_size'] > 1 else False,
inference_only=True)
shardformer = ShardFormer(shard_config=shard_config)
infer_engine.prepare_with_shard_config(shard_config)
infer_engine.shard_model_by(shardformer)
infer_engine.optimize_model(test_config)

# prepare data for generation
batch_size = MAX_BATCH_SIZE
Expand All @@ -78,10 +72,9 @@ def bench_bloom(test_config):
for i in range(iters):
torch.cuda.synchronize()
start = time.time()
outputs = infer_engine.generate(input_tokens, generate_kwargs)
outputs = infer_engine.generate(input_tokens, **generate_kwargs)
torch.cuda.synchronize()
end = time.time()
infer_engine.cache_manager.free_all()
out_len = outputs.shape[1]
print(f" iter {i}: out len {str(out_len)}, generation time {str(end - start)} s")
times.append((end - start) / (out_len - input_len))
Expand Down
11 changes: 3 additions & 8 deletions examples/inference/bench_llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,12 +77,8 @@ def run_llama_test(test_config):

model_config = model.config

infer_engine = TPInferEngine(model.half(), BATCH_SIZE, MAX_INPUT_LEN, MAX_OUTPUT_LEN)
shard_config = ShardConfig(enable_tensor_parallelism=False, inference_only=True)
shardformer = ShardFormer(shard_config=shard_config)

infer_engine.prepare_with_shard_config(shard_config)
infer_engine.shard_model_by(shardformer)
infer_engine = TPInferEngine(model, BATCH_SIZE, MAX_INPUT_LEN, MAX_OUTPUT_LEN)
infer_engine.optimize_model(test_config)

batch_size = 2
max_new_tokens = 128
Expand All @@ -100,13 +96,12 @@ def run_llama_test(test_config):
for i in range(iters):
torch.cuda.synchronize()
start = time.time()
outputs = infer_engine.generate(input_tokens, generate_kwargs)
outputs = infer_engine.generate(input_tokens, **generate_kwargs)
torch.cuda.synchronize()
end = time.time()
out_len = outputs.shape[1]
print("generation time {} s".format(str(end - start)))
times.append((end - start) / (out_len - input_len))
infer_engine.cache_manager.free_all()

print("outputs, ", len(outputs))
outputs = tokenizer.batch_decode(outputs)
Expand Down
33 changes: 18 additions & 15 deletions tests/test_infer/test_bloom_infer.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,12 @@
import torch
import torch.distributed as dist
from packaging import version
from transformers import AutoModelForCausalLM, AutoTokenizer, BloomForCausalLM
from transformers import AutoTokenizer, BloomForCausalLM

import colossalai
from colossalai.inference.tensor_parallel import TPInferEngine
from colossalai.logging import disable_existing_loggers
from colossalai.shardformer import ShardConfig, ShardFormer
from colossalai.testing import clear_cache_before_run, rerun_if_address_is_in_use, spawn
from colossalai.testing import clear_cache_before_run, parameterize, rerun_if_address_is_in_use, spawn

TP_SIZE = 2
MAX_BATCH_SIZE = 4
Expand All @@ -20,34 +19,38 @@
CUDA_SUPPORT = version.parse(torch.version.cuda) > version.parse('11.5')


def run():
model_path = "/home/lczyh/data3/models/bloom-7b1"
@parameterize('test_config', [{
'tp_size': TP_SIZE,
}])
def run(test_config):
model_path = "/data3/models/bloom-7b1"
if os.path.isdir(model_path) is False:
return

tokenizer = AutoTokenizer.from_pretrained(model_path)
tokenizer.pad_token = tokenizer.eos_token

text = "Introduce some landmarks in Beijing"
input_ids = tokenizer.batch_encode_plus([text], return_tensors='pt')
text1 = "Introduce some landmarks in Beijing"
text2 = "how is weather today?"
input_ids = tokenizer.batch_encode_plus([text1, text2], return_tensors='pt', padding=True)

model = BloomForCausalLM.from_pretrained(model_path, pad_token_id=tokenizer.eos_token_id)
model = model.half()
model.to(torch.cuda.current_device())

shard_config = ShardConfig(enable_tensor_parallelism=True, inference_only=True)
shardformer = ShardFormer(shard_config=shard_config)

infer_engine = TPInferEngine(model, MAX_BATCH_SIZE, MAX_INPUT_LEN, MAX_OUTPUT_LEN)
infer_engine.prepare_with_shard_config(shard_config=shard_config)
infer_engine.shard_model_by(shardformer)
infer_engine.optimize_model(test_config)

generate_kwargs = dict(do_sample=False)
outputs = infer_engine.generate(input_ids, **generate_kwargs)

assert outputs is not None

if not dist.is_initialized() or dist.get_rank() == 0:
output_text = tokenizer.decode(outputs[0])
print(output_text)
# output_text = tokenizer.decode(outputs[0])
# print(output_text)
for o in outputs:
output_text = tokenizer.decode(o)
# print(output_text)


def check_engine(rank, world_size, port):
Expand Down
14 changes: 6 additions & 8 deletions tests/test_infer/test_infer_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,9 +85,8 @@ def test_orig_generate():

shard_config = ShardConfig(enable_tensor_parallelism=False)

# init TPInferEngine and
# init TPInferEngine
infer_engine = TPInferEngine(model, MAX_BATCH_SIZE, MAX_INPUT_LEN, MAX_OUTPUT_LEN)
infer_engine.prepare_with_shard_config(shard_config)

# original model generate
generate_kwargs = dict(do_sample=False)
Expand All @@ -96,18 +95,17 @@ def test_orig_generate():
torch.cuda.empty_cache()


def run():
@parameterize('test_config', [{
'tp_size': TP_SIZE,
}])
def run(test_config):
model_config = BloomConfig()
model = BloomForCausalLM(model_config)
model = model.half()
model.to(torch.cuda.current_device())

shard_config = ShardConfig(enable_tensor_parallelism=True, inference_only=True)
shardformer = ShardFormer(shard_config=shard_config)

infer_engine = TPInferEngine(model, MAX_BATCH_SIZE, MAX_INPUT_LEN, MAX_OUTPUT_LEN)
infer_engine.prepare_with_shard_config(shard_config=shard_config)
infer_engine.shard_model_by(shardformer)
infer_engine.optimize_model(test_config)

assert infer_engine.cache_manager is not None
assert infer_engine.tp_size == TP_SIZE
Expand Down
18 changes: 8 additions & 10 deletions tests/test_infer/test_llama_infer.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,10 @@ def init_to_get_rotary(self, base=10000):
return


def run_llama_test():
@parameterize('test_config', [{
'tp_size': TPSIZE,
}])
def run_llama_test(test_config):

llama_model_path = "/data/scratch/llama-7b-hf"
if os.path.isdir(llama_model_path) is False:
Expand All @@ -61,19 +64,14 @@ def run_llama_test():
text = ["how is weather today?", "i am "]
input_ids = tokenizer.batch_encode_plus(text, return_tensors='pt', padding=True, device='cuda')

#print("input ids ", input_ids)
infer_engine = TPInferEngine(model.half(), BATCH_SIZE, MAX_INPUT_LEN, MAX_OUTPUT_LEN)
shard_config = ShardConfig(enable_tensor_parallelism=True, inference_only=True)
shardformer = ShardFormer(shard_config=shard_config)

infer_engine.prepare_with_shard_config(shard_config)
infer_engine.shard_model_by(shardformer)
infer_engine = TPInferEngine(model, BATCH_SIZE, MAX_INPUT_LEN, MAX_OUTPUT_LEN)
infer_engine.optimize_model(test_config)

generate_kwargs = dict(max_new_tokens=MAX_OUTPUT_LEN, do_sample=False)
outputs = infer_engine.generate(input_ids, **generate_kwargs)
#print("outputs.shape: ", outputs.shape)

#print("outputs: ", outputs)
assert outputs is not None

if not dist.is_initialized() or dist.get_rank() == 0:
for o in outputs:
output_text = tokenizer.decode(o)
Expand Down