diff --git a/colossalai/shardformer/README.md b/colossalai/shardformer/README.md
index 877e28a2db0e..f5d8bb35d91d 100644
--- a/colossalai/shardformer/README.md
+++ b/colossalai/shardformer/README.md
@@ -15,7 +15,12 @@
- [Policy](#policy)
- [Model Sharder](#model-sharder)
- [User-facing API](#user-facing-api)
- - [Shardformer Convergence](#shardformer-convergence)
+ - [⌨️ Development Notes](#️-development-notes)
+ - [Add New Policy to Shardformer](#add-new-policy-to-shardformer)
+ - [Write Your Unit Testing](#write-your-unit-testing)
+ - [📊 Benchmarking](#-benchmarking)
+ - [System Performance](#system-performance)
+ - [Convergence](#convergence)
## 🔗 Introduction
@@ -40,12 +45,9 @@ config = BertConfig.from_pretrained('bert-base-uncased')
model = BertForMaskedLM.from_pretrained('bert-base-uncased', config=config)
# create huggingface model as normal
-shard_config = ShardConfig(tensor_parallel_size=2,
- data_parallel_size=1,
- gather_output=True)
+shard_config = ShardConfig()
shard_former = ShardFormer(shard_config=shard_config)
-shard_former.init_distributed()
-sharded_model = shard_former.shard_model(model).to('cuda')
+sharded_model = shard_former.optimize(model).to('cuda')
# do everything like normal
...
@@ -67,10 +69,11 @@ class MyPolicy(Policy):
# use customized policy to shard model
my_policy = MyPolicy()
-shard_former.shard_model(model, my_policy)
+shard_former.optimize(model, my_policy)
+
-```
+```
## 🗺 Roadmap
We will follow this roadmap to develop Shardformer:
@@ -112,7 +115,6 @@ Please refer to the code for more details.
- This diagram is deprecated, need to update it
@@ -147,15 +149,13 @@ class ParallelModule(torch.nn.Module):
```python
@dataclass
class ShardConfig:
- data_parallel_size: int
- tensor_parallel_size: int
+ tensor_parallel_process_group: ProcessGroup = None
+ enable_fused_normalization: bool = False
...
# Some possible future config fields
- pipeline_parallel_size: int # Support pipeline parallelism
tensor_parallel_mode: Choice['1d', '2d', '2.5d', '3d'] # support different tensor parallel mode
inference_only: bool # only inject inference-suitable sharding policy
- gather_output: bool # gather the model output
use_flash_attention: bool # whether to use flash attention to speed up attention
```
@@ -166,42 +166,42 @@ It is merely a description, the actual sharding will be performed by `ModelShard
We abstract the policy into four stages:
1. Preprocessing: call `Policy.preprocess` to do some prior work before sharding, for example, resizing the embedding
-2. Providing a new class: call `Policy.new_model_class` to get a new class for the model, this class replaces attributes and the forward function
-3. Providing `ModulePolicyDescription`: call `Policy.module_policy` to get a bunch of `ModulePolicyDescription` to tell `ModelSharder` how the submodules's attributes, child parameters, and deeper submodules will be substituted.
-4. Postprocessing: call `Policy.postprocess` to perform some postprocessing work, for example, binding the embedding and classifier head weights of the BERT model.
+2. Providing `ModulePolicyDescription`: call `Policy.module_policy` to get a bunch of `ModulePolicyDescription` to tell `ModelSharder` how the submodules's attributes, child parameters, and deeper submodules will be substituted.
+3. Postprocessing: call `Policy.postprocess` to perform some postprocessing work, for example, binding the embedding and classifier head weights of the BERT model.
``` python
@dataclass
class ModulePolicyDescription:
- """
- Describe how the attributes and parameters will be transformed in a policy
+ r"""
+ Describe how the attributes and parameters will be transformed in a policy.
Args:
attribute_replacement (Dict[str, Any]): key is the attribute name, value is the attribute value after sharding
- param_replacement (List[Callable]): a list of functions to perform in-place param replacement. The function must receive two arguments: module, process_group. One example is
- def example_replace_weight(module: torch.nn.Module, process_group):
- weight = module.weight
- new_weight = shard_rowwise(weight, process_group)
- module.weight = torch.nn.Parameter(new_weight)
- sub_module_replacement: each element in the list is a ParamReplacementDescription object which specifies the module to be replaced and the target module used to replacement
+ param_replacement (List[Callable]): a list of functions to perform in-place param replacement. The function must receive only one arguments: module.
+ sub_module_replacement (List[SubModuleReplacementDescription]): each element in the list is a ParamReplacementDescription
+ object which specifies the module to be replaced and the target module used to replacement.
+ method_replace (Dict[str, Callable]): key is the method name, value is the method for replacement
"""
- attribute_replacement: Dict[str, Any]
- param_replacement: List[Callable]
- sub_module_replacement: List[SubModuleReplacementDescription]
+ attribute_replacement: Dict[str, Any] = None
+ param_replacement: List[Callable] = None
+ sub_module_replacement: List[SubModuleReplacementDescription] = None
+ method_replacement: Dict[str, Callable] = None
@dataclass
class SubModuleReplacementDescription:
- """
+ r"""
Describe how a submodule will be replaced
Args:
suffix (str): used to get the submodule object
target_module (ParallelModule): specifies the module class used to replace to submodule
kwargs (Dict[str, Any]): the dictionary used to pass extra arguments to the `ParallelModule.from_native_module` method.
+ ignore_if_not_exist (bool): if the submodule does not exist, ignore it or raise an exception
"""
suffix: str
target_module: ParallelModule
kwargs: Dict[str, Any] = None
+ ignore_if_not_exist: bool = False
class Policy(ABC):
@@ -230,13 +230,6 @@ class Policy(ABC):
"""
...
- @abstractmethod
- def new_model_class(self) -> Union[Type[nn.Module], None]:
- """
- replace the class of the model to substitute the forward and attributes
- """
- ...
-
@abstractmethods
def postprocess(self) -> nn.Module:
"""
@@ -253,8 +246,9 @@ class Policy(ABC):
```python
class ModelSharder:
- def __init__(self, model: torch.nn.Module, shard_config: ShardConfig, Policy: ShardPolicy = None)
+ def __init__(self, model: torch.nn.Module, shard_config: ShardConfig, Policy: ShardPolicy = None):
#TODO: input is a cls or a obj
+ ...
def shard(self) -> None:
"""
@@ -262,15 +256,6 @@ class ModelSharder:
"""
...
- def replace_model_class(self) -> None:
- """
- Replace the model's methods and attributes with our own defined class.
-
- E.g. we can replace the forward function of the original BertForMaskedLM object
- with the forward function we define in BertForMaskedLM_ class.
- """
- ...
-
def replace_module(self) -> None:
"""
Replace the layer according to the policy. Call Policy.module_policy() to get the module. Call _replace_module recursively.
@@ -291,7 +276,7 @@ class ShardFormer:
shard_former = ShardFormer(shard_config=shard_config)
shard_former.init_distributed()
- model = shard_former.shard_model(model, policy=policy)
+ model = shard_former.optimize(model, policy=policy)
dataloader = shard_former.shard_dataset(dataset)
"""
@@ -326,14 +311,69 @@ class ShardFormer:
...
```
-### Shardformer Convergence
+## ⌨️ Development Notes
+
+### Add New Policy to Shardformer
+
+This section serves as the guideline for writing new policies and register them into `shardformer`.
+
+- Step 1. Write your own model policy
+
+You can create a new file in the `colossalai/shardformer/policies` folder and name the file with the model name. You can implement your policy in this file. You should not import the any model zoo library at the header section of the file because we do not want to import the library when we do not use the policy. Libraries such as `transformers` should be imported only in the function body when needed.
+
+- Step 2. Register your policy to the autopolicy
+
+Next, you need to register your policy in the `colossalai/shardformer/policies/autopolicy.py` file.
+
+For example, if we register the policy for the BERT model, we just add a key-value in the `_POLICY_LIST` dictionary. The key if the `qualname` of the model object (you can get it by model.__class__.__qualname__). The value is a `PolicyLocation` object, which contains the file name and the class name of the policy. We do not import the policy directly because the policy file may contain libraries (such as `transformers`) which we do not want to import when we do not use the policy.
+
+```python
+_POLICY_LIST = {
+ # BERT
+ "transformers.models.bert.modeling_bert.BertModel":
+ PolicyLocation(file_name="bert", class_name="BertModelPolicy"),
+}
+```
+
+### Write Your Unit Testing
+
+This section serves as the guideline for testing the `shardformer` module.
+
+- Step 1. Add your model to the model zoo in the test kits.
+
+Add your model to the `tests/kit/model_zoo` file. This allows you to define test-related components for this model. You can take `tests/kit/model_zoo/transformers/llama.py` as an example for reference.
+
+- Step 2. Write your unit testing for the model
+
+Next, implement your unit test in the `tests/test_shardformer` folder. Please refer to other similar tests for style consistency.
+
+
+- Step 3. Execute your test
+
+When you run tests locally, you should run tests for both your newly-added test file and the whole `shardformer` module tests.
+
+```bash
+# test for your own test file
+pytest tests/test_shardformer/test_model/.py
+
+# test for the whole shardformer module
+pytest tests/test_shardformer
+```
+
+## 📊 Benchmarking
+
+### System Performance
+
+To be added.
+
+### Convergence
To validate that training the model using shardformers does not impact its convergence. We [fine-tuned the BERT model](./examples/shardformer_benchmark.py) using both shardformer and non-shardformer approaches. We compared the accuracy, loss, F1 score of the training results.
-| accuracy | f1 | loss | GPU number | model shard |
-| :-----: | :----: | :----: | :----: | :----: |
-| 0.82594 | 0.87441 | 0.09913 | 4 | True |
-| 0.81884 | 0.87299 | 0.10120 | 2 | True |
-| 0.81855 | 0.87124 | 0.10357 | 1 | False |
+| accuracy | f1 | loss | GPU number | model shard |
+| :------: | :-----: | :-----: | :--------: | :---------: |
+| 0.82594 | 0.87441 | 0.09913 | 4 | True |
+| 0.81884 | 0.87299 | 0.10120 | 2 | True |
+| 0.81855 | 0.87124 | 0.10357 | 1 | False |
Overall, the results demonstrate that using shardformers during model training does not affect the convergence.
diff --git a/colossalai/shardformer/examples/shardformer_benchmark.py b/colossalai/shardformer/examples/shardformer_benchmark.py
index bb3560ee9e21..de82305b2547 100644
--- a/colossalai/shardformer/examples/shardformer_benchmark.py
+++ b/colossalai/shardformer/examples/shardformer_benchmark.py
@@ -51,7 +51,7 @@ def train(args):
if dist.get_world_size() > 1:
shard_config = ShardConfig(enable_fused_normalization=args.fused_layernorm)
shard_former = ShardFormer(shard_config=shard_config)
- model = shard_former.shard_model(model)
+ model = shard_former.optimize(model)
optim = Adam(model.parameters(), lr=args.lr)
num_update_steps_per_epoch = len(train_dataloader) // args.accumulation_steps
diff --git a/colossalai/shardformer/policies/basepolicy.py b/colossalai/shardformer/policies/basepolicy.py
index 2b972606948c..9ea3d95de5b2 100644
--- a/colossalai/shardformer/policies/basepolicy.py
+++ b/colossalai/shardformer/policies/basepolicy.py
@@ -22,9 +22,11 @@ class SubModuleReplacementDescription:
r"""
Describe how a submodule will be replaced
- suffix (str): used to get the submodule object
- target_module (ParallelModule): specifies the module class used to replace to submodule
- kwargs (Dict[str, Any]): the dictionary used to pass extra arguments to the `ParallelModule.from_native_module` method.
+ Args:
+ suffix (str): used to get the submodule object
+ target_module (ParallelModule): specifies the module class used to replace to submodule
+ kwargs (Dict[str, Any]): the dictionary used to pass extra arguments to the `ParallelModule.from_native_module` method.
+ ignore_if_not_exist (bool): if the submodule does not exist, ignore it or raise an exception
"""
suffix: str
target_module: ParallelModule
@@ -35,47 +37,37 @@ class SubModuleReplacementDescription:
@dataclass
class ModulePolicyDescription:
r"""
- Describe how the attributes and parameters will be transformed in a policy
-
- attribute_replacement (Dict[str, Any]): key is the attribute name, value is the attribute value after sharding
- param_replacement (List[Callable]): a list of functions to perform in-place param replacement. The function
- must receive two arguments: module, process_group. One example is
-
- ```python
- def example_replace_weight(module: torch.nn.Module, process_group):
- weight = module.weight
- new_weight = shard_rowwise(weight, process_group)
- module.weight = torch.nn.Parameter(new_weight)
- ```
-
- sub_module_replacement: each element in the list is a ParamReplacementDescription object which specifies
- the module to be replaced and the target module used to replacement
+ Describe how the attributes and parameters will be transformed in a policy.
+
+ Args:
+ attribute_replacement (Dict[str, Any]): key is the attribute name, value is the attribute value after sharding
+ param_replacement (List[Callable]): a list of functions to perform in-place param replacement. The function
+ must receive only one arguments: module. One example is
+
+ ```python
+ def example_replace_weight(module: torch.nn.Module):
+ weight = module.weight
+ new_weight = shard_rowwise(weight, process_group)
+ module.weight = torch.nn.Parameter(new_weight)
+ ```
+ sub_module_replacement (List[SubModuleReplacementDescription]): each element in the list is a ParamReplacementDescription
+ object which specifies the module to be replaced and the target module used to replacement.
+ method_replace (Dict[str, Callable]): key is the method name, value is the method for replacement
"""
- attribute_replacement: Dict[str, Any]
- param_replacement: List[Callable]
- sub_module_replacement: List[SubModuleReplacementDescription]
- method_replacement: List[Callable] = None
+ attribute_replacement: Dict[str, Any] = None
+ param_replacement: List[Callable] = None
+ sub_module_replacement: List[SubModuleReplacementDescription] = None
+ method_replacement: Dict[str, Callable] = None
class Policy(ABC):
r"""
- The base class for all the policies
-
- For each different model, it should have a different policy class, like BertPolicy for Bert Model
- or OPTPolicy for OPT model.
-
- AutoPolicy:
- Shardformer already defined some policies for huggingface model, just set ``custom_policy`` = None
- to use the auto policy. In shardformer autopolicy, we define a base policy for one type model,
- like BertPolicy, and for each different Bert modle in huggingface like, BertForMaskedLM,
- BertForSequenceClassification, etc., for each different Bert model we difine different policy class
- and overwrite the method like ``inject_policy`` to modify the forward and backward process.
-
- CustomPolicy:
- If you want to define your own policy, you can set ``custom_policy`` = CustomPolicy, and overwrite
- all the methods in ``Policy`` class. You can refer to any policy we defined like the ``BertPolicy``
- class for the example.
+ The base class for all the policies. For each different model, it should have a different policy class,
+ like BertPolicy for Bert Model or OPTPolicy for OPT model.
+ Shardformer has provided many built-in sharding policies for the mainstream models. You can use the
+ built-in policies by setting `policy = None`, which is already the default arguemnt for `Shardformer.optimize`.
+ If you want to define your own policy, you can inherit from this class and overwrite the methods you want to modify.
"""
def __init__(self) -> None:
@@ -106,63 +98,24 @@ def set_shard_config(self, shard_config: ShardConfig) -> None:
def config_sanity_check(self):
"""
Check if the shard config is valid for the model. Raise an exception if the config is invalid.
+ This method is made abstractmethod with no default implementation because we want to the policy writer
+ to take note of the feature supported by his/her model and policy.
"""
pass
@abstractmethod
def preprocess(self) -> nn.Module:
r"""
- Perform some preprocessing of the model, like reshaping the embedding layer
+ Perform some preprocessing of the model, like reshaping the embedding layer.
"""
pass
@abstractmethod
def module_policy(self) -> Dict[Union[str, nn.Module], ModulePolicyDescription]:
r"""
- Return the dict for the modify policy, the key is the original layer class and the value is the
- argument for the modify layer
-
- Return:
- Dict for the modify policy,
- ::
- {
- origin layer class1 (nn.Module): ModulePolicyDescription(
- attribute_replacement = {
- "attribute1": value1,
- "attribute2": value2,
- ...
- },
- param_replacement = [
- function1,
- function2,
- ...
- ],
- sub_module_replacement = [
- `SubModuleReplacementDescription` description1,
- `SubModuleReplacementDescription` description2,
- ...
- ]
- ),
- origin layer class2 (nn.Module): ModulePolicyDescription(
- ...
- ),
- ...
- }
- """
- pass
-
- @abstractmethod
- def new_model_class(self) -> Union[Type[nn.Module], None]:
- r"""
- Return the new model class for the new model, None means no need to modify the model class
-
- Return:
- New model class
-
- E.g.
- ```
- return BertModel_
- ```
+ This method returns the module policy, which is a dictionary. The key is the module name or the module object,
+ and the value is the ModulePolicyDescription object. The ModulePolicyDescription object describes how the module
+ will be transformed.
"""
pass
diff --git a/colossalai/shardformer/policies/bert.py b/colossalai/shardformer/policies/bert.py
index 7cf6caf7ca49..5ab8fb825244 100644
--- a/colossalai/shardformer/policies/bert.py
+++ b/colossalai/shardformer/policies/bert.py
@@ -48,7 +48,6 @@ def module_policy(self):
"crossattention.self.num_attention_heads":
self.model.config.num_attention_heads // self.shard_config.tensor_parallel_size,
},
- param_replacement=[],
sub_module_replacement=[
SubModuleReplacementDescription(
suffix="attention.self.query",
@@ -88,18 +87,16 @@ def module_policy(self):
)
]),
BertEmbeddings:
- ModulePolicyDescription(attribute_replacement={},
- param_replacement=[],
- sub_module_replacement=[
- SubModuleReplacementDescription(
- suffix="word_embeddings",
- target_module=col_nn.VocabParallelEmbedding1D,
- ),
- SubModuleReplacementDescription(
- suffix="dropout",
- target_module=col_nn.DropoutForReplicatedInput,
- )
- ])
+ ModulePolicyDescription(sub_module_replacement=[
+ SubModuleReplacementDescription(
+ suffix="word_embeddings",
+ target_module=col_nn.VocabParallelEmbedding1D,
+ ),
+ SubModuleReplacementDescription(
+ suffix="dropout",
+ target_module=col_nn.DropoutForReplicatedInput,
+ )
+ ])
}
# optimization configuration
@@ -121,10 +118,6 @@ def module_policy(self):
),)
return base_policy
- def new_model_class(self):
- # do nothing
- return None
-
def postprocess(self):
return self.model
@@ -148,13 +141,10 @@ def module_policy(self):
module_policy = super().module_policy()
addon_module = {
BertLMPredictionHead:
- ModulePolicyDescription(attribute_replacement={},
- param_replacement=[],
- sub_module_replacement=[
- SubModuleReplacementDescription(suffix="decoder",
- target_module=col_nn.Linear1D_Col,
- kwargs={"gather_output": True}),
- ])
+ ModulePolicyDescription(sub_module_replacement=[
+ SubModuleReplacementDescription(
+ suffix="decoder", target_module=col_nn.Linear1D_Col, kwargs={"gather_output": True}),
+ ])
}
# optimization configuration
@@ -191,13 +181,10 @@ def module_policy(self):
module_policy = super().module_policy()
addon_module = {
BertLMPredictionHead:
- ModulePolicyDescription(attribute_replacement={},
- param_replacement=[],
- sub_module_replacement=[
- SubModuleReplacementDescription(suffix="decoder",
- target_module=col_nn.Linear1D_Col,
- kwargs={"gather_output": True}),
- ])
+ ModulePolicyDescription(sub_module_replacement=[
+ SubModuleReplacementDescription(
+ suffix="decoder", target_module=col_nn.Linear1D_Col, kwargs={"gather_output": True}),
+ ])
}
if self.shard_config.enable_fused_normalization:
addon_module[BertLMPredictionHead].sub_module_replacement.append(
@@ -230,13 +217,10 @@ def module_policy(self):
module_policy = super().module_policy()
addon_module = {
BertLMPredictionHead:
- ModulePolicyDescription(attribute_replacement={},
- param_replacement=[],
- sub_module_replacement=[
- SubModuleReplacementDescription(suffix="decoder",
- target_module=col_nn.Linear1D_Col,
- kwargs={"gather_output": True}),
- ])
+ ModulePolicyDescription(sub_module_replacement=[
+ SubModuleReplacementDescription(
+ suffix="decoder", target_module=col_nn.Linear1D_Col, kwargs={"gather_output": True}),
+ ])
}
# optimization configuration
@@ -272,14 +256,12 @@ def module_policy(self):
module_policy = super().module_policy()
addon_module = {
BertForSequenceClassification:
- ModulePolicyDescription(attribute_replacement={},
- param_replacement=[],
- sub_module_replacement=[
- SubModuleReplacementDescription(
- suffix="dropout",
- target_module=col_nn.DropoutForParallelInput,
- )
- ])
+ ModulePolicyDescription(sub_module_replacement=[
+ SubModuleReplacementDescription(
+ suffix="dropout",
+ target_module=col_nn.DropoutForParallelInput,
+ )
+ ])
}
module_policy.update(addon_module)
return module_policy
@@ -297,14 +279,12 @@ def module_policy(self):
module_policy = super().module_policy()
addon_module = {
BertForTokenClassification:
- ModulePolicyDescription(attribute_replacement={},
- param_replacement=[],
- sub_module_replacement=[
- SubModuleReplacementDescription(
- suffix="dropout",
- target_module=col_nn.DropoutForParallelInput,
- )
- ])
+ ModulePolicyDescription(sub_module_replacement=[
+ SubModuleReplacementDescription(
+ suffix="dropout",
+ target_module=col_nn.DropoutForParallelInput,
+ )
+ ])
}
module_policy.update(addon_module)
return module_policy
@@ -329,14 +309,12 @@ def module_policy(self):
module_policy = super().module_policy()
addon_module = {
BertForMultipleChoice:
- ModulePolicyDescription(attribute_replacement={},
- param_replacement=[],
- sub_module_replacement=[
- SubModuleReplacementDescription(
- suffix="dropout",
- target_module=col_nn.DropoutForParallelInput,
- )
- ])
+ ModulePolicyDescription(sub_module_replacement=[
+ SubModuleReplacementDescription(
+ suffix="dropout",
+ target_module=col_nn.DropoutForParallelInput,
+ )
+ ])
}
module_policy.update(addon_module)
return module_policy
diff --git a/colossalai/shardformer/policies/bloom.py b/colossalai/shardformer/policies/bloom.py
index c59cfbb405fc..00ab9159b0dc 100644
--- a/colossalai/shardformer/policies/bloom.py
+++ b/colossalai/shardformer/policies/bloom.py
@@ -98,7 +98,6 @@ def module_policy(self):
"self_attention.num_heads":
self.model.config.n_head // self.shard_config.tensor_parallel_size,
},
- param_replacement=[],
sub_module_replacement=[
SubModuleReplacementDescription(
suffix="self_attention.query_key_value",
@@ -125,7 +124,6 @@ def module_policy(self):
ModulePolicyDescription(attribute_replacement={
"num_heads": self.model.config.n_head // self.shard_config.tensor_parallel_size,
},
- param_replacement=[],
method_replacement={"build_alibi_tensor": build_bloom_alibi_tensor},
sub_module_replacement=[
SubModuleReplacementDescription(
@@ -160,10 +158,6 @@ def module_policy(self):
return base_policy
- def new_model_class(self):
- # do nothing
- return None
-
def postprocess(self):
return self.model
@@ -180,13 +174,10 @@ def module_policy(self):
# add a new item for casual lm
new_item = {
BloomForCausalLM:
- ModulePolicyDescription(attribute_replacement={},
- param_replacement=[],
- sub_module_replacement=[
- SubModuleReplacementDescription(suffix="lm_head",
- target_module=col_nn.Linear1D_Col,
- kwargs=dict(gather_output=True))
- ])
+ ModulePolicyDescription(sub_module_replacement=[
+ SubModuleReplacementDescription(
+ suffix="lm_head", target_module=col_nn.Linear1D_Col, kwargs=dict(gather_output=True))
+ ])
}
policy.update(new_item)
return policy
@@ -213,13 +204,10 @@ def module_policy(self):
# add a new item for casual lm
new_item = {
BloomForSequenceClassification:
- ModulePolicyDescription(attribute_replacement={},
- param_replacement=[],
- sub_module_replacement=[
- SubModuleReplacementDescription(suffix="score",
- target_module=col_nn.Linear1D_Col,
- kwargs=dict(gather_output=True))
- ])
+ ModulePolicyDescription(sub_module_replacement=[
+ SubModuleReplacementDescription(
+ suffix="score", target_module=col_nn.Linear1D_Col, kwargs=dict(gather_output=True))
+ ])
}
policy.update(new_item)
return policy
@@ -233,17 +221,14 @@ def module_policy(self):
# add a new item for casual lm
new_item = {
BloomForTokenClassification:
- ModulePolicyDescription(attribute_replacement={},
- param_replacement=[],
- sub_module_replacement=[
- SubModuleReplacementDescription(suffix="classifier",
- target_module=col_nn.Linear1D_Col,
- kwargs=dict(gather_output=True)),
- SubModuleReplacementDescription(
- suffix="dropout",
- target_module=col_nn.DropoutForReplicatedInput,
- ),
- ])
+ ModulePolicyDescription(sub_module_replacement=[
+ SubModuleReplacementDescription(
+ suffix="classifier", target_module=col_nn.Linear1D_Col, kwargs=dict(gather_output=True)),
+ SubModuleReplacementDescription(
+ suffix="dropout",
+ target_module=col_nn.DropoutForReplicatedInput,
+ ),
+ ])
}
policy.update(new_item)
return policy
diff --git a/colossalai/shardformer/policies/gpt2.py b/colossalai/shardformer/policies/gpt2.py
index c6108f5c0e85..ad0b1144a8a5 100644
--- a/colossalai/shardformer/policies/gpt2.py
+++ b/colossalai/shardformer/policies/gpt2.py
@@ -31,23 +31,20 @@ def preprocess(self):
def module_policy(self):
from transformers.models.gpt2.modeling_gpt2 import GPT2Block, GPT2Model
- return {
+ base_policy = {
GPT2Model:
- ModulePolicyDescription(attribute_replacement={},
- param_replacement=[],
- sub_module_replacement=[
- SubModuleReplacementDescription(
- suffix="wte",
- target_module=col_nn.VocabParallelEmbedding1D,
- ),
- ]),
+ ModulePolicyDescription(sub_module_replacement=[
+ SubModuleReplacementDescription(
+ suffix="wte",
+ target_module=col_nn.VocabParallelEmbedding1D,
+ ),
+ ]),
GPT2Block:
ModulePolicyDescription(attribute_replacement={
"attn.embed_dim": self.model.config.hidden_size // self.shard_config.tensor_parallel_size,
"attn.split_size": self.model.config.hidden_size // self.shard_config.tensor_parallel_size,
"attn.num_heads": self.model.config.num_attention_heads // self.shard_config.tensor_parallel_size,
},
- param_replacement=[],
sub_module_replacement=[
SubModuleReplacementDescription(
suffix="attn.c_attn",
@@ -110,9 +107,6 @@ def module_policy(self):
return base_policy
- def new_model_class(self):
- return self.model
-
def postprocess(self):
return self.model
@@ -136,13 +130,10 @@ def module_policy(self):
module_policy = super().module_policy()
addon_module = {
GPT2LMHeadModel:
- ModulePolicyDescription(attribute_replacement={},
- param_replacement=[],
- sub_module_replacement=[
- SubModuleReplacementDescription(suffix="lm_head",
- target_module=col_nn.Linear1D_Col,
- kwargs={"gather_output": True})
- ])
+ ModulePolicyDescription(sub_module_replacement=[
+ SubModuleReplacementDescription(
+ suffix="lm_head", target_module=col_nn.Linear1D_Col, kwargs={"gather_output": True})
+ ])
}
module_policy.update(addon_module)
return module_policy
@@ -169,13 +160,10 @@ def module_policy(self):
module_policy = super().module_policy()
addon_module = {
GPT2DoubleHeadsModel:
- ModulePolicyDescription(attribute_replacement={},
- param_replacement=[],
- sub_module_replacement=[
- SubModuleReplacementDescription(suffix="lm_head",
- target_module=col_nn.Linear1D_Col,
- kwargs={"gather_output": True})
- ])
+ ModulePolicyDescription(sub_module_replacement=[
+ SubModuleReplacementDescription(
+ suffix="lm_head", target_module=col_nn.Linear1D_Col, kwargs={"gather_output": True})
+ ])
}
module_policy.update(addon_module)
return module_policy
diff --git a/colossalai/shardformer/policies/llama.py b/colossalai/shardformer/policies/llama.py
index 2fd2bc22303b..8f397693745c 100644
--- a/colossalai/shardformer/policies/llama.py
+++ b/colossalai/shardformer/policies/llama.py
@@ -28,7 +28,7 @@ def preprocess(self):
def module_policy(self) -> Dict[Union[str, nn.Module], ModulePolicyDescription]:
from transformers.models.llama.modeling_llama import LlamaDecoderLayer, LlamaModel
- return {
+ base_policy = {
LlamaDecoderLayer:
ModulePolicyDescription(
attribute_replacement={
@@ -37,7 +37,6 @@ def module_policy(self) -> Dict[Union[str, nn.Module], ModulePolicyDescription]:
"self_attn.num_heads":
self.model.config.num_attention_heads // self.shard_config.tensor_parallel_size,
},
- param_replacement=[],
sub_module_replacement=[
SubModuleReplacementDescription(
suffix="self_attn.q_proj",
@@ -70,14 +69,12 @@ def module_policy(self) -> Dict[Union[str, nn.Module], ModulePolicyDescription]:
],
),
LlamaModel:
- ModulePolicyDescription(attribute_replacement={},
- param_replacement=[],
- sub_module_replacement=[
- SubModuleReplacementDescription(
- suffix="embed_tokens",
- target_module=VocabParallelEmbedding1D,
- )
- ])
+ ModulePolicyDescription(sub_module_replacement=[
+ SubModuleReplacementDescription(
+ suffix="embed_tokens",
+ target_module=VocabParallelEmbedding1D,
+ )
+ ])
}
# optimization configuration
@@ -101,9 +98,6 @@ def module_policy(self) -> Dict[Union[str, nn.Module], ModulePolicyDescription]:
return base_policy
- def new_model_class(self):
- return None
-
def postprocess(self):
return self.model
@@ -117,13 +111,10 @@ def module_policy(self):
# add a new item for casual lm
new_item = {
LlamaForCausalLM:
- ModulePolicyDescription(attribute_replacement={},
- param_replacement=[],
- sub_module_replacement=[
- SubModuleReplacementDescription(suffix="lm_head",
- target_module=Linear1D_Col,
- kwargs=dict(gather_output=True))
- ])
+ ModulePolicyDescription(sub_module_replacement=[
+ SubModuleReplacementDescription(
+ suffix="lm_head", target_module=Linear1D_Col, kwargs=dict(gather_output=True))
+ ])
}
policy.update(new_item)
return policy
@@ -139,13 +130,10 @@ def module_policy(self):
# add a new item for sequence classification
new_item = {
LlamaForSequenceClassification:
- ModulePolicyDescription(attribute_replacement={},
- param_replacement=[],
- sub_module_replacement=[
- SubModuleReplacementDescription(suffix="score",
- target_module=Linear1D_Col,
- kwargs=dict(gather_output=True))
- ])
+ ModulePolicyDescription(sub_module_replacement=[
+ SubModuleReplacementDescription(
+ suffix="score", target_module=Linear1D_Col, kwargs=dict(gather_output=True))
+ ])
}
policy.update(new_item)
return policy
diff --git a/colossalai/shardformer/policies/opt.py b/colossalai/shardformer/policies/opt.py
index dfbaaf5785ba..428ee2c9776c 100644
--- a/colossalai/shardformer/policies/opt.py
+++ b/colossalai/shardformer/policies/opt.py
@@ -31,33 +31,28 @@ def module_policy(self):
base_policy = {
OPTDecoder:
- ModulePolicyDescription(attribute_replacement={},
- param_replacement=[],
- sub_module_replacement=[
- SubModuleReplacementDescription(
- suffix="embed_tokens",
- target_module=VocabParallelEmbedding1D,
- )
- ]),
+ ModulePolicyDescription(sub_module_replacement=[
+ SubModuleReplacementDescription(
+ suffix="embed_tokens",
+ target_module=VocabParallelEmbedding1D,
+ )
+ ]),
OPTDecoderLayer:
- ModulePolicyDescription(attribute_replacement={},
- param_replacement=[],
- sub_module_replacement=[
- SubModuleReplacementDescription(
- suffix="fc1",
- target_module=Linear1D_Col,
- ),
- SubModuleReplacementDescription(
- suffix="fc2",
- target_module=Linear1D_Row,
- )
- ]),
+ ModulePolicyDescription(sub_module_replacement=[
+ SubModuleReplacementDescription(
+ suffix="fc1",
+ target_module=Linear1D_Col,
+ ),
+ SubModuleReplacementDescription(
+ suffix="fc2",
+ target_module=Linear1D_Row,
+ )
+ ]),
OPTAttention:
ModulePolicyDescription(attribute_replacement={
"embed_dim": self.model.config.hidden_size // self.shard_config.tensor_parallel_size,
"num_heads": self.model.config.num_attention_heads // self.shard_config.tensor_parallel_size
},
- param_replacement=[],
sub_module_replacement=[
SubModuleReplacementDescription(
suffix="q_proj",
@@ -95,9 +90,6 @@ def module_policy(self):
return base_policy
- def new_model_class(self):
- return None
-
def postprocess(self):
return self.model
@@ -116,13 +108,10 @@ def module_policy(self):
policy = super().module_policy()
new_item = {
OPTForCausalLM:
- ModulePolicyDescription(attribute_replacement={},
- param_replacement=[],
- sub_module_replacement=[
- SubModuleReplacementDescription(suffix="lm_head",
- target_module=Linear1D_Col,
- kwargs=dict(gather_output=True))
- ])
+ ModulePolicyDescription(sub_module_replacement=[
+ SubModuleReplacementDescription(
+ suffix="lm_head", target_module=Linear1D_Col, kwargs=dict(gather_output=True))
+ ])
}
policy.update(new_item)
diff --git a/colossalai/shardformer/policies/t5.py b/colossalai/shardformer/policies/t5.py
index 8853687e7621..37fccaabc457 100644
--- a/colossalai/shardformer/policies/t5.py
+++ b/colossalai/shardformer/policies/t5.py
@@ -44,36 +44,30 @@ def module_policy(self):
base_policy = {
T5Stack:
- ModulePolicyDescription(attribute_replacement={},
- param_replacement=[],
- sub_module_replacement=[
- SubModuleReplacementDescription(
- suffix="dropout",
- target_module=DropoutForParallelInput,
- ),
- SubModuleReplacementDescription(
- suffix="embed_tokens",
- target_module=Embedding1D,
- )
- ]),
+ ModulePolicyDescription(sub_module_replacement=[
+ SubModuleReplacementDescription(
+ suffix="dropout",
+ target_module=DropoutForParallelInput,
+ ),
+ SubModuleReplacementDescription(
+ suffix="embed_tokens",
+ target_module=Embedding1D,
+ )
+ ]),
T5LayerSelfAttention:
- ModulePolicyDescription(attribute_replacement={},
- param_replacement=[],
- sub_module_replacement=[
- SubModuleReplacementDescription(
- suffix="dropout",
- target_module=DropoutForParallelInput,
- ),
- ]),
+ ModulePolicyDescription(sub_module_replacement=[
+ SubModuleReplacementDescription(
+ suffix="dropout",
+ target_module=DropoutForParallelInput,
+ ),
+ ]),
T5LayerCrossAttention:
- ModulePolicyDescription(attribute_replacement={},
- param_replacement=[],
- sub_module_replacement=[
- SubModuleReplacementDescription(
- suffix="dropout",
- target_module=DropoutForParallelInput,
- )
- ]),
+ ModulePolicyDescription(sub_module_replacement=[
+ SubModuleReplacementDescription(
+ suffix="dropout",
+ target_module=DropoutForParallelInput,
+ )
+ ]),
T5Attention:
ModulePolicyDescription(attribute_replacement={
"d_model":
@@ -83,7 +77,6 @@ def module_policy(self):
"inner_dim":
self.model.config.num_heads * self.model.config.d_kv // self.shard_config.tensor_parallel_size
},
- param_replacement=[],
sub_module_replacement=[
SubModuleReplacementDescription(
suffix="q",
@@ -107,51 +100,44 @@ def module_policy(self):
ignore_if_not_exist=True)
]),
T5LayerFF:
- ModulePolicyDescription(attribute_replacement={},
- param_replacement=[],
- sub_module_replacement=[
- SubModuleReplacementDescription(
- suffix="dropout",
- target_module=DropoutForParallelInput,
- ),
- ]),
+ ModulePolicyDescription(sub_module_replacement=[
+ SubModuleReplacementDescription(
+ suffix="dropout",
+ target_module=DropoutForParallelInput,
+ ),
+ ]),
T5DenseGatedActDense:
- ModulePolicyDescription(attribute_replacement={},
- param_replacement=[],
- sub_module_replacement=[
- SubModuleReplacementDescription(
- suffix="wi_0",
- target_module=Linear1D_Col,
- ),
- SubModuleReplacementDescription(
- suffix="wi_1",
- target_module=Linear1D_Row,
- ),
- SubModuleReplacementDescription(suffix="wo",
- target_module=Linear1D_Col,
- kwargs=dict(gather_output=True)),
- SubModuleReplacementDescription(
- suffix="dropout",
- target_module=DropoutForParallelInput,
- )
- ]),
+ ModulePolicyDescription(sub_module_replacement=[
+ SubModuleReplacementDescription(
+ suffix="wi_0",
+ target_module=Linear1D_Col,
+ ),
+ SubModuleReplacementDescription(
+ suffix="wi_1",
+ target_module=Linear1D_Row,
+ ),
+ SubModuleReplacementDescription(
+ suffix="wo", target_module=Linear1D_Col, kwargs=dict(gather_output=True)),
+ SubModuleReplacementDescription(
+ suffix="dropout",
+ target_module=DropoutForParallelInput,
+ )
+ ]),
T5DenseActDense:
- ModulePolicyDescription(attribute_replacement={},
- param_replacement=[],
- sub_module_replacement=[
- SubModuleReplacementDescription(
- suffix="wi",
- target_module=Linear1D_Col,
- ),
- SubModuleReplacementDescription(
- suffix="wo",
- target_module=Linear1D_Row,
- ),
- SubModuleReplacementDescription(
- suffix="dropout",
- target_module=DropoutForParallelInput,
- )
- ])
+ ModulePolicyDescription(sub_module_replacement=[
+ SubModuleReplacementDescription(
+ suffix="wi",
+ target_module=Linear1D_Col,
+ ),
+ SubModuleReplacementDescription(
+ suffix="wo",
+ target_module=Linear1D_Row,
+ ),
+ SubModuleReplacementDescription(
+ suffix="dropout",
+ target_module=DropoutForParallelInput,
+ )
+ ])
}
# optimization configuration
@@ -167,9 +153,6 @@ def module_policy(self):
return base_policy
- def new_model_class(self):
- return None
-
def postprocess(self):
binding_map = [["shared", "encoder.embed_tokens"], ["shared", "decoder.embed_tokens"]]
@@ -185,14 +168,12 @@ def module_policy(self):
from transformers import T5Model
base_policy = super().module_policy()
- base_policy[T5Model] = ModulePolicyDescription(attribute_replacement={},
- param_replacement=[],
- sub_module_replacement=[
- SubModuleReplacementDescription(
- suffix="shared",
- target_module=VocabParallelEmbedding1D,
- )
- ])
+ base_policy[T5Model] = ModulePolicyDescription(sub_module_replacement=[
+ SubModuleReplacementDescription(
+ suffix="shared",
+ target_module=VocabParallelEmbedding1D,
+ )
+ ])
return base_policy
@@ -202,18 +183,14 @@ def module_policy(self):
from transformers import T5ForConditionalGeneration
policy = super().module_policy()
- policy[T5ForConditionalGeneration] = ModulePolicyDescription(attribute_replacement={},
- param_replacement=[],
- sub_module_replacement=[
- SubModuleReplacementDescription(
- suffix="shared",
- target_module=VocabParallelEmbedding1D,
- ),
- SubModuleReplacementDescription(
- suffix="lm_head",
- target_module=Linear1D_Col,
- kwargs=dict(gather_output=True))
- ])
+ policy[T5ForConditionalGeneration] = ModulePolicyDescription(sub_module_replacement=[
+ SubModuleReplacementDescription(
+ suffix="shared",
+ target_module=VocabParallelEmbedding1D,
+ ),
+ SubModuleReplacementDescription(
+ suffix="lm_head", target_module=Linear1D_Col, kwargs=dict(gather_output=True))
+ ])
return policy
def postprocess(self):
@@ -235,14 +212,12 @@ def module_policy(self):
from transformers import T5EncoderModel
base_policy = super().module_policy()
- base_policy[T5EncoderModel] = ModulePolicyDescription(attribute_replacement={},
- param_replacement=[],
- sub_module_replacement=[
- SubModuleReplacementDescription(
- suffix="shared",
- target_module=VocabParallelEmbedding1D,
- )
- ])
+ base_policy[T5EncoderModel] = ModulePolicyDescription(sub_module_replacement=[
+ SubModuleReplacementDescription(
+ suffix="shared",
+ target_module=VocabParallelEmbedding1D,
+ )
+ ])
return base_policy
def postprocess(self):
diff --git a/colossalai/shardformer/policies/vit.py b/colossalai/shardformer/policies/vit.py
index 6a404c2faf0f..eaebe2eee0ba 100644
--- a/colossalai/shardformer/policies/vit.py
+++ b/colossalai/shardformer/policies/vit.py
@@ -28,16 +28,14 @@ def preprocess(self):
def module_policy(self) -> Dict[Union[str, nn.Module], ModulePolicyDescription]:
from transformers.models.vit.modeling_vit import ViTEmbeddings, ViTLayer
- return {
+ base_policy = {
ViTEmbeddings:
- ModulePolicyDescription(attribute_replacement={},
- param_replacement=[],
- sub_module_replacement=[
- SubModuleReplacementDescription(
- suffix="dropout",
- target_module=DropoutForReplicatedInput,
- )
- ]),
+ ModulePolicyDescription(sub_module_replacement=[
+ SubModuleReplacementDescription(
+ suffix="dropout",
+ target_module=DropoutForReplicatedInput,
+ )
+ ]),
ViTLayer:
ModulePolicyDescription(attribute_replacement={
"attention.attention.num_attention_heads":
@@ -45,7 +43,6 @@ def module_policy(self) -> Dict[Union[str, nn.Module], ModulePolicyDescription]:
"attention.attention.all_head_size":
self.model.config.hidden_size // self.shard_config.tensor_parallel_size,
},
- param_replacement=[],
sub_module_replacement=[
SubModuleReplacementDescription(
suffix="attention.attention.query",
diff --git a/colossalai/shardformer/shard/shard_config.py b/colossalai/shardformer/shard/shard_config.py
index c2573bc6d4dd..0a5aa4cc4bdc 100644
--- a/colossalai/shardformer/shard/shard_config.py
+++ b/colossalai/shardformer/shard/shard_config.py
@@ -3,8 +3,6 @@
import torch.distributed as dist
from torch.distributed import ProcessGroup
-from colossalai.cluster.dist_coordinator import DistCoordinator
-
__all__ = ['ShardConfig']
@@ -15,7 +13,9 @@ class ShardConfig:
Args:
tensor_parallel_process_group (int): The process group for tensor parallelism, defaults to None, which is the global process group.
- enable_fused_normalization (bool): Whether to use fused layernorm, default is False
+ enable_fused_normalization (bool): Whether to use fused layernorm, default is False.
+ enable_tensor_parallelism (bool): Whether to use tensor parallelism, default is True.
+ enable_all_optimization (bool): Whether to turn on all optimization, default is False.
"""
tensor_parallel_process_group: ProcessGroup = None
enable_fused_normalization: bool = False
@@ -45,4 +45,4 @@ def _turn_on_all_optimization(self):
Turn on all optimization.
"""
# you can add all the optimization flag here
- self.fused_layernorm = True
+ self.enable_fused_normalization = True
diff --git a/colossalai/shardformer/shard/sharder.py b/colossalai/shardformer/shard/sharder.py
index 81c032b95f03..2867a0a4fd77 100644
--- a/colossalai/shardformer/shard/sharder.py
+++ b/colossalai/shardformer/shard/sharder.py
@@ -1,9 +1,7 @@
-from typing import Any, Callable, Dict, List
+from typing import Any, Callable, Dict, List, Union
import torch.nn as nn
-from colossalai.cluster.process_group_manager import ProcessGroupManager
-
from .._utils import getattr_, setattr_
from ..policies.autopolicy import get_autopolicy
from ..policies.basepolicy import Policy, SubModuleReplacementDescription
@@ -34,7 +32,6 @@ def shard(self) -> None:
self.policy.set_model(self.model)
self.policy.set_shard_config(self.shard_config)
self._preprocess()
- self._replace_model_class()
self._replace_module()
self._postprocess()
@@ -44,27 +41,6 @@ def _preprocess(self) -> None:
def _postprocess(self) -> None:
self.model = self.policy.postprocess()
- def _replace_model_class(self,) -> None:
- r"""
- Replace the model to policy defined model
- Mainly modify the forward and backward to fit distributed model
-
- e.g.
- ::
- BertForMaskedLM.forward -> BertForMaskedLM_.forward
- """
- new_model_class = self.policy.new_model_class()
- if new_model_class is None:
- return
-
- for key in new_model_class.__dict__.keys():
- if hasattr(self.model.__class__, key):
- setattr(
- self.model.__class__,
- key,
- getattr(new_model_class, key),
- )
-
def _replace_module(self,) -> None:
r"""
Replace the module according to the policy, and replace the module one by one
@@ -73,19 +49,18 @@ def _replace_module(self,) -> None:
model (:class:`torch.nn.Module`): The model to shard
"""
module_descriptions = self.policy.module_policy()
- for module_description in module_descriptions.items():
- origin_layer_cls = module_description[0]
- attr_replacement = module_description[1].attribute_replacement
- param_replacement = module_description[1].param_replacement
- sub_module_replacement = module_description[1].sub_module_replacement
- method_replacement = module_description[1].method_replacement
- self._recursive_replace_layer(self.model, origin_layer_cls, attr_replacement, param_replacement,
+ for layer_cls, module_description in module_descriptions.items():
+ attr_replacement = module_description.attribute_replacement
+ param_replacement = module_description.param_replacement
+ sub_module_replacement = module_description.sub_module_replacement
+ method_replacement = module_description.method_replacement
+ self._recursive_replace_layer(self.model, layer_cls, attr_replacement, param_replacement,
method_replacement, sub_module_replacement)
def _recursive_replace_layer(
self,
module: nn.Module,
- origin_cls: nn.Module,
+ origin_cls: Union[str, nn.Module],
attr_replacement: Dict[str, Any],
param_replacement: List[Callable],
method_replacement: Dict[str, Callable],
@@ -95,17 +70,25 @@ def _recursive_replace_layer(
Reverse the replace layer operation
Args:
- layer (:class:`torch.nn.Module`): The object of layer to shard
- origin_cls (:class:`transformers.model`): The origin layer class
+ layer (torch.nn.Module): The object of layer to shard
+ origin_cls (Union[str, torch.nn.Module]): The origin layer class or a string of layer class name.
attr_replacement (Dict): The attribute dict to modify
param_replacement (List[Callable]): The function list to get parameter shard information in polic
sub_module_replacement (List[Callable]): The function list to get sub module shard information in policy
"""
- if module.__class__ == origin_cls:
- self._replace_attr(module, attr_replacement)
- self._replace_param(module, param_replacement)
- self._replace_method(module, method_replacement)
- self._replace_sub_module(module, sub_module_replacement)
+ if (isinstance(origin_cls, str) and origin_cls == module.__class__.__name__) or \
+ (module.__class__ == origin_cls):
+ if attr_replacement is not None:
+ self._replace_attr(module, attr_replacement)
+
+ if param_replacement is not None:
+ self._replace_param(module, param_replacement)
+
+ if method_replacement is not None:
+ self._replace_method(module, method_replacement)
+
+ if sub_module_replacement is not None:
+ self._replace_sub_module(module, sub_module_replacement)
for name, child in module.named_children():
self._recursive_replace_layer(child, origin_cls, attr_replacement, param_replacement, method_replacement,
@@ -138,13 +121,10 @@ def _replace_param(
layer (:class:`torch.nn.Module`): The object of layer to shard
param_replacement (List[Callable]): The function list to get parameter shard information in policy
"""
- # TODO: support parameter shard
- pass
+ for param_func in param_replacement:
+ param_func(module)
def _replace_method(self, module: nn.Module, method_replacement: Dict[str, Callable]):
- if method_replacement is None:
- return
-
for method_name, new_method in method_replacement.items():
# bind the new method to the module
setattr(module, method_name, new_method.__get__(module, module.__class__))
@@ -158,8 +138,8 @@ def _replace_sub_module(
Shard one layer according to the policy, the layer should be the same class as the key in policy's argument_policy return dict
Args:
- org_layer (:class:`torch.nn.Module`): The origin layer object to shard
- param_funcs (:class:`List[typing.Callable]`): The function list to get shard information in policy class
+ org_layer (torch.nn.Module): The origin layer object to shard
+ sub_module_replacement (List[SubModuleReplacementDescription]): The sub module replacement description list
"""
for description in sub_module_replacement:
diff --git a/colossalai/shardformer/shard/shardformer.py b/colossalai/shardformer/shard/shardformer.py
index 7c4220c3a9fb..3fce12463414 100644
--- a/colossalai/shardformer/shard/shardformer.py
+++ b/colossalai/shardformer/shard/shardformer.py
@@ -22,27 +22,19 @@ class ShardFormer:
colossalai.launch_from_torch(config={})
org_model = BertForMaskedLM.from_pretrained('bert-base-uncased')
- shard_config = ShardConfig(
- tensor_parallel_size=2,
- tensor_parallel_mode='1d',
- )
+ shard_config = ShardConfig()
shard_former = ShardFormer(shard_config=shard_config)
- model = shard_former.shard_model(org_model)
+ model = shard_former.optimize(org_model)
```
"""
def __init__(self, shard_config: ShardConfig):
- """
- Do two things:
- 1. Create a colossalai.cluster.process_group_manager to manage process groups for dp, tp and pp
- 2. serve as a store for
- """
self.coordinator = DistCoordinator()
self.shard_config = shard_config
- def shard_model(self, model: nn.Module, policy: Policy = None):
+ def optimize(self, model: nn.Module, policy: Policy = None):
r"""
- The function is used to shard the PyTorch model.
+ This method will optimize the model based on the given policy.
Args:
model (`torch.nn.Model`): the origin huggingface model
diff --git a/tests/test_shardformer/test_model/_utils.py b/tests/test_shardformer/test_model/_utils.py
index a6355bf1c75e..aa1424af3289 100644
--- a/tests/test_shardformer/test_model/_utils.py
+++ b/tests/test_shardformer/test_model/_utils.py
@@ -11,7 +11,7 @@ def build_model(model_fn):
shard_config = ShardConfig(enable_fused_normalization=True)
model_copy = copy.deepcopy(org_model)
shard_former = ShardFormer(shard_config=shard_config)
- sharded_model = shard_former.shard_model(model_copy).cuda()
+ sharded_model = shard_former.optimize(model_copy).cuda()
return org_model, sharded_model
diff --git a/tests/test_shardformer/test_with_torch_ddp.py b/tests/test_shardformer/test_with_torch_ddp.py
index 61b672650965..9f8a5db6c94f 100644
--- a/tests/test_shardformer/test_with_torch_ddp.py
+++ b/tests/test_shardformer/test_with_torch_ddp.py
@@ -44,7 +44,7 @@ def check_shardformer_with_ddp(rank, world_size, port):
for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items():
# create and shard model
model = model_fn().cuda()
- sharded_model = shardformer.shard_model(model)
+ sharded_model = shardformer.optimize(model)
# add ddp
sharded_ddp_model = DDP(sharded_model, process_group=dp_process_group)