Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion colossalai/shardformer/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
from .shard import ShardConfig, shard_model
from .shard import ShardConfig, ShardFormer
42 changes: 18 additions & 24 deletions colossalai/shardformer/policies/autopolicy.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
import torch.nn as nn

from .basepolicy import Policy


def build_policies():
r"""
Expand All @@ -21,33 +23,25 @@ def build_policies():
auto_policy_dict[BertForSequenceClassification] = BertForSequenceClassificationPolicy
from transformers.models.llama.modeling_llama import LlamaModel

from .llama import LlamaPolicy
auto_policy_dict[LlamaModel] = LlamaPolicy

from transformers import LlamaForSequenceClassification

from .llama import LlamaForSequenceClassificationPolicy
auto_policy_dict[LlamaForSequenceClassification] = LlamaForSequenceClassificationPolicy

from transformers import LlamaForCausalLM

from .llama import LlamaForCausalLMPolicy
auto_policy_dict[LlamaForCausalLM] = LlamaForCausalLMPolicy

from transformers import GPT2Model

from .gpt2 import GPT2Policy
auto_policy_dict[GPT2Model] = GPT2Policy

from transformers import GPT2LMHeadModel

from .gpt2 import GPT2LMHeadModelPolicy
auto_policy_dict[GPT2LMHeadModel] = GPT2LMHeadModelPolicy
# from .llama import LlamaPolicy
# auto_policy_dict[LlamaModel] = LlamaPolicy
# from transformers import LlamaForSequenceClassification
# from .llama import LlamaForSequenceClassificationPolicy
# auto_policy_dict[LlamaForSequenceClassification] = LlamaForSequenceClassificationPolicy
# from transformers import LlamaForCausalLM
# from .llama import LlamaForCausalLMPolicy
# auto_policy_dict[LlamaForCausalLM] = LlamaForCausalLMPolicy
# from transformers import GPT2Model
# from .gpt2 import GPT2Policy
# auto_policy_dict[GPT2Model] = GPT2Policy
# from transformers import GPT2LMHeadModel
# from .gpt2 import GPT2LMHeadModelPolicy
# auto_policy_dict[GPT2LMHeadModel] = GPT2LMHeadModelPolicy

return auto_policy_dict


def get_autopolicy(model: nn.Module):
def get_autopolicy(model: nn.Module) -> Policy:
r"""
Return the auto policy for the model

Expand All @@ -63,7 +57,7 @@ def get_autopolicy(model: nn.Module):
raise NotImplementedError(
f"Auto policy for {model.__class__.__qualname__} is not implemented\n Supported models are {[i.__qualname__ for i in auto_policy_dict.keys()]}"
)
return policy
return policy()


# from transformers.models.bert.modeling_bert import BertForMaskedLM, BertForPreTraining
Expand Down
245 changes: 79 additions & 166 deletions colossalai/shardformer/policies/basepolicy.py
Original file line number Diff line number Diff line change
@@ -1,90 +1,65 @@
# part of code modified from https://github.com/tunib-ai/parallelformers

from abc import ABC, abstractmethod
from dataclasses import dataclass
from typing import Any, Callable, Dict, List, Tuple, Union
from typing import Any, Callable, Dict, List, Tuple, Type, Union

import torch.nn as nn

from ..shard.shard_config import ShardConfig

@dataclass
class Argument:
r"""
The argument class for the policy

Args:
attr_dict (Dict[str, Any]): The dict for the param setting
param_funcs (:class:`List[Callable]`): The list for the param functions
"""
attr_dict: Dict[str, Any]
param_funcs: List[Callable]

class ParallelModule():

@dataclass
class Layer:
r"""
The layer object for the policy

Args:
suffix: (str): the suffix of the layer.
replace_layer (:class:`colosalai.nn`): The layer to replace the original layer
ignore (bool): Whether to ignore this layer if it is not in the model
reversed (bool): Whether the weight in layer is reversed, commonly the weight in `torch.nn.Linear` is [out, in],
but in GPT2 `Conv1D` layer is [in, out] which is reversed.
n_cast (int): The number of weight will cast to, like q, k, v in attention layer, n_cast should be 3. commonly in TP, we just chunk the weight with the number of devices,
but in multi-head attention, we need to chunk the weight with the number of devices * n_head, and
each device should have a part of Q, K and V weight.
"""
suffix: str = None
replace_layer: Any = None
ignore: bool = False
reversed: bool = False
n_cast: int = None
def __init__(self):
pass


@dataclass
class Col_Layer(Layer):
class SubModuleReplacementDescription:
r"""
Class for col shard layer in tensor parrallel
Describe how a submodule will be replaced

Args:
weight (str): The weight suffix of the layer
bias (str): The bias suffix of the layer
gather_output (bool): Whether to gather the output of the layer
suffix (str): used to get the submodule object
target_module (ParallelModule): specifies the module class used to replace to submodule
kwargs (Dict[str, Any]): the dictionary used to pass extra arguments to the `ParallelModule.from_native_module` method.
"""
weight: str = None
bias: str = None
gather_output: bool = False
suffix: str
target_module: ParallelModule
kwargs: Dict[str, Any] = None


@dataclass
class Row_Layer(Layer):
class ModulePolicyDescription:
r"""
Class for col shard layer in tensor parrallel
Describe how the attributes and parameters will be transformed in a policy

Args:
weight (str): The weight suffix of the layer
bias (str): The bias suffix of the layer
"""
weight: str = None
bias: str = None
attribute_replacement (Dict[str, Any]): key is the attribute name, value is the attribute value after sharding
param_replacement (List[Callable]): a list of functions to perform in-place param replacement. The function
must receive two arguments: module, process_group. One example is

```python
def example_replace_weight(module: torch.nn.Module, process_group):
weight = module.weight
new_weight = shard_rowwise(weight, process_group)
module.weight = torch.nn.Parameter(new_weight)
```

@dataclass
class Dropout_Layer(Layer):
r"""
Class for dropout layer in tensor parrallel

Args:
p (str): The dropout rate suffix of the layer
sub_module_replacement: each element in the list is a ParamReplacementDescription object which specifies
the module to be replaced and the target module used to replacement
"""
p: str = None
attribute_replacement: Dict[str, Any]
param_replacement: List[Callable]
sub_module_replacement: List[SubModuleReplacementDescription]


class Policy():
class Policy(ABC):
r"""
The base class for all the policies

For each different model, it should have a different policy class, like BertPolicy for Bert Model
or OPTPolicy for OPT model.

AutoPolicy:
Shardformer already defined some policies for huggingface model, just set ``custom_policy`` = None
to use the auto policy. In shardformer autopolicy, we define a base policy for one type model,
Expand All @@ -99,137 +74,75 @@ class for the example.

"""

@staticmethod
def argument_policy(model_config, world_size: int) -> Dict[nn.Module, Argument]:
def __init__(self) -> None:
self.model = None

def set_model(self, model: nn.Module) -> None:
r"""
Return the dict for the modify policy, the key is the original layer class and the value is the
argument for the modify layer
Set model as an attribute of the Policy object so that we can access the model's attributes.

Args:
model_config (:class:`tansformer.Config`): The config of transformer model
world_size (int)): The world size of sharding model
model (:class:`nn.Module`): The model to be perform
"""
self.model = model

@abstractmethod
def preprocess(self, shard_config: ShardConfig = None) -> nn.Module:
r"""
Perform some preprocessing of the model, like reshaping the embedding layer
"""

@abstractmethod
def module_policy(self, shard_config: ShardConfig = None) -> Dict[Union[str, nn.Module], ModulePolicyDescription]:
r"""
Return the dict for the modify policy, the key is the original layer class and the value is the
argument for the modify layer

Return:
Dict for the modify policy,
::
{
origin layer class1 (nn.Module): Argument(
attr_dict = {
argument1: value1,
argument2: value2,
origin layer class1 (nn.Module): ModulePolicyDescription(
attribute_replacement = {
"attribute1": value1,
"attribute2": value2,
...
},
param_funcs = [
staticmethod1,
staticmethod2,
param_replacement = [
function1,
function2,
...
]
),
origin layer class2 (nn.Module): Argument(
attr_dict = {
argument1: value1,
argument2: value2,
...
},
param_funcs = [
staticmethod1,
staticmethod2,
],
sub_module_replacement = [
Comment thread
FrankLeeeee marked this conversation as resolved.
`SubModuleReplacementDescription` description1,
`SubModuleReplacementDescription` description2,
...
]
),
origin layer class2 (nn.Module): ModulePolicyDescription(
...
),
...
}

"""
raise NotImplementedError

@staticmethod
def inject_policy() -> Union[Tuple[nn.Module, nn.Module], None]:
@abstractmethod
def new_model_class(self) -> Union[Type[nn.Module], None]:
r"""
Return the dict for the inject model
Return the new model class for the new model, None means no need to modify the model class

Return:
The injected model, key is the original model and value is the new shardmodel
::
(OrignModel, CustomModel)
in `CustomModel`, we can overwrite the forward and backward process
"""
return None

@staticmethod
def binding_policy() -> Union[Dict[str, str], None]:
r"""
Return the dict for the binding model, None means no need to bind
New model class

Return:
This method should return the binding relationship for some layers share the weight or bias,
the key and value is the suffix of the weight or bias of the model
::
return {
"bert.embeddings.word_embeddings.weight": "cls.predictions.decoder.weight",
}
E.g.
```
return BertModel_
```
"""
return None

@staticmethod
def attn_in() -> Union[List, None]:
@abstractmethod
def postprocess(self) -> nn.Module:
r"""
Attention qkv layer
In this kind of method, we should return the list of ``Layer`` object, each ``Layer`` object should be
``Layer`` for no slicing, ``Col_Layer`` for col slicing, ``Row_Layer`` for row slicing. And the parameters
in ``Layer`` object can refer to the ``Layer`` class.

Returns:
List[Layer]: List of layer object, each layer is the new
"""
return None

@staticmethod
def attn_out() -> Union[List, None]:
r"""
Attention output projection layer

Returns:
List[Layer]: List of layer object
"""
return None

@staticmethod
def mlp_in() -> Union[List, None]:
r"""
h -> 4h mlp layer

Returns:
List[Layer]: List of layer object
"""
return None

@staticmethod
def mlp_out() -> Union[List, None]:
r"""
4h -> h mlp layer

Returns:
List[Layer]: List of layer object
"""
return None

@staticmethod
def embedding() -> Union[List, None]:
r"""
Partially slice the embedding layer

Return:
List[Layer]: List of layer object
"""
return None

@staticmethod
def unembedding() -> Union[List, None]:
r"""
Partially slice the embedding layer, None means there is no unembedding layer

Return:
List[Layer]: List of layer object
Perform some postprocessing of the model, like binding the weight of embedding layer with
the classifier layer
"""
return None
Loading