Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
96 changes: 45 additions & 51 deletions colossalai/shardformer/modeling/llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,19 +11,15 @@
)
from transformers.models.llama.modeling_llama import LlamaForCausalLM, LlamaForSequenceClassification, LlamaModel
from transformers.utils import logging
from transformers.cache_utils import Cache

from colossalai.pipeline.stage_manager import PipelineStageManager
from colossalai.shardformer.shard import ShardConfig

from ..layer import cross_entropy_1d
from ..layer._operation import gather_forward_split_backward

try:
from transformers.models.llama.modeling_llama import _prepare_4d_causal_attention_mask

LATEST_VERSION = True
except ImportError:
LATEST_VERSION = False
from transformers.models.llama.modeling_llama import _prepare_4d_causal_attention_mask, _prepare_4d_causal_attention_mask_for_sdpa


class LlamaPipelineForwards:
Expand Down Expand Up @@ -62,13 +58,13 @@ def llama_model_forward(
# retrieve input_ids and inputs_embeds
if stage_manager.is_first_stage():
if input_ids is not None and inputs_embeds is not None:
raise ValueError("You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time")
raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
elif input_ids is not None:
batch_size, seq_length = input_ids.shape
batch_size, seq_length = input_ids.shape[:2]
elif inputs_embeds is not None:
batch_size, seq_length, _ = inputs_embeds.shape
batch_size, seq_length, _ = inputs_embeds.shape[:2]
else:
raise ValueError("You have to specify either decoder_input_ids or decoder_inputs_embeds")
raise ValueError("You have to specify either input_ids or inputs_embeds")
device = input_ids.device if input_ids is not None else inputs_embeds.device
if inputs_embeds is None:
inputs_embeds = self.embed_tokens(input_ids)
Expand Down Expand Up @@ -100,22 +96,23 @@ def llama_model_forward(
position_ids = torch.arange(
past_key_values_length, seq_length + past_key_values_length, dtype=torch.long, device=device
)
position_ids = position_ids.unsqueeze(0).view(-1, seq_length)
else:
position_ids = position_ids.view(-1, seq_length).long()

# embed positions, for the first stage, hidden_states is the input embeddings,
# for the other stages, hidden_states is the output of the previous stage
if attention_mask is None:
attention_mask = torch.ones(
(batch_size, seq_length_with_past), dtype=torch.bool, device=hidden_states.device
)
if LATEST_VERSION:
attention_mask = _prepare_4d_causal_attention_mask(
attention_mask, (batch_size, seq_length), hidden_states, past_key_values_length
position_ids = position_ids.unsqueeze(0)
if self._use_flash_attention_2:
# 2d mask is passed through the layers
attention_mask = attention_mask if (attention_mask is not None and 0 in attention_mask) else None
elif self._use_sdpa and not output_attentions:
# output_attentions=True can not be supported when using SDPA, and we fall back on
# the manual implementation that requires a 4D causal mask in all cases.
attention_mask = _prepare_4d_causal_attention_mask_for_sdpa(
attention_mask,
(batch_size, seq_length),
inputs_embeds,
past_key_values_length,
)
else:
attention_mask = self._prepare_decoder_attention_mask(
# 4d mask is passed through the layers
attention_mask = _prepare_4d_causal_attention_mask(
attention_mask, (batch_size, seq_length), hidden_states, past_key_values_length
)

Expand All @@ -129,45 +126,38 @@ def llama_model_forward(
# decoder layers
all_hidden_states = () if output_hidden_states else None
all_self_attns = () if output_attentions else None
next_decoder_cache = () if use_cache else None
next_decoder_cache = None

start_idx, end_idx = stage_index[0], stage_index[1]
for idx, decoder_layer in enumerate(self.layers[start_idx:end_idx], start=start_idx):
if output_hidden_states:
all_hidden_states += (hidden_states,)

past_key_value = past_key_values[idx] if past_key_values is not None else None

if self.gradient_checkpointing and self.training:

def create_custom_forward(module):
def custom_forward(*inputs):
# None for past_key_value
return module(*inputs, output_attentions, None)

return custom_forward

layer_outputs = torch.utils.checkpoint.checkpoint(
create_custom_forward(decoder_layer),
layer_outputs = self._gradient_checkpointing_func(
decoder_layer.__call__,
hidden_states,
attention_mask,
position_ids,
None,
past_key_values,
output_attentions,
use_cache,
)
else:
layer_outputs = decoder_layer(
hidden_states,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_value=past_key_value,
past_key_value=past_key_values,
output_attentions=output_attentions,
use_cache=use_cache,
)

hidden_states = layer_outputs[0]

if use_cache:
next_decoder_cache += (layer_outputs[2 if output_attentions else 1],)
next_decoder_cache = layer_outputs[2 if output_attentions else 1]
if output_attentions:
all_self_attns += (layer_outputs[1],)

Expand Down Expand Up @@ -438,11 +428,15 @@ def forward(
hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_value: Optional[Tuple[torch.Tensor]] = None,
past_key_value: Optional[Cache] = None,
output_attentions: bool = False,
use_cache: bool = False,
**kwargs,
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
if "padding_mask" in kwargs:
warnings.warn(
"Passing `padding_mask` is deprecated and will be removed in v4.37. Please make sure use `attention_mask` instead.`"
)
bsz, q_len, _ = hidden_states.size()
assert q_len % 4 == 0, "Flash Attention Error: The sequence length should be a multiple of 4."

Expand All @@ -452,23 +446,23 @@ def forward(

kv_seq_len = key_states.shape[-2]
if past_key_value is not None:
kv_seq_len += past_key_value[0].shape[-2]
if self.layer_idx is None:
raise ValueError(
f"The cache structure has changed since version v4.36. If you are using {self.__class__.__name__} "
"for auto-regressive decoding with k/v caching, please make sure to initialize the attention class "
"with a layer index."
)
kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx)

cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)

query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)

if past_key_value is not None:
# reuse k, v, self_attention
key_states = torch.cat([past_key_value[0], key_states], dim=2)
value_states = torch.cat([past_key_value[1], value_states], dim=2)

past_key_value = (key_states, value_states) if use_cache else None
cache_kwargs = {"sin": sin, "cos": cos} # Specific to RoPE models
key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)

# repeat k/v heads if n_kv_heads < n_heads
if llama_version == 2:
key_states = repeat_kv(key_states, self.num_key_value_groups)
value_states = repeat_kv(value_states, self.num_key_value_groups)
key_states = repeat_kv(key_states, self.num_key_value_groups)
value_states = repeat_kv(value_states, self.num_key_value_groups)

me_input_shape = (bsz, q_len, self.num_heads, self.head_dim)
query_states = query_states.transpose(1, 2).contiguous().view(*me_input_shape)
Expand Down
1 change: 0 additions & 1 deletion requirements/requirements-test.txt
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@ pytest
coverage==7.2.3
git+https://github.com/hpcaitech/pytest-testmon
torchvision
transformers==4.33.0
timm
titans
torchaudio
Expand Down
1 change: 1 addition & 0 deletions requirements/requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -16,3 +16,4 @@ ray
sentencepiece
google
protobuf
transformers==4.36.0