Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 21 additions & 7 deletions colossalai/shardformer/policies/bert.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,12 +8,6 @@
from .basepolicy import ModulePolicyDescription, Policy, SubModuleReplacementDescription


class ParallelModule():

def __init__(self):
pass


class BertPolicy(Policy):

def preprocess(self, shard_config: ShardConfig = None):
Expand Down Expand Up @@ -49,7 +43,27 @@ def module_policy(self, shard_config: ShardConfig = None):
sub_module_replacement=[
SubModuleReplacementDescription(
suffix="attention.self.query",
target_module=ParallelModule,
target_module=col_nn.Linear1D_Col,
),
SubModuleReplacementDescription(
suffix="attention.self.key",
target_module=col_nn.Linear1D_Col,
),
SubModuleReplacementDescription(
suffix="attention.self.value",
target_module=col_nn.Linear1D_Col,
),
SubModuleReplacementDescription(
suffix="attention.output.dense",
target_module=col_nn.Linear1D_Row,
),
SubModuleReplacementDescription(
suffix="intermediate.dense",
target_module=col_nn.Linear1D_Col,
),
SubModuleReplacementDescription(
suffix="output.dense",
target_module=col_nn.Linear1D_Row,
),
])
}
Expand Down
15 changes: 7 additions & 8 deletions colossalai/shardformer/shard/sharder.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,8 @@
from colossalai.cluster.process_group_manager import ProcessGroupManager

from ..policies.autopolicy import get_autopolicy
from ..policies.basepolicy import Policy
from ..utils.utils import setattr_
from ..policies.basepolicy import Policy, SubModuleReplacementDescription
from ..utils.utils import getattr_, setattr_
from .shard_config import ShardConfig

__all__ = ['ModelSharder', 'shard_model']
Expand Down Expand Up @@ -90,9 +90,7 @@ def replace_module(self,) -> None:
Args:
model (:class:`torch.nn.Module`): The model to shard
"""
print(self.policy)
module_descriptions = self.policy.module_policy(self.shard_config)
print(f"*******{module_descriptions}")
for module_description in module_descriptions.items():
origin_layer_cls = module_description[0]
attr_replacement = module_description[1].attribute_replacement
Expand Down Expand Up @@ -160,7 +158,7 @@ def _replace_param(
def _replace_sub_module(
self,
org_layer: nn.Module,
sub_module_replacement: List[Callable],
sub_module_replacement: List[SubModuleReplacementDescription],
) -> None:
r"""
Shard one layer according to the policy, the layer should be the same class as the key in policy's argument_policy return dict
Expand All @@ -177,7 +175,8 @@ def _replace_sub_module(

assert target_module is not None, 'target_module should not be None'

# TODO: integrate with new layer
# replace_layer = target_module.from_native_layer(org_layer, self.pg_manager)
replace_layer = None
# TODO: support different parallel mode
native_sub_module = getattr_(org_layer, suffix)
replace_layer = target_module.from_native_module(native_sub_module, self.pg_manager.pg_store['tp1d'])

setattr_(org_layer, suffix, replace_layer)
28 changes: 17 additions & 11 deletions tests/test_shardformer/test_model/test_shard_bert.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,13 @@
import copy
import os
import random

import pytest
import torch
from transformers import AutoTokenizer, BertConfig, BertForMaskedLM

import colossalai
from colossalai.logging import disable_existing_loggers
from colossalai.shardformer.shard import ShardConfig, shard_model
from colossalai.shardformer import ShardConfig, ShardFormer
from colossalai.testing import rerun_if_address_is_in_use, spawn

os.environ['TRANSFORMERS_NO_ADVISORY_WARNINGS'] = 'true'
Expand All @@ -20,15 +20,21 @@ def build_model(rank, world_size):
config.hidden_dropout_prob = 0
config.attention_probs_dropout_prob = 0

org_model = BertForMaskedLM.from_pretrained('bert-base-uncased', config=config).to('cuda')

shardconfig = ShardConfig(
rank=rank,
world_size=world_size,
gather_output=True,
)
sharded_model = shard_model(BertForMaskedLM.from_pretrained('bert-base-uncased', config=config),
shardconfig).to('cuda')
org_model = BertForMaskedLM.from_pretrained('bert-base-uncased', config=config)
org_model_forshard = copy.deepcopy(org_model)

org_model.to('cuda')
# TODO: no need to transfer to cuda
org_model_forshard.to('cuda')
shard_config = ShardConfig(tensor_parallel_size=2,
data_parallel_size=1,
pipeline_parallel_size=1,
tensor_parallel_mode='1d',
inference_only=True,
gather_output=True)
shard_former = ShardFormer(shard_config=shard_config)
shard_former.init_distributed()
sharded_model = shard_former.shard_model(org_model_forshard).to('cuda')

return org_model, sharded_model

Expand Down