Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 2 additions & 13 deletions colossalai/shardformer/modeling/llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -511,19 +511,8 @@ def forward(
query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)

kv_seq_len = key_states.shape[-2]
if past_key_value is not None:
if self.layer_idx is None:
raise ValueError(
f"The cache structure has changed since version v4.36. If you are using {self.__class__.__name__} "
"for auto-regressive decoding with k/v caching, please make sure to initialize the attention class "
"with a layer index."
)
kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx)

cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)
cos, sin = self.rotary_emb(value_states, position_ids)
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)

if past_key_value is not None:
cache_kwargs = {"sin": sin, "cos": cos} # Specific to RoPE models
Expand Down
2 changes: 1 addition & 1 deletion requirements/requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ ray
sentencepiece
google
protobuf
transformers>=4.36.2,<4.40.0
transformers==4.39.3
peft>=0.7.1
bitsandbytes>=0.39.0
rpyc==6.0.0
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,10 +31,10 @@ def test_rotary_emb(BATCH_SIZE, SEQ_LEN, H, K_H, D, dtype):
x0 = torch.randn(TOTAL_TOKENS, SEQ_LEN, D, dtype=dtype)
x1 = torch.randn(TOTAL_TOKENS, SEQ_LEN, D, dtype=dtype)
emb = LlamaRotaryEmbedding(D)
cos, sin = emb(x0, TOTAL_TOKENS)
position_ids = torch.arange(TOTAL_TOKENS)
cos, sin = emb(x0, position_ids)
cos_2 = cos[:, : D // 2]
sin_2 = sin[:, : D // 2]
position_ids = torch.arange(TOTAL_TOKENS)
embd_x0, _ = apply_rotary_pos_emb(x0, x1, cos, sin, position_ids)
embd_stimulated_x = torch_rotary_emb(x0, cos_2, sin_2)
assert torch.allclose(embd_x0, embd_stimulated_x)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -46,10 +46,10 @@ def test_rotary_emb(BATCH_SIZE, SEQ_LEN, H, D, dtype, use_new_kcache_layout):
x0 = torch.randn(TOTAL_TOKENS, SEQ_LEN, D)
x1 = torch.randn(TOTAL_TOKENS, SEQ_LEN, D)
emb = LlamaRotaryEmbedding(D)
cos, sin = emb(x0, TOTAL_TOKENS)
position_ids = torch.arange(TOTAL_TOKENS)
cos, sin = emb(x0, position_ids)
cos_2 = cos[:, :32]
sin_2 = sin[:, :32]
position_ids = torch.arange(TOTAL_TOKENS)
embd_x0, _ = apply_rotary_pos_emb(x0, x1, cos, sin, position_ids)
embd_stimulated_x = torch_rotary_emb(x0, cos_2, sin_2)
assert torch.allclose(embd_x0, embd_stimulated_x)
Expand Down
2 changes: 1 addition & 1 deletion tests/test_infer/test_models/test_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ def test_context_attention():

position_ids = torch.arange(0, 8, dtype=torch.long, device=proj_q.device)
position_ids = position_ids.unsqueeze(0)
cos, sin = transformer_attn.rotary_emb(proj_v, 8)
cos, sin = transformer_attn.rotary_emb(proj_v, position_ids)
proj_q, proj_k = apply_rotary_pos_emb(proj_q, proj_k, cos, sin, position_ids)

pad_attn_output = attn.pad_context_forward(
Expand Down