Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Binary file added applications/Chat/coati/trainer/.sft.py.swp
Binary file not shown.
8 changes: 6 additions & 2 deletions colossalai/shardformer/layer/layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -770,6 +770,7 @@ def __init__(self,
embedding_dim: int,
padding_idx: int = None,
dtype: torch.dtype = None,
gather_output: bool = True,
weight_initializer: Callable = init.normal_(),
*args,
**kwargs):
Expand All @@ -782,6 +783,7 @@ def __init__(self,
self.padding_idx = padding_idx
self.embed_args = args
self.embed_kwargs = kwargs
self.gather_output = gather_output

self.weight = Parameter(
torch.empty((num_embeddings, embed_dim_per_partition), device=get_current_device(), dtype=dtype))
Expand Down Expand Up @@ -832,8 +834,10 @@ def _save_to_global_state_dict(self, destination, prefix, keep_vars):
def forward(self, input_: Tensor) -> Tensor:

output_parallel = F.embedding(input_, self.weight, self.padding_idx, *self.embed_args, **self.embed_kwargs)

output = gather_forward_split_backward(output_parallel, ParallelMode.PARALLEL_1D, dim=-1)
if self.gather_output:
output = gather_forward_split_backward(output_parallel, ParallelMode.PARALLEL_1D, dim=-1)
else:
output = output_parallel

return output

Expand Down
9 changes: 9 additions & 0 deletions colossalai/shardformer/policies/autopolicy.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,15 @@ def build_policies():

from .gpt2 import GPT2LMHeadModelPolicy
auto_policy_dict[GPT2LMHeadModel] = GPT2LMHeadModelPolicy

from .t5 import T5ForConditionalGenerationPolicy, T5EncoderModelPolicy, T5ModelPolicy
from transformers import T5ForConditionalGeneration, T5EncoderModel, T5Model
t5 = {
T5ForConditionalGeneration: T5ForConditionalGenerationPolicy,
T5EncoderModel: T5EncoderModelPolicy,
T5Model: T5ModelPolicy,
}
auto_policy_dict.update(t5)

return auto_policy_dict

Expand Down
12 changes: 12 additions & 0 deletions colossalai/shardformer/policies/basepolicy.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,18 @@ class Dropout_Layer(Layer):
p: str = None


@dataclass
class Embedding_Layer(Layer):
r"""
Class for col shard layer in tensor parrallel

Args:
weight (str): The weight suffix of the layer
"""
weight: str = None
gather_output: bool = True


class Policy():
r"""
The base class for all the policies
Expand Down
159 changes: 159 additions & 0 deletions colossalai/shardformer/policies/t5.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,159 @@
from typing import Dict

import torch.nn as nn
from torch.nn import Embedding
from transformers.models.t5.modeling_t5 import (
T5Attention,
T5Block,
T5DenseActDense,
T5DenseGatedActDense,
T5LayerCrossAttention,
T5LayerFF,
T5LayerSelfAttention,
T5Model,
T5Stack,
)

import colossalai.shardformer.layer.layers as col_nn

from .basepolicy import Argument, Col_Layer, Dropout_Layer, Embedding_Layer, Policy, Row_Layer


class T5ModelPolicy(Policy):

@staticmethod
def argument_policy(config, world_size: int) -> Dict[nn.Module, Argument]:
print('config heads', config.num_heads)
return {
T5Stack:
Argument(attr_dict={}, param_funcs=[T5ModelPolicy.dropout, T5ModelPolicy.embedding]),
T5Block:
Argument(attr_dict={}, param_funcs=[]),
T5LayerSelfAttention:
Argument(attr_dict={}, param_funcs=[T5ModelPolicy.dropout]),
T5LayerCrossAttention:
Argument(attr_dict={}, param_funcs=[T5ModelPolicy.dropout]),
T5Attention:
Argument(attr_dict={
"d_model": config.d_model // world_size,
"n_heads": config.num_heads // world_size,
"inner_dim": config.num_heads * config.d_kv // world_size,
},
param_funcs=[T5ModelPolicy.attn_layer]),
T5LayerFF:
Argument(attr_dict={}, param_funcs=[T5ModelPolicy.dropout]),
T5DenseGatedActDense:
Argument(attr_dict={}, param_funcs=[T5ModelPolicy.dropout, T5ModelPolicy.dense_gated_layer]),
T5DenseActDense:
Argument(attr_dict={}, param_funcs=[T5ModelPolicy.dropout, T5ModelPolicy.dense_act_layer]),
}

@staticmethod
def dense_gated_layer():
return [
Col_Layer(
suffix="wi_0",
weight="weight",
replace_layer=col_nn.Linear1D_Col,
),
Row_Layer(
suffix="wi_1",
weight="weight",
replace_layer=col_nn.Linear1D_Row,
),
Col_Layer(suffix="wo", weight="weight", replace_layer=col_nn.Linear1D_Col, gather_output=True)
]

@staticmethod
def dense_act_layer():
return [
Col_Layer(
suffix="wi",
weight="weight",
replace_layer=col_nn.Linear1D_Col,
),
Row_Layer(
suffix="wo",
weight="weight",
replace_layer=col_nn.Linear1D_Row,
)
]

@staticmethod
def attn_layer():
return [
Col_Layer(
suffix="q",
weight="weight",
bias="bias",
replace_layer=col_nn.Linear1D_Col,
),
Col_Layer(
suffix="k",
weight="weight",
bias="bias",
replace_layer=col_nn.Linear1D_Col,
),
Col_Layer(
suffix="v",
weight="weight",
bias="bias",
replace_layer=col_nn.Linear1D_Col,
),
Row_Layer(
suffix="o",
weight="weight",
bias="bias",
replace_layer=col_nn.Linear1D_Row,
),
]

@staticmethod
def dropout():
return [Dropout_Layer(
suffix="dropout",
p="p",
replace_layer=col_nn.Dropout1D,
)]

@staticmethod
def embedding():
return [
Embedding_Layer(
suffix="block[0].layer[0].SelfAttention.relative_attention_bias",
weight="weight",
replace_layer=col_nn.Embedding1D,
gather_output=False,
)
]


from transformers import T5ForConditionalGeneration


class T5ForConditionalGenerationPolicy(T5ModelPolicy):

@staticmethod
def argument_policy(config, world_size):
base_argument = T5ModelPolicy.argument_policy(config, world_size)
argument = {
T5ForConditionalGeneration: Argument(attr_dict={}, param_funcs=[T5ForConditionalGenerationPolicy.lm_head])
}
argument.update(base_argument)
return argument

@staticmethod
def lm_head():
return [Col_Layer(
suffix="lm_head",
weight="weight",
replace_layer=col_nn.Linear1D_Col,
gather_output=True,
)]


from transformers import T5EncoderModel


class T5EncoderModelPolicy(T5ModelPolicy):
pass
11 changes: 8 additions & 3 deletions colossalai/shardformer/shard/sharder.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from transformers.pytorch_utils import Conv1D

from ..policies.autopolicy import get_autopolicy
from ..policies.basepolicy import Col_Layer, Dropout_Layer, Policy, Row_Layer
from ..policies.basepolicy import Col_Layer, Dropout_Layer, Policy, Row_Layer, Embedding_Layer
from ..utils.utils import getattr_, hasattr_, setattr_
from .shard_config import ShardConfig
from .slicer import Slicer
Expand Down Expand Up @@ -155,11 +155,11 @@ def shard_one_layer(
assert suffix_layer is not None or ignore, f"Layer {org_layer.__class__.__qualname__} has no attribute {suffix}"
if suffix_layer is None and ignore:
continue
if isinstance(policy_layer, (Col_Layer, Row_Layer)):
if isinstance(policy_layer, (Col_Layer, Row_Layer, Embedding_Layer)):
weight = None
bias = None
weight_attr = suffix + '.' + policy_layer.weight if policy_layer.weight is not None else None
bias_attr = suffix + '.' + policy_layer.bias if policy_layer.bias is not None else None
bias_attr = suffix + '.' + policy_layer.bias if hasattr(policy_layer, 'bias') and policy_layer.bias is not None else None

if weight_attr is not None:
if hasattr_(org_layer, weight_attr):
Expand Down Expand Up @@ -189,6 +189,11 @@ def shard_one_layer(
weight.shape[1],
bias=False if bias is None else True,
gather_output=gather_output)
elif replace_layer_cls.__name__ == "Embedding1D":
gather_output = policy_layer.gather_output
replace_layer = replace_layer_cls(weight.shape[0],
weight.shape[1],
gather_output=gather_output)
elif replace_layer_cls.__name__ == "VocabParallelEmbedding1D":
replace_layer = replace_layer_cls(weight.shape[0], weight.shape[1],
getattr_(org_layer, f"{suffix}.padding_idx", ignore=True))
Expand Down
6 changes: 4 additions & 2 deletions colossalai/shardformer/shard/slicer.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
import torch

from ..policies.basepolicy import Col_Layer, Dropout_Layer, Layer, Row_Layer
from ..policies.basepolicy import Col_Layer, Dropout_Layer, Layer, Row_Layer, Embedding_Layer
from .shard_config import ShardConfig

dim_mapping = {Col_Layer: 0, Row_Layer: 1}
dim_mapping = {Col_Layer: 0, Row_Layer: 1, Embedding_Layer: 1}


class Slicer():
Expand Down Expand Up @@ -43,6 +43,8 @@ def slice_weight_bias(
bias = self.slice_tensor(bias, 0, True, n_cast)
elif policy_layer_cls == Row_Layer:
weight = self.slice_tensor(weight, dim, False, n_cast)
elif policy_layer_cls == Embedding_Layer:
weight = self.slice_tensor(weight, dim, False, n_cast)
else:
raise NotImplementedError(f"The policy layer class {policy_layer_cls} is not supported")
if reversed:
Expand Down
25 changes: 22 additions & 3 deletions colossalai/shardformer/utils/utils.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,22 @@
import re


def get_obj_list_element(obj, a):
re_pattern = r'\[\d+\]'
prog = re.compile(re_pattern)
result = prog.search(a)
if result:
matched_brackets = result.group()
matched_index = matched_brackets.replace('[', '')
matched_index = matched_index.replace(']', '')
a_ = a.replace(matched_brackets, '')
container_obj = getattr(obj, a_)
obj = container_obj[int(matched_index)]
else:
obj = getattr(obj, a)
return obj


def hasattr_(obj, attr: str):
r"""
Check whether the object has the multi sublevel attr
Expand All @@ -9,7 +28,7 @@ def hasattr_(obj, attr: str):
attrs = attr.split('.')
for a in attrs:
try:
obj = getattr(obj, a)
obj = get_obj_list_element(obj, a)
except AttributeError:
return False
return True
Expand All @@ -29,7 +48,7 @@ def setattr_(obj, attr: str, value, ignore: bool = False):
attrs = attr.split('.')
for a in attrs[:-1]:
try:
obj = getattr(obj, a)
obj = get_obj_list_element(obj, a)
except AttributeError:
if ignore:
return
Expand All @@ -50,7 +69,7 @@ def getattr_(obj, attr: str, ignore: bool = False):
attrs = attr.split('.')
for a in attrs:
try:
obj = getattr(obj, a)
obj = get_obj_list_element(obj, a)
except AttributeError:
if ignore:
return None
Expand Down
1 change: 1 addition & 0 deletions requirements/requirements-test.txt
Original file line number Diff line number Diff line change
Expand Up @@ -15,3 +15,4 @@ einops
triton==2.0.0.dev20221202
git+https://github.com/HazyResearch/flash-attention.git@c422fee3776eb3ea24e011ef641fd5fbeb212623#egg=flash_attn
requests==2.27.1 # downgrade to avoid huggingface error https://github.com/huggingface/transformers/issues/17611
SentencePiece
Loading