Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
50 commits
Select commit Hold shift + click to select a range
739af90
flash_attention forward upgrade
wangbluo Mar 21, 2024
976396c
llama_model_forward
wangbluo Mar 25, 2024
63ef374
remove useless comment
wangbluo Mar 25, 2024
b00f9ea
update the requirements.txt
wangbluo Mar 25, 2024
dc8b9d4
add the transformers version requirements
wangbluo Mar 25, 2024
9206dd1
remove the LATEST VERSION try
wangbluo Mar 26, 2024
cdb166c
Merge pull request #5499 from wangbluo/update_llama2
wangbluo Mar 26, 2024
f1ebe54
[shardformer] update bloom model (#5518)
wangbluo Apr 1, 2024
2cdca4d
[shardformer] update_falcon (#5520)
wangbluo Apr 3, 2024
7686f4e
[shardformer] update mistral model (#5511)
wangbluo Apr 3, 2024
fd44440
[shardformer] update gpt2 (#5502)
wangbluo Apr 3, 2024
9a5edc3
[shardformer] update gptj model (#5503)
wangbluo Apr 3, 2024
cbff8c0
[shardformer] update opt (#5522)
wangbluo Apr 3, 2024
46479fb
[shardformer] update t5 model (#5524)
wangbluo Apr 3, 2024
d7af2d8
[shardformer] update whisper model (#5529)
wangbluo Apr 3, 2024
02d9b88
[shardformer] update vit model (#5530)
wangbluo Apr 3, 2024
2006339
Merge branch 'main' into feature/update-transformers
ver217 Apr 12, 2024
c3e8215
[shardformer] fix llama modeling
ver217 Apr 12, 2024
8b72eab
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Apr 12, 2024
c2fab31
Merge pull request #5592 from ver217/hotfix/shard-llama
wangbluo Apr 12, 2024
46b90f7
[zero] support multiple (partial) backward passes (#5596)
ver217 Apr 16, 2024
b15b964
[zero] support multiple (partial) backward passes (#5596)
ver217 Apr 16, 2024
4f5fee4
fix conflicts
wangbluo Apr 18, 2024
b323f0a
[doc] fix ColossalMoE readme (#5599)
Camille7777 Apr 15, 2024
7cecde1
merge with main
wangbluo Apr 18, 2024
98eff6d
merge with main
wangbluo Apr 18, 2024
267efc8
llama_model_forward
wangbluo Mar 25, 2024
0bdcc84
remove useless comment
wangbluo Mar 25, 2024
e520e0b
remove the LATEST VERSION try
wangbluo Mar 26, 2024
2d9a21d
[shardformer] update bloom model (#5518)
wangbluo Apr 1, 2024
50b4c86
[shardformer] update mistral model (#5511)
wangbluo Apr 3, 2024
1233fc2
[shardformer] update opt (#5522)
wangbluo Apr 3, 2024
ab160a8
[shardformer] update whisper model (#5529)
wangbluo Apr 3, 2024
16a29ff
[shardformer] fix llama modeling
ver217 Apr 12, 2024
06d7c30
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Apr 12, 2024
b427fee
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Apr 18, 2024
31b8ff4
Merge pull request #5607 from hpcaitech/merge-main
wangbluo Apr 18, 2024
0b2584d
[hotfix] Fix examples no pad token & auto parallel codegen bug; (#5606)
Edenzzzz Apr 18, 2024
cbea063
[shardformer] fix pipeline grad ckpt (#5620)
ver217 Apr 22, 2024
46190f4
[shardformer] fix whisper (#5628)
ver217 Apr 23, 2024
4a0b2de
[test] fix llama model test
ver217 Apr 23, 2024
1556840
Merge pull request #5635 from ver217/hotfix/llama-upgrade
wangbluo Apr 24, 2024
2e2d1c1
fix the opt upgrade (#5634)
wangbluo Apr 24, 2024
e021cea
[shardformer] fix attn replacement (#5636)
ver217 Apr 24, 2024
fa0d8ab
[shardformer] update flashattention replacement (#5637)
flybird11111 Apr 24, 2024
d98ac05
Merge branch 'main' into feature/update-transformers
ver217 Apr 24, 2024
52f4d3a
[test] fix llama test (#5638)
ver217 Apr 24, 2024
fcceb78
[gemini] fix buffer cast (#5639)
ver217 Apr 24, 2024
2ad14bd
Fix shardformer upgrade (#5640)
wangbluo Apr 24, 2024
c253a7e
[shardformer]support pipeline parallelism for mistral. (#5642)
flybird11111 Apr 24, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
38 changes: 15 additions & 23 deletions colossalai/shardformer/modeling/bloom.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from torch.distributed import ProcessGroup
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
from torch.nn import functional as F
from transformers.modeling_attn_mask_utils import _prepare_4d_causal_attention_mask
from transformers.modeling_outputs import (
BaseModelOutputWithPastAndCrossAttentions,
CausalLMOutputWithCrossAttentions,
Expand Down Expand Up @@ -205,12 +206,13 @@ def bloom_model_forward(
alibi = self.build_alibi_tensor(attention_mask, self.num_heads, dtype=hidden_states.dtype)

# causal_mask is constructed every stage and its input is passed through different stages
causal_mask = self._prepare_attn_mask(
causal_mask = _prepare_4d_causal_attention_mask(
attention_mask,
input_shape=(batch_size, seq_length),
inputs_embeds=hidden_states,
past_key_values_length=past_key_values_length,
)

causal_mask = causal_mask.bool()
# split the input tensor along sequence dimension
# [batch_size, seq_len, hidden_size] -> [batch_size, seq_len/TP_size, hidden_size]
if shard_config and shard_config.enable_sequence_parallelism:
Expand All @@ -227,21 +229,15 @@ def bloom_model_forward(
all_hidden_states = all_hidden_states + (hidden_states,)

if self.gradient_checkpointing and self.training:

def create_custom_forward(module):
def custom_forward(*inputs):
# None for past_key_value
return module(*inputs, use_cache=use_cache, output_attentions=output_attentions)

return custom_forward

outputs = torch.utils.checkpoint.checkpoint(
create_custom_forward(block),
outputs = self._gradient_checkpointing_func(
block.__call__,
hidden_states,
alibi,
causal_mask,
layer_past,
head_mask[i],
use_cache,
output_attentions,
)
else:
outputs = block(
Expand Down Expand Up @@ -1002,11 +998,13 @@ def forward(

alibi = self.build_alibi_tensor(attention_mask, self.num_heads, dtype=hidden_states.dtype)

causal_mask = self._prepare_attn_mask(
causal_mask = _prepare_4d_causal_attention_mask(
attention_mask,
input_shape=(batch_size, seq_length),
inputs_embeds=hidden_states,
past_key_values_length=past_key_values_length,
)
causal_mask = causal_mask.bool()
# split the input tensor along sequence dimension
# [batch_size, seq_len, hidden_size] -> [batch_size, seq_len/TP_size, hidden_size]
hidden_states = split_forward_gather_backward(
Expand All @@ -1018,21 +1016,15 @@ def forward(
all_hidden_states = all_hidden_states + (hidden_states,)

if self.gradient_checkpointing and self.training:

def create_custom_forward(module):
def custom_forward(*inputs):
# None for past_key_value
return module(*inputs, use_cache=use_cache, output_attentions=output_attentions)

return custom_forward

outputs = torch.utils.checkpoint.checkpoint(
create_custom_forward(block),
outputs = self._gradient_checkpointing_func(
block.__call__,
hidden_states,
alibi,
causal_mask,
layer_past,
head_mask[i],
use_cache,
output_attentions,
)
else:
outputs = block(
Expand Down
Loading