Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
46 changes: 32 additions & 14 deletions colossalai/shardformer/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -55,30 +55,37 @@ If you wanna parallel the model in a custom way, just overwrite the policy class
You should do:

1. Inherit Policy class
2. Overwrite argument_policy method
- In this method, you need to list which layers class you wanna modify and the attributes and parameters in those layers.
3. Overwrite inject_policy method (Optional)
- If you need to modify the forward or backward progress.
4. Overwrite or add the param recording functions
2. Overwrite `argument_policy` method
- In this method, you need to list which layers class you wanna modify and the attributes and parameters in those layers. Shardformer will replace all the layer belonging to the class you specified.
- `attr_dict` is dict contains all the attributes need to be modified in this layer.
- `param_funcs` is a list contains some functions which will return the path of the weight and bias from the layer.
3. Overwrite `inject_policy` method (Optional)
- Shardformer will inject the model according to this method. If you need to modify the forward or backward progress (like distributed corssentropy loss in Bert) you need to overwrite this method.
4. Overwrite or add the param functions
- These functions use a suffix to record the path of weight or bias for the layer.
5. Overwrite binding
- The return is a list contains some `Col_Layer` or `Row_Layer` objects, which means slice along col and row respectively.
5. Overwrite `binding_policy` (Optional)
- Overwrite to specify Shardformer will bind some weight between layers, like embedding and unembedding layers.
- This function will return a dict, the key and value are the suffix of weight need to be binded.

More details can be found in shardformer/policies/basepolicy.py
``` python
from colossalai.shardformer.policies.basepolicy import Policy, Layer, Col_Layer, Row_Layer, Argument

CustomPolicy(Policy):
@staticmethod
def argument_policy(model_config, shard_config: int) -> Dict[nn.Module,Argument]:
"""
Return a dict, the key is layer will be modified and the value is the Argument class with param setting and param functions
@staticmethod
def argument_policy(model_config, shard_config: int) -> Dict[nn.Module, Argument]:
r"""
Return the dict for the modify policy, the key is the original layer class and the value is the
argument for the modify layer

Args:
model_config: The config of transformer model
shard_setting: The config of distributed model
model_config (:class:`tansformer.Config`): The config of transformer model
shard_config (:class:`ShardConfig`): The config for sharding model

Return:
Dict for the modify policy,
::
{
origin layer class1 (nn.Module): Argument(
attr_dict = {
Expand Down Expand Up @@ -112,18 +119,29 @@ CustomPolicy(Policy):

@staticmethod
def inject_policy() -> Tuple[nn.Module, nn.Module]:
"""
r"""
Return the dict for the inject model

Return:
The injected model, key is the original model and value is the new shardmodel
::
(OrignModel, CustomModel)
in `CustomModel`, we can overwrite the forward and backward process
"""
return ()

@staticmethod
def binding_policy() -> Dict:
"""
r"""
Return the dict for the binding model

Return:
This method should return the binding relationship for some layers share the weight or bias,
the key and value is the suffix of the weight or bias of the model
::
return {
"bert.embeddings.word_embeddings.weight": "cls.predictions.decoder.weight",
}
"""
return NotImplementedError

Expand Down