Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
35 changes: 30 additions & 5 deletions colossalai/inference/tensor_parallel/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,13 @@

DP_AXIS, PP_AXIS, TP_AXIS = 0, 1, 2

_supported_models = ["LlamaForCausalLM", "LlamaModel", "BloomForCausalLM"]
_supported_models = [
"LlamaForCausalLM",
"LlamaModel",
"BloomForCausalLM",
"ChatGLMModel",
"ChatGLMForConditionalGeneration",
]


class TPInferEngine:
Expand Down Expand Up @@ -63,7 +69,13 @@ def __init__(

self.head_dim = model.config.hidden_size // model.config.num_attention_heads
self.head_num = model.config.num_attention_heads
self.layer_num = model.config.num_hidden_layers
num_hidden_layers = (
model.config.num_hidden_layers if hasattr(model.config, "num_hidden_layers") else model.config.num_layers
)
self.layer_num = num_hidden_layers
self.multi_query_group_num = (
model.config.multi_query_group_num if hasattr(model.config, "multi_query_group_num") else 0
)

self.tp_size = -1 # to be set with given shard config in self.prepare_shard_config
self.cache_manager = None
Expand All @@ -77,9 +89,22 @@ def _init_manager(self) -> None:
assert self.tp_size >= 1, "TP size not initialized without providing a valid ShardConfig"
assert self.head_num % self.tp_size == 0, f"Cannot shard {self.head_num} heads with tp size {self.tp_size}"
self.head_num //= self.tp_size # update sharded number of heads
self.cache_manager = MemoryManager(
self.max_total_token_num, self.dtype, self.head_num, self.head_dim, self.layer_num
)
if self.multi_query_group_num:
Comment thread
CjhHa1 marked this conversation as resolved.
# NOTE the logic of MQA tensor parallelism should be specified.
assert (
self.multi_query_group_num % self.tp_size == 0
), f"Cannot shard {self.multi_query_group_num} query groups with tp size {self.tp_size}"
self.cache_manager = MemoryManager(
self.max_total_token_num,
self.dtype,
self.multi_query_group_num // self.tp_size,
self.head_dim,
self.layer_num,
)
else:
self.cache_manager = MemoryManager(
self.max_total_token_num, self.dtype, self.head_num, self.head_dim, self.layer_num
)

def _optimize_model(self, model: nn.Module) -> None:
"""
Expand Down
5 changes: 4 additions & 1 deletion colossalai/inference/tensor_parallel/modeling/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,7 @@
import _utils

from .bloom import BloomInferenceForwards
from .chatglm2 import ChatGLM2InferenceForwards
from .llama import LlamaInferenceForwards

__all__ = ["BloomInferenceForwards", "LlamaInferenceForwards"]
__all__ = ["BloomInferenceForwards", "LlamaInferenceForwards", "ChatGLM2InferenceForwards"]
10 changes: 10 additions & 0 deletions colossalai/inference/tensor_parallel/modeling/_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
"""
Utils for model inference
"""
from colossalai.kernel.triton.copy_kv_cache_dest import copy_kv_cache_to_dest


def copy_kv_to_mem_cache(layer_id, key_buffer, value_buffer, context_mem_index, mem_manager):
copy_kv_cache_to_dest(key_buffer, context_mem_index, mem_manager.key_buffer[layer_id])
copy_kv_cache_to_dest(value_buffer, context_mem_index, mem_manager.value_buffer[layer_id])
return
Loading