Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 2 additions & 4 deletions colossalai/inference/tensor_parallel/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,4 @@
from .modeling.llama import LlamaInferenceForwards
from .pollcies.llama import LlamaModelInferPolicy
from .engine import TPInferEngine
Comment thread
tiandiao123 marked this conversation as resolved.
from .kvcache_manager import MemoryManager
__all__ = ['LlamaInferenceForwards', 'LlamaModelInferPolicy', 'MemoryManager', 'TPInferEngine']

__all__ = ['MemoryManager', 'TPInferEngine']
9 changes: 2 additions & 7 deletions colossalai/inference/tensor_parallel/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,7 +141,6 @@ def generate_by_set_infer_state(self, input_tokens, generate_kwargs) -> torch.Te

outputs = self.sharded_model.generate(**input_tokens, **generate_kwargs, early_stopping=False)

print(f"outputs.shape {outputs.shape}")
return outputs

def prepare_batch_state(self, inputs) -> BatchInferState:
Expand Down Expand Up @@ -193,11 +192,7 @@ def prepare_batch_state(self, inputs) -> BatchInferState:
start_index += curr_seq_len
max_len_in_batch = curr_seq_len if curr_seq_len > max_len_in_batch else max_len_in_batch

print(" 666 ", max_len_in_batch)

block_loc = torch.empty((batch_size, self.max_input_len + self.max_output_len),
dtype=torch.long,
device='cuda')
block_loc = torch.empty((batch_size, self.max_input_len + self.max_output_len), dtype=torch.long, device='cuda')
batch_infer_state = BatchInferState(batch_size, max_len_in_batch)
batch_infer_state.seq_len = seq_lengths.to('cuda') # might want to assign specific device
batch_infer_state.start_loc = seq_start_indexes.to('cuda')
Expand Down Expand Up @@ -251,4 +246,4 @@ def update_batch_state(self, infer_state: Optional[BatchInferState]) -> None:
# => put information already recorded in batchinferstate and pass it to model forward
# => clear records in engine
def add_request():
raise NotImplementedError()
raise NotImplementedError()
3 changes: 2 additions & 1 deletion colossalai/inference/tensor_parallel/modeling/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from .bloom import BloomInferenceForwards
from .llama import LlamaInferenceForwards

__all__ = ['LlamaInferenceForwards']
__all__ = ['BloomInferenceForwards', 'LlamaInferenceForwards']
541 changes: 541 additions & 0 deletions colossalai/inference/tensor_parallel/modeling/bloom.py

Large diffs are not rendered by default.

4 changes: 4 additions & 0 deletions colossalai/inference/tensor_parallel/policies/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
from .bloom import BloomModelInferPolicy
from .llama import LlamaModelInferPolicy

__all__ = ['BloomModelInferPolicy', 'LlamaModelInferPolicy']
44 changes: 44 additions & 0 deletions colossalai/inference/tensor_parallel/policies/bloom.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
from colossalai.shardformer.policies.bloom import BloomForCausalLMPolicy

from ..modeling.bloom import BloomInferenceForwards


class BloomModelInferPolicy(BloomForCausalLMPolicy):

def __init__(self) -> None:
super().__init__()

def module_policy(self):
from transformers.models.bloom.modeling_bloom import BloomAttention, BloomBlock, BloomForCausalLM, BloomModel
policy = super().module_policy()
# NOTE set inference mode to shard config
self.shard_config._infer()

if self.shard_config.enable_tensor_parallelism:

method_replacement = {
'forward':
BloomInferenceForwards.bloom_for_causal_lm_forward,
'prepare_inputs_for_generation':
BloomInferenceForwards.bloom_for_causal_lm_prepare_inputs_for_generation
}
self.append_or_create_method_replacement(description=method_replacement,
policy=policy,
target_key=BloomForCausalLM)

method_replacement = {'forward': BloomInferenceForwards.bloom_model_forward}
self.append_or_create_method_replacement(description=method_replacement,
policy=policy,
target_key=BloomModel)

method_replacement = {'forward': BloomInferenceForwards.bloom_block_forward}
self.append_or_create_method_replacement(description=method_replacement,
policy=policy,
target_key=BloomBlock)

method_replacement = {'forward': BloomInferenceForwards.bloom_attention_forward}
self.append_or_create_method_replacement(description=method_replacement,
policy=policy,
target_key=BloomAttention)

return policy
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,8 @@

from colossalai.shardformer.policies.llama import LlamaForCausalLMPolicy

from ..modeling.llama import LlamaInferenceForwards
from ..modeling.llama import LlamaInferenceForwards


class LlamaModelInferPolicy(LlamaForCausalLMPolicy):

Expand All @@ -23,13 +24,17 @@ def module_policy(self):
infer_forward = LlamaInferenceForwards.llama_model_forward
method_replacement = {'forward': partial(infer_forward)}
self.append_or_create_method_replacement(description=method_replacement, policy=policy, target_key=LlamaModel)

infer_forward = LlamaInferenceForwards.llama_decoder_layer_forward
method_replacement = {'forward': partial(infer_forward)}
self.append_or_create_method_replacement(description=method_replacement, policy=policy, target_key=LlamaDecoderLayer)

self.append_or_create_method_replacement(description=method_replacement,
policy=policy,
target_key=LlamaDecoderLayer)

infer_forward = LlamaInferenceForwards.llama_flash_attn_kvcache_forward
method_replacement = {'forward': partial(infer_forward)}
self.append_or_create_method_replacement(description=method_replacement, policy=policy, target_key=LlamaAttention)
self.append_or_create_method_replacement(description=method_replacement,
policy=policy,
target_key=LlamaAttention)

return policy
return policy
3 changes: 0 additions & 3 deletions colossalai/inference/tensor_parallel/pollcies/__init__.py

This file was deleted.

8 changes: 6 additions & 2 deletions colossalai/shardformer/policies/auto_policy.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,16 +137,20 @@ class PolicyLocation:
PolicyLocation(file_name="llama", class_name="LlamaModelInferPolicy"),
"transformers.models.llama.modeling_llama.LlamaForCausalLM":
PolicyLocation(file_name="llama", class_name="LlamaModelInferPolicy"),
# Bloom
"transformers.models.bloom.modeling_bloom.BloomModel":
PolicyLocation(file_name="bloom", class_name="BloomModelInferPolicy"),
"transformers.models.bloom.modeling_bloom.BloomForCausalLM":
PolicyLocation(file_name="bloom", class_name="BloomModelInferPolicy"),
}


def import_policy(policy_location: PolicyLocation, inference_only: Optional[bool] = False) -> Policy:
"""
Dynamically import a Policy class based on the policy location.
"""

if inference_only:
module_name = f"colossalai.inference.tensor_parallel.pollcies.{policy_location.file_name}"
module_name = f"colossalai.inference.tensor_parallel.policies.{policy_location.file_name}"
else:
module_name = f"colossalai.shardformer.policies.{policy_location.file_name}"
module = importlib.import_module(module_name)
Expand Down
60 changes: 60 additions & 0 deletions tests/test_infer/test_bloom_infer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
import pytest
import torch
import torch.distributed as dist
from transformers import AutoModelForCausalLM, AutoTokenizer, BloomForCausalLM

import colossalai
from colossalai.inference.tensor_parallel import TPInferEngine
from colossalai.logging import disable_existing_loggers
from colossalai.shardformer import ShardConfig, ShardFormer
from colossalai.testing import clear_cache_before_run, rerun_if_address_is_in_use, spawn

TP_SIZE = 2
MAX_BATCH_SIZE = 4
MAX_INPUT_LEN = 16
MAX_OUTPUT_LEN = 32


def run():

model_path = "/data3/data/model_eval_for_commerical_use/phoenix-inst-chat-7b"
tokenizer = AutoTokenizer.from_pretrained(model_path)
tokenizer.pad_token = tokenizer.eos_token

text = "Introduce some landmarks in Beijing"
input_ids = tokenizer.batch_encode_plus([text], return_tensors='pt')

model = BloomForCausalLM.from_pretrained(model_path, pad_token_id=tokenizer.eos_token_id)
model = model.half()
model.to(torch.cuda.current_device())

shard_config = ShardConfig(enable_tensor_parallelism=True, inference_only=True)
shardformer = ShardFormer(shard_config=shard_config)

infer_engine = TPInferEngine(model, MAX_BATCH_SIZE, MAX_INPUT_LEN, MAX_OUTPUT_LEN)
infer_engine.prepare_with_shard_config(shard_config=shard_config)
infer_engine.shard_model_by(shardformer)

generate_kwargs = dict(do_sample=False)
outputs = infer_engine.generate(input_ids, generate_kwargs)

if not dist.is_initialized() or dist.get_rank() == 0:
output_text = tokenizer.decode(outputs[0])
print(output_text)


def check_engine(rank, world_size, port):
disable_existing_loggers()
colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
run()


@pytest.mark.dist
@rerun_if_address_is_in_use()
@clear_cache_before_run()
def test_engine_infer():
spawn(check_engine, TP_SIZE)


if __name__ == '__main__':
test_engine_infer()