diff --git a/colossalai/inference/modeling/models/glide_llama.py b/colossalai/inference/modeling/models/glide_llama.py index 7b25f3e7489d..013b0f06185d 100644 --- a/colossalai/inference/modeling/models/glide_llama.py +++ b/colossalai/inference/modeling/models/glide_llama.py @@ -6,11 +6,7 @@ import torch import torch.nn as nn -from transformers.cache_utils import Cache, DynamicCache -from transformers.modeling_attn_mask_utils import ( - _prepare_4d_causal_attention_mask, - _prepare_4d_causal_attention_mask_for_sdpa, -) +from transformers.cache_utils import Cache, DynamicCache, StaticCache from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast from transformers.models.llama.modeling_llama import ( LlamaAttention, @@ -137,6 +133,7 @@ def glide_llama_model_forward( output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, ) -> Union[Tuple, BaseModelOutputWithPast]: output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_hidden_states = ( @@ -147,57 +144,43 @@ def glide_llama_model_forward( return_dict = return_dict if return_dict is not None else self.config.use_return_dict # retrieve input_ids and inputs_embeds - if input_ids is not None and inputs_embeds is not None: - raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time") - elif input_ids is not None: - batch_size, seq_length = input_ids.shape[:2] - elif inputs_embeds is not None: - batch_size, seq_length = inputs_embeds.shape[:2] - else: - raise ValueError("You have to specify either input_ids or inputs_embeds") - - past_key_values_length = 0 - if use_cache: - use_legacy_cache = not isinstance(past_key_values, Cache) - if use_legacy_cache: - past_key_values = DynamicCache.from_legacy_cache(past_key_values) - past_key_values_length = past_key_values.get_usable_length(seq_length) - - if position_ids is None: - device = input_ids.device if input_ids is not None else inputs_embeds.device - position_ids = torch.arange( - past_key_values_length, seq_length + past_key_values_length, dtype=torch.long, device=device + if (input_ids is None) ^ (inputs_embeds is not None): + raise ValueError( + "You cannot specify both input_ids and inputs_embeds at the same time, and must specify either one" ) - position_ids = position_ids.unsqueeze(0) + + if self.gradient_checkpointing and self.training and use_cache: + logger.warning_once("`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`.") + use_cache = False if inputs_embeds is None: inputs_embeds = self.embed_tokens(input_ids) - if self._use_flash_attention_2: - # 2d mask is passed through the layers - attention_mask = attention_mask if (attention_mask is not None and 0 in attention_mask) else None - elif self._use_sdpa and not output_attentions: - # output_attentions=True can not be supported when using SDPA, and we fall back on - # the manual implementation that requires a 4D causal mask in all cases. - attention_mask = _prepare_4d_causal_attention_mask_for_sdpa( - attention_mask, - (batch_size, seq_length), - inputs_embeds, - past_key_values_length, - ) - else: - # 4d mask is passed through the layers - attention_mask = _prepare_4d_causal_attention_mask( - attention_mask, (batch_size, seq_length), inputs_embeds, past_key_values_length + past_seen_tokens = 0 + if use_cache: # kept for BC (cache positions) + if not isinstance(past_key_values, StaticCache): + past_key_values = DynamicCache.from_legacy_cache(past_key_values) + past_seen_tokens = past_key_values.get_seq_length() + + if cache_position is None: + if isinstance(past_key_values, StaticCache): + raise ValueError("cache_position is a required argument when using StaticCache.") + cache_position = torch.arange( + past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device ) + if position_ids is None: + position_ids = cache_position.unsqueeze(0) + + attention_mask = self._update_causal_mask(attention_mask, inputs_embeds, cache_position) + # embed positions hidden_states = inputs_embeds # decoder layers all_hidden_states = () if output_hidden_states else None all_self_attns = () if output_attentions else None - next_decoder_cache = () if use_cache else None + next_decoder_cache = None for decoder_layer in self.layers: if output_hidden_states: @@ -212,6 +195,7 @@ def glide_llama_model_forward( past_key_value=past_key_values, output_attentions=output_attentions, use_cache=use_cache, + cache_position=cache_position, ) hidden_states = layer_outputs[0] @@ -230,7 +214,9 @@ def glide_llama_model_forward( next_cache = None if use_cache: - next_cache = next_decoder_cache.to_legacy_cache() if use_legacy_cache else next_decoder_cache + next_cache = ( + next_decoder_cache.to_legacy_cache() if isinstance(next_decoder_cache, Cache) else next_decoder_cache + ) if not return_dict: return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None) return BaseModelOutputWithPast( diff --git a/tests/test_infer/test_kernels/cuda/test_rotary_embdding_unpad.py b/tests/test_infer/test_kernels/cuda/test_rotary_embdding_unpad.py index 8237384c03fd..57a82647d49b 100644 --- a/tests/test_infer/test_kernels/cuda/test_rotary_embdding_unpad.py +++ b/tests/test_infer/test_kernels/cuda/test_rotary_embdding_unpad.py @@ -28,15 +28,22 @@ def test_rotary_emb(BATCH_SIZE, SEQ_LEN, H, K_H, D, dtype): torch.manual_seed(10) TOTAL_TOKENS = BATCH_SIZE * SEQ_LEN # our crafted op equals to Transformers - x0 = torch.randn(TOTAL_TOKENS, SEQ_LEN, D, dtype=dtype) - x1 = torch.randn(TOTAL_TOKENS, SEQ_LEN, D, dtype=dtype) + x0 = torch.randn(BATCH_SIZE, H, SEQ_LEN, D, dtype=dtype) + x1 = torch.randn(BATCH_SIZE, H, SEQ_LEN, D, dtype=dtype) + + position_ids = torch.arange(TOTAL_TOKENS).reshape((BATCH_SIZE, SEQ_LEN)) + emb = LlamaRotaryEmbedding(D) - cos, sin = emb(x0, TOTAL_TOKENS) + + cos, sin = emb(x0, position_ids) + embd_x0, _ = apply_rotary_pos_emb(x0, x1, cos, sin) + cos = cos.reshape((TOTAL_TOKENS, -1)) + sin = sin.reshape((TOTAL_TOKENS, -1)) cos_2 = cos[:, : D // 2] sin_2 = sin[:, : D // 2] - position_ids = torch.arange(TOTAL_TOKENS) - embd_x0, _ = apply_rotary_pos_emb(x0, x1, cos, sin, position_ids) - embd_stimulated_x = torch_rotary_emb(x0, cos_2, sin_2) + x2 = x0.transpose(1, 2).reshape(TOTAL_TOKENS, H, D) + embd_stimulated_x = torch_rotary_emb(x2, cos_2, sin_2) + embd_stimulated_x = embd_stimulated_x.reshape((BATCH_SIZE, SEQ_LEN, H, D)).transpose(1, 2) assert torch.allclose(embd_x0, embd_stimulated_x) # create data diff --git a/tests/test_infer/test_kernels/triton/test_rotary_embdding_unpad.py b/tests/test_infer/test_kernels/triton/test_rotary_embdding_unpad.py index 570093693447..78b7ba81c12b 100644 --- a/tests/test_infer/test_kernels/triton/test_rotary_embdding_unpad.py +++ b/tests/test_infer/test_kernels/triton/test_rotary_embdding_unpad.py @@ -43,15 +43,19 @@ def torch_rotary_emb(x, cos, sin): def test_rotary_emb(BATCH_SIZE, SEQ_LEN, H, D, dtype, use_new_kcache_layout): TOTAL_TOKENS = BATCH_SIZE * SEQ_LEN # our crafted op equals to Transformers - x0 = torch.randn(TOTAL_TOKENS, SEQ_LEN, D) - x1 = torch.randn(TOTAL_TOKENS, SEQ_LEN, D) + x0 = torch.randn(BATCH_SIZE, H, SEQ_LEN, D, dtype=dtype) + x1 = torch.randn(BATCH_SIZE, H, SEQ_LEN, D, dtype=dtype) emb = LlamaRotaryEmbedding(D) - cos, sin = emb(x0, TOTAL_TOKENS) + position_ids = torch.arange(TOTAL_TOKENS).reshape((BATCH_SIZE, SEQ_LEN)) + cos, sin = emb(x0, position_ids) + embd_x0, _ = apply_rotary_pos_emb(x0, x1, cos, sin) + cos = cos.reshape((TOTAL_TOKENS, -1)) + sin = sin.reshape((TOTAL_TOKENS, -1)) cos_2 = cos[:, :32] sin_2 = sin[:, :32] - position_ids = torch.arange(TOTAL_TOKENS) - embd_x0, _ = apply_rotary_pos_emb(x0, x1, cos, sin, position_ids) - embd_stimulated_x = torch_rotary_emb(x0, cos_2, sin_2) + x2 = x0.transpose(1, 2).reshape(TOTAL_TOKENS, H, D) + embd_stimulated_x = torch_rotary_emb(x2, cos_2, sin_2) + embd_stimulated_x = embd_stimulated_x.reshape((BATCH_SIZE, SEQ_LEN, H, D)).transpose(1, 2) assert torch.allclose(embd_x0, embd_stimulated_x) # create data