Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
144 changes: 92 additions & 52 deletions colossalai/shardformer/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,12 @@
- [Policy](#policy)
- [Model Sharder](#model-sharder)
- [User-facing API](#user-facing-api)
- [Shardformer Convergence](#shardformer-convergence)
- [⌨️ Development Notes](#️-development-notes)
- [Add New Policy to Shardformer](#add-new-policy-to-shardformer)
- [Write Your Unit Testing](#write-your-unit-testing)
- [📊 Benchmarking](#-benchmarking)
- [System Performance](#system-performance)
- [Convergence](#convergence)


## 🔗 Introduction
Expand All @@ -40,12 +45,9 @@ config = BertConfig.from_pretrained('bert-base-uncased')
model = BertForMaskedLM.from_pretrained('bert-base-uncased', config=config)

# create huggingface model as normal
shard_config = ShardConfig(tensor_parallel_size=2,
data_parallel_size=1,
gather_output=True)
shard_config = ShardConfig()
shard_former = ShardFormer(shard_config=shard_config)
shard_former.init_distributed()
sharded_model = shard_former.shard_model(model).to('cuda')
sharded_model = shard_former.optimize(model).to('cuda')

# do everything like normal
...
Expand All @@ -67,10 +69,11 @@ class MyPolicy(Policy):

# use customized policy to shard model
my_policy = MyPolicy()
shard_former.shard_model(model, my_policy)
shard_former.optimize(model, my_policy)


```

```
## 🗺 Roadmap

We will follow this roadmap to develop Shardformer:
Expand Down Expand Up @@ -112,7 +115,6 @@ Please refer to the code for more details.
<p align="center">
<img src="https://raw.githubusercontent.com/hpcaitech/public_assets/main/colossalai/img/shardformer/shardformer_flowchart.png" width="600" />
<br/>
<b>This diagram is deprecated, need to update it</b>
</p>


Expand Down Expand Up @@ -147,15 +149,13 @@ class ParallelModule(torch.nn.Module):
```python
@dataclass
class ShardConfig:
data_parallel_size: int
tensor_parallel_size: int
tensor_parallel_process_group: ProcessGroup = None
enable_fused_normalization: bool = False
...

# Some possible future config fields
pipeline_parallel_size: int # Support pipeline parallelism
tensor_parallel_mode: Choice['1d', '2d', '2.5d', '3d'] # support different tensor parallel mode
inference_only: bool # only inject inference-suitable sharding policy
gather_output: bool # gather the model output
use_flash_attention: bool # whether to use flash attention to speed up attention
```

Expand All @@ -166,42 +166,42 @@ It is merely a description, the actual sharding will be performed by `ModelShard
We abstract the policy into four stages:

1. Preprocessing: call `Policy.preprocess` to do some prior work before sharding, for example, resizing the embedding
2. Providing a new class: call `Policy.new_model_class` to get a new class for the model, this class replaces attributes and the forward function
3. Providing `ModulePolicyDescription`: call `Policy.module_policy` to get a bunch of `ModulePolicyDescription` to tell `ModelSharder` how the submodules's attributes, child parameters, and deeper submodules will be substituted.
4. Postprocessing: call `Policy.postprocess` to perform some postprocessing work, for example, binding the embedding and classifier head weights of the BERT model.
2. Providing `ModulePolicyDescription`: call `Policy.module_policy` to get a bunch of `ModulePolicyDescription` to tell `ModelSharder` how the submodules's attributes, child parameters, and deeper submodules will be substituted.
3. Postprocessing: call `Policy.postprocess` to perform some postprocessing work, for example, binding the embedding and classifier head weights of the BERT model.

``` python
@dataclass
class ModulePolicyDescription:
"""
Describe how the attributes and parameters will be transformed in a policy
r"""
Describe how the attributes and parameters will be transformed in a policy.

Args:
attribute_replacement (Dict[str, Any]): key is the attribute name, value is the attribute value after sharding
param_replacement (List[Callable]): a list of functions to perform in-place param replacement. The function must receive two arguments: module, process_group. One example is
def example_replace_weight(module: torch.nn.Module, process_group):
weight = module.weight
new_weight = shard_rowwise(weight, process_group)
module.weight = torch.nn.Parameter(new_weight)
sub_module_replacement: each element in the list is a ParamReplacementDescription object which specifies the module to be replaced and the target module used to replacement
param_replacement (List[Callable]): a list of functions to perform in-place param replacement. The function must receive only one arguments: module.
sub_module_replacement (List[SubModuleReplacementDescription]): each element in the list is a ParamReplacementDescription
object which specifies the module to be replaced and the target module used to replacement.
method_replace (Dict[str, Callable]): key is the method name, value is the method for replacement
"""
attribute_replacement: Dict[str, Any]
param_replacement: List[Callable]
sub_module_replacement: List[SubModuleReplacementDescription]
attribute_replacement: Dict[str, Any] = None
param_replacement: List[Callable] = None
sub_module_replacement: List[SubModuleReplacementDescription] = None
method_replacement: Dict[str, Callable] = None

@dataclass
class SubModuleReplacementDescription:
"""
r"""
Describe how a submodule will be replaced

Args:
suffix (str): used to get the submodule object
target_module (ParallelModule): specifies the module class used to replace to submodule
kwargs (Dict[str, Any]): the dictionary used to pass extra arguments to the `ParallelModule.from_native_module` method.
ignore_if_not_exist (bool): if the submodule does not exist, ignore it or raise an exception
"""
suffix: str
target_module: ParallelModule
kwargs: Dict[str, Any] = None
ignore_if_not_exist: bool = False


class Policy(ABC):
Expand Down Expand Up @@ -230,13 +230,6 @@ class Policy(ABC):
"""
...

@abstractmethod
def new_model_class(self) -> Union[Type[nn.Module], None]:
"""
replace the class of the model to substitute the forward and attributes
"""
...

@abstractmethods
def postprocess(self) -> nn.Module:
"""
Expand All @@ -253,24 +246,16 @@ class Policy(ABC):
```python
class ModelSharder:

def __init__(self, model: torch.nn.Module, shard_config: ShardConfig, Policy: ShardPolicy = None)
def __init__(self, model: torch.nn.Module, shard_config: ShardConfig, Policy: ShardPolicy = None):
#TODO: input is a cls or a obj
...

def shard(self) -> None:
"""
Shard model with parallelelism with the help of pre-processing, replace_model_class, replace_module, and post-processing.
"""
...

def replace_model_class(self) -> None:
"""
Replace the model's methods and attributes with our own defined class.

E.g. we can replace the forward function of the original BertForMaskedLM object
with the forward function we define in BertForMaskedLM_ class.
"""
...

def replace_module(self) -> None:
"""
Replace the layer according to the policy. Call Policy.module_policy() to get the module. Call _replace_module recursively.
Expand All @@ -291,7 +276,7 @@ class ShardFormer:

shard_former = ShardFormer(shard_config=shard_config)
shard_former.init_distributed()
model = shard_former.shard_model(model, policy=policy)
model = shard_former.optimize(model, policy=policy)
dataloader = shard_former.shard_dataset(dataset)

"""
Expand Down Expand Up @@ -326,14 +311,69 @@ class ShardFormer:
...
```

### Shardformer Convergence
## ⌨️ Development Notes

### Add New Policy to Shardformer

This section serves as the guideline for writing new policies and register them into `shardformer`.

- Step 1. Write your own model policy

You can create a new file in the `colossalai/shardformer/policies` folder and name the file with the model name. You can implement your policy in this file. You should not import the any model zoo library at the header section of the file because we do not want to import the library when we do not use the policy. Libraries such as `transformers` should be imported only in the function body when needed.

- Step 2. Register your policy to the autopolicy

Next, you need to register your policy in the `colossalai/shardformer/policies/autopolicy.py` file.

For example, if we register the policy for the BERT model, we just add a key-value in the `_POLICY_LIST` dictionary. The key if the `qualname` of the model object (you can get it by model.__class__.__qualname__). The value is a `PolicyLocation` object, which contains the file name and the class name of the policy. We do not import the policy directly because the policy file may contain libraries (such as `transformers`) which we do not want to import when we do not use the policy.

```python
_POLICY_LIST = {
# BERT
"transformers.models.bert.modeling_bert.BertModel":
PolicyLocation(file_name="bert", class_name="BertModelPolicy"),
}
```

### Write Your Unit Testing

This section serves as the guideline for testing the `shardformer` module.

- Step 1. Add your model to the model zoo in the test kits.

Add your model to the `tests/kit/model_zoo` file. This allows you to define test-related components for this model. You can take `tests/kit/model_zoo/transformers/llama.py` as an example for reference.

- Step 2. Write your unit testing for the model

Next, implement your unit test in the `tests/test_shardformer` folder. Please refer to other similar tests for style consistency.


- Step 3. Execute your test

When you run tests locally, you should run tests for both your newly-added test file and the whole `shardformer` module tests.

```bash
# test for your own test file
pytest tests/test_shardformer/test_model/<your-file>.py

# test for the whole shardformer module
pytest tests/test_shardformer
```

## 📊 Benchmarking

### System Performance

To be added.

### Convergence

To validate that training the model using shardformers does not impact its convergence. We [fine-tuned the BERT model](./examples/shardformer_benchmark.py) using both shardformer and non-shardformer approaches. We compared the accuracy, loss, F1 score of the training results.

| accuracy | f1 | loss | GPU number | model shard |
| :-----: | :----: | :----: | :----: | :----: |
| 0.82594 | 0.87441 | 0.09913 | 4 | True |
| 0.81884 | 0.87299 | 0.10120 | 2 | True |
| 0.81855 | 0.87124 | 0.10357 | 1 | False |
| accuracy | f1 | loss | GPU number | model shard |
| :------: | :-----: | :-----: | :--------: | :---------: |
| 0.82594 | 0.87441 | 0.09913 | 4 | True |
| 0.81884 | 0.87299 | 0.10120 | 2 | True |
| 0.81855 | 0.87124 | 0.10357 | 1 | False |

Overall, the results demonstrate that using shardformers during model training does not affect the convergence.
2 changes: 1 addition & 1 deletion colossalai/shardformer/examples/shardformer_benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ def train(args):
if dist.get_world_size() > 1:
shard_config = ShardConfig(enable_fused_normalization=args.fused_layernorm)
shard_former = ShardFormer(shard_config=shard_config)
model = shard_former.shard_model(model)
model = shard_former.optimize(model)

optim = Adam(model.parameters(), lr=args.lr)
num_update_steps_per_epoch = len(train_dataloader) // args.accumulation_steps
Expand Down
Loading