Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 7 additions & 9 deletions colossalai/shardformer/shard/shard_config.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,8 @@
from dataclasses import dataclass

import torch.distributed as dist
from torch.distributed import ProcessGroup

from colossalai.cluster.dist_coordinator import DistCoordinator

__all__ = ['ShardConfig']
Expand All @@ -11,10 +14,10 @@ class ShardConfig:
The config for sharding the huggingface model

Args:
tensor_parallel_size (int): The size of tensor parallel
tensor_parallel_process_group (int): The process group for tensor parallelism, defaults to None, which is the global process group.
enable_fused_normalization (bool): Whether to use fused layernorm, default is False
"""
tensor_parallel_size: int
tensor_parallel_process_group: int = None
enable_fused_normalization: bool = False

# TODO: add support for tensor parallel
Expand All @@ -25,10 +28,5 @@ class ShardConfig:
# gather_output: bool = True

def __post_init__(self):
coordinator = DistCoordinator()

# ensure the parallel size can match the world size
world_size = coordinator.world_size
self.data_parallel_size = world_size // self.tensor_parallel_size
assert world_size == self.data_parallel_size * self.tensor_parallel_size, \
f"The world size ({world_size}) should be divisible by the data parallel size {self.data_parallel_size} and tensor parallel size {self.tensor_parallel_size}"
# get the parallel size
self.tensor_parallel_size = dist.get_world_size(self.tensor_parallel_process_group)
11 changes: 3 additions & 8 deletions colossalai/shardformer/shard/sharder.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,16 +22,10 @@ class ModelSharder(object):
shard_config: The setting of distributed model
"""

def __init__(
self,
model: nn.Module,
policy: Policy,
shard_config: ShardConfig = None, # TODO
pg_manager: ProcessGroupManager = None) -> None:
def __init__(self, model: nn.Module, policy: Policy, shard_config: ShardConfig = None) -> None:
self.model = model
self.policy = get_autopolicy(self.model) if policy is None else policy
self.shard_config = shard_config
self.pg_manager = pg_manager

def shard(self) -> None:
r"""
Expand Down Expand Up @@ -198,7 +192,8 @@ def _replace_sub_module(
continue

try:
replace_layer = target_module.from_native_module(native_sub_module, self.pg_manager.pg_store['tp1d'],
replace_layer = target_module.from_native_module(native_sub_module,
self.shard_config.tensor_parallel_process_group,
**kwargs)
except Exception as e:
raise RuntimeError(
Expand Down
25 changes: 2 additions & 23 deletions colossalai/shardformer/shard/shardformer.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
import torch.nn as nn
from torch.utils.data import Dataset

from colossalai.cluster import DistCoordinator, ProcessGroupManager
from colossalai.cluster import DistCoordinator

from ..policies.basepolicy import Policy
from .shard_config import ShardConfig
Expand All @@ -28,7 +27,6 @@ class ShardFormer:
tensor_parallel_mode='1d',
)
shard_former = ShardFormer(shard_config=shard_config)
shard_former.init_distributed()
model = shard_former.shard_model(org_model)
```
"""
Expand All @@ -41,19 +39,6 @@ def __init__(self, shard_config: ShardConfig):
"""
self.coordinator = DistCoordinator()
self.shard_config = shard_config
self.pg_manager = None

def init_distributed(self) -> ProcessGroupManager:
"""
Initialize the distributed process group according to the
"""
# create process group manager and 1d process group
# TODO: may need to support other parallel mode when the config has such as field
pg_manager = ProcessGroupManager()
pg_manager.create_process_group(name='tp1d', ranks=range(self.coordinator.world_size))
self.pg_manager = pg_manager

return pg_manager

def shard_model(self, model: nn.Module, policy: Policy = None):
r"""
Expand All @@ -64,12 +49,6 @@ def shard_model(self, model: nn.Module, policy: Policy = None):
shard_config (`ShardConfig`): the config for distribute information
policy (`Policy`): the custom policy for sharding
"""
sharder = ModelSharder(model=model, shard_config=self.shard_config, policy=policy, pg_manager=self.pg_manager)
sharder = ModelSharder(model=model, shard_config=self.shard_config, policy=policy)
sharder.shard()
return model

def shard_dataset(self, dataset: Dataset):
"""
Shard dataset for DP
"""
pass
6 changes: 2 additions & 4 deletions tests/test_shardformer/test_model/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,17 +3,15 @@
from colossalai.shardformer import ShardConfig, ShardFormer


def build_model(world_size, model_fn):
def build_model(model_fn):
# create new model
org_model = model_fn().cuda()

# shard model
shard_config = ShardConfig(tensor_parallel_size=world_size, enable_fused_normalization=True)
shard_config = ShardConfig(enable_fused_normalization=True)
model_copy = copy.deepcopy(org_model)
shard_former = ShardFormer(shard_config=shard_config)
shard_former.init_distributed()
sharded_model = shard_former.shard_model(model_copy).cuda()

return org_model, sharded_model


Expand Down
2 changes: 1 addition & 1 deletion tests/test_shardformer/test_model/test_shard_bert.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ def check_bert(rank, world_size, port):

sub_model_zoo = model_zoo.get_sub_registry('transformers_bert')
for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items():
org_model, sharded_model = build_model(world_size, model_fn)
org_model, sharded_model = build_model(model_fn)
check_forward_backward(org_model, sharded_model, data_gen_fn, output_transform_fn, loss_fn)

torch.cuda.empty_cache()
Expand Down
2 changes: 1 addition & 1 deletion tests/test_shardformer/test_model/test_shard_bloom.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ def check_bloom(rank, world_size, port):

sub_model_zoo = model_zoo.get_sub_registry('transformers_bloom')
for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items():
org_model, sharded_model = build_model(world_size, model_fn)
org_model, sharded_model = build_model(model_fn)
check_forward_backward(org_model, sharded_model, data_gen_fn, output_transform_fn, loss_fn)

torch.cuda.empty_cache()
Expand Down
2 changes: 1 addition & 1 deletion tests/test_shardformer/test_model/test_shard_gpt2.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ def check_gpt2(rank, world_size, port):

sub_model_zoo = model_zoo.get_sub_registry('transformers_gpt')
for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items():
org_model, sharded_model = build_model(world_size, model_fn)
org_model, sharded_model = build_model(model_fn)
check_forward_backward(org_model, sharded_model, data_gen_fn, output_transform_fn, loss_fn)

torch.cuda.empty_cache()
Expand Down
2 changes: 1 addition & 1 deletion tests/test_shardformer/test_model/test_shard_llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ def check_llama(rank, world_size, port):
sub_model_zoo = model_zoo.get_sub_registry('transformers_llama')

for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items():
org_model, sharded_model = build_model(world_size, model_fn)
org_model, sharded_model = build_model(model_fn)
check_forward_backward(org_model, sharded_model, data_gen_fn, output_transform_fn, loss_fn)

torch.cuda.empty_cache()
Expand Down
2 changes: 1 addition & 1 deletion tests/test_shardformer/test_model/test_shard_opt.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ def check_OPTModel(rank, world_size, port):

sub_model_zoo = model_zoo.get_sub_registry('transformers_opt')
for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items():
org_model, sharded_model = build_model(world_size, model_fn)
org_model, sharded_model = build_model(model_fn)
check_forward_backward(org_model, sharded_model, data_gen_fn, output_transform_fn, loss_fn)

torch.cuda.empty_cache()
Expand Down
2 changes: 1 addition & 1 deletion tests/test_shardformer/test_model/test_shard_t5.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ def check_t5(rank, world_size, port):
sub_model_zoo = model_zoo.get_sub_registry('transformers_t5')

for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items():
org_model, sharded_model = build_model(world_size, model_fn)
org_model, sharded_model = build_model(model_fn)
check_forward_backward(org_model, sharded_model, data_gen_fn, output_transform_fn, loss_fn)

torch.cuda.empty_cache()
Expand Down
77 changes: 77 additions & 0 deletions tests/test_shardformer/test_with_torch_ddp.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,77 @@
import pytest
import torch
import torch.distributed as dist
from torch.nn.parallel import DistributedDataParallel as DDP

import colossalai
from colossalai.cluster import DistCoordinator
from colossalai.logging import disable_existing_loggers
from colossalai.shardformer import ShardConfig, ShardFormer
from colossalai.testing import clear_cache_before_run, rerun_if_address_is_in_use, spawn
from tests.kit.model_zoo import model_zoo


def check_shardformer_with_ddp(rank, world_size, port):
disable_existing_loggers()
colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')

sub_model_zoo = model_zoo.get_sub_registry('transformers_gpt')

# create shardformer
# ranks: [0, 1, 2, 3]
# tp ranks = [0, 1], [2, 3]
# dp ranks = [0, 2], [1, 3]
dp_process_group_1 = dist.new_group([0, 2])
dp_process_group_2 = dist.new_group([1, 3])
tp_process_group_1 = dist.new_group([0, 1])
tp_process_group_2 = dist.new_group([2, 3])

coordinator = DistCoordinator()

if coordinator.rank in [0, 1]:
tp_process_group = tp_process_group_1
else:
tp_process_group = tp_process_group_2

if coordinator.rank in [0, 2]:
dp_process_group = dp_process_group_1
else:
dp_process_group = dp_process_group_2

shard_config = ShardConfig(tensor_parallel_process_group=tp_process_group, enable_fused_normalization=True)
shardformer = ShardFormer(shard_config=shard_config)

for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items():
# create and shard model
model = model_fn().cuda()
sharded_model = shardformer.shard_model(model)

# add ddp
sharded_ddp_model = DDP(sharded_model, process_group=dp_process_group)

# prepare input
data = data_gen_fn()
data = {k: v.cuda() for k, v in data.items()}

# switch to train mode
sharded_ddp_model.train()

# run forward
output = sharded_ddp_model(**data)
loss = loss_fn(output)

# backward
loss.backward()
torch.cuda.empty_cache()


@pytest.mark.dist
@rerun_if_address_is_in_use()
@clear_cache_before_run()
def test_gpt2():
spawn(check_shardformer_with_ddp, 4)


if __name__ == "__main__":
test_gpt2()
test_gpt2()