Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@

import torch.nn as nn

from .basepolicy import Policy
from .base_policy import Policy

__all__ = ["PolicyLocation", "get_autopolicy", "import_policy"]

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,13 @@

from abc import ABC, abstractmethod
from dataclasses import dataclass
from typing import Any, Callable, Dict, List, Type, Union
from typing import Any, Callable, Dict, List, Optional, Union

import torch.nn as nn
from torch import Tensor
from torch.nn import Module

from colossalai.pipeline.stage_manager import PipelineStageManager

from ..shard.shard_config import ShardConfig

Expand Down Expand Up @@ -71,9 +75,8 @@ class Policy(ABC):
"""

def __init__(self) -> None:
self.shard_config = None
self.model = None
self.shard_config = None
self.shard_config: Optional[ShardConfig] = None
self.model: Optional[Module] = None

def set_model(self, model: nn.Module) -> None:
r"""
Expand All @@ -94,6 +97,12 @@ def set_shard_config(self, shard_config: ShardConfig) -> None:
self.shard_config = shard_config
self.config_sanity_check()

@property
def pipeline_stage_manager(self) -> Optional[PipelineStageManager]:
if self.shard_config is not None:
return self.shard_config.pipeline_stage_manager
return None

@abstractmethod
def config_sanity_check(self):
"""
Expand Down Expand Up @@ -151,3 +160,19 @@ def append_or_create_submodule_replacement(
policy[target_key] = ModulePolicyDescription(sub_module_replacement=description)

return policy

def get_held_layers(self) -> List[Module]:
"""Get layers that should be held in current stage. This method should be implemented by subclass.

Returns:
List[Module]: List of layers that should be hold in current stage
"""
raise NotImplementedError

def get_shared_params(self) -> List[Dict[int, Tensor]]:
"""Get parameters that should be shared across stages. This method should be implemented by subclass.

Returns:
List[Dict[int, Tensor]]: List of parameters that should be shared across stages. E.g. [{0: module.model.embed_tokens.weight, 3: module.lm_head.weight}]
"""
return []
2 changes: 1 addition & 1 deletion colossalai/shardformer/policies/bert.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import colossalai.shardformer.layer as col_nn

from .._utils import getattr_, setattr_
from .basepolicy import ModulePolicyDescription, Policy, SubModuleReplacementDescription
from .base_policy import ModulePolicyDescription, Policy, SubModuleReplacementDescription

__all__ = [
'BertPolicy', 'BertModelPolicy', 'BertForPretrainingPolicy', 'BertLMHeadModelPolicy', 'BertForMaskedLMPolicy',
Expand Down
2 changes: 1 addition & 1 deletion colossalai/shardformer/policies/bloom.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

from .._utils import getattr_, setattr_
from ..modeling.bloom import build_bloom_alibi_tensor_fn
from .basepolicy import ModulePolicyDescription, Policy, SubModuleReplacementDescription
from .base_policy import ModulePolicyDescription, Policy, SubModuleReplacementDescription


class BloomPolicy(Policy):
Expand Down
2 changes: 1 addition & 1 deletion colossalai/shardformer/policies/gpt2.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import colossalai.shardformer.layer as col_nn

from .._utils import getattr_, setattr_
from .basepolicy import ModulePolicyDescription, Policy, SubModuleReplacementDescription
from .base_policy import ModulePolicyDescription, Policy, SubModuleReplacementDescription

__all__ = [
'GPT2Policy', 'GPT2ModelPolicy', 'GPT2LMHeadModelPolicy', 'GPT2DoubleHeadsModelPolicy',
Expand Down
2 changes: 1 addition & 1 deletion colossalai/shardformer/policies/llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

from colossalai.shardformer.layer import FusedRMSNorm, Linear1D_Col, Linear1D_Row, VocabParallelEmbedding1D

from .basepolicy import ModulePolicyDescription, Policy, SubModuleReplacementDescription
from .base_policy import ModulePolicyDescription, Policy, SubModuleReplacementDescription

__all__ = ['LlamaPolicy', 'LlamaForCausalLMPolicy', 'LlamaForSequenceClassificationPolicy']

Expand Down
2 changes: 1 addition & 1 deletion colossalai/shardformer/policies/opt.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from colossalai.shardformer.layer import FusedLayerNorm, Linear1D_Col, Linear1D_Row, VocabParallelEmbedding1D

from .._utils import getattr_, setattr_
from .basepolicy import ModulePolicyDescription, Policy, SubModuleReplacementDescription
from .base_policy import ModulePolicyDescription, Policy, SubModuleReplacementDescription

__all__ = [
'OPTPolicy', 'OPTModelPolicy', 'OPTForCausalLMPolicy', 'OPTForSequenceClassificationPolicy',
Expand Down
4 changes: 2 additions & 2 deletions colossalai/shardformer/policies/t5.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,10 @@
Linear1D_Row,
VocabParallelEmbedding1D,
)
from colossalai.shardformer.policies.basepolicy import ModulePolicyDescription
from colossalai.shardformer.policies.base_policy import ModulePolicyDescription

from .._utils import getattr_, setattr_
from .basepolicy import ModulePolicyDescription, Policy, SubModuleReplacementDescription
from .base_policy import ModulePolicyDescription, Policy, SubModuleReplacementDescription

__all__ = ["T5ModelPolicy", "T5ForConditionalGenerationPolicy", "T5EncoderPolicy"]

Expand Down
2 changes: 1 addition & 1 deletion colossalai/shardformer/policies/vit.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

from colossalai.shardformer.layer import DropoutForReplicatedInput, FusedLayerNorm, Linear1D_Col, Linear1D_Row

from .basepolicy import ModulePolicyDescription, Policy, SubModuleReplacementDescription
from .base_policy import ModulePolicyDescription, Policy, SubModuleReplacementDescription

__all__ = ['ViTPolicy']

Expand Down
9 changes: 7 additions & 2 deletions colossalai/shardformer/shard/shard_config.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,11 @@
from dataclasses import dataclass
from typing import Optional

import torch.distributed as dist
from torch.distributed import ProcessGroup

from colossalai.pipeline.stage_manager import PipelineStageManager

__all__ = ['ShardConfig']


Expand All @@ -12,12 +15,14 @@ class ShardConfig:
The config for sharding the huggingface model

Args:
tensor_parallel_process_group (int): The process group for tensor parallelism, defaults to None, which is the global process group.
tensor_parallel_process_group (Optional[ProcessGroup]): The process group for tensor parallelism, defaults to None, which is the global process group.
pipeline_stage_manager (Optional[PipelineStageManager]): The pipeline stage manager, defaults to None, which means no pipeline.
enable_tensor_parallelism (bool): Whether to turn on tensor parallelism, default is True.
enable_fused_normalization (bool): Whether to use fused layernorm, default is False.
enable_all_optimization (bool): Whether to turn on all optimization, default is False.
"""
tensor_parallel_process_group: ProcessGroup = None
tensor_parallel_process_group: Optional[ProcessGroup] = None
pipeline_stage_manager: Optional[PipelineStageManager] = None
enable_tensor_parallelism: bool = True
enable_fused_normalization: bool = False
enable_all_optimization: bool = False
Expand Down
33 changes: 30 additions & 3 deletions colossalai/shardformer/shard/sharder.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,15 @@
from typing import Any, Callable, Dict, List, Union

import torch.nn as nn
from torch import Tensor

from colossalai.lazy import LazyTensor

from .._utils import getattr_, setattr_
from ..policies.autopolicy import get_autopolicy
from ..policies.basepolicy import Policy, SubModuleReplacementDescription
from ..policies.auto_policy import get_autopolicy
from ..policies.base_policy import Policy, SubModuleReplacementDescription
from .shard_config import ShardConfig
from .utils import set_tensors_to_none

__all__ = ['ModelSharder', 'shard_model']

Expand All @@ -25,15 +29,18 @@ def __init__(self, model: nn.Module, policy: Policy, shard_config: ShardConfig =
self.policy = get_autopolicy(self.model) if policy is None else policy
self.shard_config = shard_config

def shard(self) -> None:
def shard(self) -> List[Dict[int, Tensor]]:
r"""
Shard the model according to the policy
"""
self.policy.set_model(self.model)
self.policy.set_shard_config(self.shard_config)
self._preprocess()
self._release_unheld_layers()
self._replace_module()
self._materialize()
self._postprocess()
return self.policy.get_shared_params()

def _preprocess(self) -> None:
self.model = self.policy.preprocess()
Expand Down Expand Up @@ -172,3 +179,23 @@ def _replace_sub_module(
)

setattr_(org_layer, suffix, replace_layer)

def _release_unheld_layers(self) -> None:
r"""
Release the unheld layers in the model
"""
if self.shard_config and self.shard_config.pipeline_stage_manager:
held_layers = self.policy.get_held_layers()
set_tensors_to_none(self.model, exclude=set(held_layers))

def _materialize(self) -> None:
r"""
Materialize the model if lazy initialization is used
"""
for p in self.model.parameters():
if isinstance(p, LazyTensor):
p.materialize()

for b in self.model.buffers():
if isinstance(b, LazyTensor):
b.materialize()
15 changes: 10 additions & 5 deletions colossalai/shardformer/shard/shardformer.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,11 @@
from typing import Dict, List, Tuple

import torch.nn as nn
from torch import Tensor

from colossalai.cluster import DistCoordinator

from ..policies.basepolicy import Policy
from ..policies.base_policy import Policy
from .shard_config import ShardConfig
from .sharder import ModelSharder

Expand All @@ -24,23 +27,25 @@ class ShardFormer:
org_model = BertForMaskedLM.from_pretrained('bert-base-uncased')
shard_config = ShardConfig()
shard_former = ShardFormer(shard_config=shard_config)
model = shard_former.optimize(org_model)
model, shared_params = shard_former.optimize(org_model)
```
"""

def __init__(self, shard_config: ShardConfig):
self.coordinator = DistCoordinator()
self.shard_config = shard_config

def optimize(self, model: nn.Module, policy: Policy = None):
def optimize(self, model: nn.Module, policy: Policy = None) -> Tuple[nn.Module, List[Dict[int, Tensor]]]:
r"""
This method will optimize the model based on the given policy.

Args:
model (`torch.nn.Model`): the origin huggingface model
shard_config (`ShardConfig`): the config for distribute information
policy (`Policy`): the custom policy for sharding

Returns: the sharded model and the shared parameters
"""
sharder = ModelSharder(model=model, shard_config=self.shard_config, policy=policy)
sharder.shard()
return model
shared_params = sharder.shard()
return model, shared_params
19 changes: 19 additions & 0 deletions colossalai/shardformer/shard/utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
from typing import Set

import torch.nn as nn


def set_tensors_to_none(model: nn.Module, exclude: Set[nn.Module] = set()) -> None:
"""Set all parameters and buffers of model to None

Args:
model (nn.Module): The model to set
"""
if model in exclude:
return
for child in model.children():
set_tensors_to_none(child, exclude=exclude)
for n, p in model.named_parameters(recurse=False):
setattr(model, n, None)
for n, buf in model.named_buffers(recurse=False):
setattr(model, n, None)
4 changes: 2 additions & 2 deletions tests/test_shardformer/test_model/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,8 @@ def build_model(model_fn, enable_fused_normalization=True, enable_tensor_paralle
enable_tensor_parallelism=enable_tensor_parallelism)
model_copy = copy.deepcopy(org_model)
shard_former = ShardFormer(shard_config=shard_config)
sharded_model = shard_former.optimize(model_copy).cuda()
return org_model, sharded_model
sharded_model, shared_params = shard_former.optimize(model_copy)
return org_model, sharded_model.cuda()


def run_forward(original_model, sharded_model, data_gen_fn, output_transform_fn, loss_fn):
Expand Down
27 changes: 27 additions & 0 deletions tests/test_shardformer/test_shard_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
import torch
import torch.nn as nn

from colossalai.shardformer.shard.utils import set_tensors_to_none


class Net(nn.Module):

def __init__(self) -> None:
super().__init__()
self.layers = nn.Sequential(nn.Linear(1, 2), nn.Linear(2, 3))
self.out = nn.Linear(3, 1)


def test_release_layer():
orig_cuda_allocated = torch.cuda.memory_allocated()
model = Net().cuda()
set_tensors_to_none(model, exclude={model.layers[0]})
assert model.layers[1].weight is None
assert model.layers[1].bias is None
assert model.out.weight is None
assert model.out.bias is None
set_tensors_to_none(model)
assert model.layers[0].weight is None
assert model.layers[0].bias is None
assert len(list(model.parameters())) == 0
assert torch.cuda.memory_allocated() == orig_cuda_allocated
2 changes: 1 addition & 1 deletion tests/test_shardformer/test_with_torch_ddp.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ def check_shardformer_with_ddp(rank, world_size, port):
for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items():
# create and shard model
model = model_fn().cuda()
sharded_model = shardformer.optimize(model)
sharded_model, _ = shardformer.optimize(model)

# add ddp
sharded_ddp_model = DDP(sharded_model, process_group=dp_process_group)
Expand Down