Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion colossalai/shardformer/modeling/bloom.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from torch.distributed import ProcessGroup
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
from torch.nn import functional as F
from transformers.modeling_attn_mask_utils import _prepare_4d_causal_attention_mask
from transformers.modeling_outputs import (
BaseModelOutputWithPastAndCrossAttentions,
CausalLMOutputWithCrossAttentions,
Expand All @@ -21,7 +22,7 @@
BloomModel,
)
from transformers.utils import logging
from transformers.modeling_attn_mask_utils import _prepare_4d_causal_attention_mask

from colossalai.pipeline.stage_manager import PipelineStageManager
from colossalai.shardformer.layer._operation import gather_forward_split_backward, split_forward_gather_backward
from colossalai.shardformer.shard import ShardConfig
Expand Down
16 changes: 9 additions & 7 deletions colossalai/shardformer/modeling/falcon.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,12 @@
from typing import List, Optional, Tuple, Union
import math
import warnings
from typing import List, Optional, Tuple, Union

import torch
import torch.distributed as dist
from torch.distributed import ProcessGroup
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
from torch.nn import functional as F
from transformers.modeling_attn_mask_utils import (
AttentionMaskConverter,
_prepare_4d_causal_attention_mask,
Expand All @@ -22,14 +25,13 @@
FalconForSequenceClassification,
FalconForTokenClassification,
FalconModel,
build_alibi_tensor,
apply_rotary_pos_emb,
build_alibi_tensor,
)
from transformers.utils import logging
import warnings

from colossalai.pipeline.stage_manager import PipelineStageManager
from colossalai.shardformer.shard import ShardConfig
from torch.nn import functional as F


def build_falcon_alibi_tensor_fn(process_group: ProcessGroup) -> torch.Tensor:
Expand Down Expand Up @@ -171,7 +173,7 @@ def forward(

def get_falcon_flash_attention_forward():
try:
from xformers.ops import memory_efficient_attention as me_attention
pass
except:
raise ImportError("Error: xformers module is not installed. Please install it to use flash attention.")
from transformers.models.falcon.modeling_falcon import FalconAttention
Expand Down Expand Up @@ -347,7 +349,7 @@ def falcon_model_forward(
past_key_values = None

return_dict = return_dict if return_dict is not None else self.config.use_return_dict

# case: First stage of training
if stage_manager.is_first_stage():
if input_ids is not None and inputs_embeds is not None:
Expand Down Expand Up @@ -449,7 +451,7 @@ def falcon_model_forward(
attention_mask = _prepare_4d_causal_attention_mask(
attention_mask, (batch_size, seq_length), inputs_embeds, past_key_values_length
)

# Prepare head mask if needed
# 1.0 in head_mask indicate we keep the head
# attention_probs has shape batch_size x num_heads x N x N
Expand Down
99 changes: 66 additions & 33 deletions colossalai/shardformer/modeling/llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import torch.utils.checkpoint
from torch import nn
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
from transformers.cache_utils import Cache
from transformers.modeling_outputs import (
BaseModelOutputWithPast,
CausalLMOutputWithPast,
Expand All @@ -16,11 +17,12 @@
LlamaForCausalLM,
LlamaForSequenceClassification,
LlamaModel,
_prepare_4d_causal_attention_mask,
_prepare_4d_causal_attention_mask_for_sdpa,
apply_rotary_pos_emb,
repeat_kv,
)
from transformers.utils import logging
from transformers.cache_utils import Cache

from colossalai.pipeline.stage_manager import PipelineStageManager
from colossalai.shardformer.layer._operation import (
Expand All @@ -32,8 +34,6 @@

from ..layer import ColoAttention, cross_entropy_1d

from transformers.models.llama.modeling_llama import _prepare_4d_causal_attention_mask, _prepare_4d_causal_attention_mask_for_sdpa


class LlamaPipelineForwards:
"""
Expand Down Expand Up @@ -107,7 +107,10 @@ def llama_model_forward(

if position_ids is None:
position_ids = torch.arange(
past_key_values_length, seq_length + past_key_values_length, dtype=torch.long, device=device
past_key_values_length,
seq_length + past_key_values_length,
dtype=torch.long,
device=device,
)
position_ids = position_ids.unsqueeze(0)

Expand All @@ -117,26 +120,33 @@ def llama_model_forward(
# in this case, attention_mask is a dict rather than a tensor
mask_shape = (batch_size, 1, seq_length_with_past, seq_length_with_past)
attention_mask = ColoAttention.prepare_attn_kwargs(
mask_shape, hidden_states.dtype, hidden_states.device, q_padding_mask=attention_mask, is_causal=True
mask_shape,
hidden_states.dtype,
hidden_states.device,
q_padding_mask=attention_mask,
is_causal=True,
)
else:
if self._use_flash_attention_2:
# 2d mask is passed through the layers
attention_mask = attention_mask if (attention_mask is not None and 0 in attention_mask) else None
elif self._use_sdpa and not output_attentions:
# output_attentions=True can not be supported when using SDPA, and we fall back on
# the manual implementation that requires a 4D causal mask in all cases.
attention_mask = _prepare_4d_causal_attention_mask_for_sdpa(
attention_mask,
(batch_size, seq_length),
inputs_embeds,
past_key_values_length,
)
else:
# 4d mask is passed through the layers
attention_mask = _prepare_4d_causal_attention_mask(
attention_mask, (batch_size, seq_length), hidden_states, past_key_values_length
)
# 2d mask is passed through the layers
attention_mask = attention_mask if (attention_mask is not None and 0 in attention_mask) else None
elif self._use_sdpa and not output_attentions:
# output_attentions=True can not be supported when using SDPA, and we fall back on
# the manual implementation that requires a 4D causal mask in all cases.
attention_mask = _prepare_4d_causal_attention_mask_for_sdpa(
attention_mask,
(batch_size, seq_length),
inputs_embeds,
past_key_values_length,
)
else:
# 4d mask is passed through the layers
attention_mask = _prepare_4d_causal_attention_mask(
attention_mask,
(batch_size, seq_length),
hidden_states,
past_key_values_length,
)

if self.gradient_checkpointing and self.training:
if use_cache:
Expand All @@ -159,7 +169,7 @@ def llama_model_forward(
num_ckpt_layers = shard_config.gradient_checkpoint_config.get_num_ckpt_layers(
stage=stage_manager.stage,
num_layers=end_idx - start_idx,
model_chunk_id=stage_manager.model_chunk_id if stage_manager.is_interleave else 0,
model_chunk_id=(stage_manager.model_chunk_id if stage_manager.is_interleave else 0),
)
assert num_ckpt_layers <= end_idx - start_idx

Expand Down Expand Up @@ -203,7 +213,16 @@ def llama_model_forward(
next_cache = next_decoder_cache if use_cache else None
if stage_manager.is_last_stage():
if not return_dict:
return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None)
return tuple(
v
for v in [
hidden_states,
next_cache,
all_hidden_states,
all_self_attns,
]
if v is not None
)
return BaseModelOutputWithPast(
last_hidden_state=hidden_states,
past_key_values=next_cache,
Expand Down Expand Up @@ -307,7 +326,9 @@ def llama_for_causal_lm_forward(
new_vocab_size = logits.shape[-1]
shift_logits = shift_logits.view(-1, new_vocab_size)
loss = cross_entropy_1d(
shift_logits, shift_labels, process_group=shard_config.tensor_parallel_process_group
shift_logits,
shift_labels,
process_group=shard_config.tensor_parallel_process_group,
)
else:
shift_logits = shift_logits.view(-1, self.config.vocab_size)
Expand Down Expand Up @@ -446,12 +467,10 @@ def llama_for_sequence_classification_forward(
def get_llama_flash_attention_forward(shard_config, sp_mode, sp_group, sp_size):
from transformers.models.llama.modeling_llama import LlamaAttention, apply_rotary_pos_emb

llama_version = 2
try:
from transformers.models.llama.modeling_llama import repeat_kv
except:
warnings.warn("using llamav1, llamav1 hasn't repeat_kv function")
llama_version = 1

def forward(
self: LlamaAttention,
Expand Down Expand Up @@ -494,8 +513,8 @@ def forward(
raise ValueError(
f"The cache structure has changed since version v4.36. If you are using {self.__class__.__name__} "
"for auto-regressive decoding with k/v caching, please make sure to initialize the attention class "
"with a layer index."
)
"with a layer index."
)
kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx)

cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
Expand Down Expand Up @@ -567,7 +586,10 @@ def forward(
if position_ids is None:
device = input_ids.device if input_ids is not None else inputs_embeds.device
position_ids = torch.arange(
past_key_values_length, seq_length + past_key_values_length, dtype=torch.long, device=device
past_key_values_length,
seq_length + past_key_values_length,
dtype=torch.long,
device=device,
)
position_ids = position_ids.unsqueeze(0).view(-1, seq_length)
else:
Expand All @@ -581,7 +603,11 @@ def forward(
# in this case, attention_mask is a dict rather than a tensor
mask_shape = (batch_size, 1, seq_length_with_past, seq_length_with_past)
attention_mask = ColoAttention.prepare_attn_kwargs(
mask_shape, hidden_states.dtype, hidden_states.device, q_padding_mask=attention_mask, is_causal=True
mask_shape,
hidden_states.dtype,
hidden_states.device,
q_padding_mask=attention_mask,
is_causal=True,
)

if self.gradient_checkpointing and self.training:
Expand Down Expand Up @@ -736,7 +762,9 @@ def forward(
new_vocab_size = logits.shape[-1]
shift_logits = shift_logits.view(-1, new_vocab_size)
loss = cross_entropy_1d(
shift_logits, shift_labels, process_group=shard_config.tensor_parallel_process_group
shift_logits,
shift_labels,
process_group=shard_config.tensor_parallel_process_group,
)

if not return_dict:
Expand Down Expand Up @@ -910,7 +938,10 @@ def forward(
if position_ids is None:
device = input_ids.device if input_ids is not None else inputs_embeds.device
position_ids = torch.arange(
past_key_values_length, seq_length + past_key_values_length, dtype=torch.long, device=device
past_key_values_length,
seq_length + past_key_values_length,
dtype=torch.long,
device=device,
)
position_ids = position_ids.unsqueeze(0).view(-1, seq_length)
else:
Expand All @@ -926,7 +957,9 @@ def forward(

if attention_mask is None:
attention_mask = torch.ones(
(batch_size, seq_length_with_past), dtype=torch.bool, device=inputs_embeds.device
(batch_size, seq_length_with_past),
dtype=torch.bool,
device=inputs_embeds.device,
)

attention_mask = self._prepare_decoder_attention_mask(
Expand Down
19 changes: 9 additions & 10 deletions colossalai/shardformer/modeling/mistral.py
Original file line number Diff line number Diff line change
@@ -1,21 +1,20 @@
from typing import Optional, Tuple
import warnings
from typing import List, Optional, Tuple, Union

import torch
from transformers.cache_utils import Cache
from transformers.modeling_attn_mask_utils import _prepare_4d_causal_attention_mask
from transformers.modeling_outputs import BaseModelOutputWithPast
from typing import List, Optional, Tuple, Union
import warnings
from transformers.models.mistral.modeling_mistral import MistralModel
from transformers.modeling_attn_mask_utils import _prepare_4d_causal_attention_mask
from transformers.utils import logging
from transformers.cache_utils import Cache

logger = logging.get_logger(__name__)


class MistralForwards:

@staticmethod
def mistral_model_forward(
self:MistralModel,
self: MistralModel,
input_ids: torch.LongTensor = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
Expand Down Expand Up @@ -94,7 +93,6 @@ def mistral_model_forward(
# decoder layers
all_hidden_states = () if output_hidden_states else None
all_self_attns = () if output_attentions else None
next_decoder_cache = None

for decoder_layer in self.layers:
if output_hidden_states:
Expand Down Expand Up @@ -123,7 +121,7 @@ def mistral_model_forward(
hidden_states = layer_outputs[0]

if use_cache:
next_decoder_cache = layer_outputs[2 if output_attentions else 1]
layer_outputs[2 if output_attentions else 1]

if output_attentions:
all_self_attns += (layer_outputs[1],)
Expand All @@ -145,6 +143,7 @@ def mistral_model_forward(
attentions=all_self_attns,
)


def get_mistral_flash_attention_forward():
from transformers.models.mistral.modeling_mistral import MistralAttention, apply_rotary_pos_emb, repeat_kv

Expand Down Expand Up @@ -218,7 +217,7 @@ def forward(

if not output_attentions:
attn_weights = None

return attn_output, attn_weights, past_key_value

return forward
9 changes: 5 additions & 4 deletions colossalai/shardformer/modeling/opt.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
import random
from typing import List, Optional, Tuple, Union

import torch
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
from transformers.modeling_attn_mask_utils import _prepare_4d_causal_attention_mask
from transformers.modeling_outputs import (
BaseModelOutputWithPast,
CausalLMOutputWithPast,
Expand All @@ -15,7 +17,7 @@
OPTModel,
)
from transformers.utils import logging
from transformers.modeling_attn_mask_utils import _prepare_4d_causal_attention_mask

from colossalai.pipeline.stage_manager import PipelineStageManager
from colossalai.shardformer.layer import ColoAttention
from colossalai.shardformer.shard import ShardConfig
Expand Down Expand Up @@ -55,7 +57,7 @@ class OPTPipelineForwards:
This class serves as a micro library for forward function substitution of OPT models
under pipeline setting.
"""

@staticmethod
def _expand_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int] = None):
"""
Expand All @@ -70,7 +72,6 @@ def _expand_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int]

return inverted_mask.masked_fill(inverted_mask.to(torch.bool), torch.finfo(dtype).min)


@staticmethod
def opt_model_forward(
self: OPTModel,
Expand Down Expand Up @@ -125,7 +126,7 @@ def opt_model_forward(
if decoder.project_in is not None:
inputs_embeds = decoder.project_in(inputs_embeds)
device = input_ids.device if input_ids is not None else inputs_embeds.device
_dtype = inputs_embeds.dtype
inputs_embeds.dtype
hidden_states = inputs_embeds
else:
if hidden_states is None:
Expand Down
Loading