Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
26 commits
Select commit Hold shift + click to select a range
ec217de
Feature/vit support (#4182)
klhhhhh Jul 7, 2023
ddecf73
[shardformer] support SAM (#4231)
FoolPlayer Jul 14, 2023
afcf4a0
[shardformer] support whisper (#4212)
FoolPlayer Jul 17, 2023
77cc087
Feature/chatglm (#4240)
klhhhhh Jul 20, 2023
6c2acf0
[shardformer] added tests
klhhhhh Jul 4, 2023
7668b24
[shardformer] vit test finish and support
klhhhhh Jul 6, 2023
b135b75
import chatglm
klhhhhh Jul 7, 2023
e3cd5cb
[shardformer] add test kit in model zoo for chatglm
klhhhhh Jul 7, 2023
30574a7
[sharformer] add first version of policy of chatglm
klhhhhh Jul 10, 2023
28677d4
[shardformer] polish chatglm code
klhhhhh Jul 12, 2023
28319c2
[shardformer] polish code
klhhhhh Jul 13, 2023
3f19de9
[shardformer] support chatglm without layernorm
klhhhhh Jul 14, 2023
2a4bbcf
[shardformer] delete some file
klhhhhh Jul 17, 2023
32448e3
[shardformer] ChatGLM support layernorm sharding
klhhhhh Jul 17, 2023
eb1c71a
[shardformer] register without auto policy
klhhhhh Jul 18, 2023
127e385
[shardformer] pre-commit check files
klhhhhh Jul 19, 2023
9d5b141
[shardformer] support ChatGLMForConditionalGeneration & add fusedlaye…
klhhhhh Jul 20, 2023
805f342
Merge pull request #4297 from klhhhhh/feature/support_ChatGLMForCondi…
klhhhhh Jul 21, 2023
f48a8bb
[shardformer] support Blip2 (#4243)
FoolPlayer Jul 25, 2023
bf2beb0
[shardformer] merge blip2 from shard-models branch
flybird11111 Jul 25, 2023
0d31eea
[shardformer] chatglm support flash attention and jit operator
flybird11111 Jul 26, 2023
4ae9d97
Merge branch 'feature/flash-attention-shardformer' into update-chatglm
flybird11111 Jul 26, 2023
f3f2ada
[shardformer] chatglm support flash attention and jit operator
flybird11111 Jul 26, 2023
c104392
Merge branch 'update-chatglm' of https://github.com/flybird1111/Colos…
flybird11111 Jul 26, 2023
4ee7a3c
[shardformer] chatglm support flash attention and jit operator
flybird11111 Jul 26, 2023
8e5fa05
[shardformer] chatglm support flash attention and jit operator
flybird11111 Jul 26, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
111 changes: 111 additions & 0 deletions colossalai/shardformer/modeling/chatglm.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,111 @@
import torch
import torch.nn.functional as F


def get_flash_core_attention_forward():

from colossalai.kernel.cuda_native.flash_attention import AttnMaskType, ColoAttention

from .chatglm2_6b.modeling_chatglm import CoreAttention

def forward(self: CoreAttention, query_layer, key_layer, value_layer, attention_mask):
pytorch_major_version = int(torch.__version__.split(".")[0])
if pytorch_major_version >= 2:
query_layer, key_layer, value_layer = [k.permute(1, 2, 0, 3) for k in [query_layer, key_layer, value_layer]]
if attention_mask is None and query_layer.shape[2] == key_layer.shape[2]:
context_layer = torch.nn.functional.scaled_dot_product_attention(query_layer,
key_layer,
value_layer,
is_causal=True)
else:
if attention_mask is not None:
attention_mask = ~attention_mask
context_layer = torch.nn.functional.scaled_dot_product_attention(query_layer, key_layer, value_layer,
attention_mask)
context_layer = context_layer.permute(2, 0, 1, 3)
new_context_layer_shape = context_layer.size()[:-2] + (self.hidden_size_per_partition,)
context_layer = context_layer.reshape(*new_context_layer_shape)
else:
# Raw attention scores
query_layer = query_layer.permute(1, 0, 2, 3).contiguous()
key_layer = key_layer.permute(1, 0, 2, 3).contiguous()
value_layer = value_layer.permute(1, 0, 2, 3).contiguous()

scale = 1.0 / self.norm_factor
if self.coeff is not None:
scale = scale * self.coeff

flash_attention_mask = None
attn_mask_type = None
if attention_mask is None:
attn_mask_type = AttnMaskType.causal
else:
flash_attention_mask = ~(attention_mask[:, :, -1].squeeze(1).to(torch.bool)).contiguous()
attn_mask_type = AttnMaskType.paddedcausal

attention = ColoAttention(embed_dim=self.hidden_size_per_partition,
num_heads=self.num_attention_heads_per_partition,
dropout=self.attention_dropout.p,
scale=scale)
context_layer = attention(query_layer,
key_layer,
value_layer,
attn_mask=flash_attention_mask,
attn_mask_type=attn_mask_type)

context_layer = context_layer.permute(1, 0, -1).contiguous()

return context_layer

return forward


def get_jit_fused_glm_block_forward():

from .chatglm2_6b.modeling_chatglm import GLMBlock

def forward(
self: GLMBlock,
hidden_states,
attention_mask,
rotary_pos_emb,
kv_cache=None,
use_cache=True,
):
# hidden_states: [s, b, h]
# Layer norm at the beginning of the transformer layer.
layernorm_output = self.input_layernorm(hidden_states)
# Self attention.
attention_output, kv_cache = self.self_attention(
layernorm_output,
attention_mask,
rotary_pos_emb,
kv_cache=kv_cache,
use_cache=use_cache,
)

# Residual connection.
if self.apply_residual_connection_post_layernorm:
residual = layernorm_output
else:
residual = hidden_states

layernorm_input = self.dropout_add(attention_output, residual, self.hidden_dropout, self.training)

# Layer norm post the self attention.
layernorm_output = self.post_attention_layernorm(layernorm_input)

# MLP.
mlp_output = self.mlp(layernorm_output)

# Second residual connection.
if self.apply_residual_connection_post_layernorm:
residual = layernorm_output
else:
residual = layernorm_input

output = self.dropout_add(mlp_output, residual, self.hidden_dropout, self.training)

return output, kv_cache

return forward
17 changes: 16 additions & 1 deletion colossalai/shardformer/policies/chatglm.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@

import colossalai.shardformer.layer as col_nn

from ..modeling.chatglm import get_flash_core_attention_forward, get_jit_fused_glm_block_forward
from ..modeling.jit import get_jit_fused_dropout_add_func
from .basepolicy import ModulePolicyDescription, Policy, SubModuleReplacementDescription

__all__ = ['ChatGLMModelPolicy', 'ChatGLMForConditionalGenerationPolicy']
Expand All @@ -26,7 +28,7 @@ def preprocess(self):
return self.model

def module_policy(self) -> Dict[Union[str, nn.Module], ModulePolicyDescription]:
from tests.kit.model_zoo.transformers.chatglm2_6b.modeling_chatglm import ChatGLMModel, GLMBlock
from ..modeling.chatglm2_6b.modeling_chatglm import ChatGLMModel, CoreAttention, GLMBlock

policy = {}

Expand Down Expand Up @@ -107,6 +109,19 @@ def module_policy(self) -> Dict[Union[str, nn.Module], ModulePolicyDescription]:
policy=policy,
target_key=ChatGLMModel)

# use flash attention
if self.shard_config.enable_flash_attention:
policy[CoreAttention] = ModulePolicyDescription(method_replacement={
'forward': get_flash_core_attention_forward(),
})

# use jit fused operator
if self.shard_config.enable_jit_fused:
policy[GLMBlock] = ModulePolicyDescription(method_replacement={
'forward': get_jit_fused_glm_block_forward(),
'dropout_add': get_jit_fused_dropout_add_func(),
})

return policy

def postprocess(self):
Expand Down
5 changes: 3 additions & 2 deletions tests/kit/model_zoo/transformers/chatglm.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
import torch
import transformers

from colossalai.shardformer.modeling.chatglm2_6b.configuration_chatglm import ChatGLMConfig
from colossalai.shardformer.modeling.chatglm2_6b.modeling_chatglm import ChatGLMForConditionalGeneration, ChatGLMModel

from ..registry import ModelAttribute, model_zoo
from .chatglm2_6b.configuration_chatglm import ChatGLMConfig
from .chatglm2_6b.modeling_chatglm import ChatGLMForConditionalGeneration, ChatGLMModel

# ================================
# Register single-sentence ChatGLM
Expand Down
8 changes: 6 additions & 2 deletions tests/test_shardformer/test_model/test_shard_chatglm.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,15 +72,19 @@ def check_forward_backward(org_model, sharded_model, data_gen_fn, output_transfo

@parameterize('enable_fused_normalization', [True, False])
@parameterize('enable_tensor_parallelism', [True, False])
def run_chatglm_test(enable_fused_normalization, enable_tensor_parallelism):
@parameterize('enable_flash_attention', [True, False])
@parameterize('enable_jit_fused', [True, False])
def run_chatglm_test(enable_fused_normalization, enable_tensor_parallelism, enable_flash_attention, enable_jit_fused):
sub_model_zoo = model_zoo.get_sub_registry('transformers_chatglm')
for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items():
# create new model
org_model = model_fn().cuda()

# shard model
shard_config = ShardConfig(enable_fused_normalization=enable_fused_normalization,
enable_tensor_parallelism=enable_tensor_parallelism)
enable_tensor_parallelism=enable_tensor_parallelism,
enable_flash_attention=enable_flash_attention,
enable_jit_fused=enable_jit_fused)
model_copy = copy.deepcopy(org_model)
shard_former = ShardFormer(shard_config=shard_config)
if name == "transformers_chatglm":
Expand Down