Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
47 changes: 21 additions & 26 deletions colossalai/inference/tensor_parallel/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,6 @@ class TPInferEngine:
>>> generate_kwargs = ...
>>> shard_config = ShardConfig(enable_tensor_parallelism=True, inference_only=True)
>>> infer_engine = TPInferEngine(model, shard_config, MAX_BATCH_SIZE, MAX_INPUT_LEN, MAX_OUTPUT_LEN)
>>> infer_engine.optimize_model()
>>> outputs = infer_engine.generate(input_ids, **generate_kwargs)
"""

Expand All @@ -48,11 +47,6 @@ def __init__(self,
max_output_len: int,
dtype: torch.dtype = torch.float16,
device: str = 'cuda') -> None:
self.model = model
self.model = self.model.to(device)
self.shard_config = shard_config
self.sharded_model = None

self.max_batch_size = max_batch_size
self.max_input_len = max_input_len
self.max_output_len = max_output_len
Expand All @@ -65,21 +59,26 @@ def __init__(self,

self.dtype = dtype

self.head_dim = self.model.config.hidden_size // self.model.config.num_attention_heads
self.head_num = self.model.config.num_attention_heads
self.layer_num = self.model.config.num_hidden_layers
self.head_dim = model.config.hidden_size // model.config.num_attention_heads
self.head_num = model.config.num_attention_heads
self.layer_num = model.config.num_hidden_layers

self.tp_size = -1 # to be set with given shard config in self.prepare_shard_config
self.cache_manager = None

self.shard_config = shard_config
self.model = None
# optimize the original model by sharding with ShardFormer
self._optimize_model(model=model.to(device))

def _init_manager(self) -> None:
assert self.tp_size >= 1, "TP size not initialized without providing a valid ShardConfig"
assert self.head_num % self.tp_size == 0, f"Cannot shard {self.head_num} heads with tp size {self.tp_size}"
self.head_num //= self.tp_size # update sharded number of heads
self.cache_manager = MemoryManager(self.max_total_token_num, self.dtype, self.head_num, self.head_dim,
self.layer_num)

def optimize_model(self) -> None:
def _optimize_model(self, model: nn.Module) -> None:
"""
Optimize the original model by sharding with ShardFormer.
In further generation, use the sharded model instead of original model.
Expand All @@ -88,8 +87,7 @@ def optimize_model(self) -> None:
assert self.shard_config.inference_only is True
shardformer = ShardFormer(shard_config=self.shard_config)
self._prepare_with_shard_config(shard_config=self.shard_config)
self._shard_model_by(shardformer)
self.model = None
self._shard_model_by(shardformer, model)

def _prepare_with_shard_config(self, shard_config: Optional[ShardConfig] = None) -> ShardConfig:
""" Prepare the engine with a given ShardConfig.
Expand Down Expand Up @@ -119,15 +117,15 @@ def _prepare_with_shard_config(self, shard_config: Optional[ShardConfig] = None)

return shard_config

def _shard_model_by(self, shardformer: ShardFormer) -> None:
def _shard_model_by(self, shardformer: ShardFormer, model: nn.Module) -> None:
""" Shard original model by the given ShardFormer and store the sharded model. """
assert self.tp_size == shardformer.shard_config.tensor_parallel_size, \
"Discrepancy between the tp size of TPInferEngine and the tp size of shard config"
model_name = self.model.__class__.__name__
model_name = model.__class__.__name__
assert model_name in self.supported_models, f"Unsupported model cls {model_name} for TP inference."
policy = get_autopolicy(self.model, inference_only=True)
self.sharded_model, _ = shardformer.optimize(self.model, policy)
self.sharded_model = self.sharded_model.cuda()
policy = get_autopolicy(model, inference_only=True)
self.model, _ = shardformer.optimize(model, policy)
self.model = self.model.cuda()

@property
def supported_models(self) -> List[str]:
Expand All @@ -152,10 +150,7 @@ def generate(self, input_tokens: Union[BatchEncoding, dict, list, torch.Tensor],
if 'max_new_tokens' not in generate_kwargs:
generate_kwargs.update(max_new_tokens=self.max_output_len)

if self.sharded_model is not None:
return self._generate_by_set_infer_state(input_tokens, **generate_kwargs)

return self.model.generate(**input_tokens, **generate_kwargs)
return self._generate_by_set_infer_state(input_tokens, **generate_kwargs)

def prepare_batch_state(self, inputs) -> BatchInferState:
"""
Expand Down Expand Up @@ -239,7 +234,7 @@ def _generate_by_set_infer_state(self, input_tokens, **generate_kwargs) -> torch
"""

# for testing, always use sharded model
assert self.sharded_model is not None, "sharded model does not exist"
assert self.model is not None, "sharded model does not exist"

batch_infer_state = self.prepare_batch_state(input_tokens)
assert batch_infer_state.max_len_in_batch <= self.max_input_len, "max length in batch exceeds limit"
Expand All @@ -248,14 +243,14 @@ def _generate_by_set_infer_state(self, input_tokens, **generate_kwargs) -> torch
# NOTE this is not a preferable way to pass BatchInferState during inference
# we might want to rewrite generate function (e.g. _generate_by_pass_infer_state)
# and pass BatchInferState via model forward
model = self.sharded_model
model = self.model
if isinstance(model, LlamaForCausalLM):
model = self.sharded_model.model
model = self.model.model
elif isinstance(model, BloomForCausalLM):
model = self.sharded_model.transformer
model = self.model.transformer
setattr(model, 'infer_state', batch_infer_state)

outputs = self.sharded_model.generate(**input_tokens, **generate_kwargs, early_stopping=False)
outputs = self.model.generate(**input_tokens, **generate_kwargs, early_stopping=False)

# NOTE In future development, we're going to let the scheduler to handle the cache,
# instead of freeing space explicitly at the end of generation
Expand Down
1 change: 0 additions & 1 deletion examples/inference/bench_bloom.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,6 @@ def bench_bloom(args):
# To benchmark torch original, comment out the line of optimizing model
shard_config = ShardConfig(enable_tensor_parallelism=True if args.tp_size > 1 else False, inference_only=True)
infer_engine = TPInferEngine(model, shard_config, max_batch_size, max_input_len, max_output_len)
infer_engine.optimize_model()

# prepare data for generation
generate_kwargs = dict(max_new_tokens=max_output_len, do_sample=False)
Expand Down
1 change: 0 additions & 1 deletion examples/inference/bench_llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,6 @@ def run_llama_test(args):

shard_config = ShardConfig(enable_tensor_parallelism=True if args.tp_size > 1 else False, inference_only=True)
infer_engine = TPInferEngine(model, shard_config, max_batch_size, max_input_len, max_output_len)
infer_engine.optimize_model()

generate_kwargs = dict(max_new_tokens=max_output_len, do_sample=False)
input_tokens = {
Expand Down
3 changes: 0 additions & 3 deletions tests/test_infer/test_bloom_infer.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,7 @@

import pytest
import torch
import torch.distributed as dist
from packaging import version
from transformers import AutoTokenizer, BloomForCausalLM

import colossalai
from colossalai.inference.tensor_parallel import TPInferEngine
Expand Down Expand Up @@ -35,7 +33,6 @@ def run(test_config):
shard_config = ShardConfig(enable_tensor_parallelism=True if test_config['tp_size'] > 1 else False,
inference_only=True)
infer_engine = TPInferEngine(orig_model, shard_config, MAX_BATCH_SIZE, MAX_INPUT_LEN, MAX_OUTPUT_LEN)
infer_engine.optimize_model()

generate_kwargs = dict(do_sample=False)
outputs = infer_engine.generate(data, **generate_kwargs)
Expand Down
83 changes: 20 additions & 63 deletions tests/test_infer/test_infer_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from colossalai.inference.tensor_parallel import TPInferEngine
from colossalai.inference.tensor_parallel.batch_infer_state import BatchInferState
from colossalai.logging import disable_existing_loggers
from colossalai.shardformer import ShardConfig, ShardFormer
from colossalai.shardformer import ShardConfig
from colossalai.testing import clear_cache_before_run, parameterize, rerun_if_address_is_in_use, spawn

TP_SIZE = 2
Expand All @@ -22,31 +22,25 @@
CUDA_SUPPORT = version.parse(torch.version.cuda) > version.parse('11.5')


@pytest.mark.skipif(not CUDA_SUPPORT, reason="kv-cache manager engine requires cuda version to be higher than 11.5")
def test_prepare_data():
# dummy module used for testing
class DummyModule(nn.Module):

def __init__(self, config):
super(DummyModule, self).__init__()
self.config = config

def forward(self, x):
return x

# dummy config used for testing
class DummyModelConfig:

def __init__(self):
self.hidden_size = 4096
self.num_attention_heads = 32
self.num_hidden_layers = 8
@parameterize('test_config', [{
'tp_size': TP_SIZE,
}])
def run(test_config):
model_config = BloomConfig(num_hidden_layers=4, hidden_size=128, intermediate_size=256, num_attention_heads=4)
model = BloomForCausalLM(model_config)
model = model.half()
model.to(torch.cuda.current_device())

dummy_config = DummyModelConfig()
model = DummyModule(dummy_config)
shard_config = ShardConfig(enable_tensor_parallelism=False)
# 1. check TPInferEngine init and model optimization
shard_config = ShardConfig(enable_tensor_parallelism=True if test_config['tp_size'] > 1 else False,
inference_only=True)
infer_engine = TPInferEngine(model, shard_config, MAX_BATCH_SIZE, MAX_INPUT_LEN, MAX_OUTPUT_LEN)

assert infer_engine.cache_manager is not None
assert infer_engine.tp_size == TP_SIZE
assert infer_engine.head_num == model_config.num_attention_heads // TP_SIZE

# 2. check data preparation
input_ids_list = [[80540, 15473, 3331, 11970, 90472, 361, 61335], [80540, 15473, 3331, 11970],
[80540, 15473, 3331, 11970], [80540, 15473]]
batch_size = len(input_ids_list)
Expand Down Expand Up @@ -74,49 +68,14 @@ def __init__(self):
# assert torch.equal(batch_state_out1.start_loc.to(start_loc.device), start_loc)
# assert torch.equal(batch_state_out2.start_loc.to(start_loc.device), start_loc)


@pytest.mark.skipif(not CUDA_SUPPORT, reason="kv-cache manager engine requires cuda version to be higher than 11.5")
def test_orig_generate():
# 3. check optimized model generate
input_ids = torch.randint(low=10, high=1000, size=(MAX_BATCH_SIZE, MAX_INPUT_LEN))

model_config = LlamaConfig()
model = LlamaForCausalLM(model_config)
model = model.half()
model.to(torch.cuda.current_device())

shard_config = ShardConfig(enable_tensor_parallelism=False)

# init TPInferEngine
infer_engine = TPInferEngine(model, shard_config, MAX_BATCH_SIZE, MAX_INPUT_LEN, MAX_OUTPUT_LEN)

# original model generate
generate_kwargs = dict(do_sample=False)
infer_engine.generate(input_ids, **generate_kwargs)

torch.cuda.empty_cache()


@parameterize('test_config', [{
'tp_size': TP_SIZE,
}])
def run(test_config):
model_config = BloomConfig()
model = BloomForCausalLM(model_config)
model = model.half()
model.to(torch.cuda.current_device())

shard_config = ShardConfig(enable_tensor_parallelism=True if test_config['tp_size'] > 1 else False,
inference_only=True)
infer_engine = TPInferEngine(model, shard_config, MAX_BATCH_SIZE, MAX_INPUT_LEN, MAX_OUTPUT_LEN)
infer_engine.optimize_model()

assert infer_engine.cache_manager is not None
assert infer_engine.tp_size == TP_SIZE
assert infer_engine.head_num == model_config.num_attention_heads // TP_SIZE

torch.cuda.empty_cache()


def check_engine(rank, world_size, port):
disable_existing_loggers()
colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
Expand All @@ -127,11 +86,9 @@ def check_engine(rank, world_size, port):
@pytest.mark.dist
@rerun_if_address_is_in_use()
@clear_cache_before_run()
def test_engine_infer():
def test_engine():
spawn(check_engine, TP_SIZE)


if __name__ == '__main__':
test_prepare_data()
test_orig_generate()
test_engine_infer()
test_engine()
3 changes: 0 additions & 3 deletions tests/test_infer/test_llama_infer.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,7 @@

import pytest
import torch
import torch.distributed as dist
from packaging import version
from transformers import LlamaForCausalLM, LlamaTokenizer

import colossalai
from colossalai.inference.tensor_parallel.engine import TPInferEngine
Expand Down Expand Up @@ -61,7 +59,6 @@ def run_llama_test(test_config):
shard_config = ShardConfig(enable_tensor_parallelism=True if test_config['tp_size'] > 1 else False,
inference_only=True)
infer_engine = TPInferEngine(orig_model, shard_config, BATCH_SIZE, MAX_INPUT_LEN, MAX_OUTPUT_LEN)
infer_engine.optimize_model()

generate_kwargs = dict(do_sample=False)
outputs = infer_engine.generate(data, **generate_kwargs)
Expand Down