Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions colossalai/shardformer/policies/autopolicy.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@

from .basepolicy import Policy

__all__ = ["PolicyLocation", "get_autopolicy", "import_policy"]


@dataclass
class PolicyLocation:
Expand Down
2 changes: 2 additions & 0 deletions colossalai/shardformer/policies/basepolicy.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@

from ..shard.shard_config import ShardConfig

__all__ = ["ParallelModule", "SubModuleReplacementDescription", "ModulePolicyDescription", "Policy"]


class ParallelModule():

Expand Down
30 changes: 21 additions & 9 deletions colossalai/shardformer/policies/bert.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,16 @@
import torch.nn as nn
from transformers.models.bert.modeling_bert import (
BertEmbeddings,
BertForMultipleChoice,
BertForSequenceClassification,
BertForTokenClassification,
BertLayer,
BertLMPredictionHead,
)

import colossalai.shardformer.layer as col_nn

from .._utils import getattr_, setattr_
from .basepolicy import ModulePolicyDescription, Policy, SubModuleReplacementDescription

__all__ = [
'BertPolicy', 'BertModelPolicy', 'BertForPretrainingPolicy', 'BertLMHeadModelPolicy', 'BertForMaskedLMPolicy',
'BertForNextSentencePredictionPolicy', 'BertForSequenceClassificationPolicy', 'BertForTokenClassificationPolicy',
'BertForMultipleChoicePolicy'
]


class BertPolicy(Policy):

Expand All @@ -33,6 +31,8 @@ def preprocess(self):
return self.model

def module_policy(self):
from transformers.models.bert.modeling_bert import BertEmbeddings, BertLayer

base_policy = {
BertLayer:
ModulePolicyDescription(
Expand Down Expand Up @@ -123,7 +123,7 @@ def module_policy(self):

def new_model_class(self):
# do nothing
return self.model
return None

def postprocess(self):
return self.model
Expand All @@ -143,6 +143,8 @@ def __init__(self) -> None:
super().__init__()

def module_policy(self):
from transformers.models.bert.modeling_bert import BertLMPredictionHead

module_policy = super().module_policy()
addon_module = {
BertLMPredictionHead:
Expand Down Expand Up @@ -184,6 +186,8 @@ def __init__(self) -> None:
super().__init__()

def module_policy(self):
from transformers.models.bert.modeling_bert import BertLMPredictionHead

module_policy = super().module_policy()
addon_module = {
BertLMPredictionHead:
Expand Down Expand Up @@ -221,6 +225,8 @@ def __init__(self) -> None:
super().__init__()

def module_policy(self):
from transformers.models.bert.modeling_bert import BertLMPredictionHead

module_policy = super().module_policy()
addon_module = {
BertLMPredictionHead:
Expand Down Expand Up @@ -261,6 +267,8 @@ def __init__(self) -> None:
super().__init__()

def module_policy(self):
from transformers.models.bert.modeling_bert import BertForSequenceClassification

module_policy = super().module_policy()
addon_module = {
BertForSequenceClassification:
Expand All @@ -284,6 +292,8 @@ def __init__(self) -> None:
super().__init__()

def module_policy(self):
from transformers.models.bert.modeling_bert import BertForTokenClassification

module_policy = super().module_policy()
addon_module = {
BertForTokenClassification:
Expand Down Expand Up @@ -314,6 +324,8 @@ def __init__(self) -> None:
super().__init__()

def module_policy(self):
from transformers.models.bert.modeling_bert import BertForMultipleChoice

module_policy = super().module_policy()
addon_module = {
BertForMultipleChoice:
Expand Down
14 changes: 12 additions & 2 deletions colossalai/shardformer/policies/gpt2.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,15 @@
import torch.nn as nn
from transformers.models.gpt2.modeling_gpt2 import GPT2Block, GPT2DoubleHeadsModel, GPT2LMHeadModel, GPT2Model

import colossalai.shardformer.layer as col_nn

from .._utils import getattr_, setattr_
from .basepolicy import ModulePolicyDescription, Policy, SubModuleReplacementDescription

__all__ = [
'GPT2Policy', 'GPT2ModelPolicy', 'GPT2LMHeadModelPolicy', 'GPT2DoubleHeadsModelPolicy',
'GPT2ForTokenClassificationPolicy', 'GPT2ForSequenceClassificationPolicy'
]


class GPT2Policy(Policy):

Expand All @@ -25,7 +29,9 @@ def preprocess(self):
return self.model

def module_policy(self):
base_policy = {
from transformers.models.gpt2.modeling_gpt2 import GPT2Block, GPT2Model

return {
GPT2Model:
ModulePolicyDescription(attribute_replacement={},
param_replacement=[],
Expand Down Expand Up @@ -125,6 +131,8 @@ def __init__(self) -> None:
super().__init__()

def module_policy(self):
from transformers.models.gpt2.modeling_gpt2 import GPT2LMHeadModel

module_policy = super().module_policy()
addon_module = {
GPT2LMHeadModel:
Expand Down Expand Up @@ -156,6 +164,8 @@ def __init__(self) -> None:
super().__init__()

def module_policy(self):
from transformers.models.gpt2.modeling_gpt2 import GPT2DoubleHeadsModel

module_policy = super().module_policy()
addon_module = {
GPT2DoubleHeadsModel:
Expand Down
12 changes: 9 additions & 3 deletions colossalai/shardformer/policies/llama.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,13 @@
from typing import Dict, Union

import torch.nn as nn
from transformers import LlamaForCausalLM, LlamaForSequenceClassification
from transformers.models.llama.modeling_llama import LlamaDecoderLayer, LlamaModel

from colossalai.shardformer.layer import FusedRMSNorm, Linear1D_Col, Linear1D_Row, VocabParallelEmbedding1D

from .basepolicy import ModulePolicyDescription, Policy, SubModuleReplacementDescription

__all__ = ['LlamaPolicy', 'LlamaForCausalLMPolicy', 'LlamaForSequenceClassificationPolicy']


class LlamaPolicy(Policy):

Expand All @@ -26,7 +26,9 @@ def preprocess(self):
return self.model

def module_policy(self) -> Dict[Union[str, nn.Module], ModulePolicyDescription]:
base_policy = {
from transformers.models.llama.modeling_llama import LlamaDecoderLayer, LlamaModel

return {
LlamaDecoderLayer:
ModulePolicyDescription(
attribute_replacement={
Expand Down Expand Up @@ -109,6 +111,8 @@ def postprocess(self):
class LlamaForCausalLMPolicy(LlamaPolicy):

def module_policy(self):
from transformers import LlamaForCausalLM

policy = super().module_policy()
# add a new item for casual lm
new_item = {
Expand All @@ -128,6 +132,8 @@ def module_policy(self):
class LlamaForSequenceClassificationPolicy(LlamaPolicy):

def module_policy(self):
from transformers import LlamaForSequenceClassification

policy = super().module_policy()

# add a new item for sequence classification
Expand Down
17 changes: 9 additions & 8 deletions colossalai/shardformer/policies/opt.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,12 @@
from transformers.models.opt.modeling_opt import (
OPTAttention,
OPTDecoder,
OPTDecoderLayer,
OPTForCausalLM,
OPTForSequenceClassification,
)

from colossalai.shardformer.layer import Embedding1D, FusedLayerNorm, Linear1D_Col, Linear1D_Row

from .basepolicy import ModulePolicyDescription, Policy, SubModuleReplacementDescription

__all__ = [
'OPTPolicy', 'OPTModelPolicy', 'OPTForCausalLMPolicy', 'OPTForSequenceClassificationPolicy',
'OPTForQuestionAnsweringPolicy'
]


class OPTPolicy(Policy):

Expand All @@ -29,6 +26,8 @@ def preprocess(self):
return self.model

def module_policy(self):
from transformers.models.opt.modeling_opt import OPTAttention, OPTDecoder, OPTDecoderLayer

base_policy = {
OPTDecoder:
ModulePolicyDescription(attribute_replacement={},
Expand Down Expand Up @@ -111,6 +110,8 @@ def __init__(self) -> None:
class OPTForCausalLMPolicy(OPTPolicy):

def module_policy(self):
from transformers.models.opt.modeling_opt import OPTForCausalLM

policy = super().module_policy()
new_item = {
OPTForCausalLM:
Expand Down
27 changes: 14 additions & 13 deletions colossalai/shardformer/policies/t5.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,4 @@
from transformers import T5ForConditionalGeneration
from transformers.models.t5.modeling_t5 import (
T5Attention,
T5DenseActDense,
T5DenseGatedActDense,
T5LayerCrossAttention,
T5LayerFF,
T5LayerSelfAttention,
T5Stack,
)

from colossalai.shardformer.layer import DropoutForParallelInput, Embedding1D, FusedRMSNorm, Linear1D_Col, Linear1D_Row
from colossalai.shardformer.layer import DropoutForParallelInput, Embedding1D, Linear1D_Col, Linear1D_Row

from .basepolicy import ModulePolicyDescription, Policy, SubModuleReplacementDescription

Expand All @@ -34,7 +23,17 @@ def preprocess(self):
return self.model

def module_policy(self):
base_policy = {
from transformers.models.t5.modeling_t5 import (
T5Attention,
T5DenseActDense,
T5DenseGatedActDense,
T5LayerCrossAttention,
T5LayerFF,
T5LayerSelfAttention,
T5Stack,
)

return {
T5Stack:
ModulePolicyDescription(attribute_replacement={},
param_replacement=[],
Expand Down Expand Up @@ -165,6 +164,8 @@ def postprocess(self):
class T5ForConditionalGenerationPolicy(T5ModelPolicy):

def module_policy(self):
from transformers import T5ForConditionalGeneration

policy = super().module_policy()

new_item = {
Expand Down
7 changes: 5 additions & 2 deletions colossalai/shardformer/policies/vit.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,13 @@
from typing import Dict, Union

import torch.nn as nn
from transformers.models.vit.modeling_vit import ViTAttention, ViTEmbeddings, ViTLayer, ViTModel

from colossalai.shardformer.layer import DropoutForReplicatedInput, FusedLayerNorm, Linear1D_Col, Linear1D_Row

from .basepolicy import ModulePolicyDescription, Policy, SubModuleReplacementDescription

__all__ = ['ViTPolicy']


class ViTPolicy(Policy):

Expand All @@ -25,7 +26,9 @@ def preprocess(self):
return self.model

def module_policy(self) -> Dict[Union[str, nn.Module], ModulePolicyDescription]:
base_policy = {
from transformers.models.vit.modeling_vit import ViTEmbeddings, ViTLayer

return {
ViTEmbeddings:
ModulePolicyDescription(attribute_replacement={},
param_replacement=[],
Expand Down
18 changes: 17 additions & 1 deletion colossalai/shardformer/shard/shard_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ class ShardConfig:
"""
tensor_parallel_process_group: int = None
enable_fused_normalization: bool = False
enable_all_optimization: bool = False

# TODO: add support for tensor parallel
# pipeline_parallel_size: int
Expand All @@ -27,6 +28,21 @@ class ShardConfig:
# inference_only: bool = True
# gather_output: bool = True

@property
def tensor_parallel_size(self):
return self._tensor_parallel_size

def __post_init__(self):
# get the parallel size
self.tensor_parallel_size = dist.get_world_size(self.tensor_parallel_process_group)
self._tensor_parallel_size = dist.get_world_size(self.tensor_parallel_process_group)

# turn on all optimization if all_optimization is set to True
if self.enable_all_optimization:
self._turn_on_all_optimization()

def _turn_on_all_optimization(self):
"""
Turn on all optimization.
"""
# you can add all the optimization flag here
self.fused_layernorm = True