From 321d50ca71b44286b440712f14b8f48296351124 Mon Sep 17 00:00:00 2001 From: FoolPlayer <498107402@qq.com> Date: Thu, 11 May 2023 15:46:13 +0800 Subject: [PATCH 01/12] init shardformer code structure --- colossalai/shardformer/__init__.py | 0 colossalai/shardformer/policies/__init__.py | 0 colossalai/shardformer/policies/autopolicy.py | 42 ++++++ colossalai/shardformer/policies/basepolicy.py | 141 ++++++++++++++++++ colossalai/shardformer/policies/bert.py | 112 ++++++++++++++ colossalai/shardformer/shard/__init__.py | 0 colossalai/shardformer/shard/sharder.py | 59 ++++++++ colossalai/shardformer/shard/shardmodel.py | 72 +++++++++ colossalai/shardformer/shard/slicer.py | 22 +++ colossalai/shardformer/shardmodel/__init__.py | 0 .../shardformer/shardmodel/modeling_bert.py | 62 ++++++++ colossalai/shardformer/utils/utils.py | 0 12 files changed, 510 insertions(+) create mode 100644 colossalai/shardformer/__init__.py create mode 100644 colossalai/shardformer/policies/__init__.py create mode 100644 colossalai/shardformer/policies/autopolicy.py create mode 100644 colossalai/shardformer/policies/basepolicy.py create mode 100644 colossalai/shardformer/policies/bert.py create mode 100644 colossalai/shardformer/shard/__init__.py create mode 100644 colossalai/shardformer/shard/sharder.py create mode 100644 colossalai/shardformer/shard/shardmodel.py create mode 100644 colossalai/shardformer/shard/slicer.py create mode 100644 colossalai/shardformer/shardmodel/__init__.py create mode 100644 colossalai/shardformer/shardmodel/modeling_bert.py create mode 100644 colossalai/shardformer/utils/utils.py diff --git a/colossalai/shardformer/__init__.py b/colossalai/shardformer/__init__.py new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/colossalai/shardformer/policies/__init__.py b/colossalai/shardformer/policies/__init__.py new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/colossalai/shardformer/policies/autopolicy.py b/colossalai/shardformer/policies/autopolicy.py new file mode 100644 index 000000000000..225b083cdaac --- /dev/null +++ b/colossalai/shardformer/policies/autopolicy.py @@ -0,0 +1,42 @@ +import torch.nn as nn + +def build_policies(): + """ + Build the policies for the model + + Return: + The dict for the policies + """ + auto_policy_dict = {} + + from transformers.models.bert.modeling_bert import BertForMaskedLM + from bert import BertForMaskedLMPolicy + auto_policy_dict[BertForMaskedLM] = BertForMaskedLMPolicy + + from transformers.models.bert.modeling_bert import BertForSequenceClassification + from bert import BertForSequenceClassificationPolicy + auto_policy_dict[BertForSequenceClassification] = BertForSequenceClassificationPolicy + + return auto_policy_dict + +def get_autopolicy(model:nn.Module): + """ + Return the auto policy for the model + + Args: + model: The model to be used + + Return: + The auto policy for the model + """ + print(model) + auto_policy_dict = build_policies() + policy = auto_policy_dict.get(model, None) + if policy is None: + raise NotImplementedError(f"Auto policy for {model.__qualname__} is not implemented\n Supported models are {[i.__qualname__ for i in auto_policy_dict.keys()]}") + return policy + +# from transformers.models.bert.modeling_bert import BertForMaskedLM, BertForPreTraining +# model = BertForPreTraining +# policy = get_autopolicy(model) +# print(policy) diff --git a/colossalai/shardformer/policies/basepolicy.py b/colossalai/shardformer/policies/basepolicy.py new file mode 100644 index 000000000000..dd06271a1afb --- /dev/null +++ b/colossalai/shardformer/policies/basepolicy.py @@ -0,0 +1,141 @@ +# part of code modified from https://github.com/tunib-ai/parallelformers + +import torch +import torch.nn as nn +import colossalai.nn as col_nn +from typing import Any, Dict, List, Type +from transformers import AutoConfig +from dataclasses import dataclass + +@dataclass +class Layer: + """ + The layer object for the policy + + Args: + weight: The weight name of the layer + bias: The bias name of the layer + replace_layer: The layer to replace the original layer + ignore: Whether to ignore this layer if it is not in the model + """ + weight: str + bias: str + replace_layer: Any + ignore: bool = False + +class Policy(): + """ + The base class for all the policies + """ + def __init__( + self, + inject_layer: nn.Module + ) -> None: + """ + Init the policy class + + Args: + inject_layer: Layer the policy will apply to + """ + self.inject_layer = inject_layer + + @staticmethod + def argument_policy(config, dist_setting: int) -> Dict: + """ + Return the argument and its value need to be modified + + Args: + config: The config of transformer model + dist_setting: The setting of distributed model + + Return: + Dict for the modify policy + + """ + return {} + + + @staticmethod + def inject_policy() -> Dict[nn.Module, nn.Module]: + """ + Return the dict for the inject model + + Return: + The injected model, key is the original model and value is the new shardmodel + """ + return {} + + @staticmethod + def attn_in() -> List: + """ + Attention qkv layer + + Returns: + List[Layer]: List of layer object, each layer is the new + """ + return [] + + @staticmethod + def attn_out() -> List: + """ + Attention output projection layer + + Returns: + List[Layer]: List of layer object + """ + return [] + + @staticmethod + def mlp_in() -> List: + """ + h -> 4h mlp layer + + Returns: + List[Layer]: List of layer object + """ + return [] + + @staticmethod + def mlp_out() -> List: + """ + 4h -> h mlp layer + + Returns: + List[Layer]: List of layer object + """ + return [] + + @staticmethod + def embedding()->List: + """ + Partially slice the embedding layer + vocab_size->vocab_size//gpu_nums + + Return: + List[Layer]: List of layer object + """ + return [] + + @staticmethod + def unembedding()->List: + """ + Partially slice the embedding layer + vocab_size->vocab_size//gpu_nums + + Return: + List[Layer]: List of layer object + """ + return [] + + + # @staticmethod + # def original_layer_class() -> Type[nn.Module]: + # """ + # Class to apply the policy to + # e.g. BertLayer, GPT2Block, BartEncoderLayer, ... + + # Returns: + # Type[nn.Module]: original layer class + # """ + # raise NotImplementedError + diff --git a/colossalai/shardformer/policies/bert.py b/colossalai/shardformer/policies/bert.py new file mode 100644 index 000000000000..25fa75413ab1 --- /dev/null +++ b/colossalai/shardformer/policies/bert.py @@ -0,0 +1,112 @@ +from typing import Dict, List, Type + +import torch.nn as nn +from basepolicy import Policy, Layer +import colossalai.nn as col_nn + +class BertPolicy(Policy): + @staticmethod + def attn_in() -> List: + return [ + Layer( + weight="attention.self.query.weight", + bias="attention.self.query.bias", + replace_layer=col_nn.Linear, + ), + Layer( + weight="attention.self.key.weight", + bias="attention.self.key.bias", + replace_layer=col_nn.Linear, + ), + Layer( + weight="attention.self.value.weight", + bias="attention.self.value.bias", + replace_layer=col_nn.Linear, + ), + Layer( + weight="crossattention.self.query.weight", + bias="crossattention.self.query.bias", + replace_layer=col_nn.Linear, + ignore=True, + ), + Layer( + weight="crossattention.self.key.weight", + bias="crossattention.self.key.bias", + replace_layer=col_nn.Linear, + ignore=True, + ), + Layer( + weight="crossattention.self.value.weight", + bias="crossattention.self.value.bias", + replace_layer=col_nn.Linear, + ignore=True, + ), + + ] + + @staticmethod + def attn_out() -> List: + return [ + Layer( + weight="attention.output.dense.weight", + bias="attention.output.dense.bias", + replace_layer=col_nn.Linear, + ), + Layer( + weight="crossattention.output.dense.weight", + bias="crossattention.output.dense.bias", + replace=col_nn.Linear, + ignore=True, + ), + ] + + @staticmethod + def mlp_in() -> List: + return [ + Layer( + weight="intermediate.dense.weight", + bias="intermediate.dense.bias", + replace_layer=col_nn.Linear, + ), + ] + + @staticmethod + def mlp_out() -> List: + return [ + Layer( + weight="output.dense.weight", + bias="output.dense.bias", + replace_layer=col_nn.Linear, + ), + ] + + @staticmethod + def embedding() -> List: + return [ + + ] + + @staticmethod + def unembedding() -> List: + return [ + + ] + +from transformers import BertForMaskedLM +from colossalai.shardformer.shardmodel.modeling_bert import BertForMaskedLM_ +class BertForMaskedLMPolicy(BertPolicy): + @staticmethod + def inject_policy() -> Dict: + return {BertForMaskedLM: BertForMaskedLM_} + + + +class BertForSequenceClassificationPolicy(BertPolicy): + @staticmethod + def inject_policy() -> Dict: + return {} + + +# model = BertForMaskedLM.from_pretrained("bert-base-uncased") +# _ = BertForMaskedLMPolicy(model) +# print(isinstance(model,list(_.inject_policy().keys())[0])) \ No newline at end of file diff --git a/colossalai/shardformer/shard/__init__.py b/colossalai/shardformer/shard/__init__.py new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/colossalai/shardformer/shard/sharder.py b/colossalai/shardformer/shard/sharder.py new file mode 100644 index 000000000000..30c19ee2dd54 --- /dev/null +++ b/colossalai/shardformer/shard/sharder.py @@ -0,0 +1,59 @@ +import torch.nn as nn +from typing import Any, Dict, Iterable, List, Optional, Tuple, Type, Union +from .shardmodel import ShardConfig +from policies.basepolicy import Policy, Layer +from policies.autopolicy import get_autopolicy +from .slicer import Slicer + +class ModelSharder(object): + """ + Shard the original huggingface model according to the policy + + Args: + policy: The policy to shard the model + model: The model to shard + dist_setting: The setting of distributed model + """ + def __init__( + self, + model: nn.Module, + policy: Policy, + dist_config: ShardConfig, # TODO + ) -> None: + self.model = model + self.policy = get_autopolicy(self.model) if policy is None else policy + self.slicer = Slicer() + + def shard(self) -> None: + self.replace_model() + self.replace_layer(self.model) + + def replace_model(self) -> None: + """ + Replace the model to policy defined model + Mainly modify the forward and backward to fit distributed model + e.g.: + BertForMaskedLM -> BertForMaskedLM_ + """ + pass + + def replace_layer(self, layer: nn.Module) -> None: + """ + Replace the layer according to the policy + + Args: + layer: The layer to shard + """ + pass + + def shard_layer(self, policy: Policy) -> nn.Module: + """ + Shard the layer's weight and bias according to the policy + + Args: + policy + + Returns: + The sharded layer: nn.Module + """ + pass \ No newline at end of file diff --git a/colossalai/shardformer/shard/shardmodel.py b/colossalai/shardformer/shard/shardmodel.py new file mode 100644 index 000000000000..4ccb3ecd35cf --- /dev/null +++ b/colossalai/shardformer/shard/shardmodel.py @@ -0,0 +1,72 @@ +import os +import torch +import torch.nn as nn +import transformers +import torch.distributed as dist +from dataclasses import dataclass +from contextlib import suppress +from ..policies.basepolicy import Policy +from .sharder import ModelSharder +from colossalai.tensor.d_tensor.layout import Layout + + +@dataclass +class ShardConfig: + """ + The config for sharding the huggingface model for test + """ + fp16: bool + num_gpus: int + rank: int + backend="nccl" + verbose: str = 'simple' + seed: int = None + require_grad: bool = False + master_addr: str = "127.0.0.1" + master_port: int = 29500 + + +class ShardModel(): + """ + The class for sharding the huggingface model, self.model is the sharded model + Just creat a new ShardModel object to shard huggingface model + + Args: + model: the origin huggingface model + dist_config: the config for distribute information + custom_policy: the custom policy for sharding + """ + def __init__( + self, + model: nn.Module, + dist_config: ShardConfig, # TODO + custom_policy: Policy = None, + ) -> None: + self.model = model + self.dist_config = dist_config + self.policy = custom_policy + # self.layout=, # TODO + + sharder=ModelSharder( + model=self.model, + policy=self.policy, + dist_config=self.dist_config, + ) + sharder.shard() + + + def set_environ(self) -> None: + os.environ["TOKENIZERS_PARALLELISM"] = "true" + os.environ["MKL_SERVICE_FORCE_INTEL"] = "GNU" + os.environ["MASTER_ADDR"] = str(self.dist_config.master_addr) + os.environ["MASTER_PORT"] = str(self.dist_config.master_port) + os.environ["WORLD_SIZE"] = str(self.dist_config.num_gpus) + os.environ["RANK"] = str(self.dist_config.rank) + os.environ["LOCAL_RANK"] = str(self.dist_config.rank) + if not dist.is_initialized(): + dist.init_process_group(backend=self.dist_config.backend) + + torch.cuda.set_device(int(os.getenv("LOCAL_RANK", "0"))) + + def back_to_org() -> None: + pass \ No newline at end of file diff --git a/colossalai/shardformer/shard/slicer.py b/colossalai/shardformer/shard/slicer.py new file mode 100644 index 000000000000..0cf86493165e --- /dev/null +++ b/colossalai/shardformer/shard/slicer.py @@ -0,0 +1,22 @@ +import os +from typing import Dict, Tuple + +import torch +import torch.distributed as dist +from .shardmodel import ShardConfig + +class Slicer(): + def __init__(self) -> None: + pass + + def slice_tensor( + self, + tensor_in: torch.Tensor, + dim: int, + is_bias: bool, + dist_config: ShardConfig, # TODO + ) -> torch.Tensor: + """ + Slice tensor according to the config + """ + pass \ No newline at end of file diff --git a/colossalai/shardformer/shardmodel/__init__.py b/colossalai/shardformer/shardmodel/__init__.py new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/colossalai/shardformer/shardmodel/modeling_bert.py b/colossalai/shardformer/shardmodel/modeling_bert.py new file mode 100644 index 000000000000..1d2e2d1bfded --- /dev/null +++ b/colossalai/shardformer/shardmodel/modeling_bert.py @@ -0,0 +1,62 @@ +import torch +import torch.nn as nn +from torch.nn import CrossEntropyLoss +from typing import Any, Dict, List, Type + + +from transformers import BertForMaskedLM +from transformers.models.bert.modeling_bert import MaskedLMOutput +class BertForMaskedLM_(BertForMaskedLM): + def forward( + self, + input_ids=None, + attention_mask=None, + token_type_ids=None, + position_ids=None, + head_mask=None, + inputs_embeds=None, + encoder_hidden_states=None, + encoder_attention_mask=None, + labels=None, + output_attentions=None, + output_hidden_states=None, + return_dict=None, + **kwargs, + ): + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + outputs = self.bert( + input_ids, + attention_mask=attention_mask, + token_type_ids=token_type_ids, + position_ids=position_ids, + head_mask=head_mask, + inputs_embeds=inputs_embeds, + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_attention_mask, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + sequence_output = outputs[0] + prediction_scores = self.cls(sequence_output) + + masked_lm_loss = None + + # if input_ids is not None: + # masked_lm_loss = applyDistCrossEntropy(prediction_scores, input_ids, self.config.vocab_size) + if labels is not None: + loss_fct = CrossEntropyLoss() # -100 index = padding token + masked_lm_loss = loss_fct(prediction_scores.view(-1, self.config.vocab_size), labels.view(-1)) + + if not return_dict: + output = (prediction_scores,) + outputs[2:] + return ((masked_lm_loss,) + output) if masked_lm_loss is not None else output + + return MaskedLMOutput( + loss=masked_lm_loss, + logits=prediction_scores, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) \ No newline at end of file diff --git a/colossalai/shardformer/utils/utils.py b/colossalai/shardformer/utils/utils.py new file mode 100644 index 000000000000..e69de29bb2d1 From 501df1853e39cd6928ae75ed84373096cee3126f Mon Sep 17 00:00:00 2001 From: FoolPlayer <498107402@qq.com> Date: Fri, 12 May 2023 16:17:40 +0800 Subject: [PATCH 02/12] add implement of sharder (inject and replace) --- .../{shardmodel => model}/__init__.py | 0 .../{shardmodel => model}/modeling_bert.py | 0 colossalai/shardformer/policies/basepolicy.py | 19 ++-- colossalai/shardformer/policies/bert.py | 31 +++++- colossalai/shardformer/shard/sharder.py | 95 ++++++++++++++++--- colossalai/shardformer/shard/slicer.py | 17 +++- colossalai/shardformer/utils/__init__.py | 0 colossalai/shardformer/utils/utils.py | 19 ++++ 8 files changed, 155 insertions(+), 26 deletions(-) rename colossalai/shardformer/{shardmodel => model}/__init__.py (100%) rename colossalai/shardformer/{shardmodel => model}/modeling_bert.py (100%) create mode 100644 colossalai/shardformer/utils/__init__.py diff --git a/colossalai/shardformer/shardmodel/__init__.py b/colossalai/shardformer/model/__init__.py similarity index 100% rename from colossalai/shardformer/shardmodel/__init__.py rename to colossalai/shardformer/model/__init__.py diff --git a/colossalai/shardformer/shardmodel/modeling_bert.py b/colossalai/shardformer/model/modeling_bert.py similarity index 100% rename from colossalai/shardformer/shardmodel/modeling_bert.py rename to colossalai/shardformer/model/modeling_bert.py diff --git a/colossalai/shardformer/policies/basepolicy.py b/colossalai/shardformer/policies/basepolicy.py index dd06271a1afb..85fb9b38d52b 100644 --- a/colossalai/shardformer/policies/basepolicy.py +++ b/colossalai/shardformer/policies/basepolicy.py @@ -3,7 +3,7 @@ import torch import torch.nn as nn import colossalai.nn as col_nn -from typing import Any, Dict, List, Type +from typing import Any, Dict, List, Type, Tuple from transformers import AutoConfig from dataclasses import dataclass @@ -29,7 +29,7 @@ class Policy(): """ def __init__( self, - inject_layer: nn.Module + replace_layer: nn.Module ) -> None: """ Init the policy class @@ -37,10 +37,10 @@ def __init__( Args: inject_layer: Layer the policy will apply to """ - self.inject_layer = inject_layer + self.replace_layer = replace_layer @staticmethod - def argument_policy(config, dist_setting: int) -> Dict: + def argument_policy(config, dist_setting: int) -> Dict[nn.Module, Dict]: """ Return the argument and its value need to be modified @@ -49,21 +49,26 @@ def argument_policy(config, dist_setting: int) -> Dict: dist_setting: The setting of distributed model Return: - Dict for the modify policy + Dict for the modify policy, + { + origin_layer1 (nn.Module): {argument1: value1, argument2: value2 ...}, + origin_layer2 (nn.Module): {argument1: value1, argument2: value2 ...}, + ... + } """ return {} @staticmethod - def inject_policy() -> Dict[nn.Module, nn.Module]: + def inject_policy() -> Tuple[nn.Module, nn.Module]: """ Return the dict for the inject model Return: The injected model, key is the original model and value is the new shardmodel """ - return {} + return () @staticmethod def attn_in() -> List: diff --git a/colossalai/shardformer/policies/bert.py b/colossalai/shardformer/policies/bert.py index 25fa75413ab1..240b44fd5161 100644 --- a/colossalai/shardformer/policies/bert.py +++ b/colossalai/shardformer/policies/bert.py @@ -1,10 +1,31 @@ -from typing import Dict, List, Type +from typing import Dict, List, Tuple, Type import torch.nn as nn -from basepolicy import Policy, Layer +from .basepolicy import Policy, Layer import colossalai.nn as col_nn +from transformers.models.bert.modeling_bert import BertLayer, BertEmbeddings + class BertPolicy(Policy): + @staticmethod + def argument_policy(config, world_size: int) -> Dict[nn.Module,Dict]: + return { + BertLayer: { + # 1. shard hidden size + "attention.self.all_head_size": config.hidden_size // world_size, + "crossattention.self.all_head_size": config.hidden_size // world_size, + # 2. shard number of heads + "attention.self.num_attention_heads": config.num_attention_heads // world_size, + "crossattention.self.num_attention_heads": config.num_attention_heads // world_size, + }, + BertEmbeddings: { + # 1. shard vocab size + "word_embeddings.num_embeddings": config.vocab_size // world_size, + # 2. add the size of the sliced embedding layer excluding the last slice + "word_embeddings.dim_size": (config.vocab_size+world_size-1) // world_size, + } + } + @staticmethod def attn_in() -> List: return [ @@ -93,11 +114,11 @@ def unembedding() -> List: ] from transformers import BertForMaskedLM -from colossalai.shardformer.shardmodel.modeling_bert import BertForMaskedLM_ +from colossalai.shardformer.model.modeling_bert import BertForMaskedLM_ class BertForMaskedLMPolicy(BertPolicy): @staticmethod - def inject_policy() -> Dict: - return {BertForMaskedLM: BertForMaskedLM_} + def inject_policy() -> Tuple[nn.Module, nn.Module]: + return (BertForMaskedLM, BertForMaskedLM_) diff --git a/colossalai/shardformer/shard/sharder.py b/colossalai/shardformer/shard/sharder.py index 30c19ee2dd54..85af0ea3ae0e 100644 --- a/colossalai/shardformer/shard/sharder.py +++ b/colossalai/shardformer/shard/sharder.py @@ -1,9 +1,27 @@ import torch.nn as nn from typing import Any, Dict, Iterable, List, Optional, Tuple, Type, Union -from .shardmodel import ShardConfig -from policies.basepolicy import Policy, Layer -from policies.autopolicy import get_autopolicy +# from colossalai.shardformer.shard.shardmodel import ShardConfig +from dataclasses import dataclass +from ..policies.basepolicy import Policy, Layer +from ..policies.autopolicy import get_autopolicy from .slicer import Slicer +from ..utils.utils import hasattr_, setattr_ +import colossalai.nn as col_nn + +@dataclass +class ShardConfig: + """ + The config for sharding the huggingface model for test + """ + fp16: bool + num_gpus: int + rank: int + backend="nccl" + verbose: str = 'simple' + seed: int = None + require_grad: bool = False + master_addr: str = "127.0.0.1" + master_port: int = 29500 class ModelSharder(object): """ @@ -18,35 +36,64 @@ def __init__( self, model: nn.Module, policy: Policy, - dist_config: ShardConfig, # TODO + model_config, + dist_config: ShardConfig = None, # TODO ) -> None: self.model = model self.policy = get_autopolicy(self.model) if policy is None else policy self.slicer = Slicer() + self.dist_config = dist_config + self.model_config = model_config def shard(self) -> None: - self.replace_model() - self.replace_layer(self.model) + self.inject_model(self.model, self.policy) + self.replace_layer(self.model, self.policy) - def replace_model(self) -> None: + def inject_model( + self, + model: nn.Module, + policy_cls: Policy + ) -> None: """ Replace the model to policy defined model Mainly modify the forward and backward to fit distributed model e.g.: - BertForMaskedLM -> BertForMaskedLM_ + BertForMaskedLM.forward -> BertForMaskedLM_.forward """ - pass + inject_methods = ["forward"] + inject_policy = policy_cls.inject_policy() + print(inject_policy) + org_model_cls = inject_policy[0] + shard_model_cls = inject_policy[1] + if model.__class__ == org_model_cls: + for inject_method in inject_methods: + if hasattr(model, inject_method): + setattr( + model, + inject_method, + getattr(shard_model_cls,inject_method), + ) + else: + raise NotImplementedError(f"{model.__class__} is not implemented so far") - def replace_layer(self, layer: nn.Module) -> None: + def replace_layer( + self, + model: nn.Module, + policy_cls: Policy + ) -> None: """ Replace the layer according to the policy Args: layer: The layer to shard """ - pass + argument_policies = policy_cls.argument_policy(self.model_config, 2) + for argument_policy in argument_policies.items(): + origin_layer_cls = argument_policy[0] + attr_dict = argument_policy[1] + self.reverse_replace_layer(model, origin_layer_cls, attr_dict, policy_cls) - def shard_layer(self, policy: Policy) -> nn.Module: + def shard_layer(self, policy_obj: Policy) -> nn.Module: """ Shard the layer's weight and bias according to the policy @@ -56,4 +103,26 @@ def shard_layer(self, policy: Policy) -> nn.Module: Returns: The sharded layer: nn.Module """ - pass \ No newline at end of file + return None + pass + + def reverse_replace_layer( + self, + layer: nn.Module, + origin_cls: nn.Module, + attr_dict: Dict, + policy_cls: Policy, + ) -> None: + """ + Reverse the replace layer operation + """ + for name, child in layer.named_children(): + if child.__class__ == origin_cls: + policy_obj = policy_cls(replace_layer=child) + + for k, v in attr_dict.items(): + setattr_(policy_obj, f"replace_layer.{k}", v, ingore=True) + setattr_(layer, name, self.shard_layer(policy_obj)) + + self.reverse_replace_layer(child, origin_cls, attr_dict, policy_cls) + return layer \ No newline at end of file diff --git a/colossalai/shardformer/shard/slicer.py b/colossalai/shardformer/shard/slicer.py index 0cf86493165e..4f0adb14dd7e 100644 --- a/colossalai/shardformer/shard/slicer.py +++ b/colossalai/shardformer/shard/slicer.py @@ -3,7 +3,22 @@ import torch import torch.distributed as dist -from .shardmodel import ShardConfig +from dataclasses import dataclass + +@dataclass +class ShardConfig: + """ + The config for sharding the huggingface model for test + """ + fp16: bool + num_gpus: int + rank: int + backend="nccl" + verbose: str = 'simple' + seed: int = None + require_grad: bool = False + master_addr: str = "127.0.0.1" + master_port: int = 29500 class Slicer(): def __init__(self) -> None: diff --git a/colossalai/shardformer/utils/__init__.py b/colossalai/shardformer/utils/__init__.py new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/colossalai/shardformer/utils/utils.py b/colossalai/shardformer/utils/utils.py index e69de29bb2d1..8af632fc94a2 100644 --- a/colossalai/shardformer/utils/utils.py +++ b/colossalai/shardformer/utils/utils.py @@ -0,0 +1,19 @@ +def hasattr_(obj, attr: str): + attrs = attr.split('.') + for a in attrs: + try: + obj = getattr(obj, a) + except AttributeError: + return False + return True + +def setattr_(obj, attr: str, value, ingore: bool=False): + attrs = attr.split('.') + for a in attrs[:-1]: + try: + obj = getattr(obj, a) + except AttributeError: + if ingore: + return + raise AttributeError(f"Object {obj} has no attribute {a}") + setattr(obj, attrs[-1], value) From 2ce9fd5715c5c131576ecb227ba85c99be146952 Mon Sep 17 00:00:00 2001 From: FoolPlayer <498107402@qq.com> Date: Tue, 16 May 2023 17:37:26 +0800 Subject: [PATCH 03/12] add implement of replace layer to colossal layer --- colossalai/shardformer/policies/autopolicy.py | 9 +- colossalai/shardformer/policies/bert.py | 12 +- colossalai/shardformer/shard/shardconfig.py | 17 ++ colossalai/shardformer/shard/sharder.py | 204 ++++++++++++++---- colossalai/shardformer/shard/shardmodel.py | 26 +-- colossalai/shardformer/shard/slicer.py | 39 ++-- colossalai/shardformer/utils/utils.py | 40 +++- 7 files changed, 254 insertions(+), 93 deletions(-) create mode 100644 colossalai/shardformer/shard/shardconfig.py diff --git a/colossalai/shardformer/policies/autopolicy.py b/colossalai/shardformer/policies/autopolicy.py index 225b083cdaac..9142e0dae22e 100644 --- a/colossalai/shardformer/policies/autopolicy.py +++ b/colossalai/shardformer/policies/autopolicy.py @@ -10,11 +10,11 @@ def build_policies(): auto_policy_dict = {} from transformers.models.bert.modeling_bert import BertForMaskedLM - from bert import BertForMaskedLMPolicy + from .bert import BertForMaskedLMPolicy auto_policy_dict[BertForMaskedLM] = BertForMaskedLMPolicy from transformers.models.bert.modeling_bert import BertForSequenceClassification - from bert import BertForSequenceClassificationPolicy + from .bert import BertForSequenceClassificationPolicy auto_policy_dict[BertForSequenceClassification] = BertForSequenceClassificationPolicy return auto_policy_dict @@ -29,11 +29,10 @@ def get_autopolicy(model:nn.Module): Return: The auto policy for the model """ - print(model) auto_policy_dict = build_policies() - policy = auto_policy_dict.get(model, None) + policy = auto_policy_dict.get(model.__class__, None) if policy is None: - raise NotImplementedError(f"Auto policy for {model.__qualname__} is not implemented\n Supported models are {[i.__qualname__ for i in auto_policy_dict.keys()]}") + raise NotImplementedError(f"Auto policy for {model.__class__.__qualname__} is not implemented\n Supported models are {[i.__qualname__ for i in auto_policy_dict.keys()]}") return policy # from transformers.models.bert.modeling_bert import BertForMaskedLM, BertForPreTraining diff --git a/colossalai/shardformer/policies/bert.py b/colossalai/shardformer/policies/bert.py index 240b44fd5161..127ef188503b 100644 --- a/colossalai/shardformer/policies/bert.py +++ b/colossalai/shardformer/policies/bert.py @@ -18,12 +18,12 @@ def argument_policy(config, world_size: int) -> Dict[nn.Module,Dict]: "attention.self.num_attention_heads": config.num_attention_heads // world_size, "crossattention.self.num_attention_heads": config.num_attention_heads // world_size, }, - BertEmbeddings: { - # 1. shard vocab size - "word_embeddings.num_embeddings": config.vocab_size // world_size, - # 2. add the size of the sliced embedding layer excluding the last slice - "word_embeddings.dim_size": (config.vocab_size+world_size-1) // world_size, - } + # BertEmbeddings: { + # # 1. shard vocab size + # "word_embeddings.num_embeddings": config.vocab_size // world_size, + # # 2. add the size of the sliced embedding layer excluding the last slice + # "word_embeddings.dim_size": (config.vocab_size+world_size-1) // world_size, + # } } @staticmethod diff --git a/colossalai/shardformer/shard/shardconfig.py b/colossalai/shardformer/shard/shardconfig.py new file mode 100644 index 000000000000..f9ecde1d4337 --- /dev/null +++ b/colossalai/shardformer/shard/shardconfig.py @@ -0,0 +1,17 @@ +from dataclasses import dataclass + + +@dataclass +class ShardConfig: + """ + The config for sharding the huggingface model for test + """ + fp16: bool + num_gpus: int + rank: int + backend="nccl" + verbose: str = 'simple' + seed: int = None + require_grad: bool = False + master_addr: str = "127.0.0.1" + master_port: int = 29500 \ No newline at end of file diff --git a/colossalai/shardformer/shard/sharder.py b/colossalai/shardformer/shard/sharder.py index 85af0ea3ae0e..a9b7badcb0f6 100644 --- a/colossalai/shardformer/shard/sharder.py +++ b/colossalai/shardformer/shard/sharder.py @@ -1,27 +1,13 @@ import torch.nn as nn from typing import Any, Dict, Iterable, List, Optional, Tuple, Type, Union -# from colossalai.shardformer.shard.shardmodel import ShardConfig +from .shardconfig import ShardConfig from dataclasses import dataclass from ..policies.basepolicy import Policy, Layer from ..policies.autopolicy import get_autopolicy from .slicer import Slicer -from ..utils.utils import hasattr_, setattr_ +from ..utils.utils import hasattr_, setattr_, getattr_ import colossalai.nn as col_nn -@dataclass -class ShardConfig: - """ - The config for sharding the huggingface model for test - """ - fp16: bool - num_gpus: int - rank: int - backend="nccl" - verbose: str = 'simple' - seed: int = None - require_grad: bool = False - master_addr: str = "127.0.0.1" - master_port: int = 29500 class ModelSharder(object): """ @@ -36,18 +22,19 @@ def __init__( self, model: nn.Module, policy: Policy, - model_config, - dist_config: ShardConfig = None, # TODO + shard_config: ShardConfig = None, # TODO ) -> None: self.model = model self.policy = get_autopolicy(self.model) if policy is None else policy - self.slicer = Slicer() - self.dist_config = dist_config - self.model_config = model_config + self.slicer = Slicer(shard_config) + self.shard_config = shard_config + self.model_config = self.model.config + def shard(self) -> None: self.inject_model(self.model, self.policy) self.replace_layer(self.model, self.policy) + def inject_model( self, @@ -62,9 +49,10 @@ def inject_model( """ inject_methods = ["forward"] inject_policy = policy_cls.inject_policy() - print(inject_policy) + org_model_cls = inject_policy[0] shard_model_cls = inject_policy[1] + if model.__class__ == org_model_cls: for inject_method in inject_methods: if hasattr(model, inject_method): @@ -76,13 +64,14 @@ def inject_model( else: raise NotImplementedError(f"{model.__class__} is not implemented so far") + def replace_layer( self, model: nn.Module, policy_cls: Policy ) -> None: """ - Replace the layer according to the policy + Replace the layer according to the policy, and replace the layer one by one Args: layer: The layer to shard @@ -93,18 +82,6 @@ def replace_layer( attr_dict = argument_policy[1] self.reverse_replace_layer(model, origin_layer_cls, attr_dict, policy_cls) - def shard_layer(self, policy_obj: Policy) -> nn.Module: - """ - Shard the layer's weight and bias according to the policy - - Args: - policy - - Returns: - The sharded layer: nn.Module - """ - return None - pass def reverse_replace_layer( self, @@ -115,14 +92,161 @@ def reverse_replace_layer( ) -> None: """ Reverse the replace layer operation + + Args: + layer: The object of layer to shard + origin_cls: The origin layer class + attr_dict: The attribute dict to modify + policy_cls: The policy class """ for name, child in layer.named_children(): if child.__class__ == origin_cls: - policy_obj = policy_cls(replace_layer=child) - + # replac_layer = child for k, v in attr_dict.items(): - setattr_(policy_obj, f"replace_layer.{k}", v, ingore=True) - setattr_(layer, name, self.shard_layer(policy_obj)) + setattr_(child, k, v, ignore=True) + # print(f"Sharding {name} layer", replac_layer.attention.self.__dict__) + # setattr_(layer, name, self.shard_one_layer(child, policy_cls)) + self.shard_one_layer(child, policy_cls) + continue self.reverse_replace_layer(child, origin_cls, attr_dict, policy_cls) - return layer \ No newline at end of file + return layer + + + def shard_layer(self, policy_obj: Policy) -> nn.Module: + """ + Shard the layer's weight and bias according to the policy + + Args: + policy + + Returns: + The sharded layer: nn.Module + """ + attn_inw, attn_inb, attn_inw_attr, attn_inb_attr = self.preprocess( + policy.attn_in(), + policy, + ) + + attn_outw, attn_outb, attn_outw_attr, attn_outb_attr = self.preprocess( + policy.attn_out(), + policy, + ) + mlp_inw, mlp_inb, mlp_inw_attr, mlp_inb_attr = self.preprocess( + policy.mlp_in(), + policy, + ) + mlp_outw, mlp_outb, mlp_outw_attr, mlp_outb_attr = self.preprocess( + policy.mlp_out(), + policy, + ) + emd_w, emd_b, emd_w_attr, emd_b_attr = self.preprocess( + policy.embedding(), + policy, + ) + unemd_w, unemd_b, unemd_w_attr, unemd_b_attr = self.preprocess( + policy.unembedding(), + policy, + ) + + policy = self.set_parameters( + policy, + attn_inw, + attn_inb, + *self.slicer.column_slice( + (attn_inw, attn_inb), + (attn_inw_attr, attn_inb_attr), + ), + ) + + policy = self.set_parameters( + policy, + attn_outw, + attn_outb, + *self.slicer.row_slice( + (attn_outw, attn_outb), + (attn_outw_attr, attn_outb_attr), + ), + ) + + policy = self.set_parameters( + policy, + mlp_inw, + mlp_inb, + *self.slicer.column_slice( + (mlp_inw, mlp_inb), + (mlp_inw_attr, mlp_inb_attr), + ), + ) + + policy = self.set_parameters( + policy, + mlp_outw, + mlp_outb, + *self.slicer.row_slice( + (mlp_outw, mlp_outb), + (mlp_outw_attr, mlp_outb_attr), + ), + ) + + policy = self.set_parameters( + policy, + emd_w, + emd_b, + *self.slicer.column_slice( + (emd_w, emd_b), + (emd_w_attr, emd_b_attr), + ), + ) + + policy = self.set_parameters( + policy, + unemd_w, + unemd_b, + *self.slicer.column_slice( + (unemd_w, unemd_b), + (unemd_w_attr, unemd_b_attr), + ), + ) + + return policy_obj.replace_layer + + def shard_one_layer(self, org_layer: nn.Module, policy: Policy): + """ + Shard one layer + """ + # print(org_layer) + attn_in = policy.attn_in() + for layer in attn_in: + weight = None + bias = None + weight_attr = layer.weight + bias_attr = layer.bias + replace_layer_cls = layer.replace_layer + ignore = layer.ignore + + if weight_attr is not None: + if hasattr_(org_layer, weight_attr): + weight = getattr_(org_layer, weight_attr) + elif not ignore: + raise ValueError(f"Layer {org_layer.__class__.__qualname__} has no attribute {weight_attr}") + + if bias_attr is not None: + if hasattr_(org_layer, bias_attr): + bias = getattr_(org_layer, bias_attr) + elif not ignore: + raise ValueError(f"Layer {org_layer.__class__.__qualname__} has no attribute {bias_attr}") + + # dont have the attribute in policy + if weight is None and bias is None and ignore: + continue + + # set the sliced weight and bias to the new nn_col layer + assert weight is not None or bias is not None + weight, bias = self.slicer.slice_weight_bias(weight, bias, 0) + replece_layer = replace_layer_cls(weight.shape[0], weight.shape[1], bias=True) + # print(replece_layer) + # replece_layer.weight = nn.Parameter(weight) + # replece_layer.bias = nn.Parameter(bias) + setattr_(org_layer, weight_attr[:weight_attr.rfind(".")], replece_layer, ignore=ignore) + diff --git a/colossalai/shardformer/shard/shardmodel.py b/colossalai/shardformer/shard/shardmodel.py index 4ccb3ecd35cf..ac2878f80171 100644 --- a/colossalai/shardformer/shard/shardmodel.py +++ b/colossalai/shardformer/shard/shardmodel.py @@ -5,27 +5,13 @@ import torch.distributed as dist from dataclasses import dataclass from contextlib import suppress + +from colossalai.tensor.d_tensor.layout import Layout from ..policies.basepolicy import Policy from .sharder import ModelSharder -from colossalai.tensor.d_tensor.layout import Layout +from .shardconfig import ShardConfig -@dataclass -class ShardConfig: - """ - The config for sharding the huggingface model for test - """ - fp16: bool - num_gpus: int - rank: int - backend="nccl" - verbose: str = 'simple' - seed: int = None - require_grad: bool = False - master_addr: str = "127.0.0.1" - master_port: int = 29500 - - class ShardModel(): """ The class for sharding the huggingface model, self.model is the sharded model @@ -39,18 +25,18 @@ class ShardModel(): def __init__( self, model: nn.Module, - dist_config: ShardConfig, # TODO + shard_config: ShardConfig = None, # TODO custom_policy: Policy = None, ) -> None: self.model = model - self.dist_config = dist_config + self.shard_config = shard_config self.policy = custom_policy # self.layout=, # TODO sharder=ModelSharder( model=self.model, policy=self.policy, - dist_config=self.dist_config, + shard_config=self.shard_config, ) sharder.shard() diff --git a/colossalai/shardformer/shard/slicer.py b/colossalai/shardformer/shard/slicer.py index 4f0adb14dd7e..35e5d602b370 100644 --- a/colossalai/shardformer/shard/slicer.py +++ b/colossalai/shardformer/shard/slicer.py @@ -1,37 +1,38 @@ import os from typing import Dict, Tuple +from dataclasses import dataclass import torch import torch.distributed as dist -from dataclasses import dataclass -@dataclass -class ShardConfig: - """ - The config for sharding the huggingface model for test - """ - fp16: bool - num_gpus: int - rank: int - backend="nccl" - verbose: str = 'simple' - seed: int = None - require_grad: bool = False - master_addr: str = "127.0.0.1" - master_port: int = 29500 +from .shardconfig import ShardConfig + class Slicer(): - def __init__(self) -> None: - pass + def __init__( + self, + shardconfig: ShardConfig #TODO + ) -> None: + self.shardconfig = shardconfig + + def slice_weight_bias( + self, + weight: torch.Tensor, + bias: torch.Tensor, + dim: int, + ) -> Tuple[torch.Tensor,torch.Tensor]: + weight = self.slice_tensor(weight, dim, False) + bias = self.slice_tensor(bias, dim, True) + return weight, bias def slice_tensor( self, tensor_in: torch.Tensor, dim: int, is_bias: bool, - dist_config: ShardConfig, # TODO ) -> torch.Tensor: """ Slice tensor according to the config """ - pass \ No newline at end of file + tensor_in = tensor_in[:tensor_in.shape[0]//2] + return tensor_in \ No newline at end of file diff --git a/colossalai/shardformer/utils/utils.py b/colossalai/shardformer/utils/utils.py index 8af632fc94a2..87fe714899f8 100644 --- a/colossalai/shardformer/utils/utils.py +++ b/colossalai/shardformer/utils/utils.py @@ -1,4 +1,11 @@ def hasattr_(obj, attr: str): + """ + Check whether the object has the multi sublevel attr + + Args: + obj: The object to check + attr: The multi level attr to check + """ attrs = attr.split('.') for a in attrs: try: @@ -7,13 +14,40 @@ def hasattr_(obj, attr: str): return False return True -def setattr_(obj, attr: str, value, ingore: bool=False): +def setattr_(obj, attr: str, value, ignore: bool=False): + """ + Set the object's multi sublevel attr to value, if ignore, ignore when it doesn't exist + + Args: + obj: The object to set + attr: The multi level attr to set + value: The value to set + ignore: Whether to ignore when the attr doesn't exist + """ + attrs = attr.split('.') for a in attrs[:-1]: try: obj = getattr(obj, a) except AttributeError: - if ingore: + if ignore: return - raise AttributeError(f"Object {obj} has no attribute {a}") + raise AttributeError(f"Object {obj} has no attribute {attr}") setattr(obj, attrs[-1], value) + +def getattr_(obj, attr: str): + """ + Get the object's multi sublevel attr + + Args: + obj: The object to set + attr: The multi level attr to set + """ + + attrs = attr.split('.') + for a in attrs: + try: + obj = getattr(obj, a) + except AttributeError: + raise AttributeError(f"Object {obj} has no attribute {attr}") + return obj \ No newline at end of file From 6c949c40fe534d74769df880e6fec67fff2bb904 Mon Sep 17 00:00:00 2001 From: FoolPlayer <498107402@qq.com> Date: Wed, 17 May 2023 15:39:38 +0800 Subject: [PATCH 04/12] separate different layer policy, add some notion --- colossalai/shardformer/policies/basepolicy.py | 59 +++-- colossalai/shardformer/policies/bert.py | 60 +++-- colossalai/shardformer/shard/sharder.py | 232 +++++++----------- colossalai/shardformer/shard/slicer.py | 6 +- 4 files changed, 173 insertions(+), 184 deletions(-) diff --git a/colossalai/shardformer/policies/basepolicy.py b/colossalai/shardformer/policies/basepolicy.py index 85fb9b38d52b..a0837b1577a6 100644 --- a/colossalai/shardformer/policies/basepolicy.py +++ b/colossalai/shardformer/policies/basepolicy.py @@ -3,10 +3,15 @@ import torch import torch.nn as nn import colossalai.nn as col_nn -from typing import Any, Dict, List, Type, Tuple +from typing import Any, Dict, List, Type, Tuple, Callable from transformers import AutoConfig from dataclasses import dataclass +@dataclass +class Argument: + attr_dict : Dict[str, Any] + param_funcs : List[Callable] + @dataclass class Layer: """ @@ -18,41 +23,51 @@ class Layer: replace_layer: The layer to replace the original layer ignore: Whether to ignore this layer if it is not in the model """ - weight: str - bias: str - replace_layer: Any + weight: str = None + bias: str = None + replace_layer: Any = None ignore: bool = False class Policy(): """ The base class for all the policies """ - def __init__( - self, - replace_layer: nn.Module - ) -> None: - """ - Init the policy class - - Args: - inject_layer: Layer the policy will apply to - """ - self.replace_layer = replace_layer - @staticmethod - def argument_policy(config, dist_setting: int) -> Dict[nn.Module, Dict]: + def argument_policy(model_config, shard_config: int) -> Dict[nn.Module,Argument]: """ - Return the argument and its value need to be modified + Return a dict, the key is layer will be modified and the value is the Argument class with param setting and param functions Args: - config: The config of transformer model - dist_setting: The setting of distributed model + model_config: The config of transformer model + shard_setting: The config of distributed model Return: Dict for the modify policy, { - origin_layer1 (nn.Module): {argument1: value1, argument2: value2 ...}, - origin_layer2 (nn.Module): {argument1: value1, argument2: value2 ...}, + origin layer class1 (nn.Module): Argument( + attr_dict = { + argument1: value1, + argument2: value2, + ... + }, + param_funcs = [ + staticmethod1, + staticmethod2, + ... + ] + ), + origin layer class2 (nn.Module): Argument( + attr_dict = { + argument1: value1, + argument2: value2, + ... + }, + param_funcs = [ + staticmethod1, + staticmethod2, + ... + ] + ), ... } diff --git a/colossalai/shardformer/policies/bert.py b/colossalai/shardformer/policies/bert.py index 127ef188503b..007379410109 100644 --- a/colossalai/shardformer/policies/bert.py +++ b/colossalai/shardformer/policies/bert.py @@ -1,29 +1,44 @@ -from typing import Dict, List, Tuple, Type - +from typing import Dict, List, Tuple, Type, Any, Callable import torch.nn as nn -from .basepolicy import Policy, Layer +from .basepolicy import Policy, Layer, Argument import colossalai.nn as col_nn from transformers.models.bert.modeling_bert import BertLayer, BertEmbeddings +from dataclasses import dataclass class BertPolicy(Policy): @staticmethod - def argument_policy(config, world_size: int) -> Dict[nn.Module,Dict]: + def argument_policy(config, world_size: int) -> Dict[nn.Module,Argument]: return { - BertLayer: { - # 1. shard hidden size - "attention.self.all_head_size": config.hidden_size // world_size, - "crossattention.self.all_head_size": config.hidden_size // world_size, - # 2. shard number of heads - "attention.self.num_attention_heads": config.num_attention_heads // world_size, - "crossattention.self.num_attention_heads": config.num_attention_heads // world_size, - }, - # BertEmbeddings: { - # # 1. shard vocab size - # "word_embeddings.num_embeddings": config.vocab_size // world_size, - # # 2. add the size of the sliced embedding layer excluding the last slice - # "word_embeddings.dim_size": (config.vocab_size+world_size-1) // world_size, - # } + BertLayer: Argument( + attr_dict = { + # 1. shard hidden size + "attention.self.all_head_size": config.hidden_size // world_size, + "crossattention.self.all_head_size": config.hidden_size // world_size, + # 2. shard number of heads + "attention.self.num_attention_heads": config.num_attention_heads // world_size, + "crossattention.self.num_attention_heads": config.num_attention_heads // world_size, + + }, + param_funcs = [ + BertPolicy.attn_in, + BertPolicy.attn_out, + # BertPolicy.mlp_in, + # BertPolicy.mlp_out + ] + ), + BertEmbeddings: Argument( + attr_dict = { + # 1. shard vocab size + "word_embeddings.num_embeddings": config.vocab_size // world_size, + # 2. add the size of the sliced embedding layer excluding the last slice + "word_embeddings.dim_size": (config.vocab_size+world_size-1) // world_size, + }, + param_funcs = [ + BertPolicy.embedding, + BertPolicy.unembedding, + ] + ) } @staticmethod @@ -71,12 +86,12 @@ def attn_out() -> List: Layer( weight="attention.output.dense.weight", bias="attention.output.dense.bias", - replace_layer=col_nn.Linear, + # replace_layer=col_nn.Linear, ), Layer( weight="crossattention.output.dense.weight", bias="crossattention.output.dense.bias", - replace=col_nn.Linear, + # replace_layer=col_nn.Linear, ignore=True, ), ] @@ -104,7 +119,10 @@ def mlp_out() -> List: @staticmethod def embedding() -> List: return [ - + Layer( + weight="word_embeddings.weight", + # replace=AllReduceEmbedding, + ) ] @staticmethod diff --git a/colossalai/shardformer/shard/sharder.py b/colossalai/shardformer/shard/sharder.py index a9b7badcb0f6..241c17373c5d 100644 --- a/colossalai/shardformer/shard/sharder.py +++ b/colossalai/shardformer/shard/sharder.py @@ -1,5 +1,6 @@ +import torch import torch.nn as nn -from typing import Any, Dict, Iterable, List, Optional, Tuple, Type, Union +from typing import Any, Dict, Iterable, List, Optional, Tuple, Type, Union, Callable from .shardconfig import ShardConfig from dataclasses import dataclass from ..policies.basepolicy import Policy, Layer @@ -44,7 +45,8 @@ def inject_model( """ Replace the model to policy defined model Mainly modify the forward and backward to fit distributed model - e.g.: + + e.g. BertForMaskedLM.forward -> BertForMaskedLM_.forward """ inject_methods = ["forward"] @@ -78,17 +80,19 @@ def replace_layer( """ argument_policies = policy_cls.argument_policy(self.model_config, 2) for argument_policy in argument_policies.items(): + print(argument_policy) origin_layer_cls = argument_policy[0] - attr_dict = argument_policy[1] - self.reverse_replace_layer(model, origin_layer_cls, attr_dict, policy_cls) + attr_dict = argument_policy[1].attr_dict + param_funcs = argument_policy[1].param_funcs + self.reverse_replace_layer(model, origin_layer_cls, attr_dict, param_funcs) def reverse_replace_layer( self, layer: nn.Module, origin_cls: nn.Module, - attr_dict: Dict, - policy_cls: Policy, + attr_dict: Dict[str, Any], + param_funcs: List[Callable], ) -> None: """ Reverse the replace layer operation @@ -106,147 +110,97 @@ def reverse_replace_layer( setattr_(child, k, v, ignore=True) # print(f"Sharding {name} layer", replac_layer.attention.self.__dict__) # setattr_(layer, name, self.shard_one_layer(child, policy_cls)) - self.shard_one_layer(child, policy_cls) + self.shard_one_layer(child, param_funcs) continue - self.reverse_replace_layer(child, origin_cls, attr_dict, policy_cls) + self.reverse_replace_layer(child, origin_cls, attr_dict, param_funcs) return layer - def shard_layer(self, policy_obj: Policy) -> nn.Module: + def shard_one_layer(self, org_layer: nn.Module, param_funcs: List[Callable]) -> None: """ - Shard the layer's weight and bias according to the policy - + Shard one layer according to the policy, the layer should be the same class as the key in policy's argument_policy return dict + Args: - policy - - Returns: - The sharded layer: nn.Module + org_layer: The origin layer object to shard + param_funcs: The function list to get shard information in policy class + """ - attn_inw, attn_inb, attn_inw_attr, attn_inb_attr = self.preprocess( - policy.attn_in(), - policy, - ) - - attn_outw, attn_outb, attn_outw_attr, attn_outb_attr = self.preprocess( - policy.attn_out(), - policy, - ) - mlp_inw, mlp_inb, mlp_inw_attr, mlp_inb_attr = self.preprocess( - policy.mlp_in(), - policy, - ) - mlp_outw, mlp_outb, mlp_outw_attr, mlp_outb_attr = self.preprocess( - policy.mlp_out(), - policy, - ) - emd_w, emd_b, emd_w_attr, emd_b_attr = self.preprocess( - policy.embedding(), - policy, - ) - unemd_w, unemd_b, unemd_w_attr, unemd_b_attr = self.preprocess( - policy.unembedding(), - policy, - ) - - policy = self.set_parameters( - policy, - attn_inw, - attn_inb, - *self.slicer.column_slice( - (attn_inw, attn_inb), - (attn_inw_attr, attn_inb_attr), - ), - ) - - policy = self.set_parameters( - policy, - attn_outw, - attn_outb, - *self.slicer.row_slice( - (attn_outw, attn_outb), - (attn_outw_attr, attn_outb_attr), - ), - ) - - policy = self.set_parameters( - policy, - mlp_inw, - mlp_inb, - *self.slicer.column_slice( - (mlp_inw, mlp_inb), - (mlp_inw_attr, mlp_inb_attr), - ), - ) - - policy = self.set_parameters( - policy, - mlp_outw, - mlp_outb, - *self.slicer.row_slice( - (mlp_outw, mlp_outb), - (mlp_outw_attr, mlp_outb_attr), - ), - ) - - policy = self.set_parameters( - policy, - emd_w, - emd_b, - *self.slicer.column_slice( - (emd_w, emd_b), - (emd_w_attr, emd_b_attr), - ), - ) - - policy = self.set_parameters( - policy, - unemd_w, - unemd_b, - *self.slicer.column_slice( - (unemd_w, unemd_b), - (unemd_w_attr, unemd_b_attr), - ), - ) - - return policy_obj.replace_layer - - def shard_one_layer(self, org_layer: nn.Module, policy: Policy): + # print(org_layer) + for func in param_funcs: + param_attrs = func() + for layer in param_attrs: + weight = None + bias = None + weight_attr = layer.weight + bias_attr = layer.bias + replace_layer_cls = layer.replace_layer + ignore = layer.ignore + + if weight_attr is not None: + if hasattr_(org_layer, weight_attr): + weight = getattr_(org_layer, weight_attr) + elif not ignore: + raise ValueError(f"Layer {org_layer.__class__.__qualname__} has no attribute {weight_attr}") + + if bias_attr is not None: + if hasattr_(org_layer, bias_attr): + bias = getattr_(org_layer, bias_attr) + elif not ignore: + raise ValueError(f"Layer {org_layer.__class__.__qualname__} has no attribute {bias_attr}") + + # dont have the attribute in policy, and ignore is true + if weight is None and bias is None and ignore: + continue + + # set the sliced weight and bias to the new nn_col layer + assert weight is not None or bias is not None + layer_attr = (lambda x: x[:x.rfind(".")])(weight_attr or bias_attr) + + weight, bias = self.slicer.slice_weight_bias(weight, bias, 0) + + # create new object to replace the origin layer + # TODO: col_nn + if replace_layer_cls is not None: + replece_layer = replace_layer_cls(weight.shape[0], weight.shape[1], bias=True) + # print(replece_layer) + # replece_layer.weight = nn.Parameter(weight) + # replece_layer.bias = nn.Parameter(bias) + setattr_(org_layer, layer_attr, replece_layer, ignore=ignore) + # do not replace the layer object, just replace the weight and bias + else: + self.set_param(org_layer, layer_attr, weight, bias) + + + def set_param(self, layer: Any, layer_attr: str, weight: torch.Tensor, bias: torch.Tensor = None) -> None: """ - Shard one layer + Reset the weight and bias of the layer object + + Args: + layer: The layer object + layer_attr: The attribute name of the layer + weight: The weight of the layer + bias: The bias of the layer """ - # print(org_layer) - attn_in = policy.attn_in() - for layer in attn_in: - weight = None - bias = None - weight_attr = layer.weight - bias_attr = layer.bias - replace_layer_cls = layer.replace_layer - ignore = layer.ignore - - if weight_attr is not None: - if hasattr_(org_layer, weight_attr): - weight = getattr_(org_layer, weight_attr) - elif not ignore: - raise ValueError(f"Layer {org_layer.__class__.__qualname__} has no attribute {weight_attr}") - - if bias_attr is not None: - if hasattr_(org_layer, bias_attr): - bias = getattr_(org_layer, bias_attr) - elif not ignore: - raise ValueError(f"Layer {org_layer.__class__.__qualname__} has no attribute {bias_attr}") - - # dont have the attribute in policy - if weight is None and bias is None and ignore: - continue + assert weight is not None or bias is not None + if weight is not None: + setattr_(layer, layer_attr+".weight", nn.Parameter(weight)) + self.set_layer_size(layer, layer_attr, weight.shape) + if bias is not None: + setattr_(layer, layer_attr+".bias", nn.Parameter(bias)) - # set the sliced weight and bias to the new nn_col layer - assert weight is not None or bias is not None - weight, bias = self.slicer.slice_weight_bias(weight, bias, 0) - replece_layer = replace_layer_cls(weight.shape[0], weight.shape[1], bias=True) - # print(replece_layer) - # replece_layer.weight = nn.Parameter(weight) - # replece_layer.bias = nn.Parameter(bias) - setattr_(org_layer, weight_attr[:weight_attr.rfind(".")], replece_layer, ignore=ignore) - + + def set_layer_size(self, layer: nn.Module, layer_attr: str, size: torch.Size) -> None: + """ + Set the layer attribute + + Args: + layer: The layer object + layer_attr: The attribute name of the layer + size: Torch.size + """ + attrs = ["out_features", "is_features"] + for i, attr in enumerate(attrs): + print(layer, f"{layer_attr}.{attr}") + if hasattr_(layer, f"{layer_attr}.{attr}"): + setattr_(layer, f"{layer_attr}.{attr}", size[i]) diff --git a/colossalai/shardformer/shard/slicer.py b/colossalai/shardformer/shard/slicer.py index 35e5d602b370..375f0db2be75 100644 --- a/colossalai/shardformer/shard/slicer.py +++ b/colossalai/shardformer/shard/slicer.py @@ -21,8 +21,10 @@ def slice_weight_bias( bias: torch.Tensor, dim: int, ) -> Tuple[torch.Tensor,torch.Tensor]: - weight = self.slice_tensor(weight, dim, False) - bias = self.slice_tensor(bias, dim, True) + if weight is not None: + weight = self.slice_tensor(weight, dim, False) + if bias is not None: + bias = self.slice_tensor(bias, dim, True) return weight, bias def slice_tensor( From e5d3d6620fb28d5c3b80908068af6bb2ea92e961 Mon Sep 17 00:00:00 2001 From: FoolPlayer <498107402@qq.com> Date: Wed, 17 May 2023 18:08:04 +0800 Subject: [PATCH 05/12] implement 1d and 2d slicer, can tell col or row --- colossalai/shardformer/policies/basepolicy.py | 17 ++++ colossalai/shardformer/policies/bert.py | 18 ++-- colossalai/shardformer/shard/shardconfig.py | 5 +- colossalai/shardformer/shard/sharder.py | 18 ++-- colossalai/shardformer/shard/slicer.py | 94 ++++++++++++++++++- 5 files changed, 126 insertions(+), 26 deletions(-) diff --git a/colossalai/shardformer/policies/basepolicy.py b/colossalai/shardformer/policies/basepolicy.py index a0837b1577a6..561933afb074 100644 --- a/colossalai/shardformer/policies/basepolicy.py +++ b/colossalai/shardformer/policies/basepolicy.py @@ -28,6 +28,23 @@ class Layer: replace_layer: Any = None ignore: bool = False + +@dataclass +class Col_Layer(Layer): + """ + Class for col shard layer in MegatronLM + """ + pass + + +@dataclass +class Row_Layer(Layer): + """ + Class for col shard layer in MegatronLM + """ + pass + + class Policy(): """ The base class for all the policies diff --git a/colossalai/shardformer/policies/bert.py b/colossalai/shardformer/policies/bert.py index 007379410109..d84a735eb0fd 100644 --- a/colossalai/shardformer/policies/bert.py +++ b/colossalai/shardformer/policies/bert.py @@ -1,6 +1,6 @@ from typing import Dict, List, Tuple, Type, Any, Callable import torch.nn as nn -from .basepolicy import Policy, Layer, Argument +from .basepolicy import Policy, Layer, Argument, Col_Layer, Row_Layer import colossalai.nn as col_nn from transformers.models.bert.modeling_bert import BertLayer, BertEmbeddings from dataclasses import dataclass @@ -23,14 +23,14 @@ def argument_policy(config, world_size: int) -> Dict[nn.Module,Argument]: param_funcs = [ BertPolicy.attn_in, BertPolicy.attn_out, - # BertPolicy.mlp_in, - # BertPolicy.mlp_out + BertPolicy.mlp_in, + BertPolicy.mlp_out ] ), BertEmbeddings: Argument( attr_dict = { # 1. shard vocab size - "word_embeddings.num_embeddings": config.vocab_size // world_size, + # "word_embeddings.num_embeddings": config.vocab_size // world_size, # 2. add the size of the sliced embedding layer excluding the last slice "word_embeddings.dim_size": (config.vocab_size+world_size-1) // world_size, }, @@ -44,7 +44,7 @@ def argument_policy(config, world_size: int) -> Dict[nn.Module,Argument]: @staticmethod def attn_in() -> List: return [ - Layer( + Col_Layer( weight="attention.self.query.weight", bias="attention.self.query.bias", replace_layer=col_nn.Linear, @@ -99,20 +99,20 @@ def attn_out() -> List: @staticmethod def mlp_in() -> List: return [ - Layer( + Col_Layer( weight="intermediate.dense.weight", bias="intermediate.dense.bias", - replace_layer=col_nn.Linear, + # replace_layer=col_nn.Linear, ), ] @staticmethod def mlp_out() -> List: return [ - Layer( + Row_Layer( weight="output.dense.weight", bias="output.dense.bias", - replace_layer=col_nn.Linear, + # replace_layer=col_nn.Linear, ), ] diff --git a/colossalai/shardformer/shard/shardconfig.py b/colossalai/shardformer/shard/shardconfig.py index f9ecde1d4337..be265ff0c8c1 100644 --- a/colossalai/shardformer/shard/shardconfig.py +++ b/colossalai/shardformer/shard/shardconfig.py @@ -6,9 +6,10 @@ class ShardConfig: """ The config for sharding the huggingface model for test """ - fp16: bool - num_gpus: int rank: int + fp16: bool = True + num_gpus: int = 2 + world_size: int = 2 backend="nccl" verbose: str = 'simple' seed: int = None diff --git a/colossalai/shardformer/shard/sharder.py b/colossalai/shardformer/shard/sharder.py index 241c17373c5d..23aaf3b1c201 100644 --- a/colossalai/shardformer/shard/sharder.py +++ b/colossalai/shardformer/shard/sharder.py @@ -80,7 +80,6 @@ def replace_layer( """ argument_policies = policy_cls.argument_policy(self.model_config, 2) for argument_policy in argument_policies.items(): - print(argument_policy) origin_layer_cls = argument_policy[0] attr_dict = argument_policy[1].attr_dict param_funcs = argument_policy[1].param_funcs @@ -128,14 +127,14 @@ def shard_one_layer(self, org_layer: nn.Module, param_funcs: List[Callable]) -> """ # print(org_layer) for func in param_funcs: - param_attrs = func() - for layer in param_attrs: + policy_layers = func() + for policy_layer in policy_layers: weight = None bias = None - weight_attr = layer.weight - bias_attr = layer.bias - replace_layer_cls = layer.replace_layer - ignore = layer.ignore + weight_attr = policy_layer.weight + bias_attr = policy_layer.bias + replace_layer_cls = policy_layer.replace_layer + ignore = policy_layer.ignore if weight_attr is not None: if hasattr_(org_layer, weight_attr): @@ -157,7 +156,7 @@ def shard_one_layer(self, org_layer: nn.Module, param_funcs: List[Callable]) -> assert weight is not None or bias is not None layer_attr = (lambda x: x[:x.rfind(".")])(weight_attr or bias_attr) - weight, bias = self.slicer.slice_weight_bias(weight, bias, 0) + weight, bias = self.slicer.slice_weight_bias(weight, bias, policy_layer.__class__) # create new object to replace the origin layer # TODO: col_nn @@ -199,8 +198,7 @@ def set_layer_size(self, layer: nn.Module, layer_attr: str, size: torch.Size) -> layer_attr: The attribute name of the layer size: Torch.size """ - attrs = ["out_features", "is_features"] + attrs = ["out_features", "in_features"] for i, attr in enumerate(attrs): - print(layer, f"{layer_attr}.{attr}") if hasattr_(layer, f"{layer_attr}.{attr}"): setattr_(layer, f"{layer_attr}.{attr}", size[i]) diff --git a/colossalai/shardformer/shard/slicer.py b/colossalai/shardformer/shard/slicer.py index 375f0db2be75..353891fe37bd 100644 --- a/colossalai/shardformer/shard/slicer.py +++ b/colossalai/shardformer/shard/slicer.py @@ -4,29 +4,46 @@ import torch import torch.distributed as dist - +from ..policies.basepolicy import Layer, Col_Layer, Row_Layer from .shardconfig import ShardConfig +dim_mapping = {Col_Layer: 1, Row_Layer: 0} + class Slicer(): + def __init__( self, shardconfig: ShardConfig #TODO ) -> None: self.shardconfig = shardconfig + def slice_weight_bias( self, weight: torch.Tensor, bias: torch.Tensor, - dim: int, + policy_layer_cls: Layer, ) -> Tuple[torch.Tensor,torch.Tensor]: + """ + Slice the weight and bias according to the shardconfig + + Args: + weight: The weight of the layer + bias: The bias of the layer + policy_layer_class: The class represent how to slice the tensor + """ + if policy_layer_cls == Layer: + return weight, bias if weight is not None: + assert policy_layer_cls in dim_mapping, f"The policy layer class {policy_layer_cls} is not supported" + dim = dim_mapping[policy_layer_cls] weight = self.slice_tensor(weight, dim, False) if bias is not None: - bias = self.slice_tensor(bias, dim, True) + bias = self.slice_tensor(bias, 1, True) return weight, bias + def slice_tensor( self, tensor_in: torch.Tensor, @@ -36,5 +53,72 @@ def slice_tensor( """ Slice tensor according to the config """ - tensor_in = tensor_in[:tensor_in.shape[0]//2] - return tensor_in \ No newline at end of file + if not is_bias: + return self.slice_2d(tensor_in, dim) + else: + return self.slice_1d(tensor_in) + + + def slice_2d( + self, + tensor: torch.Tensor, + dim: int, + ) -> torch.Tensor: + """ + Slice the 2D tensor + + Args: + tensor: The tensor to slice + """ + assert dim in [0,1], f"Only support 2D tensor, but got {dim}D tensor" + if dim == 0: + return self.slice_row(tensor) + elif dim == 1: + return self.slice_col(tensor) + + def slice_1d( + self, + tensor: torch.Tensor, + dim: int = None, + ) -> torch.Tensor: + """ + Slice the 1D tensor + + Args: + tensor: The tensor to slice + """ + delta = (tensor.shape[0] + self.shardconfig.world_size - 1) // self.shardconfig.world_size + down_idx = (self.shardconfig.rank - 1) * delta + up_idx = down_idx + delta + return tensor[down_idx:up_idx] + + def slice_col( + self, + tensor: torch.Tensor, + ) -> torch.Tensor: + """ + Slice the tensor in column + + Args: + tensor: The tensor to slice + """ + delta = (tensor.shape[0] + self.shardconfig.world_size - 1) // self.shardconfig.world_size + down_idx = (self.shardconfig.rank - 1) * delta + up_idx = down_idx + delta + return tensor[down_idx:up_idx,:] + + + def slice_row( + self, + tensor: torch.Tensor, + ) -> torch.Tensor: + """ + Slice the tensor in column + + Args: + tensor: The tensor to slice + """ + delta = (tensor.shape[1] + self.shardconfig.world_size - 1) // self.shardconfig.world_size + down_idx = (self.shardconfig.rank - 1) * delta + up_idx = down_idx + delta + return tensor[:,down_idx:up_idx] \ No newline at end of file From a06810cade21117d99695c042e34b165d4ca08ef Mon Sep 17 00:00:00 2001 From: FoolPlayer <498107402@qq.com> Date: Thu, 18 May 2023 18:04:10 +0800 Subject: [PATCH 06/12] fix bug when slicing and inject model --- colossalai/shardformer/model/modeling_bert.py | 1 + colossalai/shardformer/policies/basepolicy.py | 43 ++++++++++--------- colossalai/shardformer/shard/sharder.py | 41 +++++++++++------- colossalai/shardformer/shard/shardmodel.py | 2 +- colossalai/shardformer/shard/slicer.py | 4 +- 5 files changed, 52 insertions(+), 39 deletions(-) diff --git a/colossalai/shardformer/model/modeling_bert.py b/colossalai/shardformer/model/modeling_bert.py index 1d2e2d1bfded..87ed8ac308a5 100644 --- a/colossalai/shardformer/model/modeling_bert.py +++ b/colossalai/shardformer/model/modeling_bert.py @@ -23,6 +23,7 @@ def forward( return_dict=None, **kwargs, ): + print("[Inject OK] Injected forward method") return_dict = return_dict if return_dict is not None else self.config.use_return_dict outputs = self.bert( diff --git a/colossalai/shardformer/policies/basepolicy.py b/colossalai/shardformer/policies/basepolicy.py index 561933afb074..4ab9bfdbcc8c 100644 --- a/colossalai/shardformer/policies/basepolicy.py +++ b/colossalai/shardformer/policies/basepolicy.py @@ -48,6 +48,16 @@ class Row_Layer(Layer): class Policy(): """ The base class for all the policies + For each different model, it should have a different policy class, like BertPolicy for Bert Model + or OPTPolicy for OPT model. + AutoPolicy: + shardformer already defined some policies for huggingface model, just set custom_policy = None + to use the auto policy. In shardformer autopolicy, we define a base policy for one type model, + like BertPolicy, and for each different Bert modle in huggingface like, BertForMaskedLM, + BertForSequenceClassification, etc., for each different Bert model we difine different policy class + and overwrite the method inject_policy + + CustomPolicy: """ @staticmethod def argument_policy(model_config, shard_config: int) -> Dict[nn.Module,Argument]: @@ -89,7 +99,7 @@ def argument_policy(model_config, shard_config: int) -> Dict[nn.Module,Argument] } """ - return {} + raise NotImplementedError @staticmethod @@ -102,6 +112,7 @@ def inject_policy() -> Tuple[nn.Module, nn.Module]: """ return () + @staticmethod def attn_in() -> List: """ @@ -110,7 +121,8 @@ def attn_in() -> List: Returns: List[Layer]: List of layer object, each layer is the new """ - return [] + return NotImplementedError + @staticmethod def attn_out() -> List: @@ -120,7 +132,8 @@ def attn_out() -> List: Returns: List[Layer]: List of layer object """ - return [] + return NotImplementedError + @staticmethod def mlp_in() -> List: @@ -130,7 +143,8 @@ def mlp_in() -> List: Returns: List[Layer]: List of layer object """ - return [] + return NotImplementedError + @staticmethod def mlp_out() -> List: @@ -140,7 +154,8 @@ def mlp_out() -> List: Returns: List[Layer]: List of layer object """ - return [] + return NotImplementedError + @staticmethod def embedding()->List: @@ -151,7 +166,8 @@ def embedding()->List: Return: List[Layer]: List of layer object """ - return [] + return NotImplementedError + @staticmethod def unembedding()->List: @@ -162,17 +178,4 @@ def unembedding()->List: Return: List[Layer]: List of layer object """ - return [] - - - # @staticmethod - # def original_layer_class() -> Type[nn.Module]: - # """ - # Class to apply the policy to - # e.g. BertLayer, GPT2Block, BartEncoderLayer, ... - - # Returns: - # Type[nn.Module]: original layer class - # """ - # raise NotImplementedError - + return NotImplementedError diff --git a/colossalai/shardformer/shard/sharder.py b/colossalai/shardformer/shard/sharder.py index 23aaf3b1c201..83a4ced946ec 100644 --- a/colossalai/shardformer/shard/sharder.py +++ b/colossalai/shardformer/shard/sharder.py @@ -33,14 +33,13 @@ def __init__( def shard(self) -> None: - self.inject_model(self.model, self.policy) - self.replace_layer(self.model, self.policy) + self.inject_model(self.model) + self.replace_layer(self.model) def inject_model( self, model: nn.Module, - policy_cls: Policy ) -> None: """ Replace the model to policy defined model @@ -49,19 +48,18 @@ def inject_model( e.g. BertForMaskedLM.forward -> BertForMaskedLM_.forward """ - inject_methods = ["forward"] - inject_policy = policy_cls.inject_policy() + inject_policy = self.policy.inject_policy() org_model_cls = inject_policy[0] shard_model_cls = inject_policy[1] if model.__class__ == org_model_cls: - for inject_method in inject_methods: - if hasattr(model, inject_method): + for key in shard_model_cls.__dict__.keys(): + if hasattr(model.__class__, key): setattr( - model, - inject_method, - getattr(shard_model_cls,inject_method), + model.__class__, + key, + getattr(shard_model_cls,key), ) else: raise NotImplementedError(f"{model.__class__} is not implemented so far") @@ -70,7 +68,6 @@ def inject_model( def replace_layer( self, model: nn.Module, - policy_cls: Policy ) -> None: """ Replace the layer according to the policy, and replace the layer one by one @@ -78,7 +75,7 @@ def replace_layer( Args: layer: The layer to shard """ - argument_policies = policy_cls.argument_policy(self.model_config, 2) + argument_policies = self.policy.argument_policy(self.model_config, self.shard_config.world_size) for argument_policy in argument_policies.items(): origin_layer_cls = argument_policy[0] attr_dict = argument_policy[1].attr_dict @@ -116,7 +113,11 @@ def reverse_replace_layer( return layer - def shard_one_layer(self, org_layer: nn.Module, param_funcs: List[Callable]) -> None: + def shard_one_layer( + self, + org_layer: nn.Module, + param_funcs: List[Callable] + ) -> None: """ Shard one layer according to the policy, the layer should be the same class as the key in policy's argument_policy return dict @@ -156,13 +157,14 @@ def shard_one_layer(self, org_layer: nn.Module, param_funcs: List[Callable]) -> assert weight is not None or bias is not None layer_attr = (lambda x: x[:x.rfind(".")])(weight_attr or bias_attr) + print(weight.shape) weight, bias = self.slicer.slice_weight_bias(weight, bias, policy_layer.__class__) # create new object to replace the origin layer # TODO: col_nn if replace_layer_cls is not None: - replece_layer = replace_layer_cls(weight.shape[0], weight.shape[1], bias=True) - # print(replece_layer) + print(weight.shape) + replece_layer = replace_layer_cls(weight.shape[1], weight.shape[0], bias=True) # replece_layer.weight = nn.Parameter(weight) # replece_layer.bias = nn.Parameter(bias) setattr_(org_layer, layer_attr, replece_layer, ignore=ignore) @@ -171,7 +173,13 @@ def shard_one_layer(self, org_layer: nn.Module, param_funcs: List[Callable]) -> self.set_param(org_layer, layer_attr, weight, bias) - def set_param(self, layer: Any, layer_attr: str, weight: torch.Tensor, bias: torch.Tensor = None) -> None: + def set_param( + self, + layer: Any, + layer_attr: str, + weight: torch.Tensor, + bias: torch.Tensor = None + ) -> None: """ Reset the weight and bias of the layer object @@ -198,6 +206,7 @@ def set_layer_size(self, layer: nn.Module, layer_attr: str, size: torch.Size) -> layer_attr: The attribute name of the layer size: Torch.size """ + # Tensor.shape[0] -> out_features, Tensor.shape[1] -> in_features attrs = ["out_features", "in_features"] for i, attr in enumerate(attrs): if hasattr_(layer, f"{layer_attr}.{attr}"): diff --git a/colossalai/shardformer/shard/shardmodel.py b/colossalai/shardformer/shard/shardmodel.py index ac2878f80171..54d7b5ba02d9 100644 --- a/colossalai/shardformer/shard/shardmodel.py +++ b/colossalai/shardformer/shard/shardmodel.py @@ -12,7 +12,7 @@ from .shardconfig import ShardConfig -class ShardModel(): +class ShardModel(object): """ The class for sharding the huggingface model, self.model is the sharded model Just creat a new ShardModel object to shard huggingface model diff --git a/colossalai/shardformer/shard/slicer.py b/colossalai/shardformer/shard/slicer.py index 353891fe37bd..107466364e6d 100644 --- a/colossalai/shardformer/shard/slicer.py +++ b/colossalai/shardformer/shard/slicer.py @@ -103,7 +103,7 @@ def slice_col( tensor: The tensor to slice """ delta = (tensor.shape[0] + self.shardconfig.world_size - 1) // self.shardconfig.world_size - down_idx = (self.shardconfig.rank - 1) * delta + down_idx = self.shardconfig.rank * delta up_idx = down_idx + delta return tensor[down_idx:up_idx,:] @@ -119,6 +119,6 @@ def slice_row( tensor: The tensor to slice """ delta = (tensor.shape[1] + self.shardconfig.world_size - 1) // self.shardconfig.world_size - down_idx = (self.shardconfig.rank - 1) * delta + down_idx = self.shardconfig.rank * delta up_idx = down_idx + delta return tensor[:,down_idx:up_idx] \ No newline at end of file From 280e2e06db75bb6362641ad87ab59f17302ca086 Mon Sep 17 00:00:00 2001 From: FoolPlayer <498107402@qq.com> Date: Mon, 22 May 2023 14:35:51 +0800 Subject: [PATCH 07/12] fix some bug; add inference test example --- colossalai/shardformer/policies/basepolicy.py | 5 +- colossalai/shardformer/policies/bert.py | 61 ++++++++++++------- colossalai/shardformer/shard/sharder.py | 57 ++++++++++++----- colossalai/shardformer/shard/slicer.py | 57 ++++++++++++++--- colossalai/shardformer/test/config.py | 5 ++ colossalai/shardformer/test/test.py | 37 +++++++++++ colossalai/shardformer/utils/utils.py | 5 +- 7 files changed, 179 insertions(+), 48 deletions(-) create mode 100644 colossalai/shardformer/test/config.py create mode 100644 colossalai/shardformer/test/test.py diff --git a/colossalai/shardformer/policies/basepolicy.py b/colossalai/shardformer/policies/basepolicy.py index 4ab9bfdbcc8c..d444aeb53bf8 100644 --- a/colossalai/shardformer/policies/basepolicy.py +++ b/colossalai/shardformer/policies/basepolicy.py @@ -5,12 +5,13 @@ import colossalai.nn as col_nn from typing import Any, Dict, List, Type, Tuple, Callable from transformers import AutoConfig -from dataclasses import dataclass +from dataclasses import dataclass, field @dataclass class Argument: attr_dict : Dict[str, Any] param_funcs : List[Callable] + binding_layers : List[nn.Module] = field(default_factory=list) @dataclass class Layer: @@ -34,7 +35,7 @@ class Col_Layer(Layer): """ Class for col shard layer in MegatronLM """ - pass + gather_output: bool = False @dataclass diff --git a/colossalai/shardformer/policies/bert.py b/colossalai/shardformer/policies/bert.py index d84a735eb0fd..24b95e827347 100644 --- a/colossalai/shardformer/policies/bert.py +++ b/colossalai/shardformer/policies/bert.py @@ -2,7 +2,7 @@ import torch.nn as nn from .basepolicy import Policy, Layer, Argument, Col_Layer, Row_Layer import colossalai.nn as col_nn -from transformers.models.bert.modeling_bert import BertLayer, BertEmbeddings +from transformers.models.bert.modeling_bert import BertLayer, BertEmbeddings, BertLMPredictionHead from dataclasses import dataclass @@ -36,6 +36,18 @@ def argument_policy(config, world_size: int) -> Dict[nn.Module,Argument]: }, param_funcs = [ BertPolicy.embedding, + ], + binding_layers = [ + BertLMPredictionHead, + ] + ), + BertLMPredictionHead: Argument( + attr_dict = { + # 1. shard vocab size + # "word_embeddings.num_embeddings": config.vocab_size // world_size, + # 2. add the size of the sliced embedding layer excluding the last slice + }, + param_funcs = [ BertPolicy.unembedding, ] ) @@ -47,34 +59,34 @@ def attn_in() -> List: Col_Layer( weight="attention.self.query.weight", bias="attention.self.query.bias", - replace_layer=col_nn.Linear, + replace_layer=col_nn.Linear1D_Col, ), - Layer( + Col_Layer( weight="attention.self.key.weight", bias="attention.self.key.bias", - replace_layer=col_nn.Linear, + replace_layer=col_nn.Linear1D_Col, ), - Layer( + Col_Layer( weight="attention.self.value.weight", bias="attention.self.value.bias", - replace_layer=col_nn.Linear, + replace_layer=col_nn.Linear1D_Col, ), - Layer( + Col_Layer( weight="crossattention.self.query.weight", bias="crossattention.self.query.bias", - replace_layer=col_nn.Linear, + replace_layer=col_nn.Linear1D_Col, ignore=True, ), - Layer( + Col_Layer( weight="crossattention.self.key.weight", bias="crossattention.self.key.bias", - replace_layer=col_nn.Linear, + replace_layer=col_nn.Linear1D_Col, ignore=True, ), - Layer( + Col_Layer( weight="crossattention.self.value.weight", bias="crossattention.self.value.bias", - replace_layer=col_nn.Linear, + replace_layer=col_nn.Linear1D_Col, ignore=True, ), @@ -83,15 +95,15 @@ def attn_in() -> List: @staticmethod def attn_out() -> List: return [ - Layer( + Row_Layer( weight="attention.output.dense.weight", bias="attention.output.dense.bias", - # replace_layer=col_nn.Linear, + replace_layer=col_nn.Linear1D_Row, ), - Layer( + Row_Layer( weight="crossattention.output.dense.weight", bias="crossattention.output.dense.bias", - # replace_layer=col_nn.Linear, + replace_layer=col_nn.Linear1D_Row, ignore=True, ), ] @@ -102,33 +114,38 @@ def mlp_in() -> List: Col_Layer( weight="intermediate.dense.weight", bias="intermediate.dense.bias", - # replace_layer=col_nn.Linear, + replace_layer=col_nn.Linear1D_Col, ), ] @staticmethod def mlp_out() -> List: return [ - Row_Layer( + Row_Layer( weight="output.dense.weight", bias="output.dense.bias", - # replace_layer=col_nn.Linear, + replace_layer=col_nn.Linear1D_Row, ), ] @staticmethod def embedding() -> List: return [ - Layer( + Col_Layer( weight="word_embeddings.weight", - # replace=AllReduceEmbedding, + replace_layer=col_nn.VocabParallelEmbedding1D, ) ] @staticmethod def unembedding() -> List: return [ - + Col_Layer( + weight="decoder.weight", + bias="decoder.bias", + replace_layer=col_nn.Linear1D_Col, + gather_output=True, + ) ] from transformers import BertForMaskedLM diff --git a/colossalai/shardformer/shard/sharder.py b/colossalai/shardformer/shard/sharder.py index 83a4ced946ec..ef785cfee9da 100644 --- a/colossalai/shardformer/shard/sharder.py +++ b/colossalai/shardformer/shard/sharder.py @@ -8,8 +8,12 @@ from .slicer import Slicer from ..utils.utils import hasattr_, setattr_, getattr_ import colossalai.nn as col_nn +from colossalai.logging import get_dist_logger +import os +logger = get_dist_logger() + class ModelSharder(object): """ Shard the original huggingface model according to the policy @@ -30,6 +34,7 @@ def __init__( self.slicer = Slicer(shard_config) self.shard_config = shard_config self.model_config = self.model.config + self.binding_map = {} def shard(self) -> None: @@ -80,7 +85,10 @@ def replace_layer( origin_layer_cls = argument_policy[0] attr_dict = argument_policy[1].attr_dict param_funcs = argument_policy[1].param_funcs - self.reverse_replace_layer(model, origin_layer_cls, attr_dict, param_funcs) + binding_layers = argument_policy[1].binding_layers + # if binding_layer is not None: + # self.binding_map[origin_layer_cls] = binding_layer + self.reverse_replace_layer(model, origin_layer_cls, attr_dict, param_funcs, binding_layers) def reverse_replace_layer( @@ -89,6 +97,7 @@ def reverse_replace_layer( origin_cls: nn.Module, attr_dict: Dict[str, Any], param_funcs: List[Callable], + binding_layers: List[nn.Module] ) -> None: """ Reverse the replace layer operation @@ -106,17 +115,18 @@ def reverse_replace_layer( setattr_(child, k, v, ignore=True) # print(f"Sharding {name} layer", replac_layer.attention.self.__dict__) # setattr_(layer, name, self.shard_one_layer(child, policy_cls)) - self.shard_one_layer(child, param_funcs) + self.shard_one_layer(child, param_funcs, binding_layers) continue - self.reverse_replace_layer(child, origin_cls, attr_dict, param_funcs) + self.reverse_replace_layer(child, origin_cls, attr_dict, param_funcs, binding_layers) return layer def shard_one_layer( self, org_layer: nn.Module, - param_funcs: List[Callable] + param_funcs: List[Callable], + binding_layers: List[nn.Module] ) -> None: """ Shard one layer according to the policy, the layer should be the same class as the key in policy's argument_policy return dict @@ -136,6 +146,9 @@ def shard_one_layer( bias_attr = policy_layer.bias replace_layer_cls = policy_layer.replace_layer ignore = policy_layer.ignore + if policy_layer.__class__.__name__ == "Col_Layer": + gather_output = policy_layer.gather_output + print(gather_output) if weight_attr is not None: if hasattr_(org_layer, weight_attr): @@ -157,17 +170,29 @@ def shard_one_layer( assert weight is not None or bias is not None layer_attr = (lambda x: x[:x.rfind(".")])(weight_attr or bias_attr) - print(weight.shape) + # slice weight and bias weight, bias = self.slicer.slice_weight_bias(weight, bias, policy_layer.__class__) - + print(os.environ['RANK'], policy_layer.__class__, weight.shape, bias.shape if bias is not None else None) + # save the binding information + for binding_layer in binding_layers: + self.binding_map[binding_layer] = dict(weight=weight, bias=bias) + # create new object to replace the origin layer - # TODO: col_nn if replace_layer_cls is not None: - print(weight.shape) - replece_layer = replace_layer_cls(weight.shape[1], weight.shape[0], bias=True) - # replece_layer.weight = nn.Parameter(weight) - # replece_layer.bias = nn.Parameter(bias) - setattr_(org_layer, layer_attr, replece_layer, ignore=ignore) + # print(f"RANK {os.environ['RANK']}: replace {getattr_(org_layer, layer_attr).__class__} to {replace_layer_cls}, shape is {weight.shape}") + if isinstance(getattr_(org_layer, layer_attr), nn.Linear): + if replace_layer_cls.__name__ == "Linear1D_Row": + replace_layer = replace_layer_cls(weight.shape[1], weight.shape[0], bias=False if bias is None else True) + elif replace_layer_cls.__name__ == "Linear1D_Col": + replace_layer = replace_layer_cls(weight.shape[0], weight.shape[1], bias=False if bias is None else True, gather_output=gather_output) + setattr_(org_layer, layer_attr, replace_layer, ignore=ignore) + self.set_param(replace_layer, weight, bias) + elif isinstance(getattr_(org_layer, layer_attr), nn.Embedding): + replace_layer = replace_layer_cls(weight.shape[0], weight.shape[1], getattr_(org_layer, f"{layer_attr}.padding_idx", ignore=True)) + setattr_(org_layer, layer_attr, replace_layer, ignore=ignore) + self.set_param(replace_layer, weight, bias) + else: + raise NotImplementedError(f"Replacing {getattr_(org_layer, layer_attr).__class__} is not implemented so far") # do not replace the layer object, just replace the weight and bias else: self.set_param(org_layer, layer_attr, weight, bias) @@ -176,8 +201,8 @@ def shard_one_layer( def set_param( self, layer: Any, - layer_attr: str, - weight: torch.Tensor, + layer_attr: str = "", + weight: torch.Tensor = None, bias: torch.Tensor = None ) -> None: """ @@ -191,10 +216,10 @@ def set_param( """ assert weight is not None or bias is not None if weight is not None: - setattr_(layer, layer_attr+".weight", nn.Parameter(weight)) + setattr_(layer, "weight" if layer_attr == "" else layer_attr+".weight", nn.Parameter(weight)) self.set_layer_size(layer, layer_attr, weight.shape) if bias is not None: - setattr_(layer, layer_attr+".bias", nn.Parameter(bias)) + setattr_(layer, "bias" if layer_attr == "" else layer_attr+".bias", nn.Parameter(bias)) def set_layer_size(self, layer: nn.Module, layer_attr: str, size: torch.Size) -> None: diff --git a/colossalai/shardformer/shard/slicer.py b/colossalai/shardformer/shard/slicer.py index 107466364e6d..1849cdc99c72 100644 --- a/colossalai/shardformer/shard/slicer.py +++ b/colossalai/shardformer/shard/slicer.py @@ -17,16 +17,19 @@ def __init__( shardconfig: ShardConfig #TODO ) -> None: self.shardconfig = shardconfig - + def slice_weight_bias( self, weight: torch.Tensor, bias: torch.Tensor, policy_layer_cls: Layer, - ) -> Tuple[torch.Tensor,torch.Tensor]: + ): """ - Slice the weight and bias according to the shardconfig + Slice the weight and bias according to policy layer cls + Layer -> do nothing + Col_Layer -> slice the weight and bias along dim 1 + Row_Layer -> slice the weight along dim 0 and do not slice bias Args: weight: The weight of the layer @@ -35,13 +38,49 @@ def slice_weight_bias( """ if policy_layer_cls == Layer: return weight, bias + elif policy_layer_cls == Col_Layer: + weight = self.slice_tensor(weight, 1, False) + bias = self.slice_tensor(bias, 0, True) + elif policy_layer_cls == Row_Layer: + weight = self.slice_tensor(weight, 0, False) + else: + raise NotImplementedError(f"The policy layer class {policy_layer_cls} is not supported") + return weight, bias + + + def slice_weight( + self, + weight: torch.Tensor, + policy_layer_cls: Layer, + ) -> torch.Tensor: + """ + Slice the weight and bias according to the shardconfig + + Args: + weight: The weight of the layer + bias: The bias of the layer + policy_layer_class: The class represent how to slice the tensor + """ if weight is not None: - assert policy_layer_cls in dim_mapping, f"The policy layer class {policy_layer_cls} is not supported" dim = dim_mapping[policy_layer_cls] weight = self.slice_tensor(weight, dim, False) + return weight + + + def slice_bias( + self, + bias: torch.Tensor, + ) -> torch.Tensor: + """ + Slice the bias according to the shardconfig + + Args: + bias: The bias of the layer + """ + assert bias is not None, "The bias is None" if bias is not None: bias = self.slice_tensor(bias, 1, True) - return weight, bias + return bias def slice_tensor( @@ -53,6 +92,8 @@ def slice_tensor( """ Slice tensor according to the config """ + if tensor_in is None: + return None if not is_bias: return self.slice_2d(tensor_in, dim) else: @@ -76,6 +117,7 @@ def slice_2d( elif dim == 1: return self.slice_col(tensor) + def slice_1d( self, tensor: torch.Tensor, @@ -88,7 +130,7 @@ def slice_1d( tensor: The tensor to slice """ delta = (tensor.shape[0] + self.shardconfig.world_size - 1) // self.shardconfig.world_size - down_idx = (self.shardconfig.rank - 1) * delta + down_idx = self.shardconfig.rank * delta up_idx = down_idx + delta return tensor[down_idx:up_idx] @@ -121,4 +163,5 @@ def slice_row( delta = (tensor.shape[1] + self.shardconfig.world_size - 1) // self.shardconfig.world_size down_idx = self.shardconfig.rank * delta up_idx = down_idx + delta - return tensor[:,down_idx:up_idx] \ No newline at end of file + return tensor[:,down_idx:up_idx] + \ No newline at end of file diff --git a/colossalai/shardformer/test/config.py b/colossalai/shardformer/test/config.py new file mode 100644 index 000000000000..295529429237 --- /dev/null +++ b/colossalai/shardformer/test/config.py @@ -0,0 +1,5 @@ +parallel = dict( + data=1, + pipeline=1, + tensor=dict(size=2, mode='1d') +) \ No newline at end of file diff --git a/colossalai/shardformer/test/test.py b/colossalai/shardformer/test/test.py new file mode 100644 index 000000000000..c2a9053ca2f6 --- /dev/null +++ b/colossalai/shardformer/test/test.py @@ -0,0 +1,37 @@ +from transformers import AutoTokenizer +from transformers import BertForMaskedLM +import colossalai +from colossalai.shardformer.shard.shardmodel import ShardModel +from colossalai.utils import get_current_device, print_rank_0 +from colossalai.logging import get_dist_logger +from colossalai.shardformer.shard.shardconfig import ShardConfig +import inspect +import argparse +import torch.nn as nn +import os + +tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased") + +def get_args(): + parser = colossalai.get_default_parser() + return parser.parse_args() + +def inference(model: nn.Module): + # print(model) + token = "Hello, my dog is cute" + inputs = tokenizer(token, return_tensors="pt") + inputs.to("cuda") + model.to("cuda") + outputs = model(**inputs) + print(outputs) + +if __name__ == "__main__": + args = get_args() + colossalai.launch_from_torch(config=args.config) + model = BertForMaskedLM.from_pretrained("bert-base-uncased") + shard_config = ShardConfig( + rank = int(str(get_current_device()).split(':')[-1]), + world_size= int(os.environ['WORLD_SIZE']), + ) + shardmodel = ShardModel(model, shard_config) + inference(shardmodel.model) diff --git a/colossalai/shardformer/utils/utils.py b/colossalai/shardformer/utils/utils.py index 87fe714899f8..5eba87f6fe09 100644 --- a/colossalai/shardformer/utils/utils.py +++ b/colossalai/shardformer/utils/utils.py @@ -35,13 +35,14 @@ def setattr_(obj, attr: str, value, ignore: bool=False): raise AttributeError(f"Object {obj} has no attribute {attr}") setattr(obj, attrs[-1], value) -def getattr_(obj, attr: str): +def getattr_(obj, attr: str, ignore: bool=None): """ Get the object's multi sublevel attr Args: obj: The object to set attr: The multi level attr to set + ignore: Whether to ignore when the attr doesn't exist """ attrs = attr.split('.') @@ -49,5 +50,7 @@ def getattr_(obj, attr: str): try: obj = getattr(obj, a) except AttributeError: + if ignore: + return None raise AttributeError(f"Object {obj} has no attribute {attr}") return obj \ No newline at end of file From e6e8f182ac02e7c49c7d1970b79288ab6bb2759f Mon Sep 17 00:00:00 2001 From: FoolPlayer <498107402@qq.com> Date: Mon, 22 May 2023 17:58:34 +0800 Subject: [PATCH 08/12] add share weight and train example --- colossalai/shardformer/policies/basepolicy.py | 8 ++- colossalai/shardformer/policies/bert.py | 9 ++- colossalai/shardformer/shard/sharder.py | 42 +++++++------ colossalai/shardformer/test/test.py | 62 ++++++++++++++++++- 4 files changed, 96 insertions(+), 25 deletions(-) diff --git a/colossalai/shardformer/policies/basepolicy.py b/colossalai/shardformer/policies/basepolicy.py index d444aeb53bf8..aa27495b2600 100644 --- a/colossalai/shardformer/policies/basepolicy.py +++ b/colossalai/shardformer/policies/basepolicy.py @@ -11,7 +11,6 @@ class Argument: attr_dict : Dict[str, Any] param_funcs : List[Callable] - binding_layers : List[nn.Module] = field(default_factory=list) @dataclass class Layer: @@ -114,6 +113,13 @@ def inject_policy() -> Tuple[nn.Module, nn.Module]: return () + @staticmethod + def binding_policy() -> Dict: + """ + Return the dict for the binding model + """ + return NotImplementedError + @staticmethod def attn_in() -> List: """ diff --git a/colossalai/shardformer/policies/bert.py b/colossalai/shardformer/policies/bert.py index 24b95e827347..a8532100c0da 100644 --- a/colossalai/shardformer/policies/bert.py +++ b/colossalai/shardformer/policies/bert.py @@ -36,9 +36,6 @@ def argument_policy(config, world_size: int) -> Dict[nn.Module,Argument]: }, param_funcs = [ BertPolicy.embedding, - ], - binding_layers = [ - BertLMPredictionHead, ] ), BertLMPredictionHead: Argument( @@ -53,6 +50,12 @@ def argument_policy(config, world_size: int) -> Dict[nn.Module,Argument]: ) } + @staticmethod + def binding_policy() -> Dict: + return { + "bert.embeddings.word_embeddings.weight": "cls.predictions.decoder.weight", + } + @staticmethod def attn_in() -> List: return [ diff --git a/colossalai/shardformer/shard/sharder.py b/colossalai/shardformer/shard/sharder.py index ef785cfee9da..1c2302faae7b 100644 --- a/colossalai/shardformer/shard/sharder.py +++ b/colossalai/shardformer/shard/sharder.py @@ -34,13 +34,13 @@ def __init__( self.slicer = Slicer(shard_config) self.shard_config = shard_config self.model_config = self.model.config - self.binding_map = {} def shard(self) -> None: self.inject_model(self.model) self.replace_layer(self.model) - + self.bind_layer(self.model) + def inject_model( self, @@ -85,10 +85,7 @@ def replace_layer( origin_layer_cls = argument_policy[0] attr_dict = argument_policy[1].attr_dict param_funcs = argument_policy[1].param_funcs - binding_layers = argument_policy[1].binding_layers - # if binding_layer is not None: - # self.binding_map[origin_layer_cls] = binding_layer - self.reverse_replace_layer(model, origin_layer_cls, attr_dict, param_funcs, binding_layers) + self.reverse_replace_layer(model, origin_layer_cls, attr_dict, param_funcs) def reverse_replace_layer( @@ -97,7 +94,6 @@ def reverse_replace_layer( origin_cls: nn.Module, attr_dict: Dict[str, Any], param_funcs: List[Callable], - binding_layers: List[nn.Module] ) -> None: """ Reverse the replace layer operation @@ -115,10 +111,10 @@ def reverse_replace_layer( setattr_(child, k, v, ignore=True) # print(f"Sharding {name} layer", replac_layer.attention.self.__dict__) # setattr_(layer, name, self.shard_one_layer(child, policy_cls)) - self.shard_one_layer(child, param_funcs, binding_layers) + self.shard_one_layer(child, param_funcs) continue - self.reverse_replace_layer(child, origin_cls, attr_dict, param_funcs, binding_layers) + self.reverse_replace_layer(child, origin_cls, attr_dict, param_funcs) return layer @@ -126,7 +122,6 @@ def shard_one_layer( self, org_layer: nn.Module, param_funcs: List[Callable], - binding_layers: List[nn.Module] ) -> None: """ Shard one layer according to the policy, the layer should be the same class as the key in policy's argument_policy return dict @@ -148,7 +143,7 @@ def shard_one_layer( ignore = policy_layer.ignore if policy_layer.__class__.__name__ == "Col_Layer": gather_output = policy_layer.gather_output - print(gather_output) + # print(gather_output) if weight_attr is not None: if hasattr_(org_layer, weight_attr): @@ -172,10 +167,7 @@ def shard_one_layer( # slice weight and bias weight, bias = self.slicer.slice_weight_bias(weight, bias, policy_layer.__class__) - print(os.environ['RANK'], policy_layer.__class__, weight.shape, bias.shape if bias is not None else None) - # save the binding information - for binding_layer in binding_layers: - self.binding_map[binding_layer] = dict(weight=weight, bias=bias) + # print(os.environ['RANK'], policy_layer.__class__, weight.shape, bias.shape if bias is not None else None) # create new object to replace the origin layer if replace_layer_cls is not None: @@ -201,9 +193,9 @@ def shard_one_layer( def set_param( self, layer: Any, - layer_attr: str = "", weight: torch.Tensor = None, - bias: torch.Tensor = None + bias: torch.Tensor = None, + layer_attr: str = "" ) -> None: """ Reset the weight and bias of the layer object @@ -235,4 +227,18 @@ def set_layer_size(self, layer: nn.Module, layer_attr: str, size: torch.Size) -> attrs = ["out_features", "in_features"] for i, attr in enumerate(attrs): if hasattr_(layer, f"{layer_attr}.{attr}"): - setattr_(layer, f"{layer_attr}.{attr}", size[i]) + setattr_(layer, f"{layer_attr}.{attr}", size[i]) + + + def bind_layer( + self, + model: nn.Module + ) -> None: + binding_map = self.policy.binding_policy() + for k,v in binding_map.items(): + param = getattr_(model, k) + param = nn.Parameter(param) + setattr_(model, k, param) + setattr_(model, v, param) + + diff --git a/colossalai/shardformer/test/test.py b/colossalai/shardformer/test/test.py index c2a9053ca2f6..35bbec202222 100644 --- a/colossalai/shardformer/test/test.py +++ b/colossalai/shardformer/test/test.py @@ -1,4 +1,4 @@ -from transformers import AutoTokenizer +from transformers import AutoTokenizer, Trainer, TrainingArguments, DataCollatorForLanguageModeling from transformers import BertForMaskedLM import colossalai from colossalai.shardformer.shard.shardmodel import ShardModel @@ -7,17 +7,39 @@ from colossalai.shardformer.shard.shardconfig import ShardConfig import inspect import argparse +import torch +from tqdm.auto import tqdm import torch.nn as nn +from torch.utils.data import DataLoader +from datasets import load_dataset import os tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased") def get_args(): parser = colossalai.get_default_parser() + parser.add_argument("--mode", type=str, default='inference') return parser.parse_args() +def load_data(): + datasets=load_dataset('wikitext', 'wikitext-2-raw-v1') + # datasets=load_dataset("yelp_review_full") + tokenized_datasets=datasets.map(lambda examples:tokenizer(examples["text"],truncation=True,padding="max_length"),batched=True) + tokenized_datasets=tokenized_datasets.remove_columns(["text"]) + # tokenized_datasets=tokenized_datasets.rename_column("label","labels") + tokenized_datasets.set_format("torch") + + train_dataset=tokenized_datasets["train"].select(range(1000)) + test_dataset=tokenized_datasets["test"].select(range(100)) + + datacollector = DataCollatorForLanguageModeling(tokenizer, mlm=True, mlm_probability=0.15, return_tensors="pt") + train_dataloader=DataLoader(train_dataset,batch_size=8,shuffle=True, collate_fn=datacollector) + eval_dataloader=DataLoader(test_dataset,batch_size=8, collate_fn=datacollector) + return train_dataloader,eval_dataloader + def inference(model: nn.Module): - # print(model) + print(model) + tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased") token = "Hello, my dog is cute" inputs = tokenizer(token, return_tensors="pt") inputs.to("cuda") @@ -25,6 +47,37 @@ def inference(model: nn.Module): outputs = model(**inputs) print(outputs) +def train(model: nn.Module, num_epoch: int=2): + train_dataloader, eval_dataloader=load_data() + optimizer = torch.optim.AdamW(model.parameters(), lr=5e-5) + progress_bar = tqdm(range((num_epoch)*len(train_dataloader))) + criterion = nn.CrossEntropyLoss() + model.to("cuda") + model.train() + for epoch in range(num_epoch): + for batch in train_dataloader: + optimizer.zero_grad() + batch = {k: v.to('cuda') for k, v in batch.items()} + outputs = model(**batch) + loss = outputs.loss + print(loss) + # loss = criterion(outputs.logits, batch["labels"]) + loss.backward() + optimizer.step() + progress_bar.update(1) + progress_bar.set_description(f"loss: {loss.item()}") + print(f"Rank {os.environ['RANK']} Epoch:{epoch} Train Loss:{loss:.4f}") + + for batch in eval_dataloader: + batch = {k: v.to('cuda') for k, v in batch.items()} + outputs = model(**batch) + loss = outputs['loss'] + # loss = criterion(outputs.logits, batch["input_ids"]) + print(f"Rank {os.environ['RANK']} Epoch:{epoch} Test Loss:{loss:.4f}") + + + + if __name__ == "__main__": args = get_args() colossalai.launch_from_torch(config=args.config) @@ -34,4 +87,7 @@ def inference(model: nn.Module): world_size= int(os.environ['WORLD_SIZE']), ) shardmodel = ShardModel(model, shard_config) - inference(shardmodel.model) + if args.mode == "train": + train(shardmodel.model) + elif args.mode == "inference": + inference(shardmodel.model) From 08595e64212236a11a60f7b44a13c569ff6e8557 Mon Sep 17 00:00:00 2001 From: FoolPlayer <498107402@qq.com> Date: Tue, 23 May 2023 11:24:55 +0800 Subject: [PATCH 09/12] add train --- colossalai/shardformer/model/modeling_bert.py | 2 +- colossalai/shardformer/test/test.py | 21 ++++++++----------- 2 files changed, 10 insertions(+), 13 deletions(-) diff --git a/colossalai/shardformer/model/modeling_bert.py b/colossalai/shardformer/model/modeling_bert.py index 87ed8ac308a5..724be56dca0e 100644 --- a/colossalai/shardformer/model/modeling_bert.py +++ b/colossalai/shardformer/model/modeling_bert.py @@ -23,7 +23,7 @@ def forward( return_dict=None, **kwargs, ): - print("[Inject OK] Injected forward method") + # print("[Inject OK] Injected forward method") return_dict = return_dict if return_dict is not None else self.config.use_return_dict outputs = self.bert( diff --git a/colossalai/shardformer/test/test.py b/colossalai/shardformer/test/test.py index 35bbec202222..70593c3b0176 100644 --- a/colossalai/shardformer/test/test.py +++ b/colossalai/shardformer/test/test.py @@ -55,29 +55,26 @@ def train(model: nn.Module, num_epoch: int=2): model.to("cuda") model.train() for epoch in range(num_epoch): + progress_bar.set_description(f"epoch: {epoch}") + for batch in train_dataloader: optimizer.zero_grad() batch = {k: v.to('cuda') for k, v in batch.items()} outputs = model(**batch) loss = outputs.loss - print(loss) - # loss = criterion(outputs.logits, batch["labels"]) loss.backward() optimizer.step() progress_bar.update(1) - progress_bar.set_description(f"loss: {loss.item()}") - print(f"Rank {os.environ['RANK']} Epoch:{epoch} Train Loss:{loss:.4f}") + print(f"\nRank:{os.environ['RANK']} Epoch:{epoch} Train Loss:{loss:.4f}") - for batch in eval_dataloader: - batch = {k: v.to('cuda') for k, v in batch.items()} - outputs = model(**batch) - loss = outputs['loss'] - # loss = criterion(outputs.logits, batch["input_ids"]) - print(f"Rank {os.environ['RANK']} Epoch:{epoch} Test Loss:{loss:.4f}") + # for batch in eval_dataloader: + # batch = {k: v.to('cuda') for k, v in batch.items()} + # outputs = model(**batch) + # loss = outputs['loss'] + # # loss = criterion(outputs.logits, batch["input_ids"]) + # print(f"\nRank {os.environ['RANK']} Epoch:{epoch} Test Loss:{loss:.4f}") - - if __name__ == "__main__": args = get_args() colossalai.launch_from_torch(config=args.config) From 011a007062037043dc0bd50636adae2f154eab55 Mon Sep 17 00:00:00 2001 From: FoolPlayer <498107402@qq.com> Date: Tue, 23 May 2023 16:04:31 +0800 Subject: [PATCH 10/12] add docstring and readme --- colossalai/nn/layer/parallel_1d/_operation.py | 2 + colossalai/nn/layer/parallel_1d/layers.py | 9 +- colossalai/shardformer/README.md | 177 ++++++++++++++++++ colossalai/shardformer/shard/sharder.py | 153 +++++++-------- colossalai/shardformer/shard/shardmodel.py | 36 ++-- colossalai/shardformer/test/test.py | 82 ++++---- 6 files changed, 328 insertions(+), 131 deletions(-) create mode 100644 colossalai/shardformer/README.md diff --git a/colossalai/nn/layer/parallel_1d/_operation.py b/colossalai/nn/layer/parallel_1d/_operation.py index 394334558275..c5e33fd497cd 100644 --- a/colossalai/nn/layer/parallel_1d/_operation.py +++ b/colossalai/nn/layer/parallel_1d/_operation.py @@ -1,5 +1,6 @@ import torch import torch.distributed as dist + from colossalai.core import global_context as gpc try: @@ -72,6 +73,7 @@ def backward(ctx, grad_output): total_input = input grad_input = grad_output.matmul(weight) + grad_output = grad_output.contiguous() # Convert the tensor shapes to 2D for execution compatibility grad_output = grad_output.view(grad_output.shape[0] * grad_output.shape[1], grad_output.shape[2]) total_input = total_input.view(total_input.shape[0] * total_input.shape[1], total_input.shape[2]) diff --git a/colossalai/nn/layer/parallel_1d/layers.py b/colossalai/nn/layer/parallel_1d/layers.py index 406173a18c60..0ee3b4fcb502 100644 --- a/colossalai/nn/layer/parallel_1d/layers.py +++ b/colossalai/nn/layer/parallel_1d/layers.py @@ -469,7 +469,8 @@ def __init__(self, if skip_bias_add and not bias: raise ValueError('cannot skip bias addition if bias is None') - self.out_features_per_partition = divide(out_features, gpc.tensor_parallel_size) + # self.out_features_per_partition = divide(out_features*2, gpc.tensor_parallel_size) + self.out_features_per_partition = out_features # Parameters. # Initialize weight. @@ -612,7 +613,8 @@ def __init__(self, raise ValueError('cannot skip bias addition if bias is None') # Divide the weight matrix along the last dimension. - self.input_size_per_partition = divide(in_features, gpc.tensor_parallel_size) + # self.input_size_per_partition = divide(in_features*2, gpc.tensor_parallel_size) + self.input_size_per_partition = in_features # Parameters. # Initialize weight. @@ -884,7 +886,8 @@ def __init__(self, tensor_parallel_size = gpc.get_world_size(ParallelMode.PARALLEL_1D) tensor_parallel_rank = gpc.get_local_rank(ParallelMode.PARALLEL_1D) - self.num_embeddings_per_partition = divide(num_embeddings, tensor_parallel_size) + # self.num_embeddings_per_partition = divide(num_embeddings, tensor_parallel_size) + self.num_embeddings_per_partition = num_embeddings self.vocab_start_index = tensor_parallel_rank * self.num_embeddings_per_partition self.vocab_end_index = self.vocab_start_index + self.num_embeddings_per_partition diff --git a/colossalai/shardformer/README.md b/colossalai/shardformer/README.md new file mode 100644 index 000000000000..a47e280f2be4 --- /dev/null +++ b/colossalai/shardformer/README.md @@ -0,0 +1,177 @@ +## ShardFormer + +### Intro +Make the model in huggingface.co can be paralleled and can be used with colossalai according to custom policy. + +### Quick start +1. Usage +- Use +``` python +from colossalai.shardformer.shard.shardmodel import ShardModel +from transformers import BertForMaskedLM + +# create huggingface model as normal +model = BertForMaskedLM.from_pretrained("bert-base-uncased") + +# make the huggingface model paralleled to ShardModel +# auto policy: +shardmodel = ShardModel(model).model + +# custom policy: +from xxx import +shardmodel = ShardModel(model, ).model + + +# do angthing as normal +... +``` +- Policy + +If you wanna parallel the model in custom way, just overwrite the policy class for the huggingface model. + +You should do: + +1. Inherit Policy class +2. Overwrite argument_policy method + - In this method you need to list which layers class you wanna modify and the attributes and parameters in those layers. +3. Overwrite inject_policy method [Optional] + - If you need to modify the forward or backward progress. +4. Overwrite or add the param recording functions + - These function use suffix to record the path of weight or bias for the layer. +5. Overwrite binding + +More details can be found in shardformer/policies/basepolicy.py +``` python +from colossalai.shardformer.policies.basepolicy import Policy, Layer, Col_Layer, Row_Layer, Argument + +CustomPolicy(Policy): + @staticmethod + def argument_policy(model_config, shard_config: int) -> Dict[nn.Module,Argument]: + """ + Return a dict, the key is layer will be modified and the value is the Argument class with param setting and param functions + + Args: + model_config: The config of transformer model + shard_setting: The config of distributed model + + Return: + Dict for the modify policy, + { + origin layer class1 (nn.Module): Argument( + attr_dict = { + argument1: value1, + argument2: value2, + ... + }, + param_funcs = [ + staticmethod1, + staticmethod2, + ... + ] + ), + origin layer class2 (nn.Module): Argument( + attr_dict = { + argument1: value1, + argument2: value2, + ... + }, + param_funcs = [ + staticmethod1, + staticmethod2, + ... + ] + ), + ... + } + + """ + raise NotImplementedError + + @staticmethod + def inject_policy() -> Tuple[nn.Module, nn.Module]: + """ + Return the dict for the inject model + + Return: + The injected model, key is the original model and value is the new shardmodel + """ + return () + + @staticmethod + def binding_policy() -> Dict: + """ + Return the dict for the binding model + """ + return NotImplementedError + + @staticmethod + def attn_in() -> List: + """ + Attention qkv layer + + Returns: + List[Layer]: List of layer object, each layer is the new + """ + return NotImplementedError + + @staticmethod + def attn_out() -> List: + """ + Attention output projection layer + + Returns: + List[Layer]: List of layer object + """ + return NotImplementedError + + @staticmethod + def mlp_in() -> List: + """ + h -> 4h mlp layer + + Returns: + List[Layer]: List of layer object + """ + return NotImplementedError + + @staticmethod + def mlp_out() -> List: + """ + 4h -> h mlp layer + + Returns: + List[Layer]: List of layer object + """ + return NotImplementedError + + @staticmethod + def embedding() -> List: + """ + Partially slice the embedding layer + vocab_size->vocab_size//gpu_nums + + Return: + List[Layer]: List of layer object + """ + return NotImplementedError + + @staticmethod + def unembedding() -> List: + """ + Partially slice the embedding layer + vocab_size->vocab_size//gpu_nums + + Return: + List[Layer]: List of layer object + """ + return NotImplementedError + +``` + +2. Simple example +``` shell +# inference +colossalai run --nproc_per_node 2 --master_port 29500 test.py --config config.py --mode inference +# train +colossalai run --nproc_per_node 2 --master_port 29500 test.py --config config.py --mode train +``` diff --git a/colossalai/shardformer/shard/sharder.py b/colossalai/shardformer/shard/sharder.py index 1c2302faae7b..22145b7488d0 100644 --- a/colossalai/shardformer/shard/sharder.py +++ b/colossalai/shardformer/shard/sharder.py @@ -1,56 +1,59 @@ +import os +from dataclasses import dataclass +from typing import Any, Callable, Dict, Iterable, List, Optional, Tuple, Type, Union + import torch import torch.nn as nn -from typing import Any, Dict, Iterable, List, Optional, Tuple, Type, Union, Callable -from .shardconfig import ShardConfig -from dataclasses import dataclass -from ..policies.basepolicy import Policy, Layer -from ..policies.autopolicy import get_autopolicy -from .slicer import Slicer -from ..utils.utils import hasattr_, setattr_, getattr_ + import colossalai.nn as col_nn from colossalai.logging import get_dist_logger -import os +from ..policies.autopolicy import get_autopolicy +from ..policies.basepolicy import Layer, Policy +from ..utils.utils import getattr_, hasattr_, setattr_ +from .shardconfig import ShardConfig +from .slicer import Slicer logger = get_dist_logger() + class ModelSharder(object): - """ + r""" Shard the original huggingface model according to the policy Args: - policy: The policy to shard the model - model: The model to shard - dist_setting: The setting of distributed model + policy (:class:`Policy`): The policy to shard the model + model (:class:`torch.Module`): The model to shard + shard_config: The setting of distributed model """ + def __init__( self, model: nn.Module, policy: Policy, - shard_config: ShardConfig = None, # TODO - ) -> None: + shard_config: ShardConfig = None, # TODO + ) -> None: self.model = model self.policy = get_autopolicy(self.model) if policy is None else policy self.slicer = Slicer(shard_config) self.shard_config = shard_config self.model_config = self.model.config - def shard(self) -> None: self.inject_model(self.model) self.replace_layer(self.model) self.bind_layer(self.model) - def inject_model( - self, - model: nn.Module, - ) -> None: + self, + model: nn.Module, + ) -> None: """ Replace the model to policy defined model Mainly modify the forward and backward to fit distributed model - + e.g. + :: BertForMaskedLM.forward -> BertForMaskedLM_.forward """ inject_policy = self.policy.inject_policy() @@ -64,21 +67,20 @@ def inject_model( setattr( model.__class__, key, - getattr(shard_model_cls,key), + getattr(shard_model_cls, key), ) else: raise NotImplementedError(f"{model.__class__} is not implemented so far") - def replace_layer( - self, - model: nn.Module, - ) -> None: + self, + model: nn.Module, + ) -> None: """ Replace the layer according to the policy, and replace the layer one by one Args: - layer: The layer to shard + model (:class:`torch.nn.Module`): The layer to shard """ argument_policies = self.policy.argument_policy(self.model_config, self.shard_config.world_size) for argument_policy in argument_policies.items(): @@ -87,22 +89,21 @@ def replace_layer( param_funcs = argument_policy[1].param_funcs self.reverse_replace_layer(model, origin_layer_cls, attr_dict, param_funcs) - def reverse_replace_layer( - self, - layer: nn.Module, - origin_cls: nn.Module, - attr_dict: Dict[str, Any], - param_funcs: List[Callable], - ) -> None: + self, + layer: nn.Module, + origin_cls: nn.Module, + attr_dict: Dict[str, Any], + param_funcs: List[Callable], + ) -> None: """ Reverse the replace layer operation Args: - layer: The object of layer to shard - origin_cls: The origin layer class - attr_dict: The attribute dict to modify - policy_cls: The policy class + layer (:class:`torch.nn.Module`): The object of layer to shard + origin_cls (:class:`transformers.model`): The origin layer class + attr_dict (Dict): The attribute dict to modify + policy_cls (:class:`Policy`): The policy class """ for name, child in layer.named_children(): if child.__class__ == origin_cls: @@ -117,18 +118,17 @@ def reverse_replace_layer( self.reverse_replace_layer(child, origin_cls, attr_dict, param_funcs) return layer - def shard_one_layer( - self, - org_layer: nn.Module, - param_funcs: List[Callable], - ) -> None: + self, + org_layer: nn.Module, + param_funcs: List[Callable], + ) -> None: """ Shard one layer according to the policy, the layer should be the same class as the key in policy's argument_policy return dict Args: - org_layer: The origin layer object to shard - param_funcs: The function list to get shard information in policy class + org_layer (:class:`torch.nn.Module`): The origin layer object to shard + param_funcs (:class:`List[typing.Callable]`): The function list to get shard information in policy class """ # print(org_layer) @@ -174,71 +174,74 @@ def shard_one_layer( # print(f"RANK {os.environ['RANK']}: replace {getattr_(org_layer, layer_attr).__class__} to {replace_layer_cls}, shape is {weight.shape}") if isinstance(getattr_(org_layer, layer_attr), nn.Linear): if replace_layer_cls.__name__ == "Linear1D_Row": - replace_layer = replace_layer_cls(weight.shape[1], weight.shape[0], bias=False if bias is None else True) + replace_layer = replace_layer_cls(weight.shape[1], + weight.shape[0], + bias=False if bias is None else True) elif replace_layer_cls.__name__ == "Linear1D_Col": - replace_layer = replace_layer_cls(weight.shape[0], weight.shape[1], bias=False if bias is None else True, gather_output=gather_output) + replace_layer = replace_layer_cls(weight.shape[0], + weight.shape[1], + bias=False if bias is None else True, + gather_output=gather_output) setattr_(org_layer, layer_attr, replace_layer, ignore=ignore) self.set_param(replace_layer, weight, bias) - elif isinstance(getattr_(org_layer, layer_attr), nn.Embedding): - replace_layer = replace_layer_cls(weight.shape[0], weight.shape[1], getattr_(org_layer, f"{layer_attr}.padding_idx", ignore=True)) + elif isinstance(getattr_(org_layer, layer_attr), nn.Embedding): + replace_layer = replace_layer_cls(weight.shape[0], weight.shape[1], + getattr_(org_layer, f"{layer_attr}.padding_idx", ignore=True)) setattr_(org_layer, layer_attr, replace_layer, ignore=ignore) self.set_param(replace_layer, weight, bias) else: - raise NotImplementedError(f"Replacing {getattr_(org_layer, layer_attr).__class__} is not implemented so far") + raise NotImplementedError( + f"Replacing {getattr_(org_layer, layer_attr).__class__} is not implemented so far") # do not replace the layer object, just replace the weight and bias else: self.set_param(org_layer, layer_attr, weight, bias) - - def set_param( - self, - layer: Any, - weight: torch.Tensor = None, - bias: torch.Tensor = None, - layer_attr: str = "" - ) -> None: + def set_param(self, + layer: Any, + weight: torch.Tensor = None, + bias: torch.Tensor = None, + layer_attr: str = "") -> None: """ Reset the weight and bias of the layer object Args: - layer: The layer object - layer_attr: The attribute name of the layer - weight: The weight of the layer - bias: The bias of the layer + layer (:class:`torch.nn.Module`): The layer object + layer_attr (str): The attribute name of the layer + weight (:class:`torch.Tensor`): The weight of the layer + bias (:class:`torch.Tensor`): The bias of the layer """ assert weight is not None or bias is not None if weight is not None: - setattr_(layer, "weight" if layer_attr == "" else layer_attr+".weight", nn.Parameter(weight)) + setattr_(layer, "weight" if layer_attr == "" else layer_attr + ".weight", nn.Parameter(weight.contiguous())) self.set_layer_size(layer, layer_attr, weight.shape) if bias is not None: - setattr_(layer, "bias" if layer_attr == "" else layer_attr+".bias", nn.Parameter(bias)) - + setattr_(layer, "bias" if layer_attr == "" else layer_attr + ".bias", nn.Parameter(bias.contiguous())) def set_layer_size(self, layer: nn.Module, layer_attr: str, size: torch.Size) -> None: """ Set the layer attribute Args: - layer: The layer object - layer_attr: The attribute name of the layer - size: Torch.size + layer (:class:`torch.nn.Module`): The layer object + layer_attr (str): The attribute name of the layer + size (:class:`torch.Size`): The size of the tensor """ # Tensor.shape[0] -> out_features, Tensor.shape[1] -> in_features attrs = ["out_features", "in_features"] for i, attr in enumerate(attrs): if hasattr_(layer, f"{layer_attr}.{attr}"): - setattr_(layer, f"{layer_attr}.{attr}", size[i]) + setattr_(layer, f"{layer_attr}.{attr}", size[i]) + def bind_layer(self, model: nn.Module) -> None: + """ + Bind the layer according to the binding policy - def bind_layer( - self, - model: nn.Module - ) -> None: + Args: + model (:class:`torch.nn.Module`): The shard model + """ binding_map = self.policy.binding_policy() - for k,v in binding_map.items(): + for k, v in binding_map.items(): param = getattr_(model, k) param = nn.Parameter(param) setattr_(model, k, param) setattr_(model, v, param) - - diff --git a/colossalai/shardformer/shard/shardmodel.py b/colossalai/shardformer/shard/shardmodel.py index 54d7b5ba02d9..7e7d1576afd6 100644 --- a/colossalai/shardformer/shard/shardmodel.py +++ b/colossalai/shardformer/shard/shardmodel.py @@ -1,46 +1,48 @@ import os +from contextlib import suppress +from dataclasses import dataclass + import torch +import torch.distributed as dist import torch.nn as nn import transformers -import torch.distributed as dist -from dataclasses import dataclass -from contextlib import suppress from colossalai.tensor.d_tensor.layout import Layout + from ..policies.basepolicy import Policy -from .sharder import ModelSharder from .shardconfig import ShardConfig +from .sharder import ModelSharder class ShardModel(object): - """ - The class for sharding the huggingface model, self.model is the sharded model + r""" + The class for sharding the huggingface model, ''self.model'' is the sharded model Just creat a new ShardModel object to shard huggingface model Args: - model: the origin huggingface model - dist_config: the config for distribute information - custom_policy: the custom policy for sharding + model (:class:`torch.nn.Model`): the origin huggingface model + dist_config (:class:`ShardConfig`): the config for distribute information + custom_policy (:class:`Policy`): the custom policy for sharding """ + def __init__( - self, - model: nn.Module, - shard_config: ShardConfig = None, # TODO - custom_policy: Policy = None, - ) -> None: + self, + model: nn.Module, + shard_config: ShardConfig = None, # TODO + custom_policy: Policy = None, + ) -> None: self.model = model self.shard_config = shard_config self.policy = custom_policy # self.layout=, # TODO - sharder=ModelSharder( + sharder = ModelSharder( model=self.model, policy=self.policy, shard_config=self.shard_config, ) sharder.shard() - def set_environ(self) -> None: os.environ["TOKENIZERS_PARALLELISM"] = "true" os.environ["MKL_SERVICE_FORCE_INTEL"] = "GNU" @@ -55,4 +57,4 @@ def set_environ(self) -> None: torch.cuda.set_device(int(os.getenv("LOCAL_RANK", "0"))) def back_to_org() -> None: - pass \ No newline at end of file + pass diff --git a/colossalai/shardformer/test/test.py b/colossalai/shardformer/test/test.py index 70593c3b0176..0cdc6ef38fd2 100644 --- a/colossalai/shardformer/test/test.py +++ b/colossalai/shardformer/test/test.py @@ -1,41 +1,47 @@ -from transformers import AutoTokenizer, Trainer, TrainingArguments, DataCollatorForLanguageModeling -from transformers import BertForMaskedLM -import colossalai -from colossalai.shardformer.shard.shardmodel import ShardModel -from colossalai.utils import get_current_device, print_rank_0 -from colossalai.logging import get_dist_logger -from colossalai.shardformer.shard.shardconfig import ShardConfig -import inspect import argparse +import inspect +import os + import torch -from tqdm.auto import tqdm import torch.nn as nn -from torch.utils.data import DataLoader from datasets import load_dataset -import os +from torch.utils.data import DataLoader +from tqdm.auto import tqdm +from transformers import AutoTokenizer, BertForMaskedLM, DataCollatorForLanguageModeling, Trainer, TrainingArguments +import colossalai +from colossalai.logging import get_dist_logger +from colossalai.shardformer.shard.shardconfig import ShardConfig +from colossalai.shardformer.shard.shardmodel import ShardModel +from colossalai.utils import get_current_device, print_rank_0 + +os.environ['TRANSFORMERS_NO_ADVISORY_WARNINGS'] = 'true' tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased") + def get_args(): parser = colossalai.get_default_parser() parser.add_argument("--mode", type=str, default='inference') return parser.parse_args() + def load_data(): - datasets=load_dataset('wikitext', 'wikitext-2-raw-v1') + datasets = load_dataset('wikitext', 'wikitext-2-raw-v1') # datasets=load_dataset("yelp_review_full") - tokenized_datasets=datasets.map(lambda examples:tokenizer(examples["text"],truncation=True,padding="max_length"),batched=True) - tokenized_datasets=tokenized_datasets.remove_columns(["text"]) + tokenized_datasets = datasets.map( + lambda examples: tokenizer(examples["text"], truncation=True, padding="max_length"), batched=True) + tokenized_datasets = tokenized_datasets.remove_columns(["text"]) # tokenized_datasets=tokenized_datasets.rename_column("label","labels") tokenized_datasets.set_format("torch") - - train_dataset=tokenized_datasets["train"].select(range(1000)) - test_dataset=tokenized_datasets["test"].select(range(100)) + + train_dataset = tokenized_datasets["train"].select(range(500)) + test_dataset = tokenized_datasets["test"].select(range(100)) datacollector = DataCollatorForLanguageModeling(tokenizer, mlm=True, mlm_probability=0.15, return_tensors="pt") - train_dataloader=DataLoader(train_dataset,batch_size=8,shuffle=True, collate_fn=datacollector) - eval_dataloader=DataLoader(test_dataset,batch_size=8, collate_fn=datacollector) - return train_dataloader,eval_dataloader + train_dataloader = DataLoader(train_dataset, batch_size=8, shuffle=True, collate_fn=datacollector) + eval_dataloader = DataLoader(test_dataset, batch_size=8, shuffle=True, collate_fn=datacollector) + return train_dataloader, eval_dataloader + def inference(model: nn.Module): print(model) @@ -47,15 +53,16 @@ def inference(model: nn.Module): outputs = model(**inputs) print(outputs) -def train(model: nn.Module, num_epoch: int=2): - train_dataloader, eval_dataloader=load_data() + +def train(model: nn.Module, num_epoch: int = 2): + train_dataloader, eval_dataloader = load_data() optimizer = torch.optim.AdamW(model.parameters(), lr=5e-5) - progress_bar = tqdm(range((num_epoch)*len(train_dataloader))) + progress_bar = tqdm(range((num_epoch) * len(train_dataloader))) criterion = nn.CrossEntropyLoss() model.to("cuda") model.train() for epoch in range(num_epoch): - progress_bar.set_description(f"epoch: {epoch}") + progress_bar.set_description(f"Rank {get_current_device()} epoch {epoch}") for batch in train_dataloader: optimizer.zero_grad() @@ -65,26 +72,29 @@ def train(model: nn.Module, num_epoch: int=2): loss.backward() optimizer.step() progress_bar.update(1) - print(f"\nRank:{os.environ['RANK']} Epoch:{epoch} Train Loss:{loss:.4f}") - - # for batch in eval_dataloader: - # batch = {k: v.to('cuda') for k, v in batch.items()} - # outputs = model(**batch) - # loss = outputs['loss'] - # # loss = criterion(outputs.logits, batch["input_ids"]) - # print(f"\nRank {os.environ['RANK']} Epoch:{epoch} Test Loss:{loss:.4f}") - + train_loss = loss + + loss = 0.0 + for batch in eval_dataloader: + batch = {k: v.to('cuda') for k, v in batch.items()} + outputs = model(**batch) + # loss = outputs.loss + loss += outputs.loss.item() + # loss = criterion(outputs.logits, batch["input_ids"]) + test_loss = loss / len(eval_dataloader) + print_rank_0(f"Train Loss: {train_loss:.4f} Test Loss:{test_loss:.4f}") + if __name__ == "__main__": args = get_args() colossalai.launch_from_torch(config=args.config) model = BertForMaskedLM.from_pretrained("bert-base-uncased") shard_config = ShardConfig( - rank = int(str(get_current_device()).split(':')[-1]), - world_size= int(os.environ['WORLD_SIZE']), + rank=int(str(get_current_device()).split(':')[-1]), + world_size=int(os.environ['WORLD_SIZE']), ) shardmodel = ShardModel(model, shard_config) if args.mode == "train": train(shardmodel.model) elif args.mode == "inference": - inference(shardmodel.model) + inference(shardmodel.model) From 8a0e26023fb0559e36a7f4dce7e5ce89dabb1bd9 Mon Sep 17 00:00:00 2001 From: FoolPlayer <498107402@qq.com> Date: Tue, 23 May 2023 17:44:21 +0800 Subject: [PATCH 11/12] add docstring for other files --- colossalai/shardformer/policies/autopolicy.py | 25 ++-- colossalai/shardformer/policies/basepolicy.py | 122 +++++++++++------- colossalai/shardformer/shard/slicer.py | 111 ++++++---------- 3 files changed, 133 insertions(+), 125 deletions(-) diff --git a/colossalai/shardformer/policies/autopolicy.py b/colossalai/shardformer/policies/autopolicy.py index 9142e0dae22e..e096c2b13a59 100644 --- a/colossalai/shardformer/policies/autopolicy.py +++ b/colossalai/shardformer/policies/autopolicy.py @@ -1,40 +1,47 @@ import torch.nn as nn + def build_policies(): - """ + r""" Build the policies for the model - + Return: The dict for the policies """ auto_policy_dict = {} from transformers.models.bert.modeling_bert import BertForMaskedLM + from .bert import BertForMaskedLMPolicy auto_policy_dict[BertForMaskedLM] = BertForMaskedLMPolicy from transformers.models.bert.modeling_bert import BertForSequenceClassification + from .bert import BertForSequenceClassificationPolicy auto_policy_dict[BertForSequenceClassification] = BertForSequenceClassificationPolicy - + return auto_policy_dict -def get_autopolicy(model:nn.Module): - """ + +def get_autopolicy(model: nn.Module): + r""" Return the auto policy for the model Args: - model: The model to be used + model (:class:`nn.Module`): The model to get the auto policy Return: - The auto policy for the model + :class:`Policy`: The auto policy for the model """ auto_policy_dict = build_policies() policy = auto_policy_dict.get(model.__class__, None) - if policy is None: - raise NotImplementedError(f"Auto policy for {model.__class__.__qualname__} is not implemented\n Supported models are {[i.__qualname__ for i in auto_policy_dict.keys()]}") + if policy is None: + raise NotImplementedError( + f"Auto policy for {model.__class__.__qualname__} is not implemented\n Supported models are {[i.__qualname__ for i in auto_policy_dict.keys()]}" + ) return policy + # from transformers.models.bert.modeling_bert import BertForMaskedLM, BertForPreTraining # model = BertForPreTraining # policy = get_autopolicy(model) diff --git a/colossalai/shardformer/policies/basepolicy.py b/colossalai/shardformer/policies/basepolicy.py index aa27495b2600..a5cc0bc68df6 100644 --- a/colossalai/shardformer/policies/basepolicy.py +++ b/colossalai/shardformer/policies/basepolicy.py @@ -1,27 +1,38 @@ # part of code modified from https://github.com/tunib-ai/parallelformers +from dataclasses import dataclass, field +from typing import Any, Callable, Dict, List, Tuple, Type + import torch import torch.nn as nn -import colossalai.nn as col_nn -from typing import Any, Dict, List, Type, Tuple, Callable from transformers import AutoConfig -from dataclasses import dataclass, field + +import colossalai.nn as col_nn + @dataclass class Argument: - attr_dict : Dict[str, Any] - param_funcs : List[Callable] + r""" + The argument class for the policy + + Args: + attr_dict (Dict[str, Any]): The dict for the param setting + param_funcs (:class:`List[Callable]`): The list for the param functions + """ + attr_dict: Dict[str, Any] + param_funcs: List[Callable] + @dataclass class Layer: - """ + r""" The layer object for the policy Args: - weight: The weight name of the layer - bias: The bias name of the layer - replace_layer: The layer to replace the original layer - ignore: Whether to ignore this layer if it is not in the model + weight (str): The weight suffix of the layer + bias (str): The bias suffix of the layer + replace_layer (:class:`colosalai.nn`): The layer to replace the original layer + ignore (bool): Whether to ignore this layer if it is not in the model """ weight: str = None bias: str = None @@ -31,45 +42,55 @@ class Layer: @dataclass class Col_Layer(Layer): - """ + r""" Class for col shard layer in MegatronLM + + Args: + gather_output (bool): Whether to gather the output of the layer """ gather_output: bool = False @dataclass class Row_Layer(Layer): - """ + r""" Class for col shard layer in MegatronLM """ pass class Policy(): - """ + r""" The base class for all the policies - For each different model, it should have a different policy class, like BertPolicy for Bert Model - or OPTPolicy for OPT model. + For each different model, it should have a different policy class, like BertPolicy for Bert Model + or OPTPolicy for OPT model. AutoPolicy: - shardformer already defined some policies for huggingface model, just set custom_policy = None + Shardformer already defined some policies for huggingface model, just set ``custom_policy`` = None to use the auto policy. In shardformer autopolicy, we define a base policy for one type model, - like BertPolicy, and for each different Bert modle in huggingface like, BertForMaskedLM, + like BertPolicy, and for each different Bert modle in huggingface like, BertForMaskedLM, BertForSequenceClassification, etc., for each different Bert model we difine different policy class - and overwrite the method inject_policy - + and overwrite the method like ``inject_policy`` to modify the forward and backward process. + CustomPolicy: + If you want to define your own policy, you can set ``custom_policy`` = CustomPolicy, and overwrite + all the methods in ``Policy`` class. You can refer to any policy we defined like the ``BertPolicy`` + class for the example. + """ + @staticmethod - def argument_policy(model_config, shard_config: int) -> Dict[nn.Module,Argument]: - """ - Return a dict, the key is layer will be modified and the value is the Argument class with param setting and param functions + def argument_policy(model_config, shard_config: int) -> Dict[nn.Module, Argument]: + r""" + Return the dict for the modify policy, the key is the original layer class and the value is the + argument for the modify layer Args: - model_config: The config of transformer model - shard_setting: The config of distributed model - + model_config (:class:`tansformer.Config`): The config of transformer model + shard_config (:class:`ShardConfig`): The config for sharding model + Return: Dict for the modify policy, + :: { origin layer class1 (nn.Module): Argument( attr_dict = { @@ -100,40 +121,51 @@ def argument_policy(model_config, shard_config: int) -> Dict[nn.Module,Argument] """ raise NotImplementedError - @staticmethod def inject_policy() -> Tuple[nn.Module, nn.Module]: - """ - Return the dict for the inject model + r""" + Return the dict for the inject model Return: The injected model, key is the original model and value is the new shardmodel + :: + (OrignModel, CustomModel) + in `CustomModel`, we can overwrite the forward and backward process """ return () - @staticmethod def binding_policy() -> Dict: - """ + r""" Return the dict for the binding model + + Return: + This method should return the binding relationship for some layers share the weight or bias, + the key and value is the suffix of the weight or bias of the model + :: + return { + "bert.embeddings.word_embeddings.weight": "cls.predictions.decoder.weight", + } """ return NotImplementedError @staticmethod def attn_in() -> List: - """ + r""" Attention qkv layer + In this kind of method, we should return the list of ``Layer`` object, each ``Layer`` object should be + ``Layer`` for no slicing, ``Col_Layer`` for col slicing, ``Row_Layer`` for row slicing. And the parameters + in ``Layer`` object can refer to the ``Layer`` class. Returns: - List[Layer]: List of layer object, each layer is the new + List[Layer]: List of layer object, each layer is the new """ return NotImplementedError - @staticmethod def attn_out() -> List: - """ + r""" Attention output projection layer Returns: @@ -141,46 +173,40 @@ def attn_out() -> List: """ return NotImplementedError - @staticmethod def mlp_in() -> List: - """ + r""" h -> 4h mlp layer Returns: List[Layer]: List of layer object """ return NotImplementedError - @staticmethod def mlp_out() -> List: - """ + r""" 4h -> h mlp layer Returns: List[Layer]: List of layer object """ return NotImplementedError - - + @staticmethod - def embedding()->List: - """ + def embedding() -> List: + r""" Partially slice the embedding layer - vocab_size->vocab_size//gpu_nums Return: List[Layer]: List of layer object """ return NotImplementedError - - + @staticmethod - def unembedding()->List: - """ + def unembedding() -> List: + r""" Partially slice the embedding layer - vocab_size->vocab_size//gpu_nums Return: List[Layer]: List of layer object diff --git a/colossalai/shardformer/shard/slicer.py b/colossalai/shardformer/shard/slicer.py index 1849cdc99c72..0ca33dd7a09c 100644 --- a/colossalai/shardformer/shard/slicer.py +++ b/colossalai/shardformer/shard/slicer.py @@ -1,40 +1,40 @@ import os -from typing import Dict, Tuple from dataclasses import dataclass +from typing import Dict, Tuple import torch import torch.distributed as dist -from ..policies.basepolicy import Layer, Col_Layer, Row_Layer -from .shardconfig import ShardConfig +from ..policies.basepolicy import Col_Layer, Layer, Row_Layer +from .shardconfig import ShardConfig dim_mapping = {Col_Layer: 1, Row_Layer: 0} + class Slicer(): def __init__( - self, - shardconfig: ShardConfig #TODO + self, + shardconfig: ShardConfig #TODO ) -> None: self.shardconfig = shardconfig - def slice_weight_bias( self, weight: torch.Tensor, bias: torch.Tensor, policy_layer_cls: Layer, ): - """ + r""" Slice the weight and bias according to policy layer cls - Layer -> do nothing - Col_Layer -> slice the weight and bias along dim 1 - Row_Layer -> slice the weight along dim 0 and do not slice bias + ``Layer`` -> do nothing + ``Col_Layer`` -> slice the weight and bias along dim 1 + ``Row_Layer`` -> slice the weight along dim 0 and do not slice bias Args: - weight: The weight of the layer - bias: The bias of the layer - policy_layer_class: The class represent how to slice the tensor + weight (:class:`torch.nn.Module`): The weight of the layer + bias: (:class:`torch.nn.Module`): The bias of the layer + policy_layer_class (:class:`Policy`): The class represent how to slice the tensor """ if policy_layer_cls == Layer: return weight, bias @@ -46,42 +46,6 @@ def slice_weight_bias( else: raise NotImplementedError(f"The policy layer class {policy_layer_cls} is not supported") return weight, bias - - - def slice_weight( - self, - weight: torch.Tensor, - policy_layer_cls: Layer, - ) -> torch.Tensor: - """ - Slice the weight and bias according to the shardconfig - - Args: - weight: The weight of the layer - bias: The bias of the layer - policy_layer_class: The class represent how to slice the tensor - """ - if weight is not None: - dim = dim_mapping[policy_layer_cls] - weight = self.slice_tensor(weight, dim, False) - return weight - - - def slice_bias( - self, - bias: torch.Tensor, - ) -> torch.Tensor: - """ - Slice the bias according to the shardconfig - - Args: - bias: The bias of the layer - """ - assert bias is not None, "The bias is None" - if bias is not None: - bias = self.slice_tensor(bias, 1, True) - return bias - def slice_tensor( self, @@ -89,8 +53,13 @@ def slice_tensor( dim: int, is_bias: bool, ) -> torch.Tensor: - """ + r""" Slice tensor according to the config + + Args: + tensor_in (:class:`torch.Tensor`): The tensor to slice + dim (int): The dimension to slice + is_bias (bool): Whether the tensor is bias """ if tensor_in is None: return None @@ -99,69 +68,75 @@ def slice_tensor( else: return self.slice_1d(tensor_in) - def slice_2d( self, tensor: torch.Tensor, dim: int, ) -> torch.Tensor: """ - Slice the 2D tensor + Slice the 2D tensor Args: - tensor: The tensor to slice + tensor (:class:`torch.Tensor`): The tensor to slice + dim (int): The dimension to slice """ - assert dim in [0,1], f"Only support 2D tensor, but got {dim}D tensor" + assert dim in [0, 1], f"Only support 2D tensor, but got {dim}D tensor" if dim == 0: return self.slice_row(tensor) elif dim == 1: return self.slice_col(tensor) - def slice_1d( self, tensor: torch.Tensor, - dim: int = None, ) -> torch.Tensor: - """ - Slice the 1D tensor + r""" + Slice the 1D tensor Args: - tensor: The tensor to slice + tensor (:class:`torch.Tensor`): The tensor to slice + + Returns: + :class:`torch.Tensor`: The sliced tensor """ delta = (tensor.shape[0] + self.shardconfig.world_size - 1) // self.shardconfig.world_size down_idx = self.shardconfig.rank * delta up_idx = down_idx + delta - return tensor[down_idx:up_idx] + return tensor[down_idx:up_idx].contiguous() def slice_col( self, tensor: torch.Tensor, ) -> torch.Tensor: - """ + r""" Slice the tensor in column Args: - tensor: The tensor to slice + tensor (:class:`torch.Tensor`): The tensor to slice + + Returns: + :class:`torch.Tensor`: The sliced tensor + """ delta = (tensor.shape[0] + self.shardconfig.world_size - 1) // self.shardconfig.world_size down_idx = self.shardconfig.rank * delta up_idx = down_idx + delta - return tensor[down_idx:up_idx,:] - + return tensor[down_idx:up_idx, :].contiguous() def slice_row( self, tensor: torch.Tensor, ) -> torch.Tensor: - """ + r""" Slice the tensor in column Args: - tensor: The tensor to slice + tensor (:class:`torch.Tensor`): The tensor to slice + + Returns: + :class:`torch.Tensor`: The sliced tensor """ delta = (tensor.shape[1] + self.shardconfig.world_size - 1) // self.shardconfig.world_size down_idx = self.shardconfig.rank * delta up_idx = down_idx + delta - return tensor[:,down_idx:up_idx] - \ No newline at end of file + return tensor[:, down_idx:up_idx].contiguous() From 372be82c884c3a561d2c9b0a41c34deb3c48e758 Mon Sep 17 00:00:00 2001 From: FoolPlayer <498107402@qq.com> Date: Tue, 23 May 2023 18:00:42 +0800 Subject: [PATCH 12/12] pre-commit --- colossalai/shardformer/model/modeling_bert.py | 14 ++- colossalai/shardformer/policies/bert.py | 116 +++++++++--------- colossalai/shardformer/shard/shardconfig.py | 4 +- colossalai/shardformer/test/config.py | 6 +- colossalai/shardformer/utils/utils.py | 36 +++--- 5 files changed, 88 insertions(+), 88 deletions(-) diff --git a/colossalai/shardformer/model/modeling_bert.py b/colossalai/shardformer/model/modeling_bert.py index 724be56dca0e..6741ae866991 100644 --- a/colossalai/shardformer/model/modeling_bert.py +++ b/colossalai/shardformer/model/modeling_bert.py @@ -1,12 +1,14 @@ +from typing import Any, Dict, List, Type + import torch import torch.nn as nn from torch.nn import CrossEntropyLoss -from typing import Any, Dict, List, Type - - from transformers import BertForMaskedLM from transformers.models.bert.modeling_bert import MaskedLMOutput + + class BertForMaskedLM_(BertForMaskedLM): + def forward( self, input_ids=None, @@ -46,9 +48,9 @@ def forward( masked_lm_loss = None # if input_ids is not None: - # masked_lm_loss = applyDistCrossEntropy(prediction_scores, input_ids, self.config.vocab_size) + # masked_lm_loss = applyDistCrossEntropy(prediction_scores, input_ids, self.config.vocab_size) if labels is not None: - loss_fct = CrossEntropyLoss() # -100 index = padding token + loss_fct = CrossEntropyLoss() # -100 index = padding token masked_lm_loss = loss_fct(prediction_scores.view(-1, self.config.vocab_size), labels.view(-1)) if not return_dict: @@ -60,4 +62,4 @@ def forward( logits=prediction_scores, hidden_states=outputs.hidden_states, attentions=outputs.attentions, - ) \ No newline at end of file + ) diff --git a/colossalai/shardformer/policies/bert.py b/colossalai/shardformer/policies/bert.py index a8532100c0da..5d91d8ddc766 100644 --- a/colossalai/shardformer/policies/bert.py +++ b/colossalai/shardformer/policies/bert.py @@ -1,53 +1,51 @@ -from typing import Dict, List, Tuple, Type, Any, Callable +from dataclasses import dataclass +from typing import Any, Callable, Dict, List, Tuple, Type + import torch.nn as nn -from .basepolicy import Policy, Layer, Argument, Col_Layer, Row_Layer +from transformers.models.bert.modeling_bert import BertEmbeddings, BertLayer, BertLMPredictionHead + import colossalai.nn as col_nn -from transformers.models.bert.modeling_bert import BertLayer, BertEmbeddings, BertLMPredictionHead -from dataclasses import dataclass + +from .basepolicy import Argument, Col_Layer, Layer, Policy, Row_Layer class BertPolicy(Policy): + @staticmethod - def argument_policy(config, world_size: int) -> Dict[nn.Module,Argument]: + def argument_policy(config, world_size: int) -> Dict[nn.Module, Argument]: return { - BertLayer: Argument( - attr_dict = { - # 1. shard hidden size - "attention.self.all_head_size": config.hidden_size // world_size, - "crossattention.self.all_head_size": config.hidden_size // world_size, - # 2. shard number of heads - "attention.self.num_attention_heads": config.num_attention_heads // world_size, - "crossattention.self.num_attention_heads": config.num_attention_heads // world_size, - - }, - param_funcs = [ - BertPolicy.attn_in, - BertPolicy.attn_out, - BertPolicy.mlp_in, - BertPolicy.mlp_out - ] - ), - BertEmbeddings: Argument( - attr_dict = { - # 1. shard vocab size - # "word_embeddings.num_embeddings": config.vocab_size // world_size, - # 2. add the size of the sliced embedding layer excluding the last slice - "word_embeddings.dim_size": (config.vocab_size+world_size-1) // world_size, - }, - param_funcs = [ - BertPolicy.embedding, - ] - ), - BertLMPredictionHead: Argument( - attr_dict = { - # 1. shard vocab size - # "word_embeddings.num_embeddings": config.vocab_size // world_size, - # 2. add the size of the sliced embedding layer excluding the last slice - }, - param_funcs = [ - BertPolicy.unembedding, - ] - ) + BertLayer: + Argument( + attr_dict={ + # 1. shard hidden size + "attention.self.all_head_size": config.hidden_size // world_size, + "crossattention.self.all_head_size": config.hidden_size // world_size, + # 2. shard number of heads + "attention.self.num_attention_heads": config.num_attention_heads // world_size, + "crossattention.self.num_attention_heads": config.num_attention_heads // world_size, + }, + param_funcs=[BertPolicy.attn_in, BertPolicy.attn_out, BertPolicy.mlp_in, BertPolicy.mlp_out]), + BertEmbeddings: + Argument( + attr_dict={ + # 1. shard vocab size + # "word_embeddings.num_embeddings": config.vocab_size // world_size, + # 2. add the size of the sliced embedding layer excluding the last slice + "word_embeddings.dim_size": (config.vocab_size + world_size - 1) // world_size, + }, + param_funcs=[ + BertPolicy.embedding, + ]), + BertLMPredictionHead: + Argument( + attr_dict={ + # 1. shard vocab size + # "word_embeddings.num_embeddings": config.vocab_size // world_size, + # 2. add the size of the sliced embedding layer excluding the last slice + }, + param_funcs=[ + BertPolicy.unembedding, + ]) } @staticmethod @@ -92,9 +90,8 @@ def attn_in() -> List: replace_layer=col_nn.Linear1D_Col, ignore=True, ), - ] - + @staticmethod def attn_out() -> List: return [ @@ -110,17 +107,17 @@ def attn_out() -> List: ignore=True, ), ] - + @staticmethod def mlp_in() -> List: return [ - Col_Layer( + Col_Layer( weight="intermediate.dense.weight", bias="intermediate.dense.bias", replace_layer=col_nn.Linear1D_Col, ), ] - + @staticmethod def mlp_out() -> List: return [ @@ -133,13 +130,11 @@ def mlp_out() -> List: @staticmethod def embedding() -> List: - return [ - Col_Layer( - weight="word_embeddings.weight", - replace_layer=col_nn.VocabParallelEmbedding1D, - ) - ] - + return [Col_Layer( + weight="word_embeddings.weight", + replace_layer=col_nn.VocabParallelEmbedding1D, + )] + @staticmethod def unembedding() -> List: return [ @@ -151,16 +146,21 @@ def unembedding() -> List: ) ] + from transformers import BertForMaskedLM + from colossalai.shardformer.model.modeling_bert import BertForMaskedLM_ + + class BertForMaskedLMPolicy(BertPolicy): + @staticmethod def inject_policy() -> Tuple[nn.Module, nn.Module]: return (BertForMaskedLM, BertForMaskedLM_) - - + class BertForSequenceClassificationPolicy(BertPolicy): + @staticmethod def inject_policy() -> Dict: return {} @@ -168,4 +168,4 @@ def inject_policy() -> Dict: # model = BertForMaskedLM.from_pretrained("bert-base-uncased") # _ = BertForMaskedLMPolicy(model) -# print(isinstance(model,list(_.inject_policy().keys())[0])) \ No newline at end of file +# print(isinstance(model,list(_.inject_policy().keys())[0])) diff --git a/colossalai/shardformer/shard/shardconfig.py b/colossalai/shardformer/shard/shardconfig.py index be265ff0c8c1..c6a2513a6eff 100644 --- a/colossalai/shardformer/shard/shardconfig.py +++ b/colossalai/shardformer/shard/shardconfig.py @@ -10,9 +10,9 @@ class ShardConfig: fp16: bool = True num_gpus: int = 2 world_size: int = 2 - backend="nccl" + backend = "nccl" verbose: str = 'simple' seed: int = None require_grad: bool = False master_addr: str = "127.0.0.1" - master_port: int = 29500 \ No newline at end of file + master_port: int = 29500 diff --git a/colossalai/shardformer/test/config.py b/colossalai/shardformer/test/config.py index 295529429237..2b80d8b3ca12 100644 --- a/colossalai/shardformer/test/config.py +++ b/colossalai/shardformer/test/config.py @@ -1,5 +1 @@ -parallel = dict( - data=1, - pipeline=1, - tensor=dict(size=2, mode='1d') -) \ No newline at end of file +parallel = dict(data=1, pipeline=1, tensor=dict(size=2, mode='1d')) diff --git a/colossalai/shardformer/utils/utils.py b/colossalai/shardformer/utils/utils.py index 5eba87f6fe09..eb84edd88404 100644 --- a/colossalai/shardformer/utils/utils.py +++ b/colossalai/shardformer/utils/utils.py @@ -1,10 +1,10 @@ def hasattr_(obj, attr: str): - """ + r""" Check whether the object has the multi sublevel attr Args: - obj: The object to check - attr: The multi level attr to check + obj (object): The object to check + attr (str): The multi level attr to check """ attrs = attr.split('.') for a in attrs: @@ -14,15 +14,16 @@ def hasattr_(obj, attr: str): return False return True -def setattr_(obj, attr: str, value, ignore: bool=False): - """ + +def setattr_(obj, attr: str, value, ignore: bool = False): + r""" Set the object's multi sublevel attr to value, if ignore, ignore when it doesn't exist Args: - obj: The object to set - attr: The multi level attr to set - value: The value to set - ignore: Whether to ignore when the attr doesn't exist + obj (object): The object to set + attr (str): The multi level attr to set + value (Any): The value to set + ignore (bool): Whether to ignore when the attr doesn't exist """ attrs = attr.split('.') @@ -31,18 +32,19 @@ def setattr_(obj, attr: str, value, ignore: bool=False): obj = getattr(obj, a) except AttributeError: if ignore: - return + return raise AttributeError(f"Object {obj} has no attribute {attr}") setattr(obj, attrs[-1], value) -def getattr_(obj, attr: str, ignore: bool=None): - """ + +def getattr_(obj, attr: str, ignore: bool = None): + r""" Get the object's multi sublevel attr - + Args: - obj: The object to set - attr: The multi level attr to set - ignore: Whether to ignore when the attr doesn't exist + obj (object): The object to set + attr (str): The multi level attr to set + ignore (bool): Whether to ignore when the attr doesn't exist """ attrs = attr.split('.') @@ -53,4 +55,4 @@ def getattr_(obj, attr: str, ignore: bool=None): if ignore: return None raise AttributeError(f"Object {obj} has no attribute {attr}") - return obj \ No newline at end of file + return obj