Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions colossalai/inference/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from .hybridengine import CaiInferEngine
from .hybridengine.polices import LlamaModelInferPolicy
from .hybridengine.polices import BloomModelInferPolicy, LlamaModelInferPolicy

__all__ = ["CaiInferEngine", "LlamaModelInferPolicy"]
__all__ = ["CaiInferEngine", "LlamaModelInferPolicy", "BloomModelInferPolicy"]
21 changes: 15 additions & 6 deletions colossalai/inference/hybridengine/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@

_supported_models = [
"LlamaForCausalLM",
"BloomForCausalLM",
]


Expand Down Expand Up @@ -155,12 +156,20 @@ def _shardformer(self, model, model_policy, stage_manager, tp_group):

def _init_manager(self, model, max_batch_size: int, max_input_len: int, max_output_len: int) -> None:
max_total_token_num = max_batch_size * (max_input_len + max_output_len)
head_dim = model.config.hidden_size // model.config.num_attention_heads
head_num = model.config.num_attention_heads
num_hidden_layers = (
model.config.num_hidden_layers if hasattr(model.config, "num_hidden_layers") else model.config.num_layers
)
layer_num = num_hidden_layers // self.pp_size
if model.config.model_type == "llama":
head_dim = model.config.hidden_size // model.config.num_attention_heads
head_num = model.config.num_attention_heads // self.tp_size
num_hidden_layers = (
model.config.num_hidden_layers
if hasattr(model.config, "num_hidden_layers")
else model.config.num_layers
)
layer_num = num_hidden_layers // self.pp_size
elif model.config.model_type == "bloom":
head_dim = model.config.hidden_size // model.config.n_head
head_num = model.config.n_head // self.tp_size
num_hidden_layers = model.config.n_layer
layer_num = num_hidden_layers // self.pp_size

cache_manager = MemoryManager(max_total_token_num, self.dtype, head_num, head_dim, layer_num)
return cache_manager
Loading