Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 6 additions & 5 deletions colossalai/inference/tensor_parallel/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,14 +77,15 @@ def __init__(
)
self.layer_num = num_hidden_layers

self.multi_query_group_num = 0
self.multi_query_group_num = model.config.num_attention_heads
# default to attention_heads
self.multi_query_attention = model.config.multi_query_attention

if hasattr(model.config, "multi_query_group_num"):
self.multi_query_group_num = model.config.multi_query_group_num

if hasattr(model.config, "num_key_value_heads"):
self.multi_query_group_num = model.config.num_key_value_heads

self.tp_size = -1 # to be set with given shard config in self.prepare_shard_config
self.cache_manager = None

Expand All @@ -107,7 +108,7 @@ def _init_manager(self) -> None:
assert self.head_num % self.tp_size == 0, f"Cannot shard {self.head_num} heads with tp size {self.tp_size}"
self.head_num //= self.tp_size # update sharded number of heads

if self.multi_query_group_num:
if self.multi_query_attention:
# NOTE the logic of MQA tensor parallelism should be specified.
assert (
self.multi_query_group_num % self.tp_size == 0
Expand Down Expand Up @@ -218,7 +219,7 @@ def _shard_model_by(self, shardformer: ShardFormer, model: nn.Module) -> None:
), "Discrepancy between the tp size of TPInferEngine and the tp size of shard config"
model_name = model.__class__.__name__
assert model_name in self.supported_models, f"Unsupported model cls {model_name} for TP inference."

model = model.model if self.shard_config.inference_gptq else model
policy = get_autopolicy(model, shard_config=self.shard_config)

Expand Down Expand Up @@ -311,7 +312,7 @@ def prepare_batch_state(self, inputs) -> BatchInferState:
seq_start_indexes[i] = start_index
start_index += curr_seq_len
max_len_in_batch = curr_seq_len if curr_seq_len > max_len_in_batch else max_len_in_batch

block_loc = torch.empty((batch_size, self.max_input_len + self.max_output_len), dtype=torch.long, device="cuda")
batch_infer_state = BatchInferState(batch_size, max_len_in_batch)
batch_infer_state.seq_len = seq_lengths.to("cuda")
Expand Down
12 changes: 5 additions & 7 deletions colossalai/inference/tensor_parallel/modeling/chatglm2.py
Original file line number Diff line number Diff line change
Expand Up @@ -395,9 +395,9 @@ def chatglm_flash_attn_kvcache_forward(
assert use_cache is True, "use_cache should be set to True using this chatglm attention"
# hidden_states: original :[sq, b, h] --> this [b, sq, h]
batch_size = hidden_states.shape[0]
hidden_size = hidden_states.shape[-1]
# Attention heads [sq, b, h] --> [sq, b, (np * 3 * hn)]
mixed_x_layer = self.query_key_value(hidden_states)

if self.multi_query_attention:
(query_layer, key_layer, value_layer) = mixed_x_layer.split(
[
Expand Down Expand Up @@ -437,7 +437,6 @@ def chatglm_flash_attn_kvcache_forward(
mixed_x_layer = mixed_x_layer.view(*new_tensor_shape)
# [sq, b, np, 3 * hn] --> 3 [sq, b, np, hn]
(query_layer, key_layer, value_layer) = split_tensor_along_last_dim(mixed_x_layer, 3)

cos, sin = infer_state.position_cos, infer_state.position_sin

chatglm2_rotary_emb_fwd(
Expand Down Expand Up @@ -466,19 +465,18 @@ def chatglm_flash_attn_kvcache_forward(
value_layer = value_layer.reshape(
-1, self.num_multi_query_groups_per_partition, self.hidden_size_per_attention_head
)

if infer_state.is_context_stage:
# first token generation:
# copy key and value calculated in current step to memory manager

copy_kv_to_mem_cache(
infer_state.decode_layer_id,
key_layer,
value_layer,
infer_state.context_mem_index,
infer_state.cache_manager,
)

attn_output = torch.empty_like(query_layer.view(-1, self.projection_size))
attn_output = torch.empty_like(query_layer.contiguous().view(-1, self.projection_size))
Comment thread
tiandiao123 marked this conversation as resolved.

# NOTE: no bug in context attn fwd (del it )
lightllm_llama2_context_attention_fwd(
Expand Down Expand Up @@ -514,7 +512,7 @@ def chatglm_flash_attn_kvcache_forward(
)

# second token and follows
attn_output = torch.empty_like(query_layer.view(-1, self.projection_size))
attn_output = torch.empty_like(query_layer.contiguous().view(-1, self.projection_size))
cache_k = infer_state.cache_manager.key_buffer[infer_state.decode_layer_id][
: infer_state.decode_mem_end, :, :
]
Expand Down Expand Up @@ -542,6 +540,6 @@ def chatglm_flash_attn_kvcache_forward(
# =================
# Output:[b,sq, h]
# =================
output = self.dense(attn_output).reshape(batch_size, -1, hidden_size)

output = self.dense(attn_output).reshape(batch_size, -1, self.projection_size)
return output, kv_cache
5 changes: 4 additions & 1 deletion colossalai/inference/tensor_parallel/policies/chatglm2.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,10 @@ def module_policy(self):
self.append_or_create_method_replacement(
description=method_replacement, policy=policy, target_key=SelfAttention
)

if self.shard_config.enable_tensor_parallelism:
policy[GLMBlock].attribute_replacement["self_attention.num_multi_query_groups_per_partition"] = (
self.model.config.multi_query_group_num // self.shard_config.tensor_parallel_size
)
# for rmsnorm and others, we need to check the shape
return policy

Expand Down
1 change: 0 additions & 1 deletion colossalai/shardformer/layer/linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -149,7 +149,6 @@ def from_native_module(
out_features = module.out_features
bias = module.bias is not None
device = module.weight.device

# ensure only one process group is passed
if isinstance(process_group, (list, tuple)):
assert len(process_group) == 1, f"Expected only one process group, got {len(process_group)}."
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -400,7 +400,6 @@ def __init__(self, config: ChatGLMConfig, layer_number, device=None):
)

self.core_attention = CoreAttention(config, self.layer_number)

# Output.
self.dense = nn.Linear(
self.projection_size,
Expand Down
1 change: 0 additions & 1 deletion colossalai/shardformer/policies/chatglm2.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,6 @@ def module_policy(self) -> Dict[Union[str, nn.Module], ModulePolicyDescription]:
),
],
)

# optimization configuration
self.append_or_create_submodule_replacement(
description=[
Expand Down
1 change: 0 additions & 1 deletion colossalai/shardformer/shard/sharder.py
Original file line number Diff line number Diff line change
Expand Up @@ -180,7 +180,6 @@ def _replace_sub_module(
assert target_module is not None, "target_module should not be None"

native_sub_module = getattr_(org_layer, suffix, ignore=True)

# Skip replacement if submodule is not kept by current device when pipeline parallel is enabled.
if (include is not None) and (native_sub_module is not None) and (native_sub_module not in include):
continue
Expand Down
12 changes: 8 additions & 4 deletions tests/test_infer/test_chatglm2_infer.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,13 +13,14 @@
from colossalai.testing import clear_cache_before_run, parameterize, rerun_if_address_is_in_use, spawn

try:
import lightllm
Comment thread
ver217 marked this conversation as resolved.
import lightllm # noqa

HAS_LIGHTLLM_KERNEL = True
except:
HAS_LIGHTLLM_KERNEL = False

os.environ["TRANSFORMERS_NO_ADVISORY_WARNINGS"] = "true"
TPSIZE = 1
TPSIZE = 2
BATCH_SIZE = 8
MAX_INPUT_LEN = 12
MAX_OUTPUT_LEN = 100
Expand Down Expand Up @@ -67,7 +68,10 @@ def check_chatglm2(rank, world_size, port):
run_chatglm2_test()


@pytest.mark.skipif(not CUDA_SUPPORT or not HAS_LIGHTLLM_KERNEL, reason="kv-cache manager engine requires cuda version to be higher than 11.5")
@pytest.mark.skipif(
not CUDA_SUPPORT or not HAS_LIGHTLLM_KERNEL,
reason="kv-cache manager engine requires cuda version to be higher than 11.5",
)
@pytest.mark.dist
@rerun_if_address_is_in_use()
@clear_cache_before_run()
Expand Down