Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
440 changes: 440 additions & 0 deletions colossalai/inference/continous_batching/layers/attention.py

Large diffs are not rendered by default.

200 changes: 167 additions & 33 deletions colossalai/inference/tensor_parallel/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,31 +2,61 @@

import torch
import torch.nn as nn
from transformers import BloomForCausalLM, LlamaForCausalLM
import warnings
from transformers import (
AutoConfig,
AutoTokenizer,
BloomForCausalLM,
LlamaForCausalLM,
LlamaTokenizer,
PreTrainedTokenizer,
PreTrainedTokenizerFast,
)
from transformers.generation import GenerationConfig
from transformers.generation.stopping_criteria import StoppingCriteriaList
from transformers.tokenization_utils_base import BatchEncoding
try:
from vllm import LLM
from vllm.outputs import RequestOutput
from vllm.sampling_params import SamplingParams

USE_CONTINOUS_BATCHING = True

except ImportError:
warnings.warn("vllm is not installed, continuous batching will not be supported.")
USE_CONTINOUS_BATCHING = False
RequestOutput = None

from colossalai.shardformer import ShardConfig, ShardFormer
from colossalai.shardformer.policies.auto_policy import get_autopolicy

from .batch_infer_state import BatchInferState
from .kvcache_manager import MemoryManager
from .utils import init_to_get_rotary, replace_page_attention

DP_AXIS, PP_AXIS, TP_AXIS = 0, 1, 2

_supported_models = ['LlamaForCausalLM', 'LlamaModel', 'BloomForCausalLM']
_supported_models = {
"LlamaForCausalLM": LlamaForCausalLM,
"LLaMAForCausalLM": LlamaForCausalLM,
"LlamaModel": LlamaForCausalLM,
"BloomForCausalLM": BloomForCausalLM
}


class TPInferEngine:
"""Engine class for tensor parallel inference.

Args:
model (Module): original model, e.g. huggingface CausalLM
model (str, nn.Module): The name, path or instance of a HuggingFace Transformers model.
shard_config (ShardConfig): The config for sharding original model
max_batch_size (int): maximum batch size
max_input_len (int): maximum input length of sequence
max_output_len (int): maximum output length of output tokens
trust_remote_code (bool): Trust remote code (e.g., from HuggingFace) when
downloading the model and tokenizer.
tokenizer (str, PreTrainedTokenizer, PreTrainedTokenizerFast): The name, path or instance of a HuggingFace Transformers tokenizer.
use_continous_batching (bool): whether to use continous_batching
dtype (torch.dtype): datatype used to init KV cache space
device (str): device the KV cache of engine to be initialized on

Expand All @@ -40,13 +70,48 @@ class TPInferEngine:
"""

def __init__(self,
model: nn.Module,
shard_config: ShardConfig,
max_batch_size: int,
max_input_len: int,
max_output_len: int,
model: Union[str, nn.Module],
shard_config: ShardConfig = None,
max_batch_size: int = 8,
max_input_len: int = 16,
max_output_len: int = 8,
trust_remote_code: bool = False,
use_continous_batching: bool = False,
tokenizer: Optional[Union[str, PreTrainedTokenizer, PreTrainedTokenizerFast]] = None,
dtype: torch.dtype = torch.float16,
device: str = 'cuda') -> None:

if tokenizer is None:
print("model: ", model)
assert isinstance(model, str), \
"when tokenizer is None, model must be string."
tokenizer = model

self.tp_size = 1
self.use_continous_batching = use_continous_batching and USE_CONTINOUS_BATCHING

if shard_config != None and shard_config.enable_tensor_parallelism:
self.tp_size = shard_config.tensor_parallel_size

if self.use_continous_batching:
assert isinstance(model, str) and isinstance(tokenizer, str), \
"when using continous_batching, model and tokenizer must be string."
self.llm_engine = LLM(model=model,
tokenizer=tokenizer,
trust_remote_code=trust_remote_code,
tensor_parallel_size=self.tp_size)
#TODO We will replace multiple models' attention forward with shardformer in vllm to achieve multi-stream optimization later.
# kv_cache_stream = torch.cuda.Stream()
# self.model = replace_page_attention(self.llm_engine.llm_engine.workers[0].model, kv_cache_stream)
self.model = self.llm_engine.llm_engine.workers[0].model
else:
self.model, self.tokenizer = self._get_model_and_tokenizer(model, tokenizer, trust_remote_code)
self.tokenizer.pad_token = self.tokenizer.eos_token

self.model = self.model.half()
self.model = self.model.to(device)
self.shard_config = shard_config

self.max_batch_size = max_batch_size
self.max_input_len = max_input_len
self.max_output_len = max_output_len
Expand All @@ -59,17 +124,47 @@ def __init__(self,

self.dtype = dtype

self.head_dim = model.config.hidden_size // model.config.num_attention_heads
self.head_num = model.config.num_attention_heads
self.layer_num = model.config.num_hidden_layers
self.head_dim = self.model.config.hidden_size // self.model.config.num_attention_heads
self.head_num = self.model.config.num_attention_heads
self.layer_num = self.model.config.num_hidden_layers

self.tp_size = -1 # to be set with given shard config in self.prepare_shard_config
self.cache_manager = None

self.shard_config = shard_config
self.model = None
# optimize the original model by sharding with ShardFormer
self._optimize_model(model=model.to(device))

self._optimize_model()

def _get_model_and_tokenizer(self, model: str, tokenizer: str, trust_remote_code: bool) -> nn.Module:

if isinstance(model, nn.Module) and isinstance(tokenizer, (PreTrainedTokenizer, PreTrainedTokenizerFast)):
return model, tokenizer

supported_model = model

try:
config = AutoConfig.from_pretrained(model, trust_remote_code=trust_remote_code)
except ValueError as e:
if (not trust_remote_code and "requires you to execute the configuration file" in str(e)):
err_msg = ("Failed to load the model config. If the model is a custom "
"model not yet available in the HuggingFace transformers "
"library, consider setting `trust_remote_code=True` in LLM "
"or using the `--trust-remote-code` flag in the CLI.")
raise RuntimeError(err_msg) from e
else:
raise e
architectures = getattr(config, "architectures", [])
for arch in architectures:
if arch in _supported_models:
if isinstance(tokenizer, str):
if arch == "LlamaForCausalLM" or arch == "LLaMAForCausalLM" or arch == "LlamaModel":
tokenizer = LlamaTokenizer.from_pretrained(model, trust_remote_code=trust_remote_code)
else:
tokenizer = AutoTokenizer.from_pretrained(model, trust_remote_code=trust_remote_code)
if isinstance(model, str):
supported_model = _supported_models[arch].from_pretrained(model,
pad_token_id=tokenizer.eos_token_id)

return supported_model, tokenizer
raise ValueError(f"Model architectures {architectures} are not supported for now. "
f"Supported architectures: {self.supported_models}")

def _init_manager(self) -> None:
assert self.tp_size >= 1, "TP size not initialized without providing a valid ShardConfig"
Expand All @@ -78,16 +173,17 @@ def _init_manager(self) -> None:
self.cache_manager = MemoryManager(self.max_total_token_num, self.dtype, self.head_num, self.head_dim,
self.layer_num)

def _optimize_model(self, model: nn.Module) -> None:
def _optimize_model(self) -> None:
"""
Optimize the original model by sharding with ShardFormer.
In further generation, use the sharded model instead of original model.
"""
# NOTE we will change to use an inference config later with additional attrs we want
assert self.shard_config.inference_only is True
shardformer = ShardFormer(shard_config=self.shard_config)
self._prepare_with_shard_config(shard_config=self.shard_config)
self._shard_model_by(shardformer, model)
if not self.use_continous_batching:
assert self.shard_config.inference_only is True
shardformer = ShardFormer(shard_config=self.shard_config)
self._prepare_with_shard_config(shard_config=self.shard_config)
self._shard_model_by(shardformer)

def _prepare_with_shard_config(self, shard_config: Optional[ShardConfig] = None) -> ShardConfig:
""" Prepare the engine with a given ShardConfig.
Expand All @@ -96,7 +192,6 @@ def _prepare_with_shard_config(self, shard_config: Optional[ShardConfig] = None)
shard_config (ShardConfig): shard config given to specify settings of the engine.
If not provided, a default ShardConfig with tp size 1 will be created.
"""
self.tp_size = 1
if shard_config is None:
shard_config = ShardConfig(
tensor_parallel_process_group=None,
Expand All @@ -111,37 +206,76 @@ def _prepare_with_shard_config(self, shard_config: Optional[ShardConfig] = None)
else:
shard_config.inference_only = True
shard_config.pipeline_stage_manager = None
if shard_config.enable_tensor_parallelism:
self.tp_size = shard_config.tensor_parallel_size
self._init_manager()

return shard_config

def _shard_model_by(self, shardformer: ShardFormer, model: nn.Module) -> None:
def _shard_model_by(self, shardformer: ShardFormer) -> None:
""" Shard original model by the given ShardFormer and store the sharded model. """
assert self.tp_size == shardformer.shard_config.tensor_parallel_size, \
"Discrepancy between the tp size of TPInferEngine and the tp size of shard config"
model_name = model.__class__.__name__
model_name = self.model.__class__.__name__
assert model_name in self.supported_models, f"Unsupported model cls {model_name} for TP inference."
policy = get_autopolicy(model, inference_only=True)
self.model, _ = shardformer.optimize(model, policy)
if model_name == "LlamaForCausalLM":
init_to_get_rotary(self.model.model, base=10000)
policy = get_autopolicy(self.model, inference_only=True)
self.model, _ = shardformer.optimize(self.model, policy)
self.model = self.model.cuda()

@property
def supported_models(self) -> List[str]:
return _supported_models
return list(_supported_models.keys())

def generate(self, input_tokens: Union[BatchEncoding, dict, list, torch.Tensor], **generate_kwargs) -> torch.Tensor:
def generate(self,
prompts: Optional[Union[str, List[str]]] = None,
prompt_token_ids: Optional[Union[BatchEncoding, dict, list, torch.Tensor]] = None,
**generate_kwargs) -> Union[List[RequestOutput], torch.Tensor]:
"""Generate token sequence.

Args:
input_tokens: could be one of the following types
prompts: A list of prompts to generate completions for.
prompt_token_ids: could be one of the following types
1. BatchEncoding or dict (e.g. tokenizer batch_encode)
2. list of input token ids (e.g. appended result of tokenizer encode)
3. torch.Tensor (e.g. tokenizer encode with return_tensors='pt')
If None, we use the tokenizer to convert the prompts to token IDs.
Returns:
torch.Tensor: The returned sequence is given inputs + generated_tokens.
Union[List[RequestOutput], torch.Tensor]: The returned sequence is given inputs + generated_tokens.
"""

if prompts is None and prompt_token_ids is None:
raise ValueError("Either prompts or prompt_token_ids must be "
"provided.")
if isinstance(prompts, str):
# Convert a single prompt to a list.
prompts = [prompts]
if prompts is not None and prompt_token_ids is not None:
if isinstance(prompt_token_ids, (BatchEncoding, dict)):
prompt_token_len = len(prompt_token_ids['input_ids'])
elif isinstance(prompt_token_ids, torch.Tensor):
prompt_token_len = prompt_token_ids.shape[0]
else:
prompt_token_len = len(prompt_token_ids)
if len(prompts) != prompt_token_len:
raise ValueError("The lengths of prompts and prompt_token_ids "
"must be the same.")

if self.use_continous_batching:
if not isinstance(prompt_token_ids, list):
raise TypeError(f"prompt_token_ids type must be list, when using continous batching.")
sampling_params = SamplingParams(temperature=0.0, max_tokens=self.max_output_len)
return self.llm_engine.generate(prompts=prompts,
prompt_token_ids=prompt_token_ids,
sampling_params=sampling_params)

if prompt_token_ids is None:
input_tokens = self.tokenizer.batch_encode_plus(prompts, return_tensors="pt", padding=True)
else:
input_tokens = prompt_token_ids
Comment thread
isky-cd marked this conversation as resolved.

if isinstance(input_tokens, list):
input_tokens = torch.Tensor(input_tokens).cuda()

if isinstance(input_tokens, torch.Tensor):
input_tokens = dict(input_ids=input_tokens, attention_mask=torch.ones_like(input_tokens, dtype=torch.bool))
for t in input_tokens:
Expand Down
52 changes: 52 additions & 0 deletions colossalai/inference/tensor_parallel/utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
import types
import warnings

import torch
try:
from vllm.model_executor.models.llama import LlamaAttention
VLLM_INSTALLED = True
except ImportError:
warnings.warn("vllm is not installed, PageAttention will not be replaced.")
VLLM_INSTALLED = False

def init_to_get_rotary(self, base=10000):
self.config.head_dim_ = self.config.hidden_size // self.config.num_attention_heads
if not hasattr(self.config, "rope_scaling"):
rope_scaling_factor = 1.0
else:
rope_scaling_factor = self.config.rope_scaling.factor if self.config.rope_scaling is not None else 1.0
if hasattr(self.config, "max_sequence_length"):
max_seq_len = self.config.max_sequence_length
elif hasattr(self.config, "max_position_embeddings"):
max_seq_len = self.config.max_position_embeddings * rope_scaling_factor
else:
max_seq_len = 2048 * rope_scaling_factor
base = float(base)
inv_freq = 1.0 / (base**(torch.arange(0, self.config.head_dim_, 2, device="cpu", dtype=torch.float32) /
self.config.head_dim_))
t = torch.arange(max_seq_len + 1024 * 64, device="cpu", dtype=torch.float32) / rope_scaling_factor
freqs = torch.outer(t, inv_freq)

self._cos_cached = torch.cos(freqs).to(torch.float16).cuda()
self._sin_cached = torch.sin(freqs).to(torch.float16).cuda()
return


def replace_page_attention(model, kv_cache_stream):
if VLLM_INSTALLED:

from colossalai.inference.continous_batching.layers.attention import PagedAttentionWithRoPE

layers = model.model.layers
for i in range(len(layers)):
layer = layers[i]
if isinstance(layer.self_attn, LlamaAttention) is True:
attn = PagedAttentionWithRoPE(layer.self_attn.num_heads,
layer.self_attn.head_dim,
layer.self_attn.scaling,
rotary_dim=layer.self_attn.head_dim,
num_kv_heads=layer.self_attn.num_kv_heads,
kv_cache_stream=kv_cache_stream)
setattr(layer.self_attn, 'attn', attn)

return model
Loading