Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions colossalai/shardformer/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,8 @@ We will follow this roadmap to develop Shardformer:
- [x] Whisper
- [ ] Multi-modal
- [ ] To be added
- [x] SAM
- [x] BLIP-2
- [ ] Flash Attention Support
- [ ] NLP
- [x] BERT
Expand All @@ -119,6 +121,7 @@ We will follow this roadmap to develop Shardformer:
- [ ] ERNIE
- [ ] GPT Neo
- [ ] GPT-J

## 💡 API Design

We will discuss the major components of `ShardFormer` below to help you better understand how things work.
Expand Down
60 changes: 60 additions & 0 deletions colossalai/shardformer/modeling/blip2.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
from typing import Optional, Tuple, Union

import torch
import torch.nn as nn


def forward_fn():

def forward(
self,
hidden_states: torch.Tensor,
head_mask: Optional[torch.Tensor] = None,
output_attentions: Optional[bool] = False,
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
"""Input shape: Batch x Time x Channel"""

bsz, tgt_len, embed_dim = hidden_states.size()

mixed_qkv = self.qkv(hidden_states)

# modified from original code, which is:
# mixed_qkv = mixed_qkv.reshape(bsz, tgt_len, 3, self.num_heads, embed_dim // self.num_heads).permute(
# 2, 0, 3, 1, 4
# )
# to:
mixed_qkv = mixed_qkv.reshape(bsz, tgt_len, 3, self.num_heads, -1).permute(2, 0, 3, 1, 4)
query_states, key_states, value_states = (
mixed_qkv[0],
mixed_qkv[1],
mixed_qkv[2],
)

# Take the dot product between "query" and "key" to get the raw attention scores.
attention_scores = torch.matmul(query_states, key_states.transpose(-1, -2))

attention_scores = attention_scores * self.scale

# Normalize the attention scores to probabilities.
attention_probs = nn.functional.softmax(attention_scores, dim=-1)

# This is actually dropping out entire tokens to attend to, which might
# seem a bit unusual, but is taken from the original Transformer paper.
attention_probs = self.dropout(attention_probs)

# Mask heads if we want to
if head_mask is not None:
attention_probs = attention_probs * head_mask

context_layer = torch.matmul(attention_probs, value_states).permute(0, 2, 1, 3)

new_context_layer_shape = context_layer.size()[:-2] + (self.embed_dim,)
context_layer = context_layer.reshape(new_context_layer_shape)

output = self.projection(context_layer)

outputs = (output, attention_probs) if output_attentions else (output, None)

return outputs

return forward
4 changes: 1 addition & 3 deletions colossalai/shardformer/modeling/sam.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,8 @@
from typing import Tuple

import torch
import torch.distributed as dist
import torch.nn.functional as F
from torch import Tensor, nn
from torch.distributed import ProcessGroup
from torch import Tensor


def forward_fn():
Expand Down
6 changes: 6 additions & 0 deletions colossalai/shardformer/policies/autopolicy.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,6 +112,12 @@ class PolicyLocation:
# Sam
"transformers.models.sam.modeling_sam.SamModel":
PolicyLocation(file_name="sam", class_name="SamModelPolicy"),

# Blip2
"transformers.models.blip_2.modeling_blip_2.Blip2Model":
PolicyLocation(file_name="blip2", class_name="Blip2ModelPolicy"),
"transformers.models.blip_2.modeling_blip_2.Blip2ForConditionalGeneration":
PolicyLocation(file_name="blip2", class_name="Blip2ForConditionalGenerationPolicy"),
}


Expand Down
304 changes: 304 additions & 0 deletions colossalai/shardformer/policies/blip2.py

Large diffs are not rendered by default.

24 changes: 24 additions & 0 deletions colossalai/shardformer/policies/chatglm.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,31 @@ def module_policy(self) -> Dict[Union[str, nn.Module], ModulePolicyDescription]:
policy=policy,
target_key=ChatGLMModel)

else:
self.append_or_create_submodule_replacement(description=[
SubModuleReplacementDescription(suffix="input_layernorm", target_module=col_nn.FusedRMSNorm),
SubModuleReplacementDescription(suffix="post_attention_layernorm",
target_module=col_nn.FusedRMSNorm)
],
policy=policy,
target_key=GLMBlock)

if self.model.config.post_layer_norm:
self.append_or_create_submodule_replacement(description=[
SubModuleReplacementDescription(suffix="encoder.final_layernorm",
target_module=col_nn.FusedRMSNorm)
],
policy=policy,
target_key=ChatGLMModel)

return policy

def postprocess(self):
return self.model


class ChatGLMForConditionalGenerationPolicy(ChatGLMModelPolicy):

def module_policy(self):
policy = super().module_policy()
return policy
157 changes: 87 additions & 70 deletions colossalai/shardformer/policies/vit.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,13 @@

import torch.nn as nn

from colossalai.shardformer.layer import DropoutForReplicatedInput, DropoutForParallelInput, FusedLayerNorm, Linear1D_Col, Linear1D_Row
from colossalai.shardformer.layer import (
DropoutForParallelInput,
DropoutForReplicatedInput,
FusedLayerNorm,
Linear1D_Col,
Linear1D_Row,
)

from .basepolicy import ModulePolicyDescription, Policy, SubModuleReplacementDescription

Expand All @@ -18,101 +24,112 @@ def preprocess(self):
return self.model

def module_policy(self) -> Dict[Union[str, nn.Module], ModulePolicyDescription]:
from transformers.models.vit.modeling_vit import ViTEmbeddings, ViTLayer
from transformers.models.vit.modeling_vit import ViTEmbeddings, ViTLayer, ViTModel

policy = {}

if self.shard_config.enable_tensor_parallelism:
policy[ViTEmbeddings] = ModulePolicyDescription(attribute_replacement={},
param_replacement=[],
sub_module_replacement=[
SubModuleReplacementDescription(
suffix="dropout",
target_module=DropoutForReplicatedInput,
)
])

policy[ViTLayer] = ModulePolicyDescription(
attribute_replacement={
"attention.attention.num_attention_heads":
self.model.config.num_attention_heads//self.shard_config.tensor_parallel_size,
"attention.attention.all_head_size":
self.model.config.hidden_size//self.shard_config.tensor_parallel_size,
},
param_replacement=[],
sub_module_replacement=[
SubModuleReplacementDescription(
suffix="attention.attention.query",
target_module=Linear1D_Col,
),
SubModuleReplacementDescription(
suffix="attention.attention.key",
target_module=Linear1D_Col,
),
SubModuleReplacementDescription(
suffix="attention.attention.value",
target_module=Linear1D_Col,
),
SubModuleReplacementDescription(
suffix="attention.attention.dropout",
target_module=DropoutForParallelInput,
),
SubModuleReplacementDescription(
suffix="attention.output.dense",
target_module=Linear1D_Row,
),
SubModuleReplacementDescription(
suffix="attention.output.dropout",
target_module=DropoutForReplicatedInput,
),
SubModuleReplacementDescription(
suffix="intermediate.dense",
target_module=Linear1D_Col,
),
SubModuleReplacementDescription(
suffix="output.dense",
target_module=Linear1D_Row,
),
SubModuleReplacementDescription(
suffix="output.dropout",
target_module=DropoutForReplicatedInput,
),
]
)
param_replacement=[],
sub_module_replacement=[
SubModuleReplacementDescription(
suffix="dropout",
target_module=DropoutForReplicatedInput,
)
])

policy[ViTLayer] = ModulePolicyDescription(attribute_replacement={
"attention.attention.num_attention_heads":
self.model.config.num_attention_heads // self.shard_config.tensor_parallel_size,
"attention.attention.all_head_size":
self.model.config.hidden_size // self.shard_config.tensor_parallel_size,
},
param_replacement=[],
sub_module_replacement=[
SubModuleReplacementDescription(
suffix="attention.attention.query",
target_module=Linear1D_Col,
),
SubModuleReplacementDescription(
suffix="attention.attention.key",
target_module=Linear1D_Col,
),
SubModuleReplacementDescription(
suffix="attention.attention.value",
target_module=Linear1D_Col,
),
SubModuleReplacementDescription(
suffix="attention.attention.dropout",
target_module=DropoutForParallelInput,
),
SubModuleReplacementDescription(
suffix="attention.output.dense",
target_module=Linear1D_Row,
),
SubModuleReplacementDescription(
suffix="attention.output.dropout",
target_module=DropoutForReplicatedInput,
),
SubModuleReplacementDescription(
suffix="intermediate.dense",
target_module=Linear1D_Col,
),
SubModuleReplacementDescription(
suffix="output.dense",
target_module=Linear1D_Row,
),
SubModuleReplacementDescription(
suffix="output.dropout",
target_module=DropoutForReplicatedInput,
),
])

if self.shard_config.enable_fused_normalization:
policy[ViTModel] = ModulePolicyDescription(attribute_replacement={},
param_replacement=[],
sub_module_replacement=[
SubModuleReplacementDescription(
suffix="layernorm",
target_module=FusedLayerNorm,
)
])

self.append_or_create_submodule_replacement(description=[
SubModuleReplacementDescription(suffix="layernorm_before", target_module=FusedLayerNorm),
SubModuleReplacementDescription(suffix="layernorm_after", target_module=FusedLayerNorm)
],
policy=policy,
target_key=ViTLayer)

return policy



def new_model_class(self):
return None

def postprocess(self):
return self.model


class ViTForImageClassificationPolicy(ViTPolicy):

def module_policy(self):
def module_policy(self):
from transformers.models.vit.modeling_vit import ViTForImageClassification

policy = super().module_policy()
if self.shard_config.enable_tensor_parallelism:
new_item = {
ViTForImageClassification:
ModulePolicyDescription(sub_module_replacement=[
SubModuleReplacementDescription(suffix="classifier",
target_module=Linear1D_Col,
kwargs=dict(gather_output=True))
])
ModulePolicyDescription(sub_module_replacement=[
SubModuleReplacementDescription(
suffix="classifier", target_module=Linear1D_Col, kwargs=dict(gather_output=True))
])
}
policy.update(new_item)
return policy


class ViTForMaskedImageModelingPolicy(ViTPolicy):

def module_policy(self):
policy = super().module_policy()
return policy




1 change: 1 addition & 0 deletions tests/kit/model_zoo/transformers/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from .albert import *
from .bert import *
from .blip2 import *
from .bloom import *
from .chatglm import *
from .gpt import *
Expand Down
Loading