From 942609a2bda0c24c323f4ad94e9367079a1c7117 Mon Sep 17 00:00:00 2001 From: CjhHa1 Date: Mon, 6 Nov 2023 17:00:49 +0800 Subject: [PATCH 1/4] fix bug --- .../inference/tensor_parallel/modeling/chatglm2.py | 10 ++++------ .../inference/tensor_parallel/policies/chatglm2.py | 5 ++++- colossalai/shardformer/layer/linear.py | 1 - .../modeling/chatglm2_6b/modeling_chatglm.py | 1 - colossalai/shardformer/policies/chatglm2.py | 1 - colossalai/shardformer/shard/sharder.py | 1 - tests/test_infer/test_chatglm2_infer.py | 10 ++++++---- 7 files changed, 14 insertions(+), 15 deletions(-) diff --git a/colossalai/inference/tensor_parallel/modeling/chatglm2.py b/colossalai/inference/tensor_parallel/modeling/chatglm2.py index 69a92c4fe746..cce470a09280 100644 --- a/colossalai/inference/tensor_parallel/modeling/chatglm2.py +++ b/colossalai/inference/tensor_parallel/modeling/chatglm2.py @@ -395,9 +395,9 @@ def chatglm_flash_attn_kvcache_forward( assert use_cache is True, "use_cache should be set to True using this chatglm attention" # hidden_states: original :[sq, b, h] --> this [b, sq, h] batch_size = hidden_states.shape[0] + hidden_size = hidden_states.shape[-1] # Attention heads [sq, b, h] --> [sq, b, (np * 3 * hn)] mixed_x_layer = self.query_key_value(hidden_states) - if self.multi_query_attention: (query_layer, key_layer, value_layer) = mixed_x_layer.split( [ @@ -437,7 +437,6 @@ def chatglm_flash_attn_kvcache_forward( mixed_x_layer = mixed_x_layer.view(*new_tensor_shape) # [sq, b, np, 3 * hn] --> 3 [sq, b, np, hn] (query_layer, key_layer, value_layer) = split_tensor_along_last_dim(mixed_x_layer, 3) - cos, sin = infer_state.position_cos, infer_state.position_sin chatglm2_rotary_emb_fwd( @@ -466,10 +465,10 @@ def chatglm_flash_attn_kvcache_forward( value_layer = value_layer.reshape( -1, self.num_multi_query_groups_per_partition, self.hidden_size_per_attention_head ) + if infer_state.is_context_stage: # first token generation: # copy key and value calculated in current step to memory manager - copy_kv_to_mem_cache( infer_state.decode_layer_id, key_layer, @@ -477,8 +476,7 @@ def chatglm_flash_attn_kvcache_forward( infer_state.context_mem_index, infer_state.cache_manager, ) - - attn_output = torch.empty_like(query_layer.view(-1, self.projection_size)) + attn_output = torch.empty_like(query_layer.contiguous().view(-1, self.projection_size)) # NOTE: no bug in context attn fwd (del it ) lightllm_llama2_context_attention_fwd( @@ -542,6 +540,6 @@ def chatglm_flash_attn_kvcache_forward( # ================= # Output:[b,sq, h] # ================= + output = self.dense(attn_output).reshape(batch_size, -1, hidden_size) - output = self.dense(attn_output).reshape(batch_size, -1, self.projection_size) return output, kv_cache diff --git a/colossalai/inference/tensor_parallel/policies/chatglm2.py b/colossalai/inference/tensor_parallel/policies/chatglm2.py index 90f8b4fd2d7e..60dc511f5e96 100644 --- a/colossalai/inference/tensor_parallel/policies/chatglm2.py +++ b/colossalai/inference/tensor_parallel/policies/chatglm2.py @@ -48,7 +48,10 @@ def module_policy(self): self.append_or_create_method_replacement( description=method_replacement, policy=policy, target_key=SelfAttention ) - + if self.shard_config.enable_tensor_parallelism: + policy[GLMBlock].attribute_replacement["self_attention.num_multi_query_groups_per_partition"] = ( + self.model.config.multi_query_group_num // self.shard_config.tensor_parallel_size + ) # for rmsnorm and others, we need to check the shape return policy diff --git a/colossalai/shardformer/layer/linear.py b/colossalai/shardformer/layer/linear.py index cf2003877d3c..9e638622348e 100644 --- a/colossalai/shardformer/layer/linear.py +++ b/colossalai/shardformer/layer/linear.py @@ -149,7 +149,6 @@ def from_native_module( out_features = module.out_features bias = module.bias is not None device = module.weight.device - # ensure only one process group is passed if isinstance(process_group, (list, tuple)): assert len(process_group) == 1, f"Expected only one process group, got {len(process_group)}." diff --git a/colossalai/shardformer/modeling/chatglm2_6b/modeling_chatglm.py b/colossalai/shardformer/modeling/chatglm2_6b/modeling_chatglm.py index fdd49ecfeae5..71aa2296eb4c 100644 --- a/colossalai/shardformer/modeling/chatglm2_6b/modeling_chatglm.py +++ b/colossalai/shardformer/modeling/chatglm2_6b/modeling_chatglm.py @@ -400,7 +400,6 @@ def __init__(self, config: ChatGLMConfig, layer_number, device=None): ) self.core_attention = CoreAttention(config, self.layer_number) - # Output. self.dense = nn.Linear( self.projection_size, diff --git a/colossalai/shardformer/policies/chatglm2.py b/colossalai/shardformer/policies/chatglm2.py index ab18d80b76e5..d1ad9f91478b 100644 --- a/colossalai/shardformer/policies/chatglm2.py +++ b/colossalai/shardformer/policies/chatglm2.py @@ -104,7 +104,6 @@ def module_policy(self) -> Dict[Union[str, nn.Module], ModulePolicyDescription]: ), ], ) - # optimization configuration self.append_or_create_submodule_replacement( description=[ diff --git a/colossalai/shardformer/shard/sharder.py b/colossalai/shardformer/shard/sharder.py index e3c0aa93d466..0586ada9eedd 100644 --- a/colossalai/shardformer/shard/sharder.py +++ b/colossalai/shardformer/shard/sharder.py @@ -180,7 +180,6 @@ def _replace_sub_module( assert target_module is not None, "target_module should not be None" native_sub_module = getattr_(org_layer, suffix, ignore=True) - # Skip replacement if submodule is not kept by current device when pipeline parallel is enabled. if (include is not None) and (native_sub_module is not None) and (native_sub_module not in include): continue diff --git a/tests/test_infer/test_chatglm2_infer.py b/tests/test_infer/test_chatglm2_infer.py index 09bb8a94994d..a07835c58676 100644 --- a/tests/test_infer/test_chatglm2_infer.py +++ b/tests/test_infer/test_chatglm2_infer.py @@ -13,13 +13,12 @@ from colossalai.testing import clear_cache_before_run, parameterize, rerun_if_address_is_in_use, spawn try: - import lightllm HAS_LIGHTLLM_KERNEL = True except: HAS_LIGHTLLM_KERNEL = False - + os.environ["TRANSFORMERS_NO_ADVISORY_WARNINGS"] = "true" -TPSIZE = 1 +TPSIZE = 2 BATCH_SIZE = 8 MAX_INPUT_LEN = 12 MAX_OUTPUT_LEN = 100 @@ -67,7 +66,10 @@ def check_chatglm2(rank, world_size, port): run_chatglm2_test() -@pytest.mark.skipif(not CUDA_SUPPORT or not HAS_LIGHTLLM_KERNEL, reason="kv-cache manager engine requires cuda version to be higher than 11.5") +@pytest.mark.skipif( + not CUDA_SUPPORT or not HAS_LIGHTLLM_KERNEL, + reason="kv-cache manager engine requires cuda version to be higher than 11.5", +) @pytest.mark.dist @rerun_if_address_is_in_use() @clear_cache_before_run() From 052dd8dcfb18a6eafba5b78858c6ce88a019b726 Mon Sep 17 00:00:00 2001 From: CjhHa1 Date: Mon, 6 Nov 2023 17:21:28 +0800 Subject: [PATCH 2/4] fix --- tests/test_infer/test_chatglm2_infer.py | 1 + 1 file changed, 1 insertion(+) diff --git a/tests/test_infer/test_chatglm2_infer.py b/tests/test_infer/test_chatglm2_infer.py index a07835c58676..1bce510eedc0 100644 --- a/tests/test_infer/test_chatglm2_infer.py +++ b/tests/test_infer/test_chatglm2_infer.py @@ -13,6 +13,7 @@ from colossalai.testing import clear_cache_before_run, parameterize, rerun_if_address_is_in_use, spawn try: + import lightllm HAS_LIGHTLLM_KERNEL = True except: HAS_LIGHTLLM_KERNEL = False From 23f33b8bc1cff2e638b4376c4013a72a6a760854 Mon Sep 17 00:00:00 2001 From: CjhHa1 Date: Mon, 6 Nov 2023 18:19:39 +0800 Subject: [PATCH 3/4] fix multiquery --- colossalai/inference/tensor_parallel/engine.py | 11 ++++++----- .../inference/tensor_parallel/modeling/chatglm2.py | 2 +- tests/test_infer/test_chatglm2_infer.py | 2 +- 3 files changed, 8 insertions(+), 7 deletions(-) diff --git a/colossalai/inference/tensor_parallel/engine.py b/colossalai/inference/tensor_parallel/engine.py index 283f719e57fc..2eadbcab1bda 100644 --- a/colossalai/inference/tensor_parallel/engine.py +++ b/colossalai/inference/tensor_parallel/engine.py @@ -77,14 +77,15 @@ def __init__( ) self.layer_num = num_hidden_layers - self.multi_query_group_num = 0 + self.multi_query_group_num = model.config.num_attention_heads + # default to attention_heads + self.multi_query_attention = model.config.multi_query_attention if hasattr(model.config, "multi_query_group_num"): self.multi_query_group_num = model.config.multi_query_group_num if hasattr(model.config, "num_key_value_heads"): self.multi_query_group_num = model.config.num_key_value_heads - self.tp_size = -1 # to be set with given shard config in self.prepare_shard_config self.cache_manager = None @@ -107,7 +108,7 @@ def _init_manager(self) -> None: assert self.head_num % self.tp_size == 0, f"Cannot shard {self.head_num} heads with tp size {self.tp_size}" self.head_num //= self.tp_size # update sharded number of heads - if self.multi_query_group_num: + if self.multi_query_attention: # NOTE the logic of MQA tensor parallelism should be specified. assert ( self.multi_query_group_num % self.tp_size == 0 @@ -218,7 +219,7 @@ def _shard_model_by(self, shardformer: ShardFormer, model: nn.Module) -> None: ), "Discrepancy between the tp size of TPInferEngine and the tp size of shard config" model_name = model.__class__.__name__ assert model_name in self.supported_models, f"Unsupported model cls {model_name} for TP inference." - + model = model.model if self.shard_config.inference_gptq else model policy = get_autopolicy(model, shard_config=self.shard_config) @@ -311,7 +312,7 @@ def prepare_batch_state(self, inputs) -> BatchInferState: seq_start_indexes[i] = start_index start_index += curr_seq_len max_len_in_batch = curr_seq_len if curr_seq_len > max_len_in_batch else max_len_in_batch - + block_loc = torch.empty((batch_size, self.max_input_len + self.max_output_len), dtype=torch.long, device="cuda") batch_infer_state = BatchInferState(batch_size, max_len_in_batch) batch_infer_state.seq_len = seq_lengths.to("cuda") diff --git a/colossalai/inference/tensor_parallel/modeling/chatglm2.py b/colossalai/inference/tensor_parallel/modeling/chatglm2.py index cce470a09280..b8fe8eb54855 100644 --- a/colossalai/inference/tensor_parallel/modeling/chatglm2.py +++ b/colossalai/inference/tensor_parallel/modeling/chatglm2.py @@ -512,7 +512,7 @@ def chatglm_flash_attn_kvcache_forward( ) # second token and follows - attn_output = torch.empty_like(query_layer.view(-1, self.projection_size)) + attn_output = torch.empty_like(query_layer.contiguous().view(-1, self.projection_size)) cache_k = infer_state.cache_manager.key_buffer[infer_state.decode_layer_id][ : infer_state.decode_mem_end, :, : ] diff --git a/tests/test_infer/test_chatglm2_infer.py b/tests/test_infer/test_chatglm2_infer.py index 1bce510eedc0..df8a35852ca6 100644 --- a/tests/test_infer/test_chatglm2_infer.py +++ b/tests/test_infer/test_chatglm2_infer.py @@ -13,7 +13,7 @@ from colossalai.testing import clear_cache_before_run, parameterize, rerun_if_address_is_in_use, spawn try: - import lightllm + # noqa HAS_LIGHTLLM_KERNEL = True except: HAS_LIGHTLLM_KERNEL = False From 822bb303f3a984895add069d119643d79c205c0e Mon Sep 17 00:00:00 2001 From: CjhHa1 Date: Mon, 6 Nov 2023 18:21:59 +0800 Subject: [PATCH 4/4] fix multiquery --- tests/test_infer/test_chatglm2_infer.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/tests/test_infer/test_chatglm2_infer.py b/tests/test_infer/test_chatglm2_infer.py index df8a35852ca6..a2ec35dcdb8a 100644 --- a/tests/test_infer/test_chatglm2_infer.py +++ b/tests/test_infer/test_chatglm2_infer.py @@ -13,7 +13,8 @@ from colossalai.testing import clear_cache_before_run, parameterize, rerun_if_address_is_in_use, spawn try: - # noqa + import lightllm # noqa + HAS_LIGHTLLM_KERNEL = True except: HAS_LIGHTLLM_KERNEL = False