Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
23 commits
Select commit Hold shift + click to select a range
ec217de
Feature/vit support (#4182)
klhhhhh Jul 7, 2023
ddecf73
[shardformer] support SAM (#4231)
FoolPlayer Jul 14, 2023
afcf4a0
[shardformer] support whisper (#4212)
FoolPlayer Jul 17, 2023
77cc087
Feature/chatglm (#4240)
klhhhhh Jul 20, 2023
6c2acf0
[shardformer] added tests
klhhhhh Jul 4, 2023
7668b24
[shardformer] vit test finish and support
klhhhhh Jul 6, 2023
b135b75
import chatglm
klhhhhh Jul 7, 2023
e3cd5cb
[shardformer] add test kit in model zoo for chatglm
klhhhhh Jul 7, 2023
30574a7
[sharformer] add first version of policy of chatglm
klhhhhh Jul 10, 2023
28677d4
[shardformer] polish chatglm code
klhhhhh Jul 12, 2023
28319c2
[shardformer] polish code
klhhhhh Jul 13, 2023
3f19de9
[shardformer] support chatglm without layernorm
klhhhhh Jul 14, 2023
2a4bbcf
[shardformer] delete some file
klhhhhh Jul 17, 2023
32448e3
[shardformer] ChatGLM support layernorm sharding
klhhhhh Jul 17, 2023
eb1c71a
[shardformer] register without auto policy
klhhhhh Jul 18, 2023
127e385
[shardformer] pre-commit check files
klhhhhh Jul 19, 2023
9d5b141
[shardformer] support ChatGLMForConditionalGeneration & add fusedlaye…
klhhhhh Jul 20, 2023
805f342
Merge pull request #4297 from klhhhhh/feature/support_ChatGLMForCondi…
klhhhhh Jul 21, 2023
f48a8bb
[shardformer] support Blip2 (#4243)
FoolPlayer Jul 25, 2023
bf2beb0
[shardformer] merge blip2 from shard-models branch
flybird11111 Jul 25, 2023
5f958bc
[shardformer] vit support flash attention and jit operator
flybird11111 Jul 26, 2023
b2b0c9c
[shardformer] vit support flash attention and jit operator
flybird11111 Jul 26, 2023
bb39b2e
Merge branch 'feature/flash-attention-shardformer' into update-vit
flybird11111 Jul 26, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
53 changes: 53 additions & 0 deletions colossalai/shardformer/modeling/vit.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
import math
from typing import Optional, Tuple, Union

import torch
from torch import nn


def get_vit_flash_self_attention_forward():

from transformers.models.vit.modeling_vit import ViTSelfAttention

from colossalai.kernel.cuda_native.flash_attention import ColoAttention

def transpose_for_scores(x: torch.Tensor, num_attention_heads, attention_head_size) -> torch.Tensor:
new_x_shape = x.size()[:-1] + (num_attention_heads, attention_head_size)
x = x.view(new_x_shape)
return x

def forward(self: ViTSelfAttention,
hidden_states: torch.Tensor,
head_mask: Optional[torch.Tensor] = None,
output_attentions: bool = False) -> Union[Tuple[torch.Tensor, torch.Tensor], Tuple[torch.Tensor]]:
mixed_query_layer = self.query(hidden_states)

key_layer = transpose_for_scores(self.key(hidden_states), self.num_attention_heads, self.attention_head_size)
value_layer = transpose_for_scores(self.value(hidden_states), self.num_attention_heads,
self.attention_head_size)
query_layer = transpose_for_scores(mixed_query_layer, self.num_attention_heads, self.attention_head_size)

scale = 1.0 / math.sqrt(self.attention_head_size)
attention = ColoAttention(embed_dim=self.all_head_size,
num_heads=self.num_attention_heads,
dropout=self.dropout.p,
scale=scale)
context_layer = attention(query_layer, key_layer, value_layer)

outputs = (context_layer,)

return outputs

return forward


def get_jit_fused_vit_output_forward():

from transformers.models.vit.modeling_vit import ViTOutput

def forward(self: ViTOutput, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> torch.Tensor:
hidden_states = self.dense(hidden_states)
hidden_states = self.dropout_add(hidden_states, input_tensor, self.dropout.p, self.dropout.training)
return hidden_states

return forward
17 changes: 16 additions & 1 deletion colossalai/shardformer/policies/vit.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@
Linear1D_Row,
)

from ..modeling.jit import get_jit_fused_dropout_add_func
from ..modeling.vit import get_jit_fused_vit_output_forward, get_vit_flash_self_attention_forward
from .basepolicy import ModulePolicyDescription, Policy, SubModuleReplacementDescription

__all__ = ['ViTPolicy', 'ViTForImageClassificationPolicy', 'ViTForMaskedImageModelingPolicy']
Expand All @@ -24,7 +26,7 @@ def preprocess(self):
return self.model

def module_policy(self) -> Dict[Union[str, nn.Module], ModulePolicyDescription]:
from transformers.models.vit.modeling_vit import ViTEmbeddings, ViTLayer, ViTModel
from transformers.models.vit.modeling_vit import ViTEmbeddings, ViTLayer, ViTModel, ViTOutput, ViTSelfAttention

policy = {}

Expand Down Expand Up @@ -101,6 +103,19 @@ def module_policy(self) -> Dict[Union[str, nn.Module], ModulePolicyDescription]:
policy=policy,
target_key=ViTLayer)

# use flash attention
if self.shard_config.enable_flash_attention:
policy[ViTSelfAttention] = ModulePolicyDescription(method_replacement={
'forward': get_vit_flash_self_attention_forward(),
})

# use jit fused operator
if self.shard_config.enable_jit_fused:
policy[ViTOutput] = ModulePolicyDescription(method_replacement={
'forward': get_jit_fused_vit_output_forward(),
'dropout_add': get_jit_fused_dropout_add_func(),
})

return policy

def new_model_class(self):
Expand Down
9 changes: 6 additions & 3 deletions tests/test_shardformer/test_model/test_shard_vit.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,14 +50,17 @@ def check_forward_backward(org_model, sharded_model, data_gen_fn, output_transfo
all_shard_grad = shard_grad
assert torch.allclose(org_grad, all_shard_grad,
atol=1e-5), f"shard model grad is not equal to orgin model grad\n{org_grad}\n{shard_grad}"


@parameterize('enable_fused_normalization', [True, False])
@parameterize('enable_tensor_parallelism', [True, False])
def run_vit_test(enable_fused_normalization, enable_tensor_parallelism):
@parameterize('enable_flash_attention', [True, False])
@parameterize('enable_jit_fused', [True, False])
def run_vit_test(enable_fused_normalization, enable_tensor_parallelism, enable_flash_attention, enable_jit_fused):
sub_model_zoo = model_zoo.get_sub_registry('transformers_vit')
for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items():
org_model, sharded_model = build_model(model_fn, enable_fused_normalization, enable_tensor_parallelism)
org_model, sharded_model = build_model(model_fn, enable_fused_normalization, enable_tensor_parallelism,
enable_flash_attention, enable_jit_fused)
check_forward_backward(org_model, sharded_model, data_gen_fn, output_transform_fn, loss_fn)
torch.cuda.empty_cache()

Expand Down