From 96db2bca6b53164f838b93160dd76add080612d3 Mon Sep 17 00:00:00 2001 From: GuangyaoZhang Date: Fri, 7 Jun 2024 02:27:03 +0000 Subject: [PATCH 1/2] fix Llama rotary embedding api change for transformers 4.39.3 --- colossalai/shardformer/modeling/llama.py | 14 ++------------ requirements/requirements.txt | 2 +- .../cuda/test_rotary_embdding_unpad.py | 4 ++-- .../triton/test_rotary_embdding_unpad.py | 4 ++-- tests/test_infer/test_models/test_attention.py | 2 +- 5 files changed, 8 insertions(+), 18 deletions(-) diff --git a/colossalai/shardformer/modeling/llama.py b/colossalai/shardformer/modeling/llama.py index 01d10c8dcf95..8c32713171fc 100644 --- a/colossalai/shardformer/modeling/llama.py +++ b/colossalai/shardformer/modeling/llama.py @@ -512,18 +512,8 @@ def forward( key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) - kv_seq_len = key_states.shape[-2] - if past_key_value is not None: - if self.layer_idx is None: - raise ValueError( - f"The cache structure has changed since version v4.36. If you are using {self.__class__.__name__} " - "for auto-regressive decoding with k/v caching, please make sure to initialize the attention class " - "with a layer index." - ) - kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx) - - cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) - query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids) + cos, sin = self.rotary_emb(value_states, position_ids) + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) if past_key_value is not None: cache_kwargs = {"sin": sin, "cos": cos} # Specific to RoPE models diff --git a/requirements/requirements.txt b/requirements/requirements.txt index fa88501ef968..27bbc3769448 100644 --- a/requirements/requirements.txt +++ b/requirements/requirements.txt @@ -16,7 +16,7 @@ ray sentencepiece google protobuf -transformers>=4.36.2,<4.40.0 +transformers==4.39.3 peft>=0.7.1 bitsandbytes>=0.39.0 rpyc==6.0.0 diff --git a/tests/test_infer/test_kernels/cuda/test_rotary_embdding_unpad.py b/tests/test_infer/test_kernels/cuda/test_rotary_embdding_unpad.py index 8237384c03fd..5ce4fb31ff84 100644 --- a/tests/test_infer/test_kernels/cuda/test_rotary_embdding_unpad.py +++ b/tests/test_infer/test_kernels/cuda/test_rotary_embdding_unpad.py @@ -31,10 +31,10 @@ def test_rotary_emb(BATCH_SIZE, SEQ_LEN, H, K_H, D, dtype): x0 = torch.randn(TOTAL_TOKENS, SEQ_LEN, D, dtype=dtype) x1 = torch.randn(TOTAL_TOKENS, SEQ_LEN, D, dtype=dtype) emb = LlamaRotaryEmbedding(D) - cos, sin = emb(x0, TOTAL_TOKENS) + position_ids = torch.arange(TOTAL_TOKENS) + cos, sin = emb(x0, position_ids) cos_2 = cos[:, : D // 2] sin_2 = sin[:, : D // 2] - position_ids = torch.arange(TOTAL_TOKENS) embd_x0, _ = apply_rotary_pos_emb(x0, x1, cos, sin, position_ids) embd_stimulated_x = torch_rotary_emb(x0, cos_2, sin_2) assert torch.allclose(embd_x0, embd_stimulated_x) diff --git a/tests/test_infer/test_kernels/triton/test_rotary_embdding_unpad.py b/tests/test_infer/test_kernels/triton/test_rotary_embdding_unpad.py index 570093693447..efbe3ede8ca8 100644 --- a/tests/test_infer/test_kernels/triton/test_rotary_embdding_unpad.py +++ b/tests/test_infer/test_kernels/triton/test_rotary_embdding_unpad.py @@ -46,10 +46,10 @@ def test_rotary_emb(BATCH_SIZE, SEQ_LEN, H, D, dtype, use_new_kcache_layout): x0 = torch.randn(TOTAL_TOKENS, SEQ_LEN, D) x1 = torch.randn(TOTAL_TOKENS, SEQ_LEN, D) emb = LlamaRotaryEmbedding(D) - cos, sin = emb(x0, TOTAL_TOKENS) + position_ids = torch.arange(TOTAL_TOKENS) + cos, sin = emb(x0, position_ids) cos_2 = cos[:, :32] sin_2 = sin[:, :32] - position_ids = torch.arange(TOTAL_TOKENS) embd_x0, _ = apply_rotary_pos_emb(x0, x1, cos, sin, position_ids) embd_stimulated_x = torch_rotary_emb(x0, cos_2, sin_2) assert torch.allclose(embd_x0, embd_stimulated_x) diff --git a/tests/test_infer/test_models/test_attention.py b/tests/test_infer/test_models/test_attention.py index 79ed6675db5f..f66917d475d4 100644 --- a/tests/test_infer/test_models/test_attention.py +++ b/tests/test_infer/test_models/test_attention.py @@ -69,7 +69,7 @@ def test_context_attention(): position_ids = torch.arange(0, 8, dtype=torch.long, device=proj_q.device) position_ids = position_ids.unsqueeze(0) - cos, sin = transformer_attn.rotary_emb(proj_v, 8) + cos, sin = transformer_attn.rotary_emb(proj_v, position_ids) proj_q, proj_k = apply_rotary_pos_emb(proj_q, proj_k, cos, sin, position_ids) pad_attn_output = attn.pad_context_forward( From 13a18a4e3bd07f25ef5470dcd53c558f84a081c8 Mon Sep 17 00:00:00 2001 From: GuangyaoZhang Date: Fri, 7 Jun 2024 02:27:03 +0000 Subject: [PATCH 2/2] fix Llama rotary embedding api change for transformers 4.39.3 --- colossalai/shardformer/modeling/llama.py | 16 +++------------- requirements/requirements.txt | 2 +- .../cuda/test_rotary_embdding_unpad.py | 4 ++-- .../triton/test_rotary_embdding_unpad.py | 4 ++-- tests/test_infer/test_models/test_attention.py | 2 +- 5 files changed, 9 insertions(+), 19 deletions(-) diff --git a/colossalai/shardformer/modeling/llama.py b/colossalai/shardformer/modeling/llama.py index 01d10c8dcf95..21505549ceaa 100644 --- a/colossalai/shardformer/modeling/llama.py +++ b/colossalai/shardformer/modeling/llama.py @@ -511,19 +511,9 @@ def forward( query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) - - kv_seq_len = key_states.shape[-2] - if past_key_value is not None: - if self.layer_idx is None: - raise ValueError( - f"The cache structure has changed since version v4.36. If you are using {self.__class__.__name__} " - "for auto-regressive decoding with k/v caching, please make sure to initialize the attention class " - "with a layer index." - ) - kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx) - - cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) - query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids) + + cos, sin = self.rotary_emb(value_states, position_ids) + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) if past_key_value is not None: cache_kwargs = {"sin": sin, "cos": cos} # Specific to RoPE models diff --git a/requirements/requirements.txt b/requirements/requirements.txt index fa88501ef968..27bbc3769448 100644 --- a/requirements/requirements.txt +++ b/requirements/requirements.txt @@ -16,7 +16,7 @@ ray sentencepiece google protobuf -transformers>=4.36.2,<4.40.0 +transformers==4.39.3 peft>=0.7.1 bitsandbytes>=0.39.0 rpyc==6.0.0 diff --git a/tests/test_infer/test_kernels/cuda/test_rotary_embdding_unpad.py b/tests/test_infer/test_kernels/cuda/test_rotary_embdding_unpad.py index 8237384c03fd..5ce4fb31ff84 100644 --- a/tests/test_infer/test_kernels/cuda/test_rotary_embdding_unpad.py +++ b/tests/test_infer/test_kernels/cuda/test_rotary_embdding_unpad.py @@ -31,10 +31,10 @@ def test_rotary_emb(BATCH_SIZE, SEQ_LEN, H, K_H, D, dtype): x0 = torch.randn(TOTAL_TOKENS, SEQ_LEN, D, dtype=dtype) x1 = torch.randn(TOTAL_TOKENS, SEQ_LEN, D, dtype=dtype) emb = LlamaRotaryEmbedding(D) - cos, sin = emb(x0, TOTAL_TOKENS) + position_ids = torch.arange(TOTAL_TOKENS) + cos, sin = emb(x0, position_ids) cos_2 = cos[:, : D // 2] sin_2 = sin[:, : D // 2] - position_ids = torch.arange(TOTAL_TOKENS) embd_x0, _ = apply_rotary_pos_emb(x0, x1, cos, sin, position_ids) embd_stimulated_x = torch_rotary_emb(x0, cos_2, sin_2) assert torch.allclose(embd_x0, embd_stimulated_x) diff --git a/tests/test_infer/test_kernels/triton/test_rotary_embdding_unpad.py b/tests/test_infer/test_kernels/triton/test_rotary_embdding_unpad.py index 570093693447..efbe3ede8ca8 100644 --- a/tests/test_infer/test_kernels/triton/test_rotary_embdding_unpad.py +++ b/tests/test_infer/test_kernels/triton/test_rotary_embdding_unpad.py @@ -46,10 +46,10 @@ def test_rotary_emb(BATCH_SIZE, SEQ_LEN, H, D, dtype, use_new_kcache_layout): x0 = torch.randn(TOTAL_TOKENS, SEQ_LEN, D) x1 = torch.randn(TOTAL_TOKENS, SEQ_LEN, D) emb = LlamaRotaryEmbedding(D) - cos, sin = emb(x0, TOTAL_TOKENS) + position_ids = torch.arange(TOTAL_TOKENS) + cos, sin = emb(x0, position_ids) cos_2 = cos[:, :32] sin_2 = sin[:, :32] - position_ids = torch.arange(TOTAL_TOKENS) embd_x0, _ = apply_rotary_pos_emb(x0, x1, cos, sin, position_ids) embd_stimulated_x = torch_rotary_emb(x0, cos_2, sin_2) assert torch.allclose(embd_x0, embd_stimulated_x) diff --git a/tests/test_infer/test_models/test_attention.py b/tests/test_infer/test_models/test_attention.py index 79ed6675db5f..f66917d475d4 100644 --- a/tests/test_infer/test_models/test_attention.py +++ b/tests/test_infer/test_models/test_attention.py @@ -69,7 +69,7 @@ def test_context_attention(): position_ids = torch.arange(0, 8, dtype=torch.long, device=proj_q.device) position_ids = position_ids.unsqueeze(0) - cos, sin = transformer_attn.rotary_emb(proj_v, 8) + cos, sin = transformer_attn.rotary_emb(proj_v, position_ids) proj_q, proj_k = apply_rotary_pos_emb(proj_q, proj_k, cos, sin, position_ids) pad_attn_output = attn.pad_context_forward(