Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
74 changes: 30 additions & 44 deletions colossalai/inference/modeling/models/glide_llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,11 +6,7 @@

import torch
import torch.nn as nn
from transformers.cache_utils import Cache, DynamicCache
from transformers.modeling_attn_mask_utils import (
_prepare_4d_causal_attention_mask,
_prepare_4d_causal_attention_mask_for_sdpa,
)
from transformers.cache_utils import Cache, DynamicCache, StaticCache
from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast
from transformers.models.llama.modeling_llama import (
LlamaAttention,
Expand Down Expand Up @@ -137,6 +133,7 @@ def glide_llama_model_forward(
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
cache_position: Optional[torch.LongTensor] = None,
) -> Union[Tuple, BaseModelOutputWithPast]:
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
Expand All @@ -147,57 +144,43 @@ def glide_llama_model_forward(
return_dict = return_dict if return_dict is not None else self.config.use_return_dict

# retrieve input_ids and inputs_embeds
if input_ids is not None and inputs_embeds is not None:
raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
elif input_ids is not None:
batch_size, seq_length = input_ids.shape[:2]
elif inputs_embeds is not None:
batch_size, seq_length = inputs_embeds.shape[:2]
else:
raise ValueError("You have to specify either input_ids or inputs_embeds")

past_key_values_length = 0
if use_cache:
use_legacy_cache = not isinstance(past_key_values, Cache)
if use_legacy_cache:
past_key_values = DynamicCache.from_legacy_cache(past_key_values)
past_key_values_length = past_key_values.get_usable_length(seq_length)

if position_ids is None:
device = input_ids.device if input_ids is not None else inputs_embeds.device
position_ids = torch.arange(
past_key_values_length, seq_length + past_key_values_length, dtype=torch.long, device=device
if (input_ids is None) ^ (inputs_embeds is not None):
raise ValueError(
"You cannot specify both input_ids and inputs_embeds at the same time, and must specify either one"
)
position_ids = position_ids.unsqueeze(0)

if self.gradient_checkpointing and self.training and use_cache:
logger.warning_once("`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`.")
use_cache = False

if inputs_embeds is None:
inputs_embeds = self.embed_tokens(input_ids)

if self._use_flash_attention_2:
# 2d mask is passed through the layers
attention_mask = attention_mask if (attention_mask is not None and 0 in attention_mask) else None
elif self._use_sdpa and not output_attentions:
# output_attentions=True can not be supported when using SDPA, and we fall back on
# the manual implementation that requires a 4D causal mask in all cases.
attention_mask = _prepare_4d_causal_attention_mask_for_sdpa(
attention_mask,
(batch_size, seq_length),
inputs_embeds,
past_key_values_length,
)
else:
# 4d mask is passed through the layers
attention_mask = _prepare_4d_causal_attention_mask(
attention_mask, (batch_size, seq_length), inputs_embeds, past_key_values_length
past_seen_tokens = 0
if use_cache: # kept for BC (cache positions)
if not isinstance(past_key_values, StaticCache):
past_key_values = DynamicCache.from_legacy_cache(past_key_values)
past_seen_tokens = past_key_values.get_seq_length()

if cache_position is None:
if isinstance(past_key_values, StaticCache):
raise ValueError("cache_position is a required argument when using StaticCache.")
cache_position = torch.arange(
past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device
)

if position_ids is None:
position_ids = cache_position.unsqueeze(0)

attention_mask = self._update_causal_mask(attention_mask, inputs_embeds, cache_position)

# embed positions
hidden_states = inputs_embeds

# decoder layers
all_hidden_states = () if output_hidden_states else None
all_self_attns = () if output_attentions else None
next_decoder_cache = () if use_cache else None
next_decoder_cache = None

for decoder_layer in self.layers:
if output_hidden_states:
Expand All @@ -212,6 +195,7 @@ def glide_llama_model_forward(
past_key_value=past_key_values,
output_attentions=output_attentions,
use_cache=use_cache,
cache_position=cache_position,
)

hidden_states = layer_outputs[0]
Expand All @@ -230,7 +214,9 @@ def glide_llama_model_forward(

next_cache = None
if use_cache:
next_cache = next_decoder_cache.to_legacy_cache() if use_legacy_cache else next_decoder_cache
next_cache = (
next_decoder_cache.to_legacy_cache() if isinstance(next_decoder_cache, Cache) else next_decoder_cache
)
if not return_dict:
return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None)
return BaseModelOutputWithPast(
Expand Down
19 changes: 13 additions & 6 deletions tests/test_infer/test_kernels/cuda/test_rotary_embdding_unpad.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,15 +28,22 @@ def test_rotary_emb(BATCH_SIZE, SEQ_LEN, H, K_H, D, dtype):
torch.manual_seed(10)
TOTAL_TOKENS = BATCH_SIZE * SEQ_LEN
# our crafted op equals to Transformers
x0 = torch.randn(TOTAL_TOKENS, SEQ_LEN, D, dtype=dtype)
x1 = torch.randn(TOTAL_TOKENS, SEQ_LEN, D, dtype=dtype)
x0 = torch.randn(BATCH_SIZE, H, SEQ_LEN, D, dtype=dtype)
x1 = torch.randn(BATCH_SIZE, H, SEQ_LEN, D, dtype=dtype)

position_ids = torch.arange(TOTAL_TOKENS).reshape((BATCH_SIZE, SEQ_LEN))

emb = LlamaRotaryEmbedding(D)
cos, sin = emb(x0, TOTAL_TOKENS)

cos, sin = emb(x0, position_ids)
embd_x0, _ = apply_rotary_pos_emb(x0, x1, cos, sin)
cos = cos.reshape((TOTAL_TOKENS, -1))
sin = sin.reshape((TOTAL_TOKENS, -1))
cos_2 = cos[:, : D // 2]
sin_2 = sin[:, : D // 2]
position_ids = torch.arange(TOTAL_TOKENS)
embd_x0, _ = apply_rotary_pos_emb(x0, x1, cos, sin, position_ids)
embd_stimulated_x = torch_rotary_emb(x0, cos_2, sin_2)
x2 = x0.transpose(1, 2).reshape(TOTAL_TOKENS, H, D)
embd_stimulated_x = torch_rotary_emb(x2, cos_2, sin_2)
embd_stimulated_x = embd_stimulated_x.reshape((BATCH_SIZE, SEQ_LEN, H, D)).transpose(1, 2)
assert torch.allclose(embd_x0, embd_stimulated_x)

# create data
Expand Down
16 changes: 10 additions & 6 deletions tests/test_infer/test_kernels/triton/test_rotary_embdding_unpad.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,15 +43,19 @@ def torch_rotary_emb(x, cos, sin):
def test_rotary_emb(BATCH_SIZE, SEQ_LEN, H, D, dtype, use_new_kcache_layout):
TOTAL_TOKENS = BATCH_SIZE * SEQ_LEN
# our crafted op equals to Transformers
x0 = torch.randn(TOTAL_TOKENS, SEQ_LEN, D)
x1 = torch.randn(TOTAL_TOKENS, SEQ_LEN, D)
x0 = torch.randn(BATCH_SIZE, H, SEQ_LEN, D, dtype=dtype)
x1 = torch.randn(BATCH_SIZE, H, SEQ_LEN, D, dtype=dtype)
emb = LlamaRotaryEmbedding(D)
cos, sin = emb(x0, TOTAL_TOKENS)
position_ids = torch.arange(TOTAL_TOKENS).reshape((BATCH_SIZE, SEQ_LEN))
cos, sin = emb(x0, position_ids)
embd_x0, _ = apply_rotary_pos_emb(x0, x1, cos, sin)
cos = cos.reshape((TOTAL_TOKENS, -1))
sin = sin.reshape((TOTAL_TOKENS, -1))
cos_2 = cos[:, :32]
sin_2 = sin[:, :32]
position_ids = torch.arange(TOTAL_TOKENS)
embd_x0, _ = apply_rotary_pos_emb(x0, x1, cos, sin, position_ids)
embd_stimulated_x = torch_rotary_emb(x0, cos_2, sin_2)
x2 = x0.transpose(1, 2).reshape(TOTAL_TOKENS, H, D)
embd_stimulated_x = torch_rotary_emb(x2, cos_2, sin_2)
embd_stimulated_x = embd_stimulated_x.reshape((BATCH_SIZE, SEQ_LEN, H, D)).transpose(1, 2)
assert torch.allclose(embd_x0, embd_stimulated_x)

# create data
Expand Down