Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 12 additions & 2 deletions colossalai/shardformer/policies/autopolicy.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,16 +10,26 @@ def build_policies():
"""
auto_policy_dict = {}

from transformers.models.bert.modeling_bert import BertForMaskedLM
from transformers import BertForMaskedLM

from .bert import BertForMaskedLMPolicy
auto_policy_dict[BertForMaskedLM] = BertForMaskedLMPolicy

from transformers.models.bert.modeling_bert import BertForSequenceClassification
from transformers import BertForSequenceClassification

from .bert import BertForSequenceClassificationPolicy
auto_policy_dict[BertForSequenceClassification] = BertForSequenceClassificationPolicy

from transformers import GPT2Model

from .gpt2 import GPT2Policy
auto_policy_dict[GPT2Model] = GPT2Policy

from transformers import GPT2LMHeadModel

from .gpt2 import GPT2LMHeadModelPolicy
auto_policy_dict[GPT2LMHeadModel] = GPT2LMHeadModelPolicy

return auto_policy_dict


Expand Down
17 changes: 11 additions & 6 deletions colossalai/shardformer/policies/basepolicy.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,9 @@
# part of code modified from https://github.com/tunib-ai/parallelformers

from dataclasses import dataclass, field
from dataclasses import dataclass
from typing import Any, Callable, Dict, List, Tuple, Type

import torch
import torch.nn as nn
from transformers import AutoConfig


@dataclass
Expand All @@ -31,11 +29,18 @@ class Layer:
bias (str): The bias suffix of the layer
replace_layer (:class:`colosalai.nn`): The layer to replace the original layer
ignore (bool): Whether to ignore this layer if it is not in the model
reversed (bool): Whether the weight in layer is reversed, commonly the weight in `torch.nn.Linear` is [out, in],
but in GPT2 `Conv1D` layer is [in, out] which is reversed.
n_cast (int): The number of weight will cast to, like q, k, v in attention layer, n_cast should be 3. commonly in TP, we just chunk the weight with the number of devices,
but in multi-head attention, we need to chunk the weight with the number of devices * n_head, and
each device should have a part of Q, K and V weight.
"""
weight: str = None
bias: str = None
replace_layer: Any = None
ignore: bool = False
reversed: bool = False
n_cast: int = None


@dataclass
Expand Down Expand Up @@ -131,7 +136,7 @@ def inject_policy() -> Tuple[nn.Module, nn.Module]:
(OrignModel, CustomModel)
in `CustomModel`, we can overwrite the forward and backward process
"""
return ()
return None

@staticmethod
def binding_policy() -> Dict:
Expand All @@ -146,7 +151,7 @@ def binding_policy() -> Dict:
"bert.embeddings.word_embeddings.weight": "cls.predictions.decoder.weight",
}
"""
return NotImplementedError
return None

@staticmethod
def attn_in() -> List:
Expand Down Expand Up @@ -209,4 +214,4 @@ def unembedding() -> List:
Return:
List[Layer]: List of layer object
"""
return NotImplementedError
return None
1 change: 0 additions & 1 deletion colossalai/shardformer/policies/bert.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
from dataclasses import dataclass
from typing import Any, Callable, Dict, List, Tuple, Type

import torch.nn as nn
Expand Down
118 changes: 118 additions & 0 deletions colossalai/shardformer/policies/gpt2.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,118 @@
from typing import Any, Callable, Dict, List, Tuple, Type

import torch.nn as nn
from transformers.models.gpt2.modeling_gpt2 import GPT2Block, GPT2Model

import colossalai.shardformer.layer.layers as col_nn

from .basepolicy import Argument, Col_Layer, Layer, Policy, Row_Layer


class GPT2Policy(Policy):

@staticmethod
def argument_policy(config, world_size):
return {
GPT2Model:
Argument(attr_dict={}, param_funcs=[
GPT2Policy.embedding,
]),
GPT2Block:
Argument(
attr_dict={
# 1. reduce hidden size
"attn.embed_dim": config.hidden_size // world_size,
"attn.split_size": config.hidden_size // world_size,
"crossattention.embed_dim": config.hidden_size // world_size,
"crossattention.split_size": config.hidden_size // world_size,
# 2. reduce number of heads
"attn.num_heads": config.num_attention_heads // world_size,
"crossattention.num_heads": config.num_attention_heads // world_size,
},
param_funcs=[
GPT2Policy.attn_in,
GPT2Policy.attn_out,
GPT2Policy.mlp_in,
GPT2Policy.mlp_out,
]),
}

@staticmethod
def attn_in() -> List:
return [
Col_Layer(weight="attn.c_attn.weight",
bias="attn.c_attn.bias",
n_cast=3,
reversed=True,
replace_layer=col_nn.Linear1D_Col),
Col_Layer(weight="crossattention.c_attn.weight",
bias="crossattention.c_attn.bias",
n_cast=2,
reversed=True,
ignore=True,
replace_layer=col_nn.Linear1D_Col),
Col_Layer(weight="crossattention.q_attn.weight",
bias="crossattention.q_attn.bias",
reversed=True,
ignore=True,
replace_layer=col_nn.Linear1D_Col)
]

@staticmethod
def attn_out() -> List:
return [
Row_Layer(weight="attn.c_proj.weight",
bias="attn.c_proj.bias",
reversed=True,
replace_layer=col_nn.Linear1D_Row),
Row_Layer(weight="crossattention.c_proj.weight",
bias="crossattention.c_proj.bias",
reversed=True,
ignore=True,
replace_layer=col_nn.Linear1D_Row)
]

@staticmethod
def mlp_in() -> List:
return [
Col_Layer(weight="mlp.c_fc.weight", bias="mlp.c_fc.bias", reversed=True, replace_layer=col_nn.Linear1D_Col),
]

@staticmethod
def mlp_out() -> List:
return [
Row_Layer(weight="mlp.c_proj.weight",
bias="mlp.c_proj.bias",
reversed=True,
replace_layer=col_nn.Linear1D_Row)
]

@staticmethod
def embedding() -> List:
return [Col_Layer(weight="wte.weight", replace_layer=col_nn.VocabParallelEmbedding1D)]


from transformers import GPT2LMHeadModel


class GPT2LMHeadModelPolicy(GPT2Policy):

@staticmethod
def argument_policy(config, world_size):
base_argument = GPT2Policy.argument_policy(config, world_size)
argument = {
GPT2LMHeadModel: Argument(attr_dict={}, param_funcs=[
GPT2LMHeadModelPolicy.unembedding,
]),
}
argument.update(base_argument)
return argument

@staticmethod
def unembedding() -> List:
return [
Col_Layer(weight="lm_head.weight",
bias="lm_head.bias",
replace_layer=col_nn.Linear1D_Col,
gather_output=True)
]
46 changes: 28 additions & 18 deletions colossalai/shardformer/shard/sharder.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

import torch
import torch.nn as nn
from transformers.pytorch_utils import Conv1D

from ..policies.autopolicy import get_autopolicy
from ..policies.basepolicy import Policy
Expand Down Expand Up @@ -35,10 +36,22 @@ def __init__(
self.model_config = self.model.config

def shard(self) -> None:
self.reshape_embedding()
self.inject_model(self.model)
self.replace_layer(self.model)
self.bind_layer(self.model)

def reshape_embedding(self,) -> None:
r"""
Reshape the Embedding layer to make the embedding dimension divisible by world_size
"""
vocab_size = self.model_config.vocab_size
world_size = self.shard_config.world_size
if vocab_size % world_size != 0:
new_vocab_size = vocab_size + world_size - vocab_size % world_size
self.model.resize_token_embeddings(new_vocab_size)
self.model_config = self.model.config

def inject_model(
self,
model: nn.Module,
Expand All @@ -53,6 +66,8 @@ def inject_model(
"""
inject_policy = self.policy.inject_policy()

if inject_policy is None:
return
org_model_cls = inject_policy[0]
shard_model_cls = inject_policy[1]

Expand Down Expand Up @@ -82,9 +97,9 @@ def replace_layer(
origin_layer_cls = argument_policy[0]
attr_dict = argument_policy[1].attr_dict
param_funcs = argument_policy[1].param_funcs
self.reverse_replace_layer(model, origin_layer_cls, attr_dict, param_funcs)
self.traverse_replace_layer(model, origin_layer_cls, attr_dict, param_funcs)

def reverse_replace_layer(
def traverse_replace_layer(
self,
layer: nn.Module,
origin_cls: nn.Module,
Expand All @@ -100,17 +115,12 @@ def reverse_replace_layer(
attr_dict (Dict): The attribute dict to modify
policy_cls (:class:`Policy`): The policy class
"""
if layer.__class__ == origin_cls:
for k, v in attr_dict.items():
setattr_(layer, k, v, ignore=True)
self.shard_one_layer(layer, param_funcs)
for name, child in layer.named_children():
if child.__class__ == origin_cls:
# replac_layer = child
for k, v in attr_dict.items():
setattr_(child, k, v, ignore=True)
# print(f"Sharding {name} layer", replac_layer.attention.self.__dict__)
# setattr_(layer, name, self.shard_one_layer(child, policy_cls))
self.shard_one_layer(child, param_funcs)
continue

self.reverse_replace_layer(child, origin_cls, attr_dict, param_funcs)
self.traverse_replace_layer(child, origin_cls, attr_dict, param_funcs)
return layer

def shard_one_layer(
Expand All @@ -126,7 +136,6 @@ def shard_one_layer(
param_funcs (:class:`List[typing.Callable]`): The function list to get shard information in policy class

"""
# print(org_layer)
for func in param_funcs:
policy_layers = func()
for policy_layer in policy_layers:
Expand All @@ -136,9 +145,10 @@ def shard_one_layer(
bias_attr = policy_layer.bias
replace_layer_cls = policy_layer.replace_layer
ignore = policy_layer.ignore
n_cast = policy_layer.n_cast
reversed = policy_layer.reversed
if policy_layer.__class__.__name__ == "Col_Layer":
gather_output = policy_layer.gather_output
# print(gather_output)

if weight_attr is not None:
if hasattr_(org_layer, weight_attr):
Expand All @@ -161,13 +171,11 @@ def shard_one_layer(
layer_attr = (lambda x: x[:x.rfind(".")])(weight_attr or bias_attr)

# slice weight and bias
weight, bias = self.slicer.slice_weight_bias(weight, bias, policy_layer.__class__)
# print(os.environ['RANK'], policy_layer.__class__, weight.shape, bias.shape if bias is not None else None)
weight, bias = self.slicer.slice_weight_bias(weight, bias, policy_layer.__class__, n_cast, reversed)

# create new object to replace the origin layer
if replace_layer_cls is not None:
# print(f"RANK {os.environ['RANK']}: replace {getattr_(org_layer, layer_attr).__class__} to {replace_layer_cls}, shape is {weight.shape}")
if isinstance(getattr_(org_layer, layer_attr), nn.Linear):
if isinstance(getattr_(org_layer, layer_attr), (nn.Linear, Conv1D)):
if replace_layer_cls.__name__ == "Linear1D_Row":
replace_layer = replace_layer_cls(weight.shape[1],
weight.shape[0],
Expand Down Expand Up @@ -235,6 +243,8 @@ def bind_layer(self, model: nn.Module) -> None:
model (:class:`torch.nn.Module`): The shard model
"""
binding_map = self.policy.binding_policy()
if binding_map is None:
return
for k, v in binding_map.items():
param = getattr_(model, k)
param = nn.Parameter(param)
Expand Down
Loading