From 604a2131cb9495f7966e023d21858153a9440fc8 Mon Sep 17 00:00:00 2001 From: FoolPlayer <45593998+FoolPlayer@users.noreply.github.com> Date: Mon, 22 May 2023 15:02:17 +0800 Subject: [PATCH 01/49] [shardformer] init shardformer code structure (#3731) * init shardformer code structure * add implement of sharder (inject and replace) * add implement of replace layer to colossal layer * separate different layer policy, add some notion * implement 1d and 2d slicer, can tell col or row * fix bug when slicing and inject model * fix some bug; add inference test example --- colossalai/shardformer/__init__.py | 0 colossalai/shardformer/model/__init__.py | 0 colossalai/shardformer/model/modeling_bert.py | 63 +++++ colossalai/shardformer/policies/__init__.py | 0 colossalai/shardformer/policies/autopolicy.py | 41 +++ colossalai/shardformer/policies/basepolicy.py | 182 ++++++++++++++ colossalai/shardformer/policies/bert.py | 168 +++++++++++++ colossalai/shardformer/shard/__init__.py | 0 colossalai/shardformer/shard/shardconfig.py | 18 ++ colossalai/shardformer/shard/sharder.py | 238 ++++++++++++++++++ colossalai/shardformer/shard/shardmodel.py | 58 +++++ colossalai/shardformer/shard/slicer.py | 167 ++++++++++++ colossalai/shardformer/test/config.py | 5 + colossalai/shardformer/test/test.py | 37 +++ colossalai/shardformer/utils/__init__.py | 0 colossalai/shardformer/utils/utils.py | 56 +++++ 16 files changed, 1033 insertions(+) create mode 100644 colossalai/shardformer/__init__.py create mode 100644 colossalai/shardformer/model/__init__.py create mode 100644 colossalai/shardformer/model/modeling_bert.py create mode 100644 colossalai/shardformer/policies/__init__.py create mode 100644 colossalai/shardformer/policies/autopolicy.py create mode 100644 colossalai/shardformer/policies/basepolicy.py create mode 100644 colossalai/shardformer/policies/bert.py create mode 100644 colossalai/shardformer/shard/__init__.py create mode 100644 colossalai/shardformer/shard/shardconfig.py create mode 100644 colossalai/shardformer/shard/sharder.py create mode 100644 colossalai/shardformer/shard/shardmodel.py create mode 100644 colossalai/shardformer/shard/slicer.py create mode 100644 colossalai/shardformer/test/config.py create mode 100644 colossalai/shardformer/test/test.py create mode 100644 colossalai/shardformer/utils/__init__.py create mode 100644 colossalai/shardformer/utils/utils.py diff --git a/colossalai/shardformer/__init__.py b/colossalai/shardformer/__init__.py new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/colossalai/shardformer/model/__init__.py b/colossalai/shardformer/model/__init__.py new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/colossalai/shardformer/model/modeling_bert.py b/colossalai/shardformer/model/modeling_bert.py new file mode 100644 index 000000000000..87ed8ac308a5 --- /dev/null +++ b/colossalai/shardformer/model/modeling_bert.py @@ -0,0 +1,63 @@ +import torch +import torch.nn as nn +from torch.nn import CrossEntropyLoss +from typing import Any, Dict, List, Type + + +from transformers import BertForMaskedLM +from transformers.models.bert.modeling_bert import MaskedLMOutput +class BertForMaskedLM_(BertForMaskedLM): + def forward( + self, + input_ids=None, + attention_mask=None, + token_type_ids=None, + position_ids=None, + head_mask=None, + inputs_embeds=None, + encoder_hidden_states=None, + encoder_attention_mask=None, + labels=None, + output_attentions=None, + output_hidden_states=None, + return_dict=None, + **kwargs, + ): + print("[Inject OK] Injected forward method") + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + outputs = self.bert( + input_ids, + attention_mask=attention_mask, + token_type_ids=token_type_ids, + position_ids=position_ids, + head_mask=head_mask, + inputs_embeds=inputs_embeds, + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_attention_mask, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + sequence_output = outputs[0] + prediction_scores = self.cls(sequence_output) + + masked_lm_loss = None + + # if input_ids is not None: + # masked_lm_loss = applyDistCrossEntropy(prediction_scores, input_ids, self.config.vocab_size) + if labels is not None: + loss_fct = CrossEntropyLoss() # -100 index = padding token + masked_lm_loss = loss_fct(prediction_scores.view(-1, self.config.vocab_size), labels.view(-1)) + + if not return_dict: + output = (prediction_scores,) + outputs[2:] + return ((masked_lm_loss,) + output) if masked_lm_loss is not None else output + + return MaskedLMOutput( + loss=masked_lm_loss, + logits=prediction_scores, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) \ No newline at end of file diff --git a/colossalai/shardformer/policies/__init__.py b/colossalai/shardformer/policies/__init__.py new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/colossalai/shardformer/policies/autopolicy.py b/colossalai/shardformer/policies/autopolicy.py new file mode 100644 index 000000000000..9142e0dae22e --- /dev/null +++ b/colossalai/shardformer/policies/autopolicy.py @@ -0,0 +1,41 @@ +import torch.nn as nn + +def build_policies(): + """ + Build the policies for the model + + Return: + The dict for the policies + """ + auto_policy_dict = {} + + from transformers.models.bert.modeling_bert import BertForMaskedLM + from .bert import BertForMaskedLMPolicy + auto_policy_dict[BertForMaskedLM] = BertForMaskedLMPolicy + + from transformers.models.bert.modeling_bert import BertForSequenceClassification + from .bert import BertForSequenceClassificationPolicy + auto_policy_dict[BertForSequenceClassification] = BertForSequenceClassificationPolicy + + return auto_policy_dict + +def get_autopolicy(model:nn.Module): + """ + Return the auto policy for the model + + Args: + model: The model to be used + + Return: + The auto policy for the model + """ + auto_policy_dict = build_policies() + policy = auto_policy_dict.get(model.__class__, None) + if policy is None: + raise NotImplementedError(f"Auto policy for {model.__class__.__qualname__} is not implemented\n Supported models are {[i.__qualname__ for i in auto_policy_dict.keys()]}") + return policy + +# from transformers.models.bert.modeling_bert import BertForMaskedLM, BertForPreTraining +# model = BertForPreTraining +# policy = get_autopolicy(model) +# print(policy) diff --git a/colossalai/shardformer/policies/basepolicy.py b/colossalai/shardformer/policies/basepolicy.py new file mode 100644 index 000000000000..d444aeb53bf8 --- /dev/null +++ b/colossalai/shardformer/policies/basepolicy.py @@ -0,0 +1,182 @@ +# part of code modified from https://github.com/tunib-ai/parallelformers + +import torch +import torch.nn as nn +import colossalai.nn as col_nn +from typing import Any, Dict, List, Type, Tuple, Callable +from transformers import AutoConfig +from dataclasses import dataclass, field + +@dataclass +class Argument: + attr_dict : Dict[str, Any] + param_funcs : List[Callable] + binding_layers : List[nn.Module] = field(default_factory=list) + +@dataclass +class Layer: + """ + The layer object for the policy + + Args: + weight: The weight name of the layer + bias: The bias name of the layer + replace_layer: The layer to replace the original layer + ignore: Whether to ignore this layer if it is not in the model + """ + weight: str = None + bias: str = None + replace_layer: Any = None + ignore: bool = False + + +@dataclass +class Col_Layer(Layer): + """ + Class for col shard layer in MegatronLM + """ + gather_output: bool = False + + +@dataclass +class Row_Layer(Layer): + """ + Class for col shard layer in MegatronLM + """ + pass + + +class Policy(): + """ + The base class for all the policies + For each different model, it should have a different policy class, like BertPolicy for Bert Model + or OPTPolicy for OPT model. + AutoPolicy: + shardformer already defined some policies for huggingface model, just set custom_policy = None + to use the auto policy. In shardformer autopolicy, we define a base policy for one type model, + like BertPolicy, and for each different Bert modle in huggingface like, BertForMaskedLM, + BertForSequenceClassification, etc., for each different Bert model we difine different policy class + and overwrite the method inject_policy + + CustomPolicy: + """ + @staticmethod + def argument_policy(model_config, shard_config: int) -> Dict[nn.Module,Argument]: + """ + Return a dict, the key is layer will be modified and the value is the Argument class with param setting and param functions + + Args: + model_config: The config of transformer model + shard_setting: The config of distributed model + + Return: + Dict for the modify policy, + { + origin layer class1 (nn.Module): Argument( + attr_dict = { + argument1: value1, + argument2: value2, + ... + }, + param_funcs = [ + staticmethod1, + staticmethod2, + ... + ] + ), + origin layer class2 (nn.Module): Argument( + attr_dict = { + argument1: value1, + argument2: value2, + ... + }, + param_funcs = [ + staticmethod1, + staticmethod2, + ... + ] + ), + ... + } + + """ + raise NotImplementedError + + + @staticmethod + def inject_policy() -> Tuple[nn.Module, nn.Module]: + """ + Return the dict for the inject model + + Return: + The injected model, key is the original model and value is the new shardmodel + """ + return () + + + @staticmethod + def attn_in() -> List: + """ + Attention qkv layer + + Returns: + List[Layer]: List of layer object, each layer is the new + """ + return NotImplementedError + + + @staticmethod + def attn_out() -> List: + """ + Attention output projection layer + + Returns: + List[Layer]: List of layer object + """ + return NotImplementedError + + + @staticmethod + def mlp_in() -> List: + """ + h -> 4h mlp layer + + Returns: + List[Layer]: List of layer object + """ + return NotImplementedError + + + @staticmethod + def mlp_out() -> List: + """ + 4h -> h mlp layer + + Returns: + List[Layer]: List of layer object + """ + return NotImplementedError + + + @staticmethod + def embedding()->List: + """ + Partially slice the embedding layer + vocab_size->vocab_size//gpu_nums + + Return: + List[Layer]: List of layer object + """ + return NotImplementedError + + + @staticmethod + def unembedding()->List: + """ + Partially slice the embedding layer + vocab_size->vocab_size//gpu_nums + + Return: + List[Layer]: List of layer object + """ + return NotImplementedError diff --git a/colossalai/shardformer/policies/bert.py b/colossalai/shardformer/policies/bert.py new file mode 100644 index 000000000000..24b95e827347 --- /dev/null +++ b/colossalai/shardformer/policies/bert.py @@ -0,0 +1,168 @@ +from typing import Dict, List, Tuple, Type, Any, Callable +import torch.nn as nn +from .basepolicy import Policy, Layer, Argument, Col_Layer, Row_Layer +import colossalai.nn as col_nn +from transformers.models.bert.modeling_bert import BertLayer, BertEmbeddings, BertLMPredictionHead +from dataclasses import dataclass + + +class BertPolicy(Policy): + @staticmethod + def argument_policy(config, world_size: int) -> Dict[nn.Module,Argument]: + return { + BertLayer: Argument( + attr_dict = { + # 1. shard hidden size + "attention.self.all_head_size": config.hidden_size // world_size, + "crossattention.self.all_head_size": config.hidden_size // world_size, + # 2. shard number of heads + "attention.self.num_attention_heads": config.num_attention_heads // world_size, + "crossattention.self.num_attention_heads": config.num_attention_heads // world_size, + + }, + param_funcs = [ + BertPolicy.attn_in, + BertPolicy.attn_out, + BertPolicy.mlp_in, + BertPolicy.mlp_out + ] + ), + BertEmbeddings: Argument( + attr_dict = { + # 1. shard vocab size + # "word_embeddings.num_embeddings": config.vocab_size // world_size, + # 2. add the size of the sliced embedding layer excluding the last slice + "word_embeddings.dim_size": (config.vocab_size+world_size-1) // world_size, + }, + param_funcs = [ + BertPolicy.embedding, + ], + binding_layers = [ + BertLMPredictionHead, + ] + ), + BertLMPredictionHead: Argument( + attr_dict = { + # 1. shard vocab size + # "word_embeddings.num_embeddings": config.vocab_size // world_size, + # 2. add the size of the sliced embedding layer excluding the last slice + }, + param_funcs = [ + BertPolicy.unembedding, + ] + ) + } + + @staticmethod + def attn_in() -> List: + return [ + Col_Layer( + weight="attention.self.query.weight", + bias="attention.self.query.bias", + replace_layer=col_nn.Linear1D_Col, + ), + Col_Layer( + weight="attention.self.key.weight", + bias="attention.self.key.bias", + replace_layer=col_nn.Linear1D_Col, + ), + Col_Layer( + weight="attention.self.value.weight", + bias="attention.self.value.bias", + replace_layer=col_nn.Linear1D_Col, + ), + Col_Layer( + weight="crossattention.self.query.weight", + bias="crossattention.self.query.bias", + replace_layer=col_nn.Linear1D_Col, + ignore=True, + ), + Col_Layer( + weight="crossattention.self.key.weight", + bias="crossattention.self.key.bias", + replace_layer=col_nn.Linear1D_Col, + ignore=True, + ), + Col_Layer( + weight="crossattention.self.value.weight", + bias="crossattention.self.value.bias", + replace_layer=col_nn.Linear1D_Col, + ignore=True, + ), + + ] + + @staticmethod + def attn_out() -> List: + return [ + Row_Layer( + weight="attention.output.dense.weight", + bias="attention.output.dense.bias", + replace_layer=col_nn.Linear1D_Row, + ), + Row_Layer( + weight="crossattention.output.dense.weight", + bias="crossattention.output.dense.bias", + replace_layer=col_nn.Linear1D_Row, + ignore=True, + ), + ] + + @staticmethod + def mlp_in() -> List: + return [ + Col_Layer( + weight="intermediate.dense.weight", + bias="intermediate.dense.bias", + replace_layer=col_nn.Linear1D_Col, + ), + ] + + @staticmethod + def mlp_out() -> List: + return [ + Row_Layer( + weight="output.dense.weight", + bias="output.dense.bias", + replace_layer=col_nn.Linear1D_Row, + ), + ] + + @staticmethod + def embedding() -> List: + return [ + Col_Layer( + weight="word_embeddings.weight", + replace_layer=col_nn.VocabParallelEmbedding1D, + ) + ] + + @staticmethod + def unembedding() -> List: + return [ + Col_Layer( + weight="decoder.weight", + bias="decoder.bias", + replace_layer=col_nn.Linear1D_Col, + gather_output=True, + ) + ] + +from transformers import BertForMaskedLM +from colossalai.shardformer.model.modeling_bert import BertForMaskedLM_ +class BertForMaskedLMPolicy(BertPolicy): + @staticmethod + def inject_policy() -> Tuple[nn.Module, nn.Module]: + return (BertForMaskedLM, BertForMaskedLM_) + + + +class BertForSequenceClassificationPolicy(BertPolicy): + @staticmethod + def inject_policy() -> Dict: + return {} + + +# model = BertForMaskedLM.from_pretrained("bert-base-uncased") +# _ = BertForMaskedLMPolicy(model) +# print(isinstance(model,list(_.inject_policy().keys())[0])) \ No newline at end of file diff --git a/colossalai/shardformer/shard/__init__.py b/colossalai/shardformer/shard/__init__.py new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/colossalai/shardformer/shard/shardconfig.py b/colossalai/shardformer/shard/shardconfig.py new file mode 100644 index 000000000000..be265ff0c8c1 --- /dev/null +++ b/colossalai/shardformer/shard/shardconfig.py @@ -0,0 +1,18 @@ +from dataclasses import dataclass + + +@dataclass +class ShardConfig: + """ + The config for sharding the huggingface model for test + """ + rank: int + fp16: bool = True + num_gpus: int = 2 + world_size: int = 2 + backend="nccl" + verbose: str = 'simple' + seed: int = None + require_grad: bool = False + master_addr: str = "127.0.0.1" + master_port: int = 29500 \ No newline at end of file diff --git a/colossalai/shardformer/shard/sharder.py b/colossalai/shardformer/shard/sharder.py new file mode 100644 index 000000000000..ef785cfee9da --- /dev/null +++ b/colossalai/shardformer/shard/sharder.py @@ -0,0 +1,238 @@ +import torch +import torch.nn as nn +from typing import Any, Dict, Iterable, List, Optional, Tuple, Type, Union, Callable +from .shardconfig import ShardConfig +from dataclasses import dataclass +from ..policies.basepolicy import Policy, Layer +from ..policies.autopolicy import get_autopolicy +from .slicer import Slicer +from ..utils.utils import hasattr_, setattr_, getattr_ +import colossalai.nn as col_nn +from colossalai.logging import get_dist_logger +import os + + +logger = get_dist_logger() + +class ModelSharder(object): + """ + Shard the original huggingface model according to the policy + + Args: + policy: The policy to shard the model + model: The model to shard + dist_setting: The setting of distributed model + """ + def __init__( + self, + model: nn.Module, + policy: Policy, + shard_config: ShardConfig = None, # TODO + ) -> None: + self.model = model + self.policy = get_autopolicy(self.model) if policy is None else policy + self.slicer = Slicer(shard_config) + self.shard_config = shard_config + self.model_config = self.model.config + self.binding_map = {} + + + def shard(self) -> None: + self.inject_model(self.model) + self.replace_layer(self.model) + + + def inject_model( + self, + model: nn.Module, + ) -> None: + """ + Replace the model to policy defined model + Mainly modify the forward and backward to fit distributed model + + e.g. + BertForMaskedLM.forward -> BertForMaskedLM_.forward + """ + inject_policy = self.policy.inject_policy() + + org_model_cls = inject_policy[0] + shard_model_cls = inject_policy[1] + + if model.__class__ == org_model_cls: + for key in shard_model_cls.__dict__.keys(): + if hasattr(model.__class__, key): + setattr( + model.__class__, + key, + getattr(shard_model_cls,key), + ) + else: + raise NotImplementedError(f"{model.__class__} is not implemented so far") + + + def replace_layer( + self, + model: nn.Module, + ) -> None: + """ + Replace the layer according to the policy, and replace the layer one by one + + Args: + layer: The layer to shard + """ + argument_policies = self.policy.argument_policy(self.model_config, self.shard_config.world_size) + for argument_policy in argument_policies.items(): + origin_layer_cls = argument_policy[0] + attr_dict = argument_policy[1].attr_dict + param_funcs = argument_policy[1].param_funcs + binding_layers = argument_policy[1].binding_layers + # if binding_layer is not None: + # self.binding_map[origin_layer_cls] = binding_layer + self.reverse_replace_layer(model, origin_layer_cls, attr_dict, param_funcs, binding_layers) + + + def reverse_replace_layer( + self, + layer: nn.Module, + origin_cls: nn.Module, + attr_dict: Dict[str, Any], + param_funcs: List[Callable], + binding_layers: List[nn.Module] + ) -> None: + """ + Reverse the replace layer operation + + Args: + layer: The object of layer to shard + origin_cls: The origin layer class + attr_dict: The attribute dict to modify + policy_cls: The policy class + """ + for name, child in layer.named_children(): + if child.__class__ == origin_cls: + # replac_layer = child + for k, v in attr_dict.items(): + setattr_(child, k, v, ignore=True) + # print(f"Sharding {name} layer", replac_layer.attention.self.__dict__) + # setattr_(layer, name, self.shard_one_layer(child, policy_cls)) + self.shard_one_layer(child, param_funcs, binding_layers) + continue + + self.reverse_replace_layer(child, origin_cls, attr_dict, param_funcs, binding_layers) + return layer + + + def shard_one_layer( + self, + org_layer: nn.Module, + param_funcs: List[Callable], + binding_layers: List[nn.Module] + ) -> None: + """ + Shard one layer according to the policy, the layer should be the same class as the key in policy's argument_policy return dict + + Args: + org_layer: The origin layer object to shard + param_funcs: The function list to get shard information in policy class + + """ + # print(org_layer) + for func in param_funcs: + policy_layers = func() + for policy_layer in policy_layers: + weight = None + bias = None + weight_attr = policy_layer.weight + bias_attr = policy_layer.bias + replace_layer_cls = policy_layer.replace_layer + ignore = policy_layer.ignore + if policy_layer.__class__.__name__ == "Col_Layer": + gather_output = policy_layer.gather_output + print(gather_output) + + if weight_attr is not None: + if hasattr_(org_layer, weight_attr): + weight = getattr_(org_layer, weight_attr) + elif not ignore: + raise ValueError(f"Layer {org_layer.__class__.__qualname__} has no attribute {weight_attr}") + + if bias_attr is not None: + if hasattr_(org_layer, bias_attr): + bias = getattr_(org_layer, bias_attr) + elif not ignore: + raise ValueError(f"Layer {org_layer.__class__.__qualname__} has no attribute {bias_attr}") + + # dont have the attribute in policy, and ignore is true + if weight is None and bias is None and ignore: + continue + + # set the sliced weight and bias to the new nn_col layer + assert weight is not None or bias is not None + layer_attr = (lambda x: x[:x.rfind(".")])(weight_attr or bias_attr) + + # slice weight and bias + weight, bias = self.slicer.slice_weight_bias(weight, bias, policy_layer.__class__) + print(os.environ['RANK'], policy_layer.__class__, weight.shape, bias.shape if bias is not None else None) + # save the binding information + for binding_layer in binding_layers: + self.binding_map[binding_layer] = dict(weight=weight, bias=bias) + + # create new object to replace the origin layer + if replace_layer_cls is not None: + # print(f"RANK {os.environ['RANK']}: replace {getattr_(org_layer, layer_attr).__class__} to {replace_layer_cls}, shape is {weight.shape}") + if isinstance(getattr_(org_layer, layer_attr), nn.Linear): + if replace_layer_cls.__name__ == "Linear1D_Row": + replace_layer = replace_layer_cls(weight.shape[1], weight.shape[0], bias=False if bias is None else True) + elif replace_layer_cls.__name__ == "Linear1D_Col": + replace_layer = replace_layer_cls(weight.shape[0], weight.shape[1], bias=False if bias is None else True, gather_output=gather_output) + setattr_(org_layer, layer_attr, replace_layer, ignore=ignore) + self.set_param(replace_layer, weight, bias) + elif isinstance(getattr_(org_layer, layer_attr), nn.Embedding): + replace_layer = replace_layer_cls(weight.shape[0], weight.shape[1], getattr_(org_layer, f"{layer_attr}.padding_idx", ignore=True)) + setattr_(org_layer, layer_attr, replace_layer, ignore=ignore) + self.set_param(replace_layer, weight, bias) + else: + raise NotImplementedError(f"Replacing {getattr_(org_layer, layer_attr).__class__} is not implemented so far") + # do not replace the layer object, just replace the weight and bias + else: + self.set_param(org_layer, layer_attr, weight, bias) + + + def set_param( + self, + layer: Any, + layer_attr: str = "", + weight: torch.Tensor = None, + bias: torch.Tensor = None + ) -> None: + """ + Reset the weight and bias of the layer object + + Args: + layer: The layer object + layer_attr: The attribute name of the layer + weight: The weight of the layer + bias: The bias of the layer + """ + assert weight is not None or bias is not None + if weight is not None: + setattr_(layer, "weight" if layer_attr == "" else layer_attr+".weight", nn.Parameter(weight)) + self.set_layer_size(layer, layer_attr, weight.shape) + if bias is not None: + setattr_(layer, "bias" if layer_attr == "" else layer_attr+".bias", nn.Parameter(bias)) + + + def set_layer_size(self, layer: nn.Module, layer_attr: str, size: torch.Size) -> None: + """ + Set the layer attribute + + Args: + layer: The layer object + layer_attr: The attribute name of the layer + size: Torch.size + """ + # Tensor.shape[0] -> out_features, Tensor.shape[1] -> in_features + attrs = ["out_features", "in_features"] + for i, attr in enumerate(attrs): + if hasattr_(layer, f"{layer_attr}.{attr}"): + setattr_(layer, f"{layer_attr}.{attr}", size[i]) diff --git a/colossalai/shardformer/shard/shardmodel.py b/colossalai/shardformer/shard/shardmodel.py new file mode 100644 index 000000000000..54d7b5ba02d9 --- /dev/null +++ b/colossalai/shardformer/shard/shardmodel.py @@ -0,0 +1,58 @@ +import os +import torch +import torch.nn as nn +import transformers +import torch.distributed as dist +from dataclasses import dataclass +from contextlib import suppress + +from colossalai.tensor.d_tensor.layout import Layout +from ..policies.basepolicy import Policy +from .sharder import ModelSharder +from .shardconfig import ShardConfig + + +class ShardModel(object): + """ + The class for sharding the huggingface model, self.model is the sharded model + Just creat a new ShardModel object to shard huggingface model + + Args: + model: the origin huggingface model + dist_config: the config for distribute information + custom_policy: the custom policy for sharding + """ + def __init__( + self, + model: nn.Module, + shard_config: ShardConfig = None, # TODO + custom_policy: Policy = None, + ) -> None: + self.model = model + self.shard_config = shard_config + self.policy = custom_policy + # self.layout=, # TODO + + sharder=ModelSharder( + model=self.model, + policy=self.policy, + shard_config=self.shard_config, + ) + sharder.shard() + + + def set_environ(self) -> None: + os.environ["TOKENIZERS_PARALLELISM"] = "true" + os.environ["MKL_SERVICE_FORCE_INTEL"] = "GNU" + os.environ["MASTER_ADDR"] = str(self.dist_config.master_addr) + os.environ["MASTER_PORT"] = str(self.dist_config.master_port) + os.environ["WORLD_SIZE"] = str(self.dist_config.num_gpus) + os.environ["RANK"] = str(self.dist_config.rank) + os.environ["LOCAL_RANK"] = str(self.dist_config.rank) + if not dist.is_initialized(): + dist.init_process_group(backend=self.dist_config.backend) + + torch.cuda.set_device(int(os.getenv("LOCAL_RANK", "0"))) + + def back_to_org() -> None: + pass \ No newline at end of file diff --git a/colossalai/shardformer/shard/slicer.py b/colossalai/shardformer/shard/slicer.py new file mode 100644 index 000000000000..1849cdc99c72 --- /dev/null +++ b/colossalai/shardformer/shard/slicer.py @@ -0,0 +1,167 @@ +import os +from typing import Dict, Tuple +from dataclasses import dataclass + +import torch +import torch.distributed as dist +from ..policies.basepolicy import Layer, Col_Layer, Row_Layer +from .shardconfig import ShardConfig + + +dim_mapping = {Col_Layer: 1, Row_Layer: 0} + +class Slicer(): + + def __init__( + self, + shardconfig: ShardConfig #TODO + ) -> None: + self.shardconfig = shardconfig + + + def slice_weight_bias( + self, + weight: torch.Tensor, + bias: torch.Tensor, + policy_layer_cls: Layer, + ): + """ + Slice the weight and bias according to policy layer cls + Layer -> do nothing + Col_Layer -> slice the weight and bias along dim 1 + Row_Layer -> slice the weight along dim 0 and do not slice bias + + Args: + weight: The weight of the layer + bias: The bias of the layer + policy_layer_class: The class represent how to slice the tensor + """ + if policy_layer_cls == Layer: + return weight, bias + elif policy_layer_cls == Col_Layer: + weight = self.slice_tensor(weight, 1, False) + bias = self.slice_tensor(bias, 0, True) + elif policy_layer_cls == Row_Layer: + weight = self.slice_tensor(weight, 0, False) + else: + raise NotImplementedError(f"The policy layer class {policy_layer_cls} is not supported") + return weight, bias + + + def slice_weight( + self, + weight: torch.Tensor, + policy_layer_cls: Layer, + ) -> torch.Tensor: + """ + Slice the weight and bias according to the shardconfig + + Args: + weight: The weight of the layer + bias: The bias of the layer + policy_layer_class: The class represent how to slice the tensor + """ + if weight is not None: + dim = dim_mapping[policy_layer_cls] + weight = self.slice_tensor(weight, dim, False) + return weight + + + def slice_bias( + self, + bias: torch.Tensor, + ) -> torch.Tensor: + """ + Slice the bias according to the shardconfig + + Args: + bias: The bias of the layer + """ + assert bias is not None, "The bias is None" + if bias is not None: + bias = self.slice_tensor(bias, 1, True) + return bias + + + def slice_tensor( + self, + tensor_in: torch.Tensor, + dim: int, + is_bias: bool, + ) -> torch.Tensor: + """ + Slice tensor according to the config + """ + if tensor_in is None: + return None + if not is_bias: + return self.slice_2d(tensor_in, dim) + else: + return self.slice_1d(tensor_in) + + + def slice_2d( + self, + tensor: torch.Tensor, + dim: int, + ) -> torch.Tensor: + """ + Slice the 2D tensor + + Args: + tensor: The tensor to slice + """ + assert dim in [0,1], f"Only support 2D tensor, but got {dim}D tensor" + if dim == 0: + return self.slice_row(tensor) + elif dim == 1: + return self.slice_col(tensor) + + + def slice_1d( + self, + tensor: torch.Tensor, + dim: int = None, + ) -> torch.Tensor: + """ + Slice the 1D tensor + + Args: + tensor: The tensor to slice + """ + delta = (tensor.shape[0] + self.shardconfig.world_size - 1) // self.shardconfig.world_size + down_idx = self.shardconfig.rank * delta + up_idx = down_idx + delta + return tensor[down_idx:up_idx] + + def slice_col( + self, + tensor: torch.Tensor, + ) -> torch.Tensor: + """ + Slice the tensor in column + + Args: + tensor: The tensor to slice + """ + delta = (tensor.shape[0] + self.shardconfig.world_size - 1) // self.shardconfig.world_size + down_idx = self.shardconfig.rank * delta + up_idx = down_idx + delta + return tensor[down_idx:up_idx,:] + + + def slice_row( + self, + tensor: torch.Tensor, + ) -> torch.Tensor: + """ + Slice the tensor in column + + Args: + tensor: The tensor to slice + """ + delta = (tensor.shape[1] + self.shardconfig.world_size - 1) // self.shardconfig.world_size + down_idx = self.shardconfig.rank * delta + up_idx = down_idx + delta + return tensor[:,down_idx:up_idx] + \ No newline at end of file diff --git a/colossalai/shardformer/test/config.py b/colossalai/shardformer/test/config.py new file mode 100644 index 000000000000..295529429237 --- /dev/null +++ b/colossalai/shardformer/test/config.py @@ -0,0 +1,5 @@ +parallel = dict( + data=1, + pipeline=1, + tensor=dict(size=2, mode='1d') +) \ No newline at end of file diff --git a/colossalai/shardformer/test/test.py b/colossalai/shardformer/test/test.py new file mode 100644 index 000000000000..c2a9053ca2f6 --- /dev/null +++ b/colossalai/shardformer/test/test.py @@ -0,0 +1,37 @@ +from transformers import AutoTokenizer +from transformers import BertForMaskedLM +import colossalai +from colossalai.shardformer.shard.shardmodel import ShardModel +from colossalai.utils import get_current_device, print_rank_0 +from colossalai.logging import get_dist_logger +from colossalai.shardformer.shard.shardconfig import ShardConfig +import inspect +import argparse +import torch.nn as nn +import os + +tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased") + +def get_args(): + parser = colossalai.get_default_parser() + return parser.parse_args() + +def inference(model: nn.Module): + # print(model) + token = "Hello, my dog is cute" + inputs = tokenizer(token, return_tensors="pt") + inputs.to("cuda") + model.to("cuda") + outputs = model(**inputs) + print(outputs) + +if __name__ == "__main__": + args = get_args() + colossalai.launch_from_torch(config=args.config) + model = BertForMaskedLM.from_pretrained("bert-base-uncased") + shard_config = ShardConfig( + rank = int(str(get_current_device()).split(':')[-1]), + world_size= int(os.environ['WORLD_SIZE']), + ) + shardmodel = ShardModel(model, shard_config) + inference(shardmodel.model) diff --git a/colossalai/shardformer/utils/__init__.py b/colossalai/shardformer/utils/__init__.py new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/colossalai/shardformer/utils/utils.py b/colossalai/shardformer/utils/utils.py new file mode 100644 index 000000000000..5eba87f6fe09 --- /dev/null +++ b/colossalai/shardformer/utils/utils.py @@ -0,0 +1,56 @@ +def hasattr_(obj, attr: str): + """ + Check whether the object has the multi sublevel attr + + Args: + obj: The object to check + attr: The multi level attr to check + """ + attrs = attr.split('.') + for a in attrs: + try: + obj = getattr(obj, a) + except AttributeError: + return False + return True + +def setattr_(obj, attr: str, value, ignore: bool=False): + """ + Set the object's multi sublevel attr to value, if ignore, ignore when it doesn't exist + + Args: + obj: The object to set + attr: The multi level attr to set + value: The value to set + ignore: Whether to ignore when the attr doesn't exist + """ + + attrs = attr.split('.') + for a in attrs[:-1]: + try: + obj = getattr(obj, a) + except AttributeError: + if ignore: + return + raise AttributeError(f"Object {obj} has no attribute {attr}") + setattr(obj, attrs[-1], value) + +def getattr_(obj, attr: str, ignore: bool=None): + """ + Get the object's multi sublevel attr + + Args: + obj: The object to set + attr: The multi level attr to set + ignore: Whether to ignore when the attr doesn't exist + """ + + attrs = attr.split('.') + for a in attrs: + try: + obj = getattr(obj, a) + except AttributeError: + if ignore: + return None + raise AttributeError(f"Object {obj} has no attribute {attr}") + return obj \ No newline at end of file From ffacf0fcbade77300ce52568e20a7720a89946a8 Mon Sep 17 00:00:00 2001 From: FoolPlayer <45593998+FoolPlayer@users.noreply.github.com> Date: Wed, 24 May 2023 10:26:46 +0800 Subject: [PATCH 02/49] [shardformer]: Feature/shardformer, add some docstring and readme (#3816) * init shardformer code structure * add implement of sharder (inject and replace) * add implement of replace layer to colossal layer * separate different layer policy, add some notion * implement 1d and 2d slicer, can tell col or row * fix bug when slicing and inject model * fix some bug; add inference test example * add share weight and train example * add train * add docstring and readme * add docstring for other files * pre-commit --- colossalai/nn/layer/parallel_1d/_operation.py | 2 + colossalai/nn/layer/parallel_1d/layers.py | 9 +- colossalai/shardformer/README.md | 177 +++++++++++++++++ colossalai/shardformer/model/modeling_bert.py | 16 +- colossalai/shardformer/policies/autopolicy.py | 25 ++- colossalai/shardformer/policies/basepolicy.py | 128 +++++++----- colossalai/shardformer/policies/bert.py | 125 ++++++------ colossalai/shardformer/shard/shardconfig.py | 4 +- colossalai/shardformer/shard/sharder.py | 187 +++++++++--------- colossalai/shardformer/shard/shardmodel.py | 36 ++-- colossalai/shardformer/shard/slicer.py | 113 +++++------ colossalai/shardformer/test/config.py | 6 +- colossalai/shardformer/test/test.py | 87 ++++++-- colossalai/shardformer/utils/utils.py | 36 ++-- 14 files changed, 612 insertions(+), 339 deletions(-) create mode 100644 colossalai/shardformer/README.md diff --git a/colossalai/nn/layer/parallel_1d/_operation.py b/colossalai/nn/layer/parallel_1d/_operation.py index 394334558275..c5e33fd497cd 100644 --- a/colossalai/nn/layer/parallel_1d/_operation.py +++ b/colossalai/nn/layer/parallel_1d/_operation.py @@ -1,5 +1,6 @@ import torch import torch.distributed as dist + from colossalai.core import global_context as gpc try: @@ -72,6 +73,7 @@ def backward(ctx, grad_output): total_input = input grad_input = grad_output.matmul(weight) + grad_output = grad_output.contiguous() # Convert the tensor shapes to 2D for execution compatibility grad_output = grad_output.view(grad_output.shape[0] * grad_output.shape[1], grad_output.shape[2]) total_input = total_input.view(total_input.shape[0] * total_input.shape[1], total_input.shape[2]) diff --git a/colossalai/nn/layer/parallel_1d/layers.py b/colossalai/nn/layer/parallel_1d/layers.py index 406173a18c60..0ee3b4fcb502 100644 --- a/colossalai/nn/layer/parallel_1d/layers.py +++ b/colossalai/nn/layer/parallel_1d/layers.py @@ -469,7 +469,8 @@ def __init__(self, if skip_bias_add and not bias: raise ValueError('cannot skip bias addition if bias is None') - self.out_features_per_partition = divide(out_features, gpc.tensor_parallel_size) + # self.out_features_per_partition = divide(out_features*2, gpc.tensor_parallel_size) + self.out_features_per_partition = out_features # Parameters. # Initialize weight. @@ -612,7 +613,8 @@ def __init__(self, raise ValueError('cannot skip bias addition if bias is None') # Divide the weight matrix along the last dimension. - self.input_size_per_partition = divide(in_features, gpc.tensor_parallel_size) + # self.input_size_per_partition = divide(in_features*2, gpc.tensor_parallel_size) + self.input_size_per_partition = in_features # Parameters. # Initialize weight. @@ -884,7 +886,8 @@ def __init__(self, tensor_parallel_size = gpc.get_world_size(ParallelMode.PARALLEL_1D) tensor_parallel_rank = gpc.get_local_rank(ParallelMode.PARALLEL_1D) - self.num_embeddings_per_partition = divide(num_embeddings, tensor_parallel_size) + # self.num_embeddings_per_partition = divide(num_embeddings, tensor_parallel_size) + self.num_embeddings_per_partition = num_embeddings self.vocab_start_index = tensor_parallel_rank * self.num_embeddings_per_partition self.vocab_end_index = self.vocab_start_index + self.num_embeddings_per_partition diff --git a/colossalai/shardformer/README.md b/colossalai/shardformer/README.md new file mode 100644 index 000000000000..a47e280f2be4 --- /dev/null +++ b/colossalai/shardformer/README.md @@ -0,0 +1,177 @@ +## ShardFormer + +### Intro +Make the model in huggingface.co can be paralleled and can be used with colossalai according to custom policy. + +### Quick start +1. Usage +- Use +``` python +from colossalai.shardformer.shard.shardmodel import ShardModel +from transformers import BertForMaskedLM + +# create huggingface model as normal +model = BertForMaskedLM.from_pretrained("bert-base-uncased") + +# make the huggingface model paralleled to ShardModel +# auto policy: +shardmodel = ShardModel(model).model + +# custom policy: +from xxx import +shardmodel = ShardModel(model, ).model + + +# do angthing as normal +... +``` +- Policy + +If you wanna parallel the model in custom way, just overwrite the policy class for the huggingface model. + +You should do: + +1. Inherit Policy class +2. Overwrite argument_policy method + - In this method you need to list which layers class you wanna modify and the attributes and parameters in those layers. +3. Overwrite inject_policy method [Optional] + - If you need to modify the forward or backward progress. +4. Overwrite or add the param recording functions + - These function use suffix to record the path of weight or bias for the layer. +5. Overwrite binding + +More details can be found in shardformer/policies/basepolicy.py +``` python +from colossalai.shardformer.policies.basepolicy import Policy, Layer, Col_Layer, Row_Layer, Argument + +CustomPolicy(Policy): + @staticmethod + def argument_policy(model_config, shard_config: int) -> Dict[nn.Module,Argument]: + """ + Return a dict, the key is layer will be modified and the value is the Argument class with param setting and param functions + + Args: + model_config: The config of transformer model + shard_setting: The config of distributed model + + Return: + Dict for the modify policy, + { + origin layer class1 (nn.Module): Argument( + attr_dict = { + argument1: value1, + argument2: value2, + ... + }, + param_funcs = [ + staticmethod1, + staticmethod2, + ... + ] + ), + origin layer class2 (nn.Module): Argument( + attr_dict = { + argument1: value1, + argument2: value2, + ... + }, + param_funcs = [ + staticmethod1, + staticmethod2, + ... + ] + ), + ... + } + + """ + raise NotImplementedError + + @staticmethod + def inject_policy() -> Tuple[nn.Module, nn.Module]: + """ + Return the dict for the inject model + + Return: + The injected model, key is the original model and value is the new shardmodel + """ + return () + + @staticmethod + def binding_policy() -> Dict: + """ + Return the dict for the binding model + """ + return NotImplementedError + + @staticmethod + def attn_in() -> List: + """ + Attention qkv layer + + Returns: + List[Layer]: List of layer object, each layer is the new + """ + return NotImplementedError + + @staticmethod + def attn_out() -> List: + """ + Attention output projection layer + + Returns: + List[Layer]: List of layer object + """ + return NotImplementedError + + @staticmethod + def mlp_in() -> List: + """ + h -> 4h mlp layer + + Returns: + List[Layer]: List of layer object + """ + return NotImplementedError + + @staticmethod + def mlp_out() -> List: + """ + 4h -> h mlp layer + + Returns: + List[Layer]: List of layer object + """ + return NotImplementedError + + @staticmethod + def embedding() -> List: + """ + Partially slice the embedding layer + vocab_size->vocab_size//gpu_nums + + Return: + List[Layer]: List of layer object + """ + return NotImplementedError + + @staticmethod + def unembedding() -> List: + """ + Partially slice the embedding layer + vocab_size->vocab_size//gpu_nums + + Return: + List[Layer]: List of layer object + """ + return NotImplementedError + +``` + +2. Simple example +``` shell +# inference +colossalai run --nproc_per_node 2 --master_port 29500 test.py --config config.py --mode inference +# train +colossalai run --nproc_per_node 2 --master_port 29500 test.py --config config.py --mode train +``` diff --git a/colossalai/shardformer/model/modeling_bert.py b/colossalai/shardformer/model/modeling_bert.py index 87ed8ac308a5..6741ae866991 100644 --- a/colossalai/shardformer/model/modeling_bert.py +++ b/colossalai/shardformer/model/modeling_bert.py @@ -1,12 +1,14 @@ +from typing import Any, Dict, List, Type + import torch import torch.nn as nn from torch.nn import CrossEntropyLoss -from typing import Any, Dict, List, Type - - from transformers import BertForMaskedLM from transformers.models.bert.modeling_bert import MaskedLMOutput + + class BertForMaskedLM_(BertForMaskedLM): + def forward( self, input_ids=None, @@ -23,7 +25,7 @@ def forward( return_dict=None, **kwargs, ): - print("[Inject OK] Injected forward method") + # print("[Inject OK] Injected forward method") return_dict = return_dict if return_dict is not None else self.config.use_return_dict outputs = self.bert( @@ -46,9 +48,9 @@ def forward( masked_lm_loss = None # if input_ids is not None: - # masked_lm_loss = applyDistCrossEntropy(prediction_scores, input_ids, self.config.vocab_size) + # masked_lm_loss = applyDistCrossEntropy(prediction_scores, input_ids, self.config.vocab_size) if labels is not None: - loss_fct = CrossEntropyLoss() # -100 index = padding token + loss_fct = CrossEntropyLoss() # -100 index = padding token masked_lm_loss = loss_fct(prediction_scores.view(-1, self.config.vocab_size), labels.view(-1)) if not return_dict: @@ -60,4 +62,4 @@ def forward( logits=prediction_scores, hidden_states=outputs.hidden_states, attentions=outputs.attentions, - ) \ No newline at end of file + ) diff --git a/colossalai/shardformer/policies/autopolicy.py b/colossalai/shardformer/policies/autopolicy.py index 9142e0dae22e..e096c2b13a59 100644 --- a/colossalai/shardformer/policies/autopolicy.py +++ b/colossalai/shardformer/policies/autopolicy.py @@ -1,40 +1,47 @@ import torch.nn as nn + def build_policies(): - """ + r""" Build the policies for the model - + Return: The dict for the policies """ auto_policy_dict = {} from transformers.models.bert.modeling_bert import BertForMaskedLM + from .bert import BertForMaskedLMPolicy auto_policy_dict[BertForMaskedLM] = BertForMaskedLMPolicy from transformers.models.bert.modeling_bert import BertForSequenceClassification + from .bert import BertForSequenceClassificationPolicy auto_policy_dict[BertForSequenceClassification] = BertForSequenceClassificationPolicy - + return auto_policy_dict -def get_autopolicy(model:nn.Module): - """ + +def get_autopolicy(model: nn.Module): + r""" Return the auto policy for the model Args: - model: The model to be used + model (:class:`nn.Module`): The model to get the auto policy Return: - The auto policy for the model + :class:`Policy`: The auto policy for the model """ auto_policy_dict = build_policies() policy = auto_policy_dict.get(model.__class__, None) - if policy is None: - raise NotImplementedError(f"Auto policy for {model.__class__.__qualname__} is not implemented\n Supported models are {[i.__qualname__ for i in auto_policy_dict.keys()]}") + if policy is None: + raise NotImplementedError( + f"Auto policy for {model.__class__.__qualname__} is not implemented\n Supported models are {[i.__qualname__ for i in auto_policy_dict.keys()]}" + ) return policy + # from transformers.models.bert.modeling_bert import BertForMaskedLM, BertForPreTraining # model = BertForPreTraining # policy = get_autopolicy(model) diff --git a/colossalai/shardformer/policies/basepolicy.py b/colossalai/shardformer/policies/basepolicy.py index d444aeb53bf8..a5cc0bc68df6 100644 --- a/colossalai/shardformer/policies/basepolicy.py +++ b/colossalai/shardformer/policies/basepolicy.py @@ -1,28 +1,38 @@ # part of code modified from https://github.com/tunib-ai/parallelformers +from dataclasses import dataclass, field +from typing import Any, Callable, Dict, List, Tuple, Type + import torch import torch.nn as nn -import colossalai.nn as col_nn -from typing import Any, Dict, List, Type, Tuple, Callable from transformers import AutoConfig -from dataclasses import dataclass, field + +import colossalai.nn as col_nn + @dataclass class Argument: - attr_dict : Dict[str, Any] - param_funcs : List[Callable] - binding_layers : List[nn.Module] = field(default_factory=list) + r""" + The argument class for the policy + + Args: + attr_dict (Dict[str, Any]): The dict for the param setting + param_funcs (:class:`List[Callable]`): The list for the param functions + """ + attr_dict: Dict[str, Any] + param_funcs: List[Callable] + @dataclass class Layer: - """ + r""" The layer object for the policy Args: - weight: The weight name of the layer - bias: The bias name of the layer - replace_layer: The layer to replace the original layer - ignore: Whether to ignore this layer if it is not in the model + weight (str): The weight suffix of the layer + bias (str): The bias suffix of the layer + replace_layer (:class:`colosalai.nn`): The layer to replace the original layer + ignore (bool): Whether to ignore this layer if it is not in the model """ weight: str = None bias: str = None @@ -32,45 +42,55 @@ class Layer: @dataclass class Col_Layer(Layer): - """ + r""" Class for col shard layer in MegatronLM + + Args: + gather_output (bool): Whether to gather the output of the layer """ gather_output: bool = False @dataclass class Row_Layer(Layer): - """ + r""" Class for col shard layer in MegatronLM """ pass class Policy(): - """ + r""" The base class for all the policies - For each different model, it should have a different policy class, like BertPolicy for Bert Model - or OPTPolicy for OPT model. + For each different model, it should have a different policy class, like BertPolicy for Bert Model + or OPTPolicy for OPT model. AutoPolicy: - shardformer already defined some policies for huggingface model, just set custom_policy = None + Shardformer already defined some policies for huggingface model, just set ``custom_policy`` = None to use the auto policy. In shardformer autopolicy, we define a base policy for one type model, - like BertPolicy, and for each different Bert modle in huggingface like, BertForMaskedLM, + like BertPolicy, and for each different Bert modle in huggingface like, BertForMaskedLM, BertForSequenceClassification, etc., for each different Bert model we difine different policy class - and overwrite the method inject_policy - + and overwrite the method like ``inject_policy`` to modify the forward and backward process. + CustomPolicy: + If you want to define your own policy, you can set ``custom_policy`` = CustomPolicy, and overwrite + all the methods in ``Policy`` class. You can refer to any policy we defined like the ``BertPolicy`` + class for the example. + """ + @staticmethod - def argument_policy(model_config, shard_config: int) -> Dict[nn.Module,Argument]: - """ - Return a dict, the key is layer will be modified and the value is the Argument class with param setting and param functions + def argument_policy(model_config, shard_config: int) -> Dict[nn.Module, Argument]: + r""" + Return the dict for the modify policy, the key is the original layer class and the value is the + argument for the modify layer Args: - model_config: The config of transformer model - shard_setting: The config of distributed model - + model_config (:class:`tansformer.Config`): The config of transformer model + shard_config (:class:`ShardConfig`): The config for sharding model + Return: Dict for the modify policy, + :: { origin layer class1 (nn.Module): Argument( attr_dict = { @@ -101,33 +121,51 @@ def argument_policy(model_config, shard_config: int) -> Dict[nn.Module,Argument] """ raise NotImplementedError - @staticmethod def inject_policy() -> Tuple[nn.Module, nn.Module]: - """ - Return the dict for the inject model + r""" + Return the dict for the inject model Return: The injected model, key is the original model and value is the new shardmodel + :: + (OrignModel, CustomModel) + in `CustomModel`, we can overwrite the forward and backward process """ return () - @staticmethod - def attn_in() -> List: + def binding_policy() -> Dict: + r""" + Return the dict for the binding model + + Return: + This method should return the binding relationship for some layers share the weight or bias, + the key and value is the suffix of the weight or bias of the model + :: + return { + "bert.embeddings.word_embeddings.weight": "cls.predictions.decoder.weight", + } """ + return NotImplementedError + + @staticmethod + def attn_in() -> List: + r""" Attention qkv layer + In this kind of method, we should return the list of ``Layer`` object, each ``Layer`` object should be + ``Layer`` for no slicing, ``Col_Layer`` for col slicing, ``Row_Layer`` for row slicing. And the parameters + in ``Layer`` object can refer to the ``Layer`` class. Returns: - List[Layer]: List of layer object, each layer is the new + List[Layer]: List of layer object, each layer is the new """ return NotImplementedError - @staticmethod def attn_out() -> List: - """ + r""" Attention output projection layer Returns: @@ -135,46 +173,40 @@ def attn_out() -> List: """ return NotImplementedError - @staticmethod def mlp_in() -> List: - """ + r""" h -> 4h mlp layer Returns: List[Layer]: List of layer object """ return NotImplementedError - @staticmethod def mlp_out() -> List: - """ + r""" 4h -> h mlp layer Returns: List[Layer]: List of layer object """ return NotImplementedError - - + @staticmethod - def embedding()->List: - """ + def embedding() -> List: + r""" Partially slice the embedding layer - vocab_size->vocab_size//gpu_nums Return: List[Layer]: List of layer object """ return NotImplementedError - - + @staticmethod - def unembedding()->List: - """ + def unembedding() -> List: + r""" Partially slice the embedding layer - vocab_size->vocab_size//gpu_nums Return: List[Layer]: List of layer object diff --git a/colossalai/shardformer/policies/bert.py b/colossalai/shardformer/policies/bert.py index 24b95e827347..5d91d8ddc766 100644 --- a/colossalai/shardformer/policies/bert.py +++ b/colossalai/shardformer/policies/bert.py @@ -1,56 +1,57 @@ -from typing import Dict, List, Tuple, Type, Any, Callable +from dataclasses import dataclass +from typing import Any, Callable, Dict, List, Tuple, Type + import torch.nn as nn -from .basepolicy import Policy, Layer, Argument, Col_Layer, Row_Layer +from transformers.models.bert.modeling_bert import BertEmbeddings, BertLayer, BertLMPredictionHead + import colossalai.nn as col_nn -from transformers.models.bert.modeling_bert import BertLayer, BertEmbeddings, BertLMPredictionHead -from dataclasses import dataclass + +from .basepolicy import Argument, Col_Layer, Layer, Policy, Row_Layer class BertPolicy(Policy): + @staticmethod - def argument_policy(config, world_size: int) -> Dict[nn.Module,Argument]: + def argument_policy(config, world_size: int) -> Dict[nn.Module, Argument]: return { - BertLayer: Argument( - attr_dict = { - # 1. shard hidden size - "attention.self.all_head_size": config.hidden_size // world_size, - "crossattention.self.all_head_size": config.hidden_size // world_size, - # 2. shard number of heads - "attention.self.num_attention_heads": config.num_attention_heads // world_size, - "crossattention.self.num_attention_heads": config.num_attention_heads // world_size, - - }, - param_funcs = [ - BertPolicy.attn_in, - BertPolicy.attn_out, - BertPolicy.mlp_in, - BertPolicy.mlp_out - ] - ), - BertEmbeddings: Argument( - attr_dict = { - # 1. shard vocab size - # "word_embeddings.num_embeddings": config.vocab_size // world_size, - # 2. add the size of the sliced embedding layer excluding the last slice - "word_embeddings.dim_size": (config.vocab_size+world_size-1) // world_size, - }, - param_funcs = [ - BertPolicy.embedding, - ], - binding_layers = [ - BertLMPredictionHead, - ] - ), - BertLMPredictionHead: Argument( - attr_dict = { - # 1. shard vocab size - # "word_embeddings.num_embeddings": config.vocab_size // world_size, - # 2. add the size of the sliced embedding layer excluding the last slice - }, - param_funcs = [ - BertPolicy.unembedding, - ] - ) + BertLayer: + Argument( + attr_dict={ + # 1. shard hidden size + "attention.self.all_head_size": config.hidden_size // world_size, + "crossattention.self.all_head_size": config.hidden_size // world_size, + # 2. shard number of heads + "attention.self.num_attention_heads": config.num_attention_heads // world_size, + "crossattention.self.num_attention_heads": config.num_attention_heads // world_size, + }, + param_funcs=[BertPolicy.attn_in, BertPolicy.attn_out, BertPolicy.mlp_in, BertPolicy.mlp_out]), + BertEmbeddings: + Argument( + attr_dict={ + # 1. shard vocab size + # "word_embeddings.num_embeddings": config.vocab_size // world_size, + # 2. add the size of the sliced embedding layer excluding the last slice + "word_embeddings.dim_size": (config.vocab_size + world_size - 1) // world_size, + }, + param_funcs=[ + BertPolicy.embedding, + ]), + BertLMPredictionHead: + Argument( + attr_dict={ + # 1. shard vocab size + # "word_embeddings.num_embeddings": config.vocab_size // world_size, + # 2. add the size of the sliced embedding layer excluding the last slice + }, + param_funcs=[ + BertPolicy.unembedding, + ]) + } + + @staticmethod + def binding_policy() -> Dict: + return { + "bert.embeddings.word_embeddings.weight": "cls.predictions.decoder.weight", } @staticmethod @@ -89,9 +90,8 @@ def attn_in() -> List: replace_layer=col_nn.Linear1D_Col, ignore=True, ), - ] - + @staticmethod def attn_out() -> List: return [ @@ -107,17 +107,17 @@ def attn_out() -> List: ignore=True, ), ] - + @staticmethod def mlp_in() -> List: return [ - Col_Layer( + Col_Layer( weight="intermediate.dense.weight", bias="intermediate.dense.bias", replace_layer=col_nn.Linear1D_Col, ), ] - + @staticmethod def mlp_out() -> List: return [ @@ -130,13 +130,11 @@ def mlp_out() -> List: @staticmethod def embedding() -> List: - return [ - Col_Layer( - weight="word_embeddings.weight", - replace_layer=col_nn.VocabParallelEmbedding1D, - ) - ] - + return [Col_Layer( + weight="word_embeddings.weight", + replace_layer=col_nn.VocabParallelEmbedding1D, + )] + @staticmethod def unembedding() -> List: return [ @@ -148,16 +146,21 @@ def unembedding() -> List: ) ] + from transformers import BertForMaskedLM + from colossalai.shardformer.model.modeling_bert import BertForMaskedLM_ + + class BertForMaskedLMPolicy(BertPolicy): + @staticmethod def inject_policy() -> Tuple[nn.Module, nn.Module]: return (BertForMaskedLM, BertForMaskedLM_) - - + class BertForSequenceClassificationPolicy(BertPolicy): + @staticmethod def inject_policy() -> Dict: return {} @@ -165,4 +168,4 @@ def inject_policy() -> Dict: # model = BertForMaskedLM.from_pretrained("bert-base-uncased") # _ = BertForMaskedLMPolicy(model) -# print(isinstance(model,list(_.inject_policy().keys())[0])) \ No newline at end of file +# print(isinstance(model,list(_.inject_policy().keys())[0])) diff --git a/colossalai/shardformer/shard/shardconfig.py b/colossalai/shardformer/shard/shardconfig.py index be265ff0c8c1..c6a2513a6eff 100644 --- a/colossalai/shardformer/shard/shardconfig.py +++ b/colossalai/shardformer/shard/shardconfig.py @@ -10,9 +10,9 @@ class ShardConfig: fp16: bool = True num_gpus: int = 2 world_size: int = 2 - backend="nccl" + backend = "nccl" verbose: str = 'simple' seed: int = None require_grad: bool = False master_addr: str = "127.0.0.1" - master_port: int = 29500 \ No newline at end of file + master_port: int = 29500 diff --git a/colossalai/shardformer/shard/sharder.py b/colossalai/shardformer/shard/sharder.py index ef785cfee9da..2f6bb4265a11 100644 --- a/colossalai/shardformer/shard/sharder.py +++ b/colossalai/shardformer/shard/sharder.py @@ -1,56 +1,59 @@ +import os +from dataclasses import dataclass +from typing import Any, Callable, Dict, Iterable, List, Optional, Tuple, Type, Union + import torch import torch.nn as nn -from typing import Any, Dict, Iterable, List, Optional, Tuple, Type, Union, Callable -from .shardconfig import ShardConfig -from dataclasses import dataclass -from ..policies.basepolicy import Policy, Layer -from ..policies.autopolicy import get_autopolicy -from .slicer import Slicer -from ..utils.utils import hasattr_, setattr_, getattr_ + import colossalai.nn as col_nn from colossalai.logging import get_dist_logger -import os +from ..policies.autopolicy import get_autopolicy +from ..policies.basepolicy import Layer, Policy +from ..utils.utils import getattr_, hasattr_, setattr_ +from .shardconfig import ShardConfig +from .slicer import Slicer logger = get_dist_logger() + class ModelSharder(object): - """ + r""" Shard the original huggingface model according to the policy Args: - policy: The policy to shard the model - model: The model to shard - dist_setting: The setting of distributed model + policy (:class:`Policy`): The policy to shard the model + model (:class:`torch.Module`): The model to shard + shard_config: The setting of distributed model """ + def __init__( self, model: nn.Module, policy: Policy, - shard_config: ShardConfig = None, # TODO - ) -> None: + shard_config: ShardConfig = None, # TODO + ) -> None: self.model = model self.policy = get_autopolicy(self.model) if policy is None else policy self.slicer = Slicer(shard_config) self.shard_config = shard_config self.model_config = self.model.config - self.binding_map = {} - def shard(self) -> None: self.inject_model(self.model) self.replace_layer(self.model) - - + self.bind_layer(self.model) + def inject_model( - self, - model: nn.Module, - ) -> None: - """ + self, + model: nn.Module, + ) -> None: + r""" Replace the model to policy defined model Mainly modify the forward and backward to fit distributed model - + e.g. + :: BertForMaskedLM.forward -> BertForMaskedLM_.forward """ inject_policy = self.policy.inject_policy() @@ -64,49 +67,43 @@ def inject_model( setattr( model.__class__, key, - getattr(shard_model_cls,key), + getattr(shard_model_cls, key), ) else: raise NotImplementedError(f"{model.__class__} is not implemented so far") - def replace_layer( - self, - model: nn.Module, - ) -> None: - """ + self, + model: nn.Module, + ) -> None: + r""" Replace the layer according to the policy, and replace the layer one by one Args: - layer: The layer to shard + model (:class:`torch.nn.Module`): The layer to shard """ argument_policies = self.policy.argument_policy(self.model_config, self.shard_config.world_size) for argument_policy in argument_policies.items(): origin_layer_cls = argument_policy[0] attr_dict = argument_policy[1].attr_dict param_funcs = argument_policy[1].param_funcs - binding_layers = argument_policy[1].binding_layers - # if binding_layer is not None: - # self.binding_map[origin_layer_cls] = binding_layer - self.reverse_replace_layer(model, origin_layer_cls, attr_dict, param_funcs, binding_layers) - + self.reverse_replace_layer(model, origin_layer_cls, attr_dict, param_funcs) def reverse_replace_layer( - self, - layer: nn.Module, - origin_cls: nn.Module, - attr_dict: Dict[str, Any], - param_funcs: List[Callable], - binding_layers: List[nn.Module] - ) -> None: - """ + self, + layer: nn.Module, + origin_cls: nn.Module, + attr_dict: Dict[str, Any], + param_funcs: List[Callable], + ) -> None: + r""" Reverse the replace layer operation Args: - layer: The object of layer to shard - origin_cls: The origin layer class - attr_dict: The attribute dict to modify - policy_cls: The policy class + layer (:class:`torch.nn.Module`): The object of layer to shard + origin_cls (:class:`transformers.model`): The origin layer class + attr_dict (Dict): The attribute dict to modify + policy_cls (:class:`Policy`): The policy class """ for name, child in layer.named_children(): if child.__class__ == origin_cls: @@ -115,25 +112,23 @@ def reverse_replace_layer( setattr_(child, k, v, ignore=True) # print(f"Sharding {name} layer", replac_layer.attention.self.__dict__) # setattr_(layer, name, self.shard_one_layer(child, policy_cls)) - self.shard_one_layer(child, param_funcs, binding_layers) + self.shard_one_layer(child, param_funcs) continue - self.reverse_replace_layer(child, origin_cls, attr_dict, param_funcs, binding_layers) + self.reverse_replace_layer(child, origin_cls, attr_dict, param_funcs) return layer - def shard_one_layer( - self, - org_layer: nn.Module, - param_funcs: List[Callable], - binding_layers: List[nn.Module] - ) -> None: - """ + self, + org_layer: nn.Module, + param_funcs: List[Callable], + ) -> None: + r""" Shard one layer according to the policy, the layer should be the same class as the key in policy's argument_policy return dict Args: - org_layer: The origin layer object to shard - param_funcs: The function list to get shard information in policy class + org_layer (:class:`torch.nn.Module`): The origin layer object to shard + param_funcs (:class:`List[typing.Callable]`): The function list to get shard information in policy class """ # print(org_layer) @@ -148,7 +143,7 @@ def shard_one_layer( ignore = policy_layer.ignore if policy_layer.__class__.__name__ == "Col_Layer": gather_output = policy_layer.gather_output - print(gather_output) + # print(gather_output) if weight_attr is not None: if hasattr_(org_layer, weight_attr): @@ -172,67 +167,81 @@ def shard_one_layer( # slice weight and bias weight, bias = self.slicer.slice_weight_bias(weight, bias, policy_layer.__class__) - print(os.environ['RANK'], policy_layer.__class__, weight.shape, bias.shape if bias is not None else None) - # save the binding information - for binding_layer in binding_layers: - self.binding_map[binding_layer] = dict(weight=weight, bias=bias) + # print(os.environ['RANK'], policy_layer.__class__, weight.shape, bias.shape if bias is not None else None) # create new object to replace the origin layer if replace_layer_cls is not None: # print(f"RANK {os.environ['RANK']}: replace {getattr_(org_layer, layer_attr).__class__} to {replace_layer_cls}, shape is {weight.shape}") if isinstance(getattr_(org_layer, layer_attr), nn.Linear): if replace_layer_cls.__name__ == "Linear1D_Row": - replace_layer = replace_layer_cls(weight.shape[1], weight.shape[0], bias=False if bias is None else True) + replace_layer = replace_layer_cls(weight.shape[1], + weight.shape[0], + bias=False if bias is None else True) elif replace_layer_cls.__name__ == "Linear1D_Col": - replace_layer = replace_layer_cls(weight.shape[0], weight.shape[1], bias=False if bias is None else True, gather_output=gather_output) + replace_layer = replace_layer_cls(weight.shape[0], + weight.shape[1], + bias=False if bias is None else True, + gather_output=gather_output) setattr_(org_layer, layer_attr, replace_layer, ignore=ignore) self.set_param(replace_layer, weight, bias) - elif isinstance(getattr_(org_layer, layer_attr), nn.Embedding): - replace_layer = replace_layer_cls(weight.shape[0], weight.shape[1], getattr_(org_layer, f"{layer_attr}.padding_idx", ignore=True)) + elif isinstance(getattr_(org_layer, layer_attr), nn.Embedding): + replace_layer = replace_layer_cls(weight.shape[0], weight.shape[1], + getattr_(org_layer, f"{layer_attr}.padding_idx", ignore=True)) setattr_(org_layer, layer_attr, replace_layer, ignore=ignore) self.set_param(replace_layer, weight, bias) else: - raise NotImplementedError(f"Replacing {getattr_(org_layer, layer_attr).__class__} is not implemented so far") + raise NotImplementedError( + f"Replacing {getattr_(org_layer, layer_attr).__class__} is not implemented so far") # do not replace the layer object, just replace the weight and bias else: self.set_param(org_layer, layer_attr, weight, bias) - - def set_param( - self, - layer: Any, - layer_attr: str = "", - weight: torch.Tensor = None, - bias: torch.Tensor = None - ) -> None: - """ + def set_param(self, + layer: Any, + weight: torch.Tensor = None, + bias: torch.Tensor = None, + layer_attr: str = "") -> None: + r""" Reset the weight and bias of the layer object Args: - layer: The layer object - layer_attr: The attribute name of the layer - weight: The weight of the layer - bias: The bias of the layer + layer (:class:`torch.nn.Module`): The layer object + layer_attr (str): The attribute name of the layer + weight (:class:`torch.Tensor`): The weight of the layer + bias (:class:`torch.Tensor`): The bias of the layer """ assert weight is not None or bias is not None if weight is not None: - setattr_(layer, "weight" if layer_attr == "" else layer_attr+".weight", nn.Parameter(weight)) + setattr_(layer, "weight" if layer_attr == "" else layer_attr + ".weight", nn.Parameter(weight.contiguous())) self.set_layer_size(layer, layer_attr, weight.shape) if bias is not None: - setattr_(layer, "bias" if layer_attr == "" else layer_attr+".bias", nn.Parameter(bias)) - + setattr_(layer, "bias" if layer_attr == "" else layer_attr + ".bias", nn.Parameter(bias.contiguous())) def set_layer_size(self, layer: nn.Module, layer_attr: str, size: torch.Size) -> None: - """ + r""" Set the layer attribute Args: - layer: The layer object - layer_attr: The attribute name of the layer - size: Torch.size + layer (:class:`torch.nn.Module`): The layer object + layer_attr (str): The attribute name of the layer + size (:class:`torch.Size`): The size of the tensor """ # Tensor.shape[0] -> out_features, Tensor.shape[1] -> in_features attrs = ["out_features", "in_features"] for i, attr in enumerate(attrs): if hasattr_(layer, f"{layer_attr}.{attr}"): - setattr_(layer, f"{layer_attr}.{attr}", size[i]) + setattr_(layer, f"{layer_attr}.{attr}", size[i]) + + def bind_layer(self, model: nn.Module) -> None: + r""" + Bind the layer according to the binding policy + + Args: + model (:class:`torch.nn.Module`): The shard model + """ + binding_map = self.policy.binding_policy() + for k, v in binding_map.items(): + param = getattr_(model, k) + param = nn.Parameter(param) + setattr_(model, k, param) + setattr_(model, v, param) diff --git a/colossalai/shardformer/shard/shardmodel.py b/colossalai/shardformer/shard/shardmodel.py index 54d7b5ba02d9..7e7d1576afd6 100644 --- a/colossalai/shardformer/shard/shardmodel.py +++ b/colossalai/shardformer/shard/shardmodel.py @@ -1,46 +1,48 @@ import os +from contextlib import suppress +from dataclasses import dataclass + import torch +import torch.distributed as dist import torch.nn as nn import transformers -import torch.distributed as dist -from dataclasses import dataclass -from contextlib import suppress from colossalai.tensor.d_tensor.layout import Layout + from ..policies.basepolicy import Policy -from .sharder import ModelSharder from .shardconfig import ShardConfig +from .sharder import ModelSharder class ShardModel(object): - """ - The class for sharding the huggingface model, self.model is the sharded model + r""" + The class for sharding the huggingface model, ''self.model'' is the sharded model Just creat a new ShardModel object to shard huggingface model Args: - model: the origin huggingface model - dist_config: the config for distribute information - custom_policy: the custom policy for sharding + model (:class:`torch.nn.Model`): the origin huggingface model + dist_config (:class:`ShardConfig`): the config for distribute information + custom_policy (:class:`Policy`): the custom policy for sharding """ + def __init__( - self, - model: nn.Module, - shard_config: ShardConfig = None, # TODO - custom_policy: Policy = None, - ) -> None: + self, + model: nn.Module, + shard_config: ShardConfig = None, # TODO + custom_policy: Policy = None, + ) -> None: self.model = model self.shard_config = shard_config self.policy = custom_policy # self.layout=, # TODO - sharder=ModelSharder( + sharder = ModelSharder( model=self.model, policy=self.policy, shard_config=self.shard_config, ) sharder.shard() - def set_environ(self) -> None: os.environ["TOKENIZERS_PARALLELISM"] = "true" os.environ["MKL_SERVICE_FORCE_INTEL"] = "GNU" @@ -55,4 +57,4 @@ def set_environ(self) -> None: torch.cuda.set_device(int(os.getenv("LOCAL_RANK", "0"))) def back_to_org() -> None: - pass \ No newline at end of file + pass diff --git a/colossalai/shardformer/shard/slicer.py b/colossalai/shardformer/shard/slicer.py index 1849cdc99c72..096f5db95f49 100644 --- a/colossalai/shardformer/shard/slicer.py +++ b/colossalai/shardformer/shard/slicer.py @@ -1,40 +1,40 @@ import os -from typing import Dict, Tuple from dataclasses import dataclass +from typing import Dict, Tuple import torch import torch.distributed as dist -from ..policies.basepolicy import Layer, Col_Layer, Row_Layer -from .shardconfig import ShardConfig +from ..policies.basepolicy import Col_Layer, Layer, Row_Layer +from .shardconfig import ShardConfig dim_mapping = {Col_Layer: 1, Row_Layer: 0} + class Slicer(): def __init__( - self, - shardconfig: ShardConfig #TODO + self, + shardconfig: ShardConfig #TODO ) -> None: self.shardconfig = shardconfig - def slice_weight_bias( self, weight: torch.Tensor, bias: torch.Tensor, policy_layer_cls: Layer, ): - """ + r""" Slice the weight and bias according to policy layer cls - Layer -> do nothing - Col_Layer -> slice the weight and bias along dim 1 - Row_Layer -> slice the weight along dim 0 and do not slice bias + ``Layer`` -> do nothing + ``Col_Layer`` -> slice the weight and bias along dim 1 + ``Row_Layer`` -> slice the weight along dim 0 and do not slice bias Args: - weight: The weight of the layer - bias: The bias of the layer - policy_layer_class: The class represent how to slice the tensor + weight (:class:`torch.nn.Module`): The weight of the layer + bias: (:class:`torch.nn.Module`): The bias of the layer + policy_layer_class (:class:`Policy`): The class represent how to slice the tensor """ if policy_layer_cls == Layer: return weight, bias @@ -46,42 +46,6 @@ def slice_weight_bias( else: raise NotImplementedError(f"The policy layer class {policy_layer_cls} is not supported") return weight, bias - - - def slice_weight( - self, - weight: torch.Tensor, - policy_layer_cls: Layer, - ) -> torch.Tensor: - """ - Slice the weight and bias according to the shardconfig - - Args: - weight: The weight of the layer - bias: The bias of the layer - policy_layer_class: The class represent how to slice the tensor - """ - if weight is not None: - dim = dim_mapping[policy_layer_cls] - weight = self.slice_tensor(weight, dim, False) - return weight - - - def slice_bias( - self, - bias: torch.Tensor, - ) -> torch.Tensor: - """ - Slice the bias according to the shardconfig - - Args: - bias: The bias of the layer - """ - assert bias is not None, "The bias is None" - if bias is not None: - bias = self.slice_tensor(bias, 1, True) - return bias - def slice_tensor( self, @@ -89,8 +53,13 @@ def slice_tensor( dim: int, is_bias: bool, ) -> torch.Tensor: - """ + r""" Slice tensor according to the config + + Args: + tensor_in (:class:`torch.Tensor`): The tensor to slice + dim (int): The dimension to slice + is_bias (bool): Whether the tensor is bias """ if tensor_in is None: return None @@ -99,69 +68,75 @@ def slice_tensor( else: return self.slice_1d(tensor_in) - def slice_2d( self, tensor: torch.Tensor, dim: int, ) -> torch.Tensor: - """ - Slice the 2D tensor + r""" + Slice the 2D tensor Args: - tensor: The tensor to slice + tensor (:class:`torch.Tensor`): The tensor to slice + dim (int): The dimension to slice """ - assert dim in [0,1], f"Only support 2D tensor, but got {dim}D tensor" + assert dim in [0, 1], f"Only support 2D tensor, but got {dim}D tensor" if dim == 0: return self.slice_row(tensor) elif dim == 1: return self.slice_col(tensor) - def slice_1d( self, tensor: torch.Tensor, - dim: int = None, ) -> torch.Tensor: - """ - Slice the 1D tensor + r""" + Slice the 1D tensor Args: - tensor: The tensor to slice + tensor (:class:`torch.Tensor`): The tensor to slice + + Returns: + :class:`torch.Tensor`: The sliced tensor """ delta = (tensor.shape[0] + self.shardconfig.world_size - 1) // self.shardconfig.world_size down_idx = self.shardconfig.rank * delta up_idx = down_idx + delta - return tensor[down_idx:up_idx] + return tensor[down_idx:up_idx].contiguous() def slice_col( self, tensor: torch.Tensor, ) -> torch.Tensor: - """ + r""" Slice the tensor in column Args: - tensor: The tensor to slice + tensor (:class:`torch.Tensor`): The tensor to slice + + Returns: + :class:`torch.Tensor`: The sliced tensor + """ delta = (tensor.shape[0] + self.shardconfig.world_size - 1) // self.shardconfig.world_size down_idx = self.shardconfig.rank * delta up_idx = down_idx + delta - return tensor[down_idx:up_idx,:] - + return tensor[down_idx:up_idx, :].contiguous() def slice_row( self, tensor: torch.Tensor, ) -> torch.Tensor: - """ + r""" Slice the tensor in column Args: - tensor: The tensor to slice + tensor (:class:`torch.Tensor`): The tensor to slice + + Returns: + :class:`torch.Tensor`: The sliced tensor """ delta = (tensor.shape[1] + self.shardconfig.world_size - 1) // self.shardconfig.world_size down_idx = self.shardconfig.rank * delta up_idx = down_idx + delta - return tensor[:,down_idx:up_idx] - \ No newline at end of file + return tensor[:, down_idx:up_idx].contiguous() diff --git a/colossalai/shardformer/test/config.py b/colossalai/shardformer/test/config.py index 295529429237..2b80d8b3ca12 100644 --- a/colossalai/shardformer/test/config.py +++ b/colossalai/shardformer/test/config.py @@ -1,5 +1 @@ -parallel = dict( - data=1, - pipeline=1, - tensor=dict(size=2, mode='1d') -) \ No newline at end of file +parallel = dict(data=1, pipeline=1, tensor=dict(size=2, mode='1d')) diff --git a/colossalai/shardformer/test/test.py b/colossalai/shardformer/test/test.py index c2a9053ca2f6..0cdc6ef38fd2 100644 --- a/colossalai/shardformer/test/test.py +++ b/colossalai/shardformer/test/test.py @@ -1,23 +1,51 @@ -from transformers import AutoTokenizer -from transformers import BertForMaskedLM +import argparse +import inspect +import os + +import torch +import torch.nn as nn +from datasets import load_dataset +from torch.utils.data import DataLoader +from tqdm.auto import tqdm +from transformers import AutoTokenizer, BertForMaskedLM, DataCollatorForLanguageModeling, Trainer, TrainingArguments + import colossalai -from colossalai.shardformer.shard.shardmodel import ShardModel -from colossalai.utils import get_current_device, print_rank_0 from colossalai.logging import get_dist_logger from colossalai.shardformer.shard.shardconfig import ShardConfig -import inspect -import argparse -import torch.nn as nn -import os +from colossalai.shardformer.shard.shardmodel import ShardModel +from colossalai.utils import get_current_device, print_rank_0 +os.environ['TRANSFORMERS_NO_ADVISORY_WARNINGS'] = 'true' tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased") + def get_args(): parser = colossalai.get_default_parser() + parser.add_argument("--mode", type=str, default='inference') return parser.parse_args() + +def load_data(): + datasets = load_dataset('wikitext', 'wikitext-2-raw-v1') + # datasets=load_dataset("yelp_review_full") + tokenized_datasets = datasets.map( + lambda examples: tokenizer(examples["text"], truncation=True, padding="max_length"), batched=True) + tokenized_datasets = tokenized_datasets.remove_columns(["text"]) + # tokenized_datasets=tokenized_datasets.rename_column("label","labels") + tokenized_datasets.set_format("torch") + + train_dataset = tokenized_datasets["train"].select(range(500)) + test_dataset = tokenized_datasets["test"].select(range(100)) + + datacollector = DataCollatorForLanguageModeling(tokenizer, mlm=True, mlm_probability=0.15, return_tensors="pt") + train_dataloader = DataLoader(train_dataset, batch_size=8, shuffle=True, collate_fn=datacollector) + eval_dataloader = DataLoader(test_dataset, batch_size=8, shuffle=True, collate_fn=datacollector) + return train_dataloader, eval_dataloader + + def inference(model: nn.Module): - # print(model) + print(model) + tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased") token = "Hello, my dog is cute" inputs = tokenizer(token, return_tensors="pt") inputs.to("cuda") @@ -25,13 +53,48 @@ def inference(model: nn.Module): outputs = model(**inputs) print(outputs) + +def train(model: nn.Module, num_epoch: int = 2): + train_dataloader, eval_dataloader = load_data() + optimizer = torch.optim.AdamW(model.parameters(), lr=5e-5) + progress_bar = tqdm(range((num_epoch) * len(train_dataloader))) + criterion = nn.CrossEntropyLoss() + model.to("cuda") + model.train() + for epoch in range(num_epoch): + progress_bar.set_description(f"Rank {get_current_device()} epoch {epoch}") + + for batch in train_dataloader: + optimizer.zero_grad() + batch = {k: v.to('cuda') for k, v in batch.items()} + outputs = model(**batch) + loss = outputs.loss + loss.backward() + optimizer.step() + progress_bar.update(1) + train_loss = loss + + loss = 0.0 + for batch in eval_dataloader: + batch = {k: v.to('cuda') for k, v in batch.items()} + outputs = model(**batch) + # loss = outputs.loss + loss += outputs.loss.item() + # loss = criterion(outputs.logits, batch["input_ids"]) + test_loss = loss / len(eval_dataloader) + print_rank_0(f"Train Loss: {train_loss:.4f} Test Loss:{test_loss:.4f}") + + if __name__ == "__main__": args = get_args() colossalai.launch_from_torch(config=args.config) model = BertForMaskedLM.from_pretrained("bert-base-uncased") shard_config = ShardConfig( - rank = int(str(get_current_device()).split(':')[-1]), - world_size= int(os.environ['WORLD_SIZE']), + rank=int(str(get_current_device()).split(':')[-1]), + world_size=int(os.environ['WORLD_SIZE']), ) shardmodel = ShardModel(model, shard_config) - inference(shardmodel.model) + if args.mode == "train": + train(shardmodel.model) + elif args.mode == "inference": + inference(shardmodel.model) diff --git a/colossalai/shardformer/utils/utils.py b/colossalai/shardformer/utils/utils.py index 5eba87f6fe09..eb84edd88404 100644 --- a/colossalai/shardformer/utils/utils.py +++ b/colossalai/shardformer/utils/utils.py @@ -1,10 +1,10 @@ def hasattr_(obj, attr: str): - """ + r""" Check whether the object has the multi sublevel attr Args: - obj: The object to check - attr: The multi level attr to check + obj (object): The object to check + attr (str): The multi level attr to check """ attrs = attr.split('.') for a in attrs: @@ -14,15 +14,16 @@ def hasattr_(obj, attr: str): return False return True -def setattr_(obj, attr: str, value, ignore: bool=False): - """ + +def setattr_(obj, attr: str, value, ignore: bool = False): + r""" Set the object's multi sublevel attr to value, if ignore, ignore when it doesn't exist Args: - obj: The object to set - attr: The multi level attr to set - value: The value to set - ignore: Whether to ignore when the attr doesn't exist + obj (object): The object to set + attr (str): The multi level attr to set + value (Any): The value to set + ignore (bool): Whether to ignore when the attr doesn't exist """ attrs = attr.split('.') @@ -31,18 +32,19 @@ def setattr_(obj, attr: str, value, ignore: bool=False): obj = getattr(obj, a) except AttributeError: if ignore: - return + return raise AttributeError(f"Object {obj} has no attribute {attr}") setattr(obj, attrs[-1], value) -def getattr_(obj, attr: str, ignore: bool=None): - """ + +def getattr_(obj, attr: str, ignore: bool = None): + r""" Get the object's multi sublevel attr - + Args: - obj: The object to set - attr: The multi level attr to set - ignore: Whether to ignore when the attr doesn't exist + obj (object): The object to set + attr (str): The multi level attr to set + ignore (bool): Whether to ignore when the attr doesn't exist """ attrs = attr.split('.') @@ -53,4 +55,4 @@ def getattr_(obj, attr: str, ignore: bool=None): if ignore: return None raise AttributeError(f"Object {obj} has no attribute {attr}") - return obj \ No newline at end of file + return obj From 69d3daace09fce93885c7b282f29b1e3c4d4918a Mon Sep 17 00:00:00 2001 From: Frank Lee Date: Wed, 24 May 2023 11:51:48 +0800 Subject: [PATCH 03/49] [shardformer] updated readme (#3827) --- colossalai/shardformer/README.md | 53 ++++++++++++++++++++------------ 1 file changed, 33 insertions(+), 20 deletions(-) diff --git a/colossalai/shardformer/README.md b/colossalai/shardformer/README.md index a47e280f2be4..f76cbac8d7b8 100644 --- a/colossalai/shardformer/README.md +++ b/colossalai/shardformer/README.md @@ -1,11 +1,22 @@ -## ShardFormer +# ⚡️ ShardFormer -### Intro -Make the model in huggingface.co can be paralleled and can be used with colossalai according to custom policy. +## 📚 Table of Contents + +- [⚡️ ShardFormer](#️-shardformer) + - [📚 Table of Contents](#-table-of-contents) + - [🔗 Introduction](#-introduction) + - [🔨 Usage](#-usage) + - [🔮 Simple example](#-simple-example) + - [💡 Policy](#-policy) + +## 🔗 Introduction + +**Shardformer** is a module that automatically parallelizes the mainstream models in libraries such as HuggingFace and TIMM. This module aims to make parallelization hassle-free for users who are not from the system background. + +## 🔨 Usage + +The sample API usage is given below: -### Quick start -1. Usage -- Use ``` python from colossalai.shardformer.shard.shardmodel import ShardModel from transformers import BertForMaskedLM @@ -21,23 +32,33 @@ shardmodel = ShardModel(model).model from xxx import shardmodel = ShardModel(model, ).model - # do angthing as normal ... ``` -- Policy -If you wanna parallel the model in custom way, just overwrite the policy class for the huggingface model. +## 🔮 Simple example + +``` shell +# inference +colossalai run --nproc_per_node 2 --master_port 29500 test.py --config config.py --mode inference +# train +colossalai run --nproc_per_node 2 --master_port 29500 test.py --config config.py --mode train +``` + + +## 💡 Policy + +If you wanna parallel the model in a custom way, just overwrite the policy class for the Hugging Face model. You should do: 1. Inherit Policy class 2. Overwrite argument_policy method - - In this method you need to list which layers class you wanna modify and the attributes and parameters in those layers. -3. Overwrite inject_policy method [Optional] + - In this method, you need to list which layers class you wanna modify and the attributes and parameters in those layers. +3. Overwrite inject_policy method (Optional) - If you need to modify the forward or backward progress. 4. Overwrite or add the param recording functions - - These function use suffix to record the path of weight or bias for the layer. + - These functions use a suffix to record the path of weight or bias for the layer. 5. Overwrite binding More details can be found in shardformer/policies/basepolicy.py @@ -167,11 +188,3 @@ CustomPolicy(Policy): return NotImplementedError ``` - -2. Simple example -``` shell -# inference -colossalai run --nproc_per_node 2 --master_port 29500 test.py --config config.py --mode inference -# train -colossalai run --nproc_per_node 2 --master_port 29500 test.py --config config.py --mode train -``` From 0470f1ba29e79f9a056172207cc4674369e3b540 Mon Sep 17 00:00:00 2001 From: Frank Lee Date: Wed, 24 May 2023 16:01:26 +0800 Subject: [PATCH 04/49] [shardformer] refactored the user api (#3828) * [shardformer] refactored the user api * polish code --- colossalai/shardformer/README.md | 6 +- colossalai/shardformer/shard/__init__.py | 5 ++ .../shard/{shardconfig.py => shard_config.py} | 2 + colossalai/shardformer/shard/sharder.py | 27 ++++++--- colossalai/shardformer/shard/shardmodel.py | 60 ------------------- colossalai/shardformer/shard/slicer.py | 7 +-- colossalai/shardformer/test/test.py | 15 ++--- 7 files changed, 35 insertions(+), 87 deletions(-) rename colossalai/shardformer/shard/{shardconfig.py => shard_config.py} (93%) delete mode 100644 colossalai/shardformer/shard/shardmodel.py diff --git a/colossalai/shardformer/README.md b/colossalai/shardformer/README.md index f76cbac8d7b8..10fd1809b287 100644 --- a/colossalai/shardformer/README.md +++ b/colossalai/shardformer/README.md @@ -18,7 +18,7 @@ The sample API usage is given below: ``` python -from colossalai.shardformer.shard.shardmodel import ShardModel +from colossalai.shardformer import shard_model from transformers import BertForMaskedLM # create huggingface model as normal @@ -26,11 +26,11 @@ model = BertForMaskedLM.from_pretrained("bert-base-uncased") # make the huggingface model paralleled to ShardModel # auto policy: -shardmodel = ShardModel(model).model +sharded_model = shard_model(model) # custom policy: from xxx import -shardmodel = ShardModel(model, ).model +sharded_model = shard_model(model, ) # do angthing as normal ... diff --git a/colossalai/shardformer/shard/__init__.py b/colossalai/shardformer/shard/__init__.py index e69de29bb2d1..d5f70163ad57 100644 --- a/colossalai/shardformer/shard/__init__.py +++ b/colossalai/shardformer/shard/__init__.py @@ -0,0 +1,5 @@ +from .shard_config import ShardConfig +from .sharder import ModelSharder, shard_model +from .slicer import Slicer + +__all__ = ['ShardConfig', 'ModelSharder', 'shard_model', 'Slicer'] diff --git a/colossalai/shardformer/shard/shardconfig.py b/colossalai/shardformer/shard/shard_config.py similarity index 93% rename from colossalai/shardformer/shard/shardconfig.py rename to colossalai/shardformer/shard/shard_config.py index c6a2513a6eff..4cf9162b9548 100644 --- a/colossalai/shardformer/shard/shardconfig.py +++ b/colossalai/shardformer/shard/shard_config.py @@ -1,5 +1,7 @@ from dataclasses import dataclass +__all__ = ['ShardConfig'] + @dataclass class ShardConfig: diff --git a/colossalai/shardformer/shard/sharder.py b/colossalai/shardformer/shard/sharder.py index 2f6bb4265a11..2218661889f8 100644 --- a/colossalai/shardformer/shard/sharder.py +++ b/colossalai/shardformer/shard/sharder.py @@ -1,20 +1,15 @@ -import os -from dataclasses import dataclass -from typing import Any, Callable, Dict, Iterable, List, Optional, Tuple, Type, Union +from typing import Any, Callable, Dict, List import torch import torch.nn as nn -import colossalai.nn as col_nn -from colossalai.logging import get_dist_logger - from ..policies.autopolicy import get_autopolicy -from ..policies.basepolicy import Layer, Policy +from ..policies.basepolicy import Policy from ..utils.utils import getattr_, hasattr_, setattr_ -from .shardconfig import ShardConfig +from .shard_config import ShardConfig from .slicer import Slicer -logger = get_dist_logger() +__all__ = ['ModelSharder', 'shard_model'] class ModelSharder(object): @@ -245,3 +240,17 @@ def bind_layer(self, model: nn.Module) -> None: param = nn.Parameter(param) setattr_(model, k, param) setattr_(model, v, param) + + +def shard_model(model: nn.Module, shard_config: ShardConfig = None, policy: Policy = None): + r""" + The function is used to shard the PyTorch model. + + Args: + model (`torch.nn.Model`): the origin huggingface model + shard_config (`ShardConfig`): the config for distribute information + policy (`Policy`): the custom policy for sharding + """ + sharder = ModelSharder(model=model, shard_config=shard_config, policy=policy) + sharder.shard() + return model diff --git a/colossalai/shardformer/shard/shardmodel.py b/colossalai/shardformer/shard/shardmodel.py deleted file mode 100644 index 7e7d1576afd6..000000000000 --- a/colossalai/shardformer/shard/shardmodel.py +++ /dev/null @@ -1,60 +0,0 @@ -import os -from contextlib import suppress -from dataclasses import dataclass - -import torch -import torch.distributed as dist -import torch.nn as nn -import transformers - -from colossalai.tensor.d_tensor.layout import Layout - -from ..policies.basepolicy import Policy -from .shardconfig import ShardConfig -from .sharder import ModelSharder - - -class ShardModel(object): - r""" - The class for sharding the huggingface model, ''self.model'' is the sharded model - Just creat a new ShardModel object to shard huggingface model - - Args: - model (:class:`torch.nn.Model`): the origin huggingface model - dist_config (:class:`ShardConfig`): the config for distribute information - custom_policy (:class:`Policy`): the custom policy for sharding - """ - - def __init__( - self, - model: nn.Module, - shard_config: ShardConfig = None, # TODO - custom_policy: Policy = None, - ) -> None: - self.model = model - self.shard_config = shard_config - self.policy = custom_policy - # self.layout=, # TODO - - sharder = ModelSharder( - model=self.model, - policy=self.policy, - shard_config=self.shard_config, - ) - sharder.shard() - - def set_environ(self) -> None: - os.environ["TOKENIZERS_PARALLELISM"] = "true" - os.environ["MKL_SERVICE_FORCE_INTEL"] = "GNU" - os.environ["MASTER_ADDR"] = str(self.dist_config.master_addr) - os.environ["MASTER_PORT"] = str(self.dist_config.master_port) - os.environ["WORLD_SIZE"] = str(self.dist_config.num_gpus) - os.environ["RANK"] = str(self.dist_config.rank) - os.environ["LOCAL_RANK"] = str(self.dist_config.rank) - if not dist.is_initialized(): - dist.init_process_group(backend=self.dist_config.backend) - - torch.cuda.set_device(int(os.getenv("LOCAL_RANK", "0"))) - - def back_to_org() -> None: - pass diff --git a/colossalai/shardformer/shard/slicer.py b/colossalai/shardformer/shard/slicer.py index 096f5db95f49..957ce1f85814 100644 --- a/colossalai/shardformer/shard/slicer.py +++ b/colossalai/shardformer/shard/slicer.py @@ -1,12 +1,7 @@ -import os -from dataclasses import dataclass -from typing import Dict, Tuple - import torch -import torch.distributed as dist from ..policies.basepolicy import Col_Layer, Layer, Row_Layer -from .shardconfig import ShardConfig +from .shard_config import ShardConfig dim_mapping = {Col_Layer: 1, Row_Layer: 0} diff --git a/colossalai/shardformer/test/test.py b/colossalai/shardformer/test/test.py index 0cdc6ef38fd2..202208123ced 100644 --- a/colossalai/shardformer/test/test.py +++ b/colossalai/shardformer/test/test.py @@ -1,5 +1,3 @@ -import argparse -import inspect import os import torch @@ -7,12 +5,10 @@ from datasets import load_dataset from torch.utils.data import DataLoader from tqdm.auto import tqdm -from transformers import AutoTokenizer, BertForMaskedLM, DataCollatorForLanguageModeling, Trainer, TrainingArguments +from transformers import AutoTokenizer, BertForMaskedLM, DataCollatorForLanguageModeling import colossalai -from colossalai.logging import get_dist_logger -from colossalai.shardformer.shard.shardconfig import ShardConfig -from colossalai.shardformer.shard.shardmodel import ShardModel +from colossalai.shardformer.shard import ShardConfig, shard_model from colossalai.utils import get_current_device, print_rank_0 os.environ['TRANSFORMERS_NO_ADVISORY_WARNINGS'] = 'true' @@ -93,8 +89,9 @@ def train(model: nn.Module, num_epoch: int = 2): rank=int(str(get_current_device()).split(':')[-1]), world_size=int(os.environ['WORLD_SIZE']), ) - shardmodel = ShardModel(model, shard_config) + sharded_model = shard_model(model, shard_config) + if args.mode == "train": - train(shardmodel.model) + train(sharded_model) elif args.mode == "inference": - inference(shardmodel.model) + inference(sharded_model) From 051e970fac468d7a346539836f954a43584aebc1 Mon Sep 17 00:00:00 2001 From: FoolPlayer <45593998+FoolPlayer@users.noreply.github.com> Date: Wed, 24 May 2023 18:02:54 +0800 Subject: [PATCH 05/49] [shardformer] update readme with modules implement doc (#3834) * update readme with modules content * remove img --- colossalai/shardformer/README.md | 69 ++++++++++++++++++++++++++++++++ 1 file changed, 69 insertions(+) diff --git a/colossalai/shardformer/README.md b/colossalai/shardformer/README.md index 10fd1809b287..55b6aa75ef84 100644 --- a/colossalai/shardformer/README.md +++ b/colossalai/shardformer/README.md @@ -8,6 +8,8 @@ - [🔨 Usage](#-usage) - [🔮 Simple example](#-simple-example) - [💡 Policy](#-policy) + - [😊 Module](#-module) + ## 🔗 Introduction @@ -188,3 +190,70 @@ CustomPolicy(Policy): return NotImplementedError ``` + + +## 😊 Module + + 1. Flowchart + +

+ +

+ + 2. Important Modules + + - CLASS `shard_model`: + + This is the user api to use shardformer, just create a model from transformers and define a custom policy or use shardformer autopolicy to make a shard model. + + - CLASS `Layer`: + + Parameters: + - weight (str): The weight suffix of the layer + - bias (str): The bias suffix of the layer + - replace_layer (:class:`colosalai.nn`): The layer to replace the original layer + - ignore (bool): Whether to ignore this layer if it is not in the model + + This class is used to specify the replacement policy for a particular layer. If `replace_layer` is None, only parameter partitioning will be performed without replacing the layer class. + + CLASS `Col_Layer(Layer)`: + - gather_output (bool): Whether to gather the output of the layer + + This class inherited from `Layer`, representing the layer will be sliced along column. + + CLASS `Row_Layer(Layer)`: + + This class inherited from `Layer`, representing the layer will be sliced along row. + + - CLASS `Policy`: + + In Shardformer, this class holds significant importance as it defines the model partitioning methods, required parameter modifications, and model injection techniques all within a single Policy class. + - `Policy.attn_in()/attn_out()/mlp_in()/mlp_out()/embedding()/unembedding()`...... + + These functions define the partitioning methods of the parameters at different locations in the model. Each function returns a list of objects of Layer class that specify the replacement approach for these parameters. Shardformer also supports user-defined functions for modifying their models, in addition to the listed functions. + - `Policy.argument_policy()` + + In this function, the user should use multiple dict to define which class of layers will require replacement. This includes the attributes and parameters that need to be modified or replaced. Attributes are stored in the form of a "suffix-string: value" dict, while parameters are stored via multiple static methods that return the replacement approach. + - `Policy.inject_policy()` + + This function will return the injected model to replace the original model. The new model should be a nn.Module class which includes modified forward or backward functions or anything else. + - `Policy.binding_policy()` + + This function will return the weight sharing information in the model in some dict. The key and value are both the suffixes of the shared parameters. + + - CLASS `ModelSharder(model, policy)`: + + This class helps shard the model, the parameter is the created transformers model and the custom policy. If custom policy is None, shardformer will automatically get already defined policy for the model. + - `ModelShard.inject_model()` + + This function is used to inject the model to modify the forward and backward progress. + - `ModelShard.replace_layer()` + + This function is used to replace the original layers with colossalai layer to make them paralleled and can do distributed communication. + - `ModelShard.bind_layer()` + + This function is used to help different layers share weight or bias. + + - CLASS `Slicer`: + + This class is used to slice tensor according to policy. From 3e840f739c859679424e8f20b4fe6dad27909e80 Mon Sep 17 00:00:00 2001 From: FoolPlayer <45593998+FoolPlayer@users.noreply.github.com> Date: Thu, 1 Jun 2023 16:21:02 +0800 Subject: [PATCH 06/49] [shardformer] add Dropout layer support different dropout pattern (#3856) * add dropout layer, add dropout test * modify seed manager as context manager * add a copy of col_nn.layer * add dist_crossentropy loss; separate module test * polish the code * fix dist crossentropy loss --- colossalai/nn/layer/parallel_1d/_operation.py | 1 - colossalai/nn/layer/parallel_1d/layers.py | 9 +- colossalai/shardformer/README.md | 19 + colossalai/shardformer/layer/__init__.py | 0 colossalai/shardformer/layer/_operation.py | 97 ++ .../shardformer/layer/dist_crossentropy.py | 105 ++ colossalai/shardformer/layer/dropout.py | 58 + colossalai/shardformer/layer/layers.py | 1043 +++++++++++++++++ colossalai/shardformer/model/modeling_bert.py | 10 +- colossalai/shardformer/policies/basepolicy.py | 2 - colossalai/shardformer/policies/bert.py | 4 +- colossalai/shardformer/shard/slicer.py | 15 +- colossalai/shardformer/test/module_test.py | 50 + colossalai/shardformer/test/test.py | 41 +- 14 files changed, 1413 insertions(+), 41 deletions(-) create mode 100644 colossalai/shardformer/layer/__init__.py create mode 100644 colossalai/shardformer/layer/_operation.py create mode 100644 colossalai/shardformer/layer/dist_crossentropy.py create mode 100644 colossalai/shardformer/layer/dropout.py create mode 100644 colossalai/shardformer/layer/layers.py create mode 100644 colossalai/shardformer/test/module_test.py diff --git a/colossalai/nn/layer/parallel_1d/_operation.py b/colossalai/nn/layer/parallel_1d/_operation.py index c5e33fd497cd..300baf9c12ba 100644 --- a/colossalai/nn/layer/parallel_1d/_operation.py +++ b/colossalai/nn/layer/parallel_1d/_operation.py @@ -73,7 +73,6 @@ def backward(ctx, grad_output): total_input = input grad_input = grad_output.matmul(weight) - grad_output = grad_output.contiguous() # Convert the tensor shapes to 2D for execution compatibility grad_output = grad_output.view(grad_output.shape[0] * grad_output.shape[1], grad_output.shape[2]) total_input = total_input.view(total_input.shape[0] * total_input.shape[1], total_input.shape[2]) diff --git a/colossalai/nn/layer/parallel_1d/layers.py b/colossalai/nn/layer/parallel_1d/layers.py index 0ee3b4fcb502..406173a18c60 100644 --- a/colossalai/nn/layer/parallel_1d/layers.py +++ b/colossalai/nn/layer/parallel_1d/layers.py @@ -469,8 +469,7 @@ def __init__(self, if skip_bias_add and not bias: raise ValueError('cannot skip bias addition if bias is None') - # self.out_features_per_partition = divide(out_features*2, gpc.tensor_parallel_size) - self.out_features_per_partition = out_features + self.out_features_per_partition = divide(out_features, gpc.tensor_parallel_size) # Parameters. # Initialize weight. @@ -613,8 +612,7 @@ def __init__(self, raise ValueError('cannot skip bias addition if bias is None') # Divide the weight matrix along the last dimension. - # self.input_size_per_partition = divide(in_features*2, gpc.tensor_parallel_size) - self.input_size_per_partition = in_features + self.input_size_per_partition = divide(in_features, gpc.tensor_parallel_size) # Parameters. # Initialize weight. @@ -886,8 +884,7 @@ def __init__(self, tensor_parallel_size = gpc.get_world_size(ParallelMode.PARALLEL_1D) tensor_parallel_rank = gpc.get_local_rank(ParallelMode.PARALLEL_1D) - # self.num_embeddings_per_partition = divide(num_embeddings, tensor_parallel_size) - self.num_embeddings_per_partition = num_embeddings + self.num_embeddings_per_partition = divide(num_embeddings, tensor_parallel_size) self.vocab_start_index = tensor_parallel_rank * self.num_embeddings_per_partition self.vocab_end_index = self.vocab_start_index + self.num_embeddings_per_partition diff --git a/colossalai/shardformer/README.md b/colossalai/shardformer/README.md index 55b6aa75ef84..3394e9457da3 100644 --- a/colossalai/shardformer/README.md +++ b/colossalai/shardformer/README.md @@ -257,3 +257,22 @@ CustomPolicy(Policy): - CLASS `Slicer`: This class is used to slice tensor according to policy. + + + 3. DistCrossEntropy Loss + - Overview + + In order to reduce the communication size, caculate the crossentropy before all gather, refer to [Megatron-LM](https://github.com/NVIDIA/Megatron-LM), reduce the communication size from [batch_size * seq_length * vocab_size] to [batch_size * seq_length]. The origin loss function is: + $$ loss = -\log(\frac{\exp(x[class])}{\sum_i\exp(x[i])})$$ + + alse can be represented as: + + $$ loss = \log(\sum_i\exp(x[i])) - x[class]$$ + + - Step + + - First get the maximum logits across all the devices, make all the logist minus the maximun value to scale the value less than zero to avoid the value of exp being too large + + - Get a mask to mask the logits not in the local device + + - Caculate the loss according to the second formula diff --git a/colossalai/shardformer/layer/__init__.py b/colossalai/shardformer/layer/__init__.py new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/colossalai/shardformer/layer/_operation.py b/colossalai/shardformer/layer/_operation.py new file mode 100644 index 000000000000..e817ea3ebbee --- /dev/null +++ b/colossalai/shardformer/layer/_operation.py @@ -0,0 +1,97 @@ +import torch +import torch.distributed as dist + +from colossalai.core import global_context as gpc + +try: + import fused_mix_prec_layer_norm_cuda +except: + fused_mix_prec_layer_norm_cuda = None + + +class FusedLayerNormAffineFunction1D(torch.autograd.Function): + r"""Layernorm + + Args: + input: input matrix. + weight: weight matrix. + bias: bias matrix. + normalized_shape: input shape from an expected input of size. + :math:`[* \times \text{normalized_shape}[0] \times \text{normalized_shape}[1] \times \ldots \times \text{normalized_shape}[-1]]` + If a single integer is used, it is treated as a singleton list, and this module will + normalize over the last dimension which is expected to be of that specific size. + eps: a value added to the denominator for numerical stability + """ + + @staticmethod + def forward(ctx, input, weight, bias, normalized_shape, eps): + ctx.normalized_shape = normalized_shape + ctx.eps = eps + input_ = input.contiguous() + weight_ = weight.contiguous() + bias_ = bias.contiguous() + output, mean, invvar = fused_mix_prec_layer_norm_cuda.forward_affine(input_, ctx.normalized_shape, weight_, + bias_, ctx.eps) + ctx.save_for_backward(input_, weight_, bias_, mean, invvar) + return output + + @staticmethod + def backward(ctx, grad_output): + input_, weight_, bias_, mean, invvar = ctx.saved_tensors + grad_input = grad_weight = grad_bias = None + grad_input, grad_weight, grad_bias \ + = fused_mix_prec_layer_norm_cuda.backward_affine( + grad_output.contiguous(), mean, invvar, + input_, ctx.normalized_shape, + weight_, bias_, ctx.eps) + + return grad_input, grad_weight, grad_bias, None, None + + +class LinearWithAsyncCommunication(torch.autograd.Function): + """ + Linear layer execution with asynchronous communication in backprop. + """ + + @staticmethod + def forward(ctx, input_, weight, bias, parallel_mode, async_grad_allreduce): + ctx.save_for_backward(input_, weight) + ctx.use_bias = bias is not None + ctx.parallel_mode = parallel_mode + ctx.async_grad_allreduce = async_grad_allreduce + + output = torch.matmul(input_, weight.t()) + if bias is not None: + output = output + bias + return output + + @staticmethod + def backward(ctx, grad_output): + input, weight = ctx.saved_tensors + use_bias = ctx.use_bias + + total_input = input + grad_input = grad_output.matmul(weight) + grad_output = grad_output.contiguous() + # Convert the tensor shapes to 2D for execution compatibility + grad_output = grad_output.view(grad_output.shape[0] * grad_output.shape[1], grad_output.shape[2]) + total_input = total_input.view(total_input.shape[0] * total_input.shape[1], total_input.shape[2]) + + if ctx.async_grad_allreduce: + # Asynchronous all-reduce + handle = dist.all_reduce(grad_input, group=gpc.get_group(ctx.parallel_mode), async_op=True) + # Delay the start of weight gradient computation shortly (3us) to have + # all-reduce scheduled first and have GPU resources allocated + _ = torch.empty(1, device=grad_output.device) + 1 + + grad_weight = grad_output.t().matmul(total_input) + grad_bias = grad_output.sum(dim=0) if use_bias else None + + if ctx.async_grad_allreduce: + handle.wait() + + return grad_input, grad_weight, grad_bias, None, None, None + + +def linear_with_async_comm(input_, weight, bias, parallel_mode, async_grad_allreduce): + return LinearWithAsyncCommunication.apply(input_, weight, bias, parallel_mode, async_grad_allreduce) diff --git a/colossalai/shardformer/layer/dist_crossentropy.py b/colossalai/shardformer/layer/dist_crossentropy.py new file mode 100644 index 000000000000..1869594670ce --- /dev/null +++ b/colossalai/shardformer/layer/dist_crossentropy.py @@ -0,0 +1,105 @@ +import torch +import torch.distributed as dist +import torch.nn as nn +import torch.nn.functional as F +from torch.autograd import Function + + +class DistCrossEntropy(Function): + r""" + Overwrite the forward and backward function to calculate the cross entropy loss before gather + + Args: + Function (:class:`torch.autograd.Function`): default + """ + + @staticmethod + def forward(ctx, vocab_logits: torch.Tensor, target: torch.Tensor): + r""" + Calculate the cross entropy loss before gather, the origin loss function is as follows: + loss = -log(exp(x[class])/sum(exp(x[i])) + and can be rewrite as: + loss = log(sum(exp(x[i])) - x[class] + + To avoid the `nan` of log(sim(exp(x[i]))), we minus the max of x[i] + + Args: + vocab_logits (:class:`torch.Tensor`): The logits of the vocabulary, shape is + [batch_size, seq_len, vocab_size] + labels (:class:`torch.Tensor`): The labels of the vocabulary, shape is + [batch_size, seq_len] + + Returns: + :class:`torch.Tensor`: The cross entropy loss + """ + # get the max + logits_max = torch.max(vocab_logits, dim=-1)[0] + dist.all_reduce(logits_max, op=dist.ReduceOp.MAX) + + # minus the max to avoid the result of sum of exp is too large and the log is nan + vocab_logits = vocab_logits - logits_max.unsqueeze(dim=-1) + + # mask the target in the local device + partition_vocab_size = vocab_logits.size()[-1] + rank = dist.get_rank() + world_size = dist.get_world_size() + global_vocab_size = partition_vocab_size * world_size + + # [down, up) => false, other device and -100 => true + delta = (global_vocab_size + world_size - 1) // world_size + down_shreshold = rank * delta + up_shreshold = down_shreshold + delta + mask = (target < down_shreshold) | (target >= up_shreshold) + masked_target = target.clone() - down_shreshold + masked_target[mask] = 0 + + # reshape the logist and target + # reshape the vocab_logits to [bath_size * seq_len, vocab_size] + # reshape the labels to [bath_size * seq_len] + logits_2d = vocab_logits.view(-1, partition_vocab_size) + masked_target_1d = masked_target.view(-1) + + # extract the x[class] and set the x[other device] to zero + pred_logits_1d = logits_2d[torch.arange(start=0, end=logits_2d.shape[0], device=logits_2d.device), + masked_target_1d] + pred_logits_1d = pred_logits_1d.clone().contiguous() + pred_logits = pred_logits_1d.view_as(target) + pred_logits[mask] = 0.0 + + # allreduce the get all x(i,y) + dist.all_reduce(pred_logits, op=dist.ReduceOp.SUM) + exp_logits = vocab_logits + torch.exp(vocab_logits, out=exp_logits) + sum_exp_logits = torch.sum(exp_logits, dim=-1) + dist.all_reduce(sum_exp_logits, op=dist.ReduceOp.SUM) + + # calculate the loss + # loss = log(sum(exp(x[i]))) - x[class] + loss = torch.log(sum_exp_logits) - pred_logits + loss = torch.sum(loss).div_(loss.numel()) + + # caculate the softmax + exp_logits.div_(sum_exp_logits.unsqueeze(dim=-1)) + ctx.save_for_backward(exp_logits, mask, masked_target_1d) + + return loss + + @staticmethod + def backward(ctx, grad_output): + # retrieve the saved tensors + exp_logits, mask, masked_target_1d = ctx.saved_tensors + + # use exp logits as the input grad + grad_logits = exp_logits + partion_vocab_size = grad_logits.shape[-1] + grad_logits_2d = grad_logits.view(-1, partion_vocab_size) + + update = 1.0 - mask.view(-1).float() + grad_logits_2d[torch.arange(0, grad_logits_2d.shape[0]), masked_target_1d] -= update + + grad_logits.mul_(grad_output.unsqueeze(dim=-1)) + return grad_logits, None, None + + +def applyDistCrossEntropy(vocab_logits: torch.Tensor, labels: torch.Tensor) -> torch.Tensor: + return DistCrossEntropy.apply(vocab_logits, labels) diff --git a/colossalai/shardformer/layer/dropout.py b/colossalai/shardformer/layer/dropout.py new file mode 100644 index 000000000000..acc114029ac1 --- /dev/null +++ b/colossalai/shardformer/layer/dropout.py @@ -0,0 +1,58 @@ +import os +import time +from contextlib import contextmanager + +import torch +import torch.nn as nn + + +class SeedManager: + """ + This class is a random state manager to change random state for different random seed. + + """ + + def __init__(self): + original_state = torch.cuda.get_rng_state() + seed = int(f"{int(time.time())}{os.environ['RANK']}") + torch.cuda.manual_seed(int(seed)) + self.dropout_state = torch.cuda.get_rng_state() + torch.cuda.set_rng_state(original_state) + + def set_mode(self, rng_state): + torch.cuda.set_rng_state(rng_state) + + def get_current_mode(self): + current_state = torch.cuda.get_rng_state() + return current_state + + @contextmanager + def dropout_mode(self): + """ + This is a context manager to change the dropout state and recover the original state. + + Usage: + :: + >>> with _seed_manager.dropout_mode(): + >>> input = super().forward(input) + """ + try: + current_mode = self.get_current_mode() + yield self.set_mode(self.dropout_state) + finally: + self.dropout_state = self.get_current_mode() + self.set_mode(current_mode) + + +_seed_manager = SeedManager() + + +class Dropout1D(nn.Dropout): + + def __init__(self, p=0.5, inplace=False): + super().__init__(p, inplace) + + def forward(self, input): + with _seed_manager.dropout_mode(): + input = super().forward(input) + return input diff --git a/colossalai/shardformer/layer/layers.py b/colossalai/shardformer/layer/layers.py new file mode 100644 index 000000000000..f5123885bbe4 --- /dev/null +++ b/colossalai/shardformer/layer/layers.py @@ -0,0 +1,1043 @@ +#!/usr/bin/env python +# -*- encoding: utf-8 -*- + +import math +from collections import OrderedDict +from typing import Callable, Tuple + +import torch +import torch.nn.functional as F +from torch import Tensor +from torch.nn.parameter import Parameter + +from colossalai.communication import broadcast +from colossalai.context import ParallelMode, seed +from colossalai.core import global_context as gpc +from colossalai.global_variables import tensor_parallel_env as env +from colossalai.kernel import LayerNorm +from colossalai.nn import init as init +from colossalai.nn.layer.base_layer import ParallelLayer +from colossalai.nn.layer.colossalai_layer._utils import ColossalaiModule +from colossalai.nn.layer.parallel_1d._utils import ( + gather_forward_split_backward, + get_parallel_input, + reduce_grad, + reduce_input, + set_parallel_input, + split_forward_gather_backward, +) +from colossalai.nn.layer.utils import divide, set_tensor_parallel_attribute_by_partition +from colossalai.nn.layer.vanilla import VanillaLayerNorm, VanillaPatchEmbedding +from colossalai.registry import LAYERS +from colossalai.utils.checkpointing import ( + broadcast_state_dict, + gather_tensor_parallel_state_dict, + partition_tensor_parallel_state_dict, +) +from colossalai.utils.cuda import get_current_device + +from ._operation import linear_with_async_comm + +Fast_LN = None +try: + from apex.contrib.layer_norm.layer_norm import FastLayerNorm + Fast_LN = FastLayerNorm +except ImportError: + pass + + +# @LAYERS.register_module +class Linear1D(ColossalaiModule): + r"""Linear layer for 1D parallelism. + + Args: + in_features (int): size of each input sample. + out_features (int): size of each output sample. + bias (bool, optional): If set to ``False``, the layer will not learn an additive bias, defaults to ``True``. + dtype (:class:`torch.dtype`, optional): The dtype of parameters, defaults to None. + gather_output (bool, optional): Whether to call all-gather on output, defaults to False. + skip_bias_add (bool, optional): If set to ``True``, it will skip bias add for linear layer, + which is preserved for kernel fusion, defaults to False + weight_initializer (:class:`typing.Callable`, optional): + The initializer of weight, defaults to kaiming uniform initializer. + bias_initializer (:class:`typing.Callable`, optional): + The initializer of bias, defaults to xavier uniform initializer. + + More details about ``initializer`` please refer to + `init `_. + """ + + def __init__(self, + in_features: int, + out_features: int, + bias: bool = True, + dtype: torch.dtype = None, + gather_output: bool = False, + skip_bias_add: bool = False, + weight_initializer: Callable = init.kaiming_uniform_(a=math.sqrt(5)), + bias_initializer: Callable = init.xavier_uniform_(a=1, scale=1)): + parallel_input = get_parallel_input() + if not parallel_input and not gather_output: + layer = Linear1D_Col(in_features, + out_features, + bias=bias, + dtype=dtype, + skip_bias_add=skip_bias_add, + weight_initializer=weight_initializer, + bias_initializer=bias_initializer) + else: + layer = Linear1D_Row(in_features, + out_features, + bias=bias, + dtype=dtype, + parallel_input=parallel_input, + skip_bias_add=skip_bias_add, + weight_initializer=weight_initializer, + bias_initializer=bias_initializer) + super().__init__(layer) + + +# @LAYERS.register_module +class LayerNorm1D(ColossalaiModule): + r""" + Layer Normalization for colossalai + + Args: + normalized_shape (int): input shape from an expected input of size. + :math:`[* \times \text{normalized_shape}[0] \times \text{normalized_shape}[1] + \times \ldots \times \text{normalized_shape}[-1]]` + If a single integer is used, it is treated as a singleton list, and this module will + normalize over the last dimension which is expected to be of that specific size. + eps (float): a value added to the denominator for numerical stability, defaults to 1e-05. + bias (bool, optional): Whether to add a bias, defaults to ``True``. + dtype (:class:`torch.dtype`, optional): The dtype of parameters, defaults to None. + """ + + _fast_ln_supported_sizes = [ + 1024, 1536, 2048, 2304, 3072, 3840, 4096, 5120, 6144, 8192, 10240, 12288, 12800, 15360, 16384, 18432, 20480, + 24576, 25600, 30720, 32768, 40960, 49152, 65536 + ] + + def __init__(self, normalized_shape: int, eps=1e-05, bias=True, dtype=None): + if Fast_LN is not None and normalized_shape in self._fast_ln_supported_sizes: + norm = Fast_LN(normalized_shape, eps=eps).to(dtype) + else: + norm = None + try: + from apex.normalization import FusedLayerNorm + norm = FusedLayerNorm(normalized_shape, eps=eps).to(dtype) + except ImportError: + norm = LayerNorm(normalized_shape, eps=eps).to(dtype) + super().__init__(norm) + + def _load_from_state_dict(self, state_dict, prefix, *args): + local_state = OrderedDict() + weight_key = prefix + 'weight' + bias_key = prefix + 'bias' + if gpc.get_local_rank(ParallelMode.TENSOR) == 0: + # weight + weight = state_dict.pop(weight_key, None) + if weight is not None: + local_state[weight_key] = weight + # bias + bias = state_dict.pop(bias_key, None) + if bias is not None: + local_state[bias_key] = bias + + local_state = broadcast_state_dict(local_state, ParallelMode.PARALLEL_1D) + super()._load_from_state_dict(local_state, prefix, *args) + + def _save_to_state_dict(self, destination, prefix, keep_vars): + if gpc.get_local_rank(ParallelMode.TENSOR) == 0: + super()._save_to_state_dict(destination, prefix, keep_vars) + + +# @LAYERS.register_module +class Classifier1D(ParallelLayer): + r"""RowLinear with given weight. Classifier of 1D parallelism. + + Args: + in_features (int): size of each input sample. + num_classes (int): number of classes. + weight (:class:`torch.nn.Parameter`, optional): weight of the classifier, defaults to None. + bias (bool, optional): If set to ``False``, the layer will not learn an additive bias, defaults to ``True``. + dtype (:class:`torch.dtype`, optional): The dtype of parameters, defaults to None. + weight_initializer (:class:`typing.Callable`, optional): + The initializer of weight, defaults to kaiming uniform initializer. + bias_initializer (:class:`typing.Callable`, optional): + The initializer of bias, defaults to xavier uniform initializer. + + More details about ``initializer`` please refer to + `init `_. + """ + + def __init__(self, + in_features: int, + num_classes: int, + weight: Parameter = None, + bias: bool = True, + dtype: torch.dtype = None, + weight_initializer: Callable = init.kaiming_uniform_(a=math.sqrt(5)), + bias_initializer: Callable = init.xavier_uniform_(a=1, scale=1)): + super().__init__() + self.in_features = in_features + self.num_classes = num_classes + self.parallel_input = get_parallel_input() + + # Divide the weight matrix along the last dimension. + self.input_size_per_partition = divide(in_features, gpc.tensor_parallel_size) + + # Parameters. + # Initialize weight. + factory_kwargs = {'device': get_current_device(), 'dtype': dtype} + if weight is not None: + self.weight = weight + self.has_weight = False + else: + self.weight = Parameter(torch.empty(self.num_classes, self.input_size_per_partition, **factory_kwargs)) + self.has_weight = True + if bias: + self.bias = Parameter(torch.empty(self.num_classes, **factory_kwargs)) + else: + self.bias = None + with seed(ParallelMode.TENSOR): + self.reset_parameters(weight_initializer, bias_initializer) + self._set_tensor_parallel_attributes() + set_parallel_input(False) + env.vocab_parallel = False + + def reset_parameters(self, weight_initializer, bias_initializer) -> None: + fan_in, fan_out = self.in_features, self.num_classes + if self.has_weight: + weight_initializer(self.weight, fan_in=fan_in, fan_out=fan_out) + if self.bias is not None: + bias_initializer(self.bias, fan_in=fan_in) + broadcast(self.bias, gpc.get_ranks_in_group(ParallelMode.PARALLEL_1D)[0], ParallelMode.PARALLEL_1D) + + def _set_tensor_parallel_attributes(self): + if self.has_weight: + num_partition = gpc.get_world_size(ParallelMode.TENSOR) + set_tensor_parallel_attribute_by_partition(self.weight, num_partition) + + def _load_from_global_state_dict(self, state_dict, prefix, *args): + local_state = OrderedDict() + weight_key = prefix + 'weight' + bias_key = prefix + 'bias' + if gpc.get_local_rank(ParallelMode.TENSOR) == 0: + # weight + if self.has_weight: + weight = state_dict.pop(weight_key, None) + if weight is not None: + local_state[weight_key] = weight + # bias + if self.bias is not None: + bias = state_dict.pop(bias_key, None) + if bias is not None: + local_state[bias_key] = bias + + local_state = partition_tensor_parallel_state_dict(local_state, + ParallelMode.PARALLEL_1D, + dims={ + weight_key: -1, + bias_key: 0 + }, + partition_states={ + weight_key: True, + bias_key: False + }) + super()._load_from_global_state_dict(local_state, prefix, *args) + + def _save_to_global_state_dict(self, destination, prefix, keep_vars): + weight_key = prefix + 'weight' + bias_key = prefix + 'bias' + local_state = OrderedDict() + if self.has_weight: + local_state[weight_key] = self.weight + if self.bias is not None: + local_state[bias_key] = self.bias + local_state = gather_tensor_parallel_state_dict(local_state, + ParallelMode.PARALLEL_1D, + dims={ + weight_key: -1, + bias_key: 0 + }, + partition_states={ + weight_key: True, + bias_key: False + }, + keep_vars=keep_vars) + destination.update(local_state) + + def forward(self, input_: Tensor) -> Tensor: + # Set up backprop all-reduce. + if self.parallel_input: + assert input_.shape[-1] == self.weight.shape[-1], \ + 'Invalid shapes in Classifier1D forward: input={}, weight={}. Expected last dim of input {}.'.format( + input_.shape, self.weight.shape, self.weight.shape[-1]) + input_ = input_ + else: + assert divide(input_.shape[-1], gpc.tensor_parallel_size) == self.weight.shape[-1], \ + 'Invalid shapes in Classifier1D forward: input={}, weight={}. Expected last dim of input {}.'.format( + input_.shape, self.weight.shape, self.weight.shape[-1] * gpc.tensor_parallel_size) + input_ = split_forward_gather_backward(input_, ParallelMode.PARALLEL_1D, dim=-1) + + output_parallel = F.linear(input_, self.weight) + output = reduce_input(output_parallel, ParallelMode.PARALLEL_1D) + if self.bias is not None: + output = output + self.bias + return output + + +# @LAYERS.register_module +class VocabParallelClassifier1D(ParallelLayer): + r"""ColLinear with given weight. Classifier of 1D parallelism. + + Args: + in_features (int): size of each input sample. + num_classes (int): number of classes. + weight (:class:`torch.nn.Parameter`, optional): weight of the classifier, defaults to None. + bias (bool, optional): If set to ``False``, the layer will not learn an additive bias, defaults to ``True``. + dtype (:class:`torch.dtype`, optional): The dtype of parameters, defaults to None. + weight_initializer (:class:`typing.Callable`, optional): + The initializer of weight, defaults to kaiming uniform initializer. + bias_initializer (:class:`typing.Callable`, optional): + The initializer of bias, defaults to xavier uniform initializer. + + More details about ``initializer`` please refer to + `init `_. + """ + + def __init__(self, + in_features: int, + num_classes: int, + weight: Parameter = None, + bias: bool = True, + dtype: torch.dtype = None, + gather_output: bool = False, + weight_initializer: Callable = init.kaiming_uniform_(a=math.sqrt(5)), + bias_initializer: Callable = init.xavier_uniform_(a=1, scale=1)): + super().__init__() + self.in_features = in_features + self.num_classes = num_classes + self.gather_output = gather_output + self.parallel_input = get_parallel_input() + + # Divide the weight matrix along the last dimension. + self.num_classes_per_partition = divide(num_classes, gpc.tensor_parallel_size) + + # Parameters. + # Initialize weight. + factory_kwargs = {'device': get_current_device(), 'dtype': dtype} + if weight is not None: + self.weight = weight + self.has_weight = False + else: + self.weight = Parameter(torch.empty(self.num_classes_per_partition, self.in_features, **factory_kwargs)) + self.has_weight = True + if bias: + self.bias = Parameter(torch.empty(self.num_classes_per_partition, **factory_kwargs)) + else: + self.bias = None + with seed(ParallelMode.TENSOR): + self.reset_parameters(weight_initializer, bias_initializer) + self._set_tensor_parallel_attributes() + set_parallel_input(False) + env.vocab_parallel = True + + def reset_parameters(self, weight_initializer, bias_initializer) -> None: + fan_in, fan_out = self.in_features, self.num_classes + if self.has_weight: + weight_initializer(self.weight, fan_in=fan_in, fan_out=fan_out) + if self.bias is not None: + bias_initializer(self.bias, fan_in=fan_in) + + def _set_tensor_parallel_attributes(self): + num_partition = gpc.get_world_size(ParallelMode.TENSOR) + if self.has_weight: + set_tensor_parallel_attribute_by_partition(self.weight, num_partition) + if self.bias is not None: + set_tensor_parallel_attribute_by_partition(self.bias, num_partition) + + def _load_from_global_state_dict(self, state_dict, prefix, *args): + local_state = OrderedDict() + weight_key = prefix + 'weight' + bias_key = prefix + 'bias' + if gpc.get_local_rank(ParallelMode.TENSOR) == 0: + # weight + if self.has_weight: + weight = state_dict.pop(weight_key, None) + if weight is not None: + local_state[weight_key] = weight + # bias + if self.bias is not None: + bias = state_dict.pop(bias_key, None) + if bias is not None: + local_state[bias_key] = bias + + local_state = partition_tensor_parallel_state_dict(local_state, + ParallelMode.PARALLEL_1D, + dims={ + weight_key: 0, + bias_key: 0 + }, + partition_states={ + weight_key: True, + bias_key: True + }) + super()._load_from_global_state_dict(local_state, prefix, *args) + + def _save_to_global_state_dict(self, destination, prefix, keep_vars): + weight_key = prefix + 'weight' + bias_key = prefix + 'bias' + local_state = OrderedDict() + if self.has_weight: + local_state[weight_key] = self.weight + if self.bias is not None: + local_state[bias_key] = self.bias + local_state = gather_tensor_parallel_state_dict(local_state, + ParallelMode.PARALLEL_1D, + dims={ + weight_key: 0, + bias_key: 0 + }, + partition_states={ + weight_key: True, + bias_key: True + }, + keep_vars=keep_vars) + destination.update(local_state) + + def forward(self, input_: Tensor) -> Tensor: + assert input_.shape[-1] == self.weight.shape[-1], \ + 'Invalid shapes in VocabParallelClassifier1D forward: input={}, weight={}. Expected last dim of input {}.'.format( + input_.shape, self.weight.shape, self.weight.shape[-1]) + # Set up backprop all-reduce. + input_parallel = reduce_grad(input_, ParallelMode.PARALLEL_1D) + # Matrix multiply. + output_parallel = F.linear(input_parallel, self.weight, self.bias) + if self.gather_output: + # All-gather across the partitions. + output = gather_forward_split_backward(output_parallel, ParallelMode.PARALLEL_1D, dim=-1) + else: + output = output_parallel + return output + + +# @LAYERS.register_module +class Linear1D_Col(ParallelLayer): + r"""Linear layer with column parallelism. + + The linear layer is defined as :math:`Y = XA + b`. A is parallelized along + its second dimension as :math:`A = [A_1, ..., A_p]`. + + Args: + in_features (int): size of each input sample. + out_features (int): size of each output sample. + bias (bool, optional): If set to ``False``, the layer will not learn an additive bias, defaults to ``True``. + dtype (:class:`torch.dtype`, optional): The dtype of parameters, defaults to None. + gather_output (bool, optional): If true, call all-gather on output and make Y available + to all GPUs, otherwise, every GPU will have its output + which is :math:`Y_i = XA_i`, defaults to False + skip_bias_add (bool, optional): If set to ``True``, it will skip bias add for linear layer, + which is preserved for kernel fusion, defaults to False + weight_initializer (:class:`typing.Callable`, optional): + The initializer of weight, defaults to kaiming uniform initializer. + bias_initializer (:class:`typing.Callable`, optional): + The initializer of bias, defaults to xavier uniform initializer. + + More details about ``initializer`` please refer to + `init `_. + """ + + def __init__(self, + in_features: int, + out_features: int, + bias: bool = True, + dtype: torch.dtype = None, + gather_output: bool = False, + skip_bias_add: bool = False, + weight_initializer: Callable = init.kaiming_uniform_(a=math.sqrt(5)), + bias_initializer: Callable = init.xavier_uniform_(a=1, scale=1)): + super().__init__() + + # Keep input parameters + self.in_features = in_features + self.out_features = out_features + self.gather_output = gather_output + self.skip_bias_add = skip_bias_add + + if skip_bias_add and not bias: + raise ValueError('cannot skip bias addition if bias is None') + + # self.out_features_per_partition = divide(out_features*2, gpc.tensor_parallel_size) + self.out_features_per_partition = out_features + + # Parameters. + # Initialize weight. + factory_kwargs = {'device': get_current_device(), 'dtype': dtype} + self.weight = Parameter(torch.empty(self.out_features_per_partition, self.in_features, **factory_kwargs)) + + if bias: + self.bias = Parameter(torch.empty(self.out_features_per_partition, **factory_kwargs)) + else: + self.bias = None + with seed(ParallelMode.TENSOR): + self.reset_parameters(weight_initializer, bias_initializer) + self._set_tensor_parallel_attributes() + is_parallel_output = not self.gather_output + set_parallel_input(is_parallel_output) + + def reset_parameters(self, weight_initializer, bias_initializer) -> None: + fan_in, fan_out = self.in_features, self.out_features + weight_initializer(self.weight, fan_in=fan_in, fan_out=fan_out) + if self.bias is not None: + bias_initializer(self.bias, fan_in=fan_in) + + def _set_tensor_parallel_attributes(self): + num_partition = gpc.get_world_size(ParallelMode.TENSOR) + set_tensor_parallel_attribute_by_partition(self.weight, num_partition) + if self.bias is not None: + set_tensor_parallel_attribute_by_partition(self.bias, num_partition) + + def _load_from_global_state_dict(self, state_dict, prefix, *args): + local_state = OrderedDict() + weight_key = prefix + 'weight' + bias_key = prefix + 'bias' + if gpc.get_local_rank(ParallelMode.TENSOR) == 0: + # weight + weight = state_dict.pop(weight_key, None) + if weight is not None: + local_state[weight_key] = weight + # bias + if self.bias is not None: + bias = state_dict.pop(bias_key, None) + if bias is not None: + local_state[bias_key] = bias + + local_state = partition_tensor_parallel_state_dict(local_state, + ParallelMode.PARALLEL_1D, + dims={ + weight_key: 0, + bias_key: 0 + }, + partition_states={ + weight_key: True, + bias_key: True + }) + super()._load_from_global_state_dict(local_state, prefix, *args) + + def _save_to_global_state_dict(self, destination, prefix, keep_vars): + weight_key = prefix + 'weight' + bias_key = prefix + 'bias' + local_state = OrderedDict({weight_key: self.weight}) + if self.bias is not None: + local_state[bias_key] = self.bias + local_state = gather_tensor_parallel_state_dict(local_state, + ParallelMode.PARALLEL_1D, + dims={ + weight_key: 0, + bias_key: 0 + }, + partition_states={ + weight_key: True, + bias_key: True + }, + keep_vars=keep_vars) + destination.update(local_state) + + def forward(self, input_: Tensor) -> Tuple[Tensor, Tensor]: + assert input_.shape[-1] == self.weight.shape[-1], \ + 'Invalid shapes in Linear1D_Col forward: input={}, weight={}. Expected last dim of input {}.'.format( + input_.shape, self.weight.shape, self.weight.shape[-1]) + # Set up backprop all-reduce. + # input_parallel = reduce_grad(input_, ParallelMode.PARALLEL_1D) + input_parallel = input_ + # Matrix multiply. + bias = self.bias if not self.skip_bias_add else None + # output_parallel = F.linear(input_parallel, self.weight, bias) + output_parallel = linear_with_async_comm(input_parallel, self.weight, bias, ParallelMode.PARALLEL_1D, True) + if self.gather_output: + # All-gather across the partitions. + output = gather_forward_split_backward(output_parallel, ParallelMode.PARALLEL_1D, dim=-1) + else: + output = output_parallel + + if self.skip_bias_add: + return output, self.bias + else: + return output + + +# @LAYERS.register_module +class Linear1D_Row(ParallelLayer): + r""" Linear layer with row parallelism + + Args: + in_features (int): size of each input sample. + out_features (int): size of each output sample. + bias (bool, optional): If set to ``False``, the layer will not learn an additive bias, defaults to ``True``. + dtype (:class:`torch.dtype`, optional): The dtype of parameters, defaults to None. + parallel_input (bool, optional): If set to ``True``, it's assumed that the input is split, defaults to False. + skip_bias_add (bool, optional): If set to ``True``, it will skip bias add for linear layer, + which is preserved for kernel fusion, defaults to False + weight_initializer (:class:`typing.Callable`, optional): + The initializer of weight, defaults to kaiming uniform initializer. + bias_initializer (:class:`typing.Callable`, optional): + The initializer of bias, defaults to xavier uniform initializer. + + More details about ``initializer`` please refer to + `init `_. + """ + + def __init__(self, + in_features: int, + out_features: int, + bias: bool = True, + dtype: torch.dtype = None, + parallel_input: bool = True, + skip_bias_add: bool = False, + weight_initializer: Callable = init.kaiming_uniform_(a=math.sqrt(5)), + bias_initializer: Callable = init.xavier_uniform_(a=1, scale=1), + stream_chunk_num: int = 1): + super().__init__() + + self.stream_chunk_num = stream_chunk_num + + # Keep input parameters + self.in_features = in_features + self.out_features = out_features + self.parallel_input = parallel_input + self.skip_bias_add = skip_bias_add + + if skip_bias_add and not bias: + raise ValueError('cannot skip bias addition if bias is None') + + # Divide the weight matrix along the last dimension. + # self.input_size_per_partition = divide(in_features*2, gpc.tensor_parallel_size) + self.input_size_per_partition = in_features + + # Parameters. + # Initialize weight. + factory_kwargs = {'device': get_current_device(), 'dtype': dtype} + self.weight = Parameter(torch.empty(self.out_features, self.input_size_per_partition, **factory_kwargs)) + + if self.stream_chunk_num > 1: + # TODO() work for inference only + self.chunk_weight() + if bias: + self.bias = Parameter(torch.empty(self.out_features, **factory_kwargs)) + else: + self.bias = None + with seed(ParallelMode.TENSOR): + self.reset_parameters(weight_initializer, bias_initializer) + self._set_tensor_parallel_attributes() + set_parallel_input(False) + + def chunk_weight(self): + self.weight_list = torch.chunk(self.weight, self.stream_chunk_num, dim=0) + + def reset_parameters(self, weight_initializer, bias_initializer) -> None: + fan_in, fan_out = self.in_features, self.out_features + weight_initializer(self.weight, fan_in=fan_in, fan_out=fan_out) + if self.bias is not None: + bias_initializer(self.bias, fan_in=fan_in) + broadcast(self.bias, gpc.get_ranks_in_group(ParallelMode.PARALLEL_1D)[0], ParallelMode.PARALLEL_1D) + + def _set_tensor_parallel_attributes(self): + num_partition = gpc.get_world_size(ParallelMode.TENSOR) + set_tensor_parallel_attribute_by_partition(self.weight, num_partition) + + def _load_from_global_state_dict(self, state_dict, prefix, *args): + local_state = OrderedDict() + weight_key = prefix + 'weight' + bias_key = prefix + 'bias' + if gpc.get_local_rank(ParallelMode.TENSOR) == 0: + # weight + weight = state_dict.pop(weight_key, None) + if weight is not None: + local_state[weight_key] = weight + # bias + if self.bias is not None: + bias = state_dict.pop(bias_key, None) + if bias is not None: + local_state[bias_key] = bias + + local_state = partition_tensor_parallel_state_dict(local_state, + ParallelMode.PARALLEL_1D, + dims={ + weight_key: -1, + bias_key: 0 + }, + partition_states={ + weight_key: True, + bias_key: False + }) + super()._load_from_global_state_dict(local_state, prefix, *args) + + def _save_to_global_state_dict(self, destination, prefix, keep_vars): + weight_key = prefix + 'weight' + bias_key = prefix + 'bias' + local_state = OrderedDict({weight_key: self.weight}) + if self.bias is not None: + local_state[bias_key] = self.bias + local_state = gather_tensor_parallel_state_dict(local_state, + ParallelMode.PARALLEL_1D, + dims={ + weight_key: -1, + bias_key: 0 + }, + partition_states={ + weight_key: True, + bias_key: False + }, + keep_vars=keep_vars) + destination.update(local_state) + + def forward(self, input_: Tensor) -> Tensor: + # Set up backprop all-reduce. + if self.parallel_input: + assert input_.shape[-1] == self.weight.shape[-1], \ + 'Invalid shapes in Linear1D_Row forward: input={}, weight={}. Expected last dim of input {}.'.format( + input_.shape, self.weight.shape, self.weight.shape[-1]) + input_ = input_ + else: + assert divide(input_.shape[-1], gpc.tensor_parallel_size) == self.weight.shape[-1], \ + 'Invalid shapes in Linear1D_Row forward: input={}, weight={}. Expected last dim of input {}.'.format( + input_.shape, self.weight.shape, self.weight.shape[-1] * gpc.tensor_parallel_size) + input_ = split_forward_gather_backward(input_, ParallelMode.PARALLEL_1D, dim=-1) + + if self.stream_chunk_num > 1: + if self.training: + raise RuntimeError("use stream_chunk_num=1 in Linear1D_Row for training!") + with torch.no_grad(): + output_parallel_list = [None for i in range(self.stream_chunk_num)] + handle_list = [] + for i in range(self.stream_chunk_num): + output_parallel_list[i] = F.linear(input_, self.weight_list[i]) + handle = torch.distributed.all_reduce(output_parallel_list[i], + group=gpc.get_group(ParallelMode.PARALLEL_1D), + async_op=True) + handle_list.append(handle) + # output_parallel_list[i] = reduce_input(output_parallel_list[i], ParallelMode.PARALLEL_1D) + for handle in handle_list: + handle.wait() + output = torch.cat(output_parallel_list, dim=-1) + else: + output_parallel = F.linear(input_, self.weight) + # output_parallel = linear_with_async_comm(input_, self.weight, None, ParallelMode.PARALLEL_1D, False) + output = reduce_input(output_parallel, ParallelMode.PARALLEL_1D) + if not self.skip_bias_add: + if self.bias is not None: + output = output + self.bias + return output + else: + return output, self.bias + + +# @LAYERS.register_module +class Embedding1D(ParallelLayer): + r"""Embedding for 1D parallelism. + + Args: + num_embeddings (int): number of embeddings. + embedding_dim (int): dimension of embedding. + padding_idx (int, optional): If specified, the entries at padding_idx do not contribute to the gradient; + therefore, the embedding vector at padding_idx is not updated during training, + i.e. it remains as a fixed “pad”, defaults to None. + dtype (:class:`torch.dtype`, optional): The dtype of parameters, defaults to None. + weight_initializer (:class:`typing.Callable`, optional): + he initializer of weight, defaults to normal initializer. + + The ``args`` and ``kwargs`` used in :class:`torch.nn.functional.embedding` should contain: + :: + + max_norm (float, optional): If given, each embedding vector with norm larger than max_norm is + renormalized to have norm max_norm. Note: this will modify weight in-place. + norm_type (float, optional): The p of the p-norm to compute for the max_norm option. Default 2. + scale_grad_by_freq (bool, optional): If given, this will scale gradients by the inverse + of frequency of the words in the mini-batch. Default False. + sparse (bool, optional): If True, gradient w.r.t. weight will be a sparse tensor. Default False. + + More details about ``args`` and ``kwargs`` could be found in + `Embedding `_. + + More details about ``initializer`` please refer to + `init `_ + """ + + def __init__(self, + num_embeddings: int, + embedding_dim: int, + padding_idx: int = None, + dtype: torch.dtype = None, + weight_initializer: Callable = init.normal_(), + *args, + **kwargs): + super().__init__() + + self.num_embeddings = num_embeddings + self.embed_dim = embedding_dim + embed_dim_per_partition = divide(embedding_dim, gpc.tensor_parallel_size) + + self.padding_idx = padding_idx + self.embed_args = args + self.embed_kwargs = kwargs + + self.weight = Parameter( + torch.empty((num_embeddings, embed_dim_per_partition), device=get_current_device(), dtype=dtype)) + + self.reset_parameters(weight_initializer) + self._set_tensor_parallel_attributes() + set_parallel_input(False) + + def _set_tensor_parallel_attributes(self): + set_tensor_parallel_attribute_by_partition(self.weight, gpc.tensor_parallel_size) + + def reset_parameters(self, weight_initializer) -> None: + with seed(ParallelMode.TENSOR): + fan_in, fan_out = self.num_embeddings, self.embed_dim + weight_initializer(self.weight, fan_in=fan_in, fan_out=fan_out) + self._fill_padding_idx_with_zero() + + def _fill_padding_idx_with_zero(self) -> None: + if self.padding_idx is not None: + with torch.no_grad(): + self.weight[self.padding_idx].fill_(0) + + def _load_from_global_state_dict(self, state_dict, prefix, *args): + local_state = OrderedDict() + weight_key = prefix + 'weight' + if gpc.get_local_rank(ParallelMode.TENSOR) == 0: + # weight + weight = state_dict.pop(weight_key, None) + if weight is not None: + local_state[weight_key] = weight + + local_state = partition_tensor_parallel_state_dict(local_state, + ParallelMode.PARALLEL_1D, + dims={weight_key: -1}, + partition_states={weight_key: True}) + super()._load_from_global_state_dict(local_state, prefix, *args) + + def _save_to_global_state_dict(self, destination, prefix, keep_vars): + weight_key = prefix + 'weight' + local_state = OrderedDict({weight_key: self.weight}) + local_state = gather_tensor_parallel_state_dict(local_state, + ParallelMode.PARALLEL_1D, + dims={weight_key: -1}, + partition_states={weight_key: True}, + keep_vars=keep_vars) + destination.update(local_state) + + def forward(self, input_: Tensor) -> Tensor: + + output_parallel = F.embedding(input_, self.weight, self.padding_idx, *self.embed_args, **self.embed_kwargs) + + output = gather_forward_split_backward(output_parallel, ParallelMode.PARALLEL_1D, dim=-1) + + return output + + +# @LAYERS.register_module +class VocabParallelEmbedding1D(ParallelLayer): + r"""Embedding parallelized in the vocabulary dimension. + + Args: + num_embeddings (int): number of embeddings. + embedding_dim (int): dimension of embedding. + padding_idx (int, optional): If specified, the entries at padding_idx do not contribute to the gradient; + therefore, the embedding vector at padding_idx is not updated during training, + i.e. it remains as a fixed “pad”, defaults to None. + dtype (:class:`torch.dtype`, optional): The dtype of parameters, defaults to None. + weight_initializer (:class:`typing.Callable`, optional): + he initializer of weight, defaults to normal initializer. + + The ``args`` and ``kwargs`` used in :class:``torch.nn.functional.embedding`` should contain: + :: + + max_norm (float, optional): If given, each embedding vector with norm larger than max_norm is + renormalized to have norm max_norm. Note: this will modify weight in-place. + norm_type (float, optional): The p of the p-norm to compute for the max_norm option. Default 2. + scale_grad_by_freq (bool, optional): If given, this will scale gradients by the inverse + of frequency of the words in the mini-batch. Default False. + sparse (bool, optional): If True, gradient w.r.t. weight will be a sparse tensor. Default False. + + More details about ``args`` and ``kwargs`` could be found in + `Embedding `_. + + More details about initializer please refer to + `init `_. + """ + + def __init__(self, + num_embeddings: int, + embedding_dim: int, + padding_idx: int = None, + dtype: torch.dtype = None, + weight_initializer: Callable = init.normal_(), + *args, + **kwargs): + super().__init__() + self.num_embeddings = num_embeddings + self.embed_dim = embedding_dim + self.padding_idx = padding_idx + self.embed_args = args + self.embed_kwargs = kwargs + + tensor_parallel_size = gpc.get_world_size(ParallelMode.PARALLEL_1D) + tensor_parallel_rank = gpc.get_local_rank(ParallelMode.PARALLEL_1D) + # self.num_embeddings_per_partition = divide(num_embeddings, tensor_parallel_size) + self.num_embeddings_per_partition = num_embeddings + self.vocab_start_index = tensor_parallel_rank * self.num_embeddings_per_partition + self.vocab_end_index = self.vocab_start_index + self.num_embeddings_per_partition + + self.weight = Parameter( + torch.empty((self.num_embeddings_per_partition, self.embed_dim), device=get_current_device(), dtype=dtype)) + + self.reset_parameters(weight_initializer) + self._set_tensor_parallel_attributes() + set_parallel_input(False) + env.vocab_parallel = True + + def _set_tensor_parallel_attributes(self): + set_tensor_parallel_attribute_by_partition(self.weight, gpc.tensor_parallel_size) + + def reset_parameters(self, weight_initializer) -> None: + with seed(ParallelMode.TENSOR): + fan_in, fan_out = self.num_embeddings, self.embed_dim + weight_initializer(self.weight, fan_in=fan_in, fan_out=fan_out) + self._fill_padding_idx_with_zero() + + def _fill_padding_idx_with_zero(self) -> None: + if self.padding_idx is not None and \ + self.padding_idx >= self.vocab_start_index and self.padding_idx < self.vocab_end_index: + with torch.no_grad(): + self.weight[self.padding_idx - self.vocab_start_index].fill_(0) + + def _load_from_global_state_dict(self, state_dict, prefix, *args): + local_state = OrderedDict() + weight_key = prefix + 'weight' + if gpc.get_local_rank(ParallelMode.TENSOR) == 0: + # weight + weight = state_dict.pop(weight_key, None) + if weight is not None: + local_state[weight_key] = weight + + local_state = partition_tensor_parallel_state_dict(local_state, + ParallelMode.PARALLEL_1D, + dims={weight_key: 0}, + partition_states={weight_key: True}) + super()._load_from_global_state_dict(local_state, prefix, *args) + + def _save_to_global_state_dict(self, destination, prefix, keep_vars): + weight_key = prefix + 'weight' + local_state = OrderedDict({weight_key: self.weight}) + local_state = gather_tensor_parallel_state_dict(local_state, + ParallelMode.PARALLEL_1D, + dims={weight_key: 0}, + partition_states={weight_key: True}, + keep_vars=keep_vars) + destination.update(local_state) + + def forward(self, input_: Tensor) -> Tensor: + # Build the mask. + input_mask = (input_ < self.vocab_start_index) | (input_ >= self.vocab_end_index) + # Mask the input. + masked_input = input_.clone() - self.vocab_start_index + masked_input[input_mask] = 0 + + output_parallel = F.embedding(masked_input, self.weight, self.padding_idx, *self.embed_args, + **self.embed_kwargs) + + # Mask the output embedding. + output_parallel[input_mask, :] = 0. + # Reduce across all the model parallel GPUs. + output = reduce_input(output_parallel, ParallelMode.PARALLEL_1D) + return output + + +# @LAYERS.register_module +class Dropout1D(ParallelLayer): + """Dropout layer of 1D parallelism. + + Args: + p (float, optional): probability of an element to be zeroed, defaults 0.5. + inplace (bool, optional): whether to do dropout in-place, default to be False. + """ + + def __init__(self, p: float = 0.5, inplace: bool = False): + super().__init__() + self.parallel_input = get_parallel_input() + self.p = p + self.inplace = inplace + + def forward(self, input_: Tensor) -> Tensor: + if self.parallel_input: + with seed(ParallelMode.TENSOR): + output = F.dropout(input_, self.p, self.training, self.inplace) + else: + output = F.dropout(input_, self.p, self.training, self.inplace) + return output + + +# @LAYERS.register_module +class PatchEmbedding1D(ColossalaiModule): + """ + 2D Image to Patch Embedding + + :param img_size: image size + :type img_size: int + :param patch_size: patch size + :type patch_size: int + :param in_chans: number of channels of input image + :type in_chans: int + :param embed_size: size of embedding + :type embed_size: int + :param dtype: The dtype of parameters, defaults to None + :type dtype: torch.dtype, optional + :param flatten: whether to flatten output tensor, defaults to True + :type flatten: bool, optional + :param weight_initializer: The initializer of weight, defaults to kaiming uniform initializer + :type weight_initializer: typing.Callable, optional + :param bias_initializer: The initializer of bias, defaults to xavier uniform initializer + :type bias_initializer: typing.Callable, optional + :param position_embed_initializer: The initializer of position embedding, defaults to zero + :type position_embed_initializer: typing.Callable, optional + """ + + def __init__(self, + img_size: int, + patch_size: int, + in_chans: int, + embed_size: int, + dtype: torch.dtype = None, + flatten: bool = True, + weight_initializer: Callable = init.kaiming_uniform_(a=math.sqrt(5)), + bias_initializer: Callable = init.xavier_uniform_(a=1, scale=1), + position_embed_initializer: Callable = init.zeros_()): + embed = VanillaPatchEmbedding(img_size, + patch_size, + in_chans, + embed_size, + dtype=dtype, + flatten=flatten, + weight_initializer=weight_initializer, + bias_initializer=bias_initializer, + position_embed_initializer=position_embed_initializer) + super().__init__(embed) + + def _load_from_state_dict(self, state_dict, prefix, *args): + local_state = OrderedDict() + param_keys = [prefix + 'weight', prefix + 'bias', prefix + 'cls_token', prefix + 'pos_embed'] + if gpc.get_local_rank(ParallelMode.TENSOR) == 0: + for key in param_keys: + param = state_dict.pop(key, None) + if param is not None: + local_state[key] = param + + local_state = broadcast_state_dict(local_state, ParallelMode.PARALLEL_1D) + super()._load_from_state_dict(local_state, prefix, *args) + + def _save_to_state_dict(self, destination, prefix, keep_vars): + if gpc.get_local_rank(ParallelMode.TENSOR) == 0: + super()._save_to_state_dict(destination, prefix, keep_vars) diff --git a/colossalai/shardformer/model/modeling_bert.py b/colossalai/shardformer/model/modeling_bert.py index 6741ae866991..bd07ab80c00d 100644 --- a/colossalai/shardformer/model/modeling_bert.py +++ b/colossalai/shardformer/model/modeling_bert.py @@ -6,6 +6,8 @@ from transformers import BertForMaskedLM from transformers.models.bert.modeling_bert import MaskedLMOutput +from ..layer.dist_crossentropy import applyDistCrossEntropy + class BertForMaskedLM_(BertForMaskedLM): @@ -47,11 +49,11 @@ def forward( masked_lm_loss = None - # if input_ids is not None: - # masked_lm_loss = applyDistCrossEntropy(prediction_scores, input_ids, self.config.vocab_size) if labels is not None: - loss_fct = CrossEntropyLoss() # -100 index = padding token - masked_lm_loss = loss_fct(prediction_scores.view(-1, self.config.vocab_size), labels.view(-1)) + masked_lm_loss = applyDistCrossEntropy(prediction_scores, labels) + # if labels is not None: + # loss_fct = CrossEntropyLoss() # -100 index = padding token + # masked_lm_loss = loss_fct(prediction_scores.view(-1, self.config.vocab_size), labels.view(-1)) if not return_dict: output = (prediction_scores,) + outputs[2:] diff --git a/colossalai/shardformer/policies/basepolicy.py b/colossalai/shardformer/policies/basepolicy.py index a5cc0bc68df6..2eb7eb29e1a4 100644 --- a/colossalai/shardformer/policies/basepolicy.py +++ b/colossalai/shardformer/policies/basepolicy.py @@ -7,8 +7,6 @@ import torch.nn as nn from transformers import AutoConfig -import colossalai.nn as col_nn - @dataclass class Argument: diff --git a/colossalai/shardformer/policies/bert.py b/colossalai/shardformer/policies/bert.py index 5d91d8ddc766..ab77b29f71f4 100644 --- a/colossalai/shardformer/policies/bert.py +++ b/colossalai/shardformer/policies/bert.py @@ -4,7 +4,7 @@ import torch.nn as nn from transformers.models.bert.modeling_bert import BertEmbeddings, BertLayer, BertLMPredictionHead -import colossalai.nn as col_nn +import colossalai.shardformer.layer.layers as col_nn from .basepolicy import Argument, Col_Layer, Layer, Policy, Row_Layer @@ -142,7 +142,7 @@ def unembedding() -> List: weight="decoder.weight", bias="decoder.bias", replace_layer=col_nn.Linear1D_Col, - gather_output=True, + # gather_output=True, ) ] diff --git a/colossalai/shardformer/shard/slicer.py b/colossalai/shardformer/shard/slicer.py index 957ce1f85814..26053b9f7408 100644 --- a/colossalai/shardformer/shard/slicer.py +++ b/colossalai/shardformer/shard/slicer.py @@ -94,10 +94,7 @@ def slice_1d( Returns: :class:`torch.Tensor`: The sliced tensor """ - delta = (tensor.shape[0] + self.shardconfig.world_size - 1) // self.shardconfig.world_size - down_idx = self.shardconfig.rank * delta - up_idx = down_idx + delta - return tensor[down_idx:up_idx].contiguous() + return tensor.chunk(self.shardconfig.world_size, dim=0)[self.shardconfig.rank].contiguous() def slice_col( self, @@ -113,10 +110,7 @@ def slice_col( :class:`torch.Tensor`: The sliced tensor """ - delta = (tensor.shape[0] + self.shardconfig.world_size - 1) // self.shardconfig.world_size - down_idx = self.shardconfig.rank * delta - up_idx = down_idx + delta - return tensor[down_idx:up_idx, :].contiguous() + return tensor.chunk(self.shardconfig.world_size, dim=0)[self.shardconfig.rank].contiguous() def slice_row( self, @@ -131,7 +125,4 @@ def slice_row( Returns: :class:`torch.Tensor`: The sliced tensor """ - delta = (tensor.shape[1] + self.shardconfig.world_size - 1) // self.shardconfig.world_size - down_idx = self.shardconfig.rank * delta - up_idx = down_idx + delta - return tensor[:, down_idx:up_idx].contiguous() + return tensor.chunk(self.shardconfig.world_size, dim=1)[self.shardconfig.rank].contiguous() diff --git a/colossalai/shardformer/test/module_test.py b/colossalai/shardformer/test/module_test.py new file mode 100644 index 000000000000..83dc7ec6cf4a --- /dev/null +++ b/colossalai/shardformer/test/module_test.py @@ -0,0 +1,50 @@ +import os + +import torch +import torch.nn as nn +import torch.nn.functional as F + +import colossalai +from colossalai.shardformer.layer.dist_crossentropy import applyDistCrossEntropy +from colossalai.shardformer.layer.dropout import Dropout1D + + +def get_args(): + parser = colossalai.get_default_parser() + parser.add_argument("--module", type=str, default='distloss') + return parser.parse_args() + + +def test_dist_crossentropy(): + pred = torch.randn(2, 4, 8, requires_grad=True) + labels = torch.randint(8, (1, 4)).repeat(2, 1) + + pred_ = pred.view(-1, 8) + labels_ = labels.view(-1) + loss = F.cross_entropy(pred_, labels_) + loss.backward() + print(f"normal loss:{loss}") + + pred = pred.chunk(int(os.environ['WORLD_SIZE']), -1)[int(os.environ['RANK'])] + loss = applyDistCrossEntropy(pred.to('cuda'), labels.to('cuda')) + loss.backward() + print(f"dist loss:{loss}") + + +def test_dropout(): + input = torch.randn(5, 4).to("cuda") + m = Dropout1D(p=0.2).to("cuda") + for i in range(2): + print(f"Output: {m(input)}") + print(torch.randn(1)) + + +if __name__ == '__main__': + args = get_args() + colossalai.launch_from_torch(config={}) + if args.module == 'distloss': + test_dist_crossentropy() + elif args.module == 'dropout': + test_dropout() + else: + print("not implemented yet") diff --git a/colossalai/shardformer/test/test.py b/colossalai/shardformer/test/test.py index 202208123ced..b896fd4a4020 100644 --- a/colossalai/shardformer/test/test.py +++ b/colossalai/shardformer/test/test.py @@ -1,11 +1,12 @@ import os +import random import torch import torch.nn as nn from datasets import load_dataset from torch.utils.data import DataLoader from tqdm.auto import tqdm -from transformers import AutoTokenizer, BertForMaskedLM, DataCollatorForLanguageModeling +from transformers import AutoTokenizer, BertForMaskedLM, DataCollatorForLanguageModeling, get_scheduler import colossalai from colossalai.shardformer.shard import ShardConfig, shard_model @@ -18,6 +19,7 @@ def get_args(): parser = colossalai.get_default_parser() parser.add_argument("--mode", type=str, default='inference') + parser.add_argument("--save_model", action='store_true') return parser.parse_args() @@ -30,36 +32,40 @@ def load_data(): # tokenized_datasets=tokenized_datasets.rename_column("label","labels") tokenized_datasets.set_format("torch") - train_dataset = tokenized_datasets["train"].select(range(500)) - test_dataset = tokenized_datasets["test"].select(range(100)) + train_dataset = tokenized_datasets["train"] + test_dataset = tokenized_datasets["test"] datacollector = DataCollatorForLanguageModeling(tokenizer, mlm=True, mlm_probability=0.15, return_tensors="pt") - train_dataloader = DataLoader(train_dataset, batch_size=8, shuffle=True, collate_fn=datacollector) - eval_dataloader = DataLoader(test_dataset, batch_size=8, shuffle=True, collate_fn=datacollector) + train_dataloader = DataLoader(train_dataset, batch_size=16, shuffle=True, collate_fn=datacollector) + eval_dataloader = DataLoader(test_dataset, batch_size=16, shuffle=True, collate_fn=datacollector) return train_dataloader, eval_dataloader -def inference(model: nn.Module): - print(model) +def inference(model: nn.Module, args): tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased") token = "Hello, my dog is cute" inputs = tokenizer(token, return_tensors="pt") inputs.to("cuda") + model.eval() model.to("cuda") outputs = model(**inputs) print(outputs) -def train(model: nn.Module, num_epoch: int = 2): +def train(model: nn.Module, args, num_epoch: int = 3): train_dataloader, eval_dataloader = load_data() optimizer = torch.optim.AdamW(model.parameters(), lr=5e-5) - progress_bar = tqdm(range((num_epoch) * len(train_dataloader))) - criterion = nn.CrossEntropyLoss() + num_training = num_epoch * len(train_dataloader) + progress_bar = tqdm(range(num_training)) + lr_scheduler = get_scheduler(name="linear", + optimizer=optimizer, + num_warmup_steps=0, + num_training_steps=num_training) + best_test_loss = float("inf") model.to("cuda") model.train() for epoch in range(num_epoch): progress_bar.set_description(f"Rank {get_current_device()} epoch {epoch}") - for batch in train_dataloader: optimizer.zero_grad() batch = {k: v.to('cuda') for k, v in batch.items()} @@ -67,6 +73,7 @@ def train(model: nn.Module, num_epoch: int = 2): loss = outputs.loss loss.backward() optimizer.step() + lr_scheduler.step() progress_bar.update(1) train_loss = loss @@ -75,16 +82,20 @@ def train(model: nn.Module, num_epoch: int = 2): batch = {k: v.to('cuda') for k, v in batch.items()} outputs = model(**batch) # loss = outputs.loss + assert not torch.isnan(outputs.loss), f"{batch}" loss += outputs.loss.item() # loss = criterion(outputs.logits, batch["input_ids"]) test_loss = loss / len(eval_dataloader) print_rank_0(f"Train Loss: {train_loss:.4f} Test Loss:{test_loss:.4f}") + if args.save_model and test_loss < best_test_loss: + best_test_loss = test_loss + torch.save(model.state_dict(), "./checkpoints/best_model.pth") if __name__ == "__main__": args = get_args() - colossalai.launch_from_torch(config=args.config) model = BertForMaskedLM.from_pretrained("bert-base-uncased") + colossalai.launch_from_torch(config=args.config) shard_config = ShardConfig( rank=int(str(get_current_device()).split(':')[-1]), world_size=int(os.environ['WORLD_SIZE']), @@ -92,6 +103,8 @@ def train(model: nn.Module, num_epoch: int = 2): sharded_model = shard_model(model, shard_config) if args.mode == "train": - train(sharded_model) + train(sharded_model, args) elif args.mode == "inference": - inference(sharded_model) + inference(sharded_model, args) + else: + raise NotImplementedError From bf9c2fde68b5a681ed8abb2331c3c4866acde6bb Mon Sep 17 00:00:00 2001 From: FoolPlayer <45593998+FoolPlayer@users.noreply.github.com> Date: Tue, 6 Jun 2023 15:31:52 +0800 Subject: [PATCH 07/49] update README (#3909) --- colossalai/shardformer/README.md | 46 ++++++++++++++++++++++---------- 1 file changed, 32 insertions(+), 14 deletions(-) diff --git a/colossalai/shardformer/README.md b/colossalai/shardformer/README.md index 3394e9457da3..93a4f1e578e4 100644 --- a/colossalai/shardformer/README.md +++ b/colossalai/shardformer/README.md @@ -55,30 +55,37 @@ If you wanna parallel the model in a custom way, just overwrite the policy class You should do: 1. Inherit Policy class -2. Overwrite argument_policy method - - In this method, you need to list which layers class you wanna modify and the attributes and parameters in those layers. -3. Overwrite inject_policy method (Optional) - - If you need to modify the forward or backward progress. -4. Overwrite or add the param recording functions +2. Overwrite `argument_policy` method + - In this method, you need to list which layers class you wanna modify and the attributes and parameters in those layers. Shardformer will replace all the layer belonging to the class you specified. + - `attr_dict` is dict contains all the attributes need to be modified in this layer. + - `param_funcs` is a list contains some functions which will return the path of the weight and bias from the layer. +3. Overwrite `inject_policy` method (Optional) + - Shardformer will inject the model according to this method. If you need to modify the forward or backward progress (like distributed corssentropy loss in Bert) you need to overwrite this method. +4. Overwrite or add the param functions - These functions use a suffix to record the path of weight or bias for the layer. -5. Overwrite binding + - The return is a list contains some `Col_Layer` or `Row_Layer` objects, which means slice along col and row respectively. +5. Overwrite `binding_policy` (Optional) + - Overwrite to specify Shardformer will bind some weight between layers, like embedding and unembedding layers. + - This function will return a dict, the key and value are the suffix of weight need to be binded. More details can be found in shardformer/policies/basepolicy.py ``` python from colossalai.shardformer.policies.basepolicy import Policy, Layer, Col_Layer, Row_Layer, Argument CustomPolicy(Policy): - @staticmethod - def argument_policy(model_config, shard_config: int) -> Dict[nn.Module,Argument]: - """ - Return a dict, the key is layer will be modified and the value is the Argument class with param setting and param functions +@staticmethod + def argument_policy(model_config, shard_config: int) -> Dict[nn.Module, Argument]: + r""" + Return the dict for the modify policy, the key is the original layer class and the value is the + argument for the modify layer Args: - model_config: The config of transformer model - shard_setting: The config of distributed model + model_config (:class:`tansformer.Config`): The config of transformer model + shard_config (:class:`ShardConfig`): The config for sharding model Return: Dict for the modify policy, + :: { origin layer class1 (nn.Module): Argument( attr_dict = { @@ -112,18 +119,29 @@ CustomPolicy(Policy): @staticmethod def inject_policy() -> Tuple[nn.Module, nn.Module]: - """ + r""" Return the dict for the inject model Return: The injected model, key is the original model and value is the new shardmodel + :: + (OrignModel, CustomModel) + in `CustomModel`, we can overwrite the forward and backward process """ return () @staticmethod def binding_policy() -> Dict: - """ + r""" Return the dict for the binding model + + Return: + This method should return the binding relationship for some layers share the weight or bias, + the key and value is the suffix of the weight or bias of the model + :: + return { + "bert.embeddings.word_embeddings.weight": "cls.predictions.decoder.weight", + } """ return NotImplementedError From 551fec318a2a23b79f28275515cb8ec49a9d68ae Mon Sep 17 00:00:00 2001 From: FoolPlayer <45593998+FoolPlayer@users.noreply.github.com> Date: Wed, 7 Jun 2023 16:09:40 +0800 Subject: [PATCH 08/49] [shardformer] add gpt2 policy and modify shard and slicer to support (#3883) * add gpt2 policy and modify shard and slicer to support * remove unused code * polish code --- colossalai/shardformer/policies/autopolicy.py | 14 ++- colossalai/shardformer/policies/basepolicy.py | 17 ++- colossalai/shardformer/policies/bert.py | 1 - colossalai/shardformer/policies/gpt2.py | 118 ++++++++++++++++++ colossalai/shardformer/shard/sharder.py | 46 ++++--- colossalai/shardformer/shard/slicer.py | 53 ++++++-- colossalai/shardformer/test/test.py | 28 +++-- 7 files changed, 233 insertions(+), 44 deletions(-) create mode 100644 colossalai/shardformer/policies/gpt2.py diff --git a/colossalai/shardformer/policies/autopolicy.py b/colossalai/shardformer/policies/autopolicy.py index e096c2b13a59..54cc63ba124f 100644 --- a/colossalai/shardformer/policies/autopolicy.py +++ b/colossalai/shardformer/policies/autopolicy.py @@ -10,16 +10,26 @@ def build_policies(): """ auto_policy_dict = {} - from transformers.models.bert.modeling_bert import BertForMaskedLM + from transformers import BertForMaskedLM from .bert import BertForMaskedLMPolicy auto_policy_dict[BertForMaskedLM] = BertForMaskedLMPolicy - from transformers.models.bert.modeling_bert import BertForSequenceClassification + from transformers import BertForSequenceClassification from .bert import BertForSequenceClassificationPolicy auto_policy_dict[BertForSequenceClassification] = BertForSequenceClassificationPolicy + from transformers import GPT2Model + + from .gpt2 import GPT2Policy + auto_policy_dict[GPT2Model] = GPT2Policy + + from transformers import GPT2LMHeadModel + + from .gpt2 import GPT2LMHeadModelPolicy + auto_policy_dict[GPT2LMHeadModel] = GPT2LMHeadModelPolicy + return auto_policy_dict diff --git a/colossalai/shardformer/policies/basepolicy.py b/colossalai/shardformer/policies/basepolicy.py index 2eb7eb29e1a4..644d115a270e 100644 --- a/colossalai/shardformer/policies/basepolicy.py +++ b/colossalai/shardformer/policies/basepolicy.py @@ -1,11 +1,9 @@ # part of code modified from https://github.com/tunib-ai/parallelformers -from dataclasses import dataclass, field +from dataclasses import dataclass from typing import Any, Callable, Dict, List, Tuple, Type -import torch import torch.nn as nn -from transformers import AutoConfig @dataclass @@ -31,11 +29,18 @@ class Layer: bias (str): The bias suffix of the layer replace_layer (:class:`colosalai.nn`): The layer to replace the original layer ignore (bool): Whether to ignore this layer if it is not in the model + reversed (bool): Whether the weight in layer is reversed, commonly the weight in `torch.nn.Linear` is [out, in], + but in GPT2 `Conv1D` layer is [in, out] which is reversed. + n_cast (int): The number of weight will cast to, like q, k, v in attention layer, n_cast should be 3. commonly in TP, we just chunk the weight with the number of devices, + but in multi-head attention, we need to chunk the weight with the number of devices * n_head, and + each device should have a part of Q, K and V weight. """ weight: str = None bias: str = None replace_layer: Any = None ignore: bool = False + reversed: bool = False + n_cast: int = None @dataclass @@ -131,7 +136,7 @@ def inject_policy() -> Tuple[nn.Module, nn.Module]: (OrignModel, CustomModel) in `CustomModel`, we can overwrite the forward and backward process """ - return () + return None @staticmethod def binding_policy() -> Dict: @@ -146,7 +151,7 @@ def binding_policy() -> Dict: "bert.embeddings.word_embeddings.weight": "cls.predictions.decoder.weight", } """ - return NotImplementedError + return None @staticmethod def attn_in() -> List: @@ -209,4 +214,4 @@ def unembedding() -> List: Return: List[Layer]: List of layer object """ - return NotImplementedError + return None diff --git a/colossalai/shardformer/policies/bert.py b/colossalai/shardformer/policies/bert.py index ab77b29f71f4..89b32f065c27 100644 --- a/colossalai/shardformer/policies/bert.py +++ b/colossalai/shardformer/policies/bert.py @@ -1,4 +1,3 @@ -from dataclasses import dataclass from typing import Any, Callable, Dict, List, Tuple, Type import torch.nn as nn diff --git a/colossalai/shardformer/policies/gpt2.py b/colossalai/shardformer/policies/gpt2.py new file mode 100644 index 000000000000..44dc9c72f986 --- /dev/null +++ b/colossalai/shardformer/policies/gpt2.py @@ -0,0 +1,118 @@ +from typing import Any, Callable, Dict, List, Tuple, Type + +import torch.nn as nn +from transformers.models.gpt2.modeling_gpt2 import GPT2Block, GPT2Model + +import colossalai.shardformer.layer.layers as col_nn + +from .basepolicy import Argument, Col_Layer, Layer, Policy, Row_Layer + + +class GPT2Policy(Policy): + + @staticmethod + def argument_policy(config, world_size): + return { + GPT2Model: + Argument(attr_dict={}, param_funcs=[ + GPT2Policy.embedding, + ]), + GPT2Block: + Argument( + attr_dict={ + # 1. reduce hidden size + "attn.embed_dim": config.hidden_size // world_size, + "attn.split_size": config.hidden_size // world_size, + "crossattention.embed_dim": config.hidden_size // world_size, + "crossattention.split_size": config.hidden_size // world_size, + # 2. reduce number of heads + "attn.num_heads": config.num_attention_heads // world_size, + "crossattention.num_heads": config.num_attention_heads // world_size, + }, + param_funcs=[ + GPT2Policy.attn_in, + GPT2Policy.attn_out, + GPT2Policy.mlp_in, + GPT2Policy.mlp_out, + ]), + } + + @staticmethod + def attn_in() -> List: + return [ + Col_Layer(weight="attn.c_attn.weight", + bias="attn.c_attn.bias", + n_cast=3, + reversed=True, + replace_layer=col_nn.Linear1D_Col), + Col_Layer(weight="crossattention.c_attn.weight", + bias="crossattention.c_attn.bias", + n_cast=2, + reversed=True, + ignore=True, + replace_layer=col_nn.Linear1D_Col), + Col_Layer(weight="crossattention.q_attn.weight", + bias="crossattention.q_attn.bias", + reversed=True, + ignore=True, + replace_layer=col_nn.Linear1D_Col) + ] + + @staticmethod + def attn_out() -> List: + return [ + Row_Layer(weight="attn.c_proj.weight", + bias="attn.c_proj.bias", + reversed=True, + replace_layer=col_nn.Linear1D_Row), + Row_Layer(weight="crossattention.c_proj.weight", + bias="crossattention.c_proj.bias", + reversed=True, + ignore=True, + replace_layer=col_nn.Linear1D_Row) + ] + + @staticmethod + def mlp_in() -> List: + return [ + Col_Layer(weight="mlp.c_fc.weight", bias="mlp.c_fc.bias", reversed=True, replace_layer=col_nn.Linear1D_Col), + ] + + @staticmethod + def mlp_out() -> List: + return [ + Row_Layer(weight="mlp.c_proj.weight", + bias="mlp.c_proj.bias", + reversed=True, + replace_layer=col_nn.Linear1D_Row) + ] + + @staticmethod + def embedding() -> List: + return [Col_Layer(weight="wte.weight", replace_layer=col_nn.VocabParallelEmbedding1D)] + + +from transformers import GPT2LMHeadModel + + +class GPT2LMHeadModelPolicy(GPT2Policy): + + @staticmethod + def argument_policy(config, world_size): + base_argument = GPT2Policy.argument_policy(config, world_size) + argument = { + GPT2LMHeadModel: Argument(attr_dict={}, param_funcs=[ + GPT2LMHeadModelPolicy.unembedding, + ]), + } + argument.update(base_argument) + return argument + + @staticmethod + def unembedding() -> List: + return [ + Col_Layer(weight="lm_head.weight", + bias="lm_head.bias", + replace_layer=col_nn.Linear1D_Col, + gather_output=True) + ] diff --git a/colossalai/shardformer/shard/sharder.py b/colossalai/shardformer/shard/sharder.py index 2218661889f8..1ada75e06b67 100644 --- a/colossalai/shardformer/shard/sharder.py +++ b/colossalai/shardformer/shard/sharder.py @@ -2,6 +2,7 @@ import torch import torch.nn as nn +from transformers.pytorch_utils import Conv1D from ..policies.autopolicy import get_autopolicy from ..policies.basepolicy import Policy @@ -35,10 +36,22 @@ def __init__( self.model_config = self.model.config def shard(self) -> None: + self.reshape_embedding() self.inject_model(self.model) self.replace_layer(self.model) self.bind_layer(self.model) + def reshape_embedding(self,) -> None: + r""" + Reshape the Embedding layer to make the embedding dimension divisible by world_size + """ + vocab_size = self.model_config.vocab_size + world_size = self.shard_config.world_size + if vocab_size % world_size != 0: + new_vocab_size = vocab_size + world_size - vocab_size % world_size + self.model.resize_token_embeddings(new_vocab_size) + self.model_config = self.model.config + def inject_model( self, model: nn.Module, @@ -53,6 +66,8 @@ def inject_model( """ inject_policy = self.policy.inject_policy() + if inject_policy is None: + return org_model_cls = inject_policy[0] shard_model_cls = inject_policy[1] @@ -82,9 +97,9 @@ def replace_layer( origin_layer_cls = argument_policy[0] attr_dict = argument_policy[1].attr_dict param_funcs = argument_policy[1].param_funcs - self.reverse_replace_layer(model, origin_layer_cls, attr_dict, param_funcs) + self.traverse_replace_layer(model, origin_layer_cls, attr_dict, param_funcs) - def reverse_replace_layer( + def traverse_replace_layer( self, layer: nn.Module, origin_cls: nn.Module, @@ -100,17 +115,12 @@ def reverse_replace_layer( attr_dict (Dict): The attribute dict to modify policy_cls (:class:`Policy`): The policy class """ + if layer.__class__ == origin_cls: + for k, v in attr_dict.items(): + setattr_(layer, k, v, ignore=True) + self.shard_one_layer(layer, param_funcs) for name, child in layer.named_children(): - if child.__class__ == origin_cls: - # replac_layer = child - for k, v in attr_dict.items(): - setattr_(child, k, v, ignore=True) - # print(f"Sharding {name} layer", replac_layer.attention.self.__dict__) - # setattr_(layer, name, self.shard_one_layer(child, policy_cls)) - self.shard_one_layer(child, param_funcs) - continue - - self.reverse_replace_layer(child, origin_cls, attr_dict, param_funcs) + self.traverse_replace_layer(child, origin_cls, attr_dict, param_funcs) return layer def shard_one_layer( @@ -126,7 +136,6 @@ def shard_one_layer( param_funcs (:class:`List[typing.Callable]`): The function list to get shard information in policy class """ - # print(org_layer) for func in param_funcs: policy_layers = func() for policy_layer in policy_layers: @@ -136,9 +145,10 @@ def shard_one_layer( bias_attr = policy_layer.bias replace_layer_cls = policy_layer.replace_layer ignore = policy_layer.ignore + n_cast = policy_layer.n_cast + reversed = policy_layer.reversed if policy_layer.__class__.__name__ == "Col_Layer": gather_output = policy_layer.gather_output - # print(gather_output) if weight_attr is not None: if hasattr_(org_layer, weight_attr): @@ -161,13 +171,11 @@ def shard_one_layer( layer_attr = (lambda x: x[:x.rfind(".")])(weight_attr or bias_attr) # slice weight and bias - weight, bias = self.slicer.slice_weight_bias(weight, bias, policy_layer.__class__) - # print(os.environ['RANK'], policy_layer.__class__, weight.shape, bias.shape if bias is not None else None) + weight, bias = self.slicer.slice_weight_bias(weight, bias, policy_layer.__class__, n_cast, reversed) # create new object to replace the origin layer if replace_layer_cls is not None: - # print(f"RANK {os.environ['RANK']}: replace {getattr_(org_layer, layer_attr).__class__} to {replace_layer_cls}, shape is {weight.shape}") - if isinstance(getattr_(org_layer, layer_attr), nn.Linear): + if isinstance(getattr_(org_layer, layer_attr), (nn.Linear, Conv1D)): if replace_layer_cls.__name__ == "Linear1D_Row": replace_layer = replace_layer_cls(weight.shape[1], weight.shape[0], @@ -235,6 +243,8 @@ def bind_layer(self, model: nn.Module) -> None: model (:class:`torch.nn.Module`): The shard model """ binding_map = self.policy.binding_policy() + if binding_map is None: + return for k, v in binding_map.items(): param = getattr_(model, k) param = nn.Parameter(param) diff --git a/colossalai/shardformer/shard/slicer.py b/colossalai/shardformer/shard/slicer.py index 26053b9f7408..6d35bd193fed 100644 --- a/colossalai/shardformer/shard/slicer.py +++ b/colossalai/shardformer/shard/slicer.py @@ -19,6 +19,8 @@ def slice_weight_bias( weight: torch.Tensor, bias: torch.Tensor, policy_layer_cls: Layer, + n_cast: int = None, + reversed: bool = False, ): r""" Slice the weight and bias according to policy layer cls @@ -33,13 +35,18 @@ def slice_weight_bias( """ if policy_layer_cls == Layer: return weight, bias - elif policy_layer_cls == Col_Layer: - weight = self.slice_tensor(weight, 1, False) + + dim = dim_mapping[policy_layer_cls] if not reversed else (1 - dim_mapping[policy_layer_cls]) + # print(weight.shape, dim) + if policy_layer_cls == Col_Layer: + weight = self.slice_tensor(weight, dim, False, n_cast) bias = self.slice_tensor(bias, 0, True) elif policy_layer_cls == Row_Layer: - weight = self.slice_tensor(weight, 0, False) + weight = self.slice_tensor(weight, dim, False, n_cast) else: raise NotImplementedError(f"The policy layer class {policy_layer_cls} is not supported") + if reversed: + weight = weight.transpose(0, 1).contiguous() return weight, bias def slice_tensor( @@ -47,6 +54,7 @@ def slice_tensor( tensor_in: torch.Tensor, dim: int, is_bias: bool, + n_cast: int = None, ) -> torch.Tensor: r""" Slice tensor according to the config @@ -59,14 +67,15 @@ def slice_tensor( if tensor_in is None: return None if not is_bias: - return self.slice_2d(tensor_in, dim) + return self.slice_2d(tensor_in, dim, n_cast) else: - return self.slice_1d(tensor_in) + return self.slice_1d(tensor_in, n_cast) def slice_2d( self, tensor: torch.Tensor, dim: int, + n_cast: int = None, ) -> torch.Tensor: r""" Slice the 2D tensor @@ -77,13 +86,14 @@ def slice_2d( """ assert dim in [0, 1], f"Only support 2D tensor, but got {dim}D tensor" if dim == 0: - return self.slice_row(tensor) + return self.slice_row(tensor, n_cast) elif dim == 1: - return self.slice_col(tensor) + return self.slice_col(tensor, n_cast) def slice_1d( self, tensor: torch.Tensor, + n_cast: int = None, ) -> torch.Tensor: r""" Slice the 1D tensor @@ -94,11 +104,19 @@ def slice_1d( Returns: :class:`torch.Tensor`: The sliced tensor """ - return tensor.chunk(self.shardconfig.world_size, dim=0)[self.shardconfig.rank].contiguous() + if n_cast is None: + return tensor.chunk(self.shardconfig.world_size, dim=0)[self.shardconfig.rank].contiguous() + else: + tensor_chunks = tensor.chunk(self.shardconfig.world_size * n_cast, dim=0) + chunk_list = [ + tensor_chunks[i] for i in range(self.shardconfig.rank, len(tensor_chunks), self.shardconfig.world_size) + ] + return torch.cat(chunk_list, dim=0).contiguous() def slice_col( self, tensor: torch.Tensor, + n_cast: int = None, ) -> torch.Tensor: r""" Slice the tensor in column @@ -110,11 +128,19 @@ def slice_col( :class:`torch.Tensor`: The sliced tensor """ - return tensor.chunk(self.shardconfig.world_size, dim=0)[self.shardconfig.rank].contiguous() + if n_cast is None: + return tensor.chunk(self.shardconfig.world_size, dim=0)[self.shardconfig.rank].contiguous() + else: + tensor_chunks = tensor.chunk(self.shardconfig.world_size * n_cast, dim=0) + chunk_list = [ + tensor_chunks[i] for i in range(self.shardconfig.rank, len(tensor_chunks), self.shardconfig.world_size) + ] + return torch.cat(chunk_list, dim=0).contiguous() def slice_row( self, tensor: torch.Tensor, + n_cast: int = None, ) -> torch.Tensor: r""" Slice the tensor in column @@ -125,4 +151,11 @@ def slice_row( Returns: :class:`torch.Tensor`: The sliced tensor """ - return tensor.chunk(self.shardconfig.world_size, dim=1)[self.shardconfig.rank].contiguous() + if n_cast is None: + return tensor.chunk(self.shardconfig.world_size, dim=1)[self.shardconfig.rank].contiguous() + else: + tensor_chunks = tensor.chunk(self.shardconfig.world_size * n_cast, dim=1) + chunk_list = [ + tensor_chunks[i] for i in range(self.shardconfig.rank, len(tensor_chunks), self.shardconfig.world_size) + ] + return torch.cat(chunk_list, dim=1).contiguous() diff --git a/colossalai/shardformer/test/test.py b/colossalai/shardformer/test/test.py index b896fd4a4020..e2d5a94c782a 100644 --- a/colossalai/shardformer/test/test.py +++ b/colossalai/shardformer/test/test.py @@ -6,24 +6,28 @@ from datasets import load_dataset from torch.utils.data import DataLoader from tqdm.auto import tqdm -from transformers import AutoTokenizer, BertForMaskedLM, DataCollatorForLanguageModeling, get_scheduler +from transformers import AutoTokenizer, BertForMaskedLM, DataCollatorForLanguageModeling, GPT2LMHeadModel, get_scheduler import colossalai from colossalai.shardformer.shard import ShardConfig, shard_model from colossalai.utils import get_current_device, print_rank_0 os.environ['TRANSFORMERS_NO_ADVISORY_WARNINGS'] = 'true' -tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased") def get_args(): parser = colossalai.get_default_parser() parser.add_argument("--mode", type=str, default='inference') parser.add_argument("--save_model", action='store_true') + parser.add_argument("--model", type=str, default='bert-base-uncased') return parser.parse_args() -def load_data(): +def load_data(args): + tokenizer = AutoTokenizer.from_pretrained(args.model) + if tokenizer.pad_token is None: + tokenizer.add_special_tokens({"pad_token": "[PAD]"}) + # tokenizer.pad_token_id = 0 datasets = load_dataset('wikitext', 'wikitext-2-raw-v1') # datasets=load_dataset("yelp_review_full") tokenized_datasets = datasets.map( @@ -42,18 +46,23 @@ def load_data(): def inference(model: nn.Module, args): - tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased") + print(model) + # print(model.wte.weight.shape) + tokenizer = AutoTokenizer.from_pretrained(args.model) + if tokenizer.pad_token is None: + tokenizer.add_special_tokens({"pad_token": "[PAD]"}) + tokenizer.pad_token_id = 0 token = "Hello, my dog is cute" inputs = tokenizer(token, return_tensors="pt") inputs.to("cuda") model.eval() model.to("cuda") outputs = model(**inputs) - print(outputs) + print(outputs[0]) def train(model: nn.Module, args, num_epoch: int = 3): - train_dataloader, eval_dataloader = load_data() + train_dataloader, eval_dataloader = load_data(args) optimizer = torch.optim.AdamW(model.parameters(), lr=5e-5) num_training = num_epoch * len(train_dataloader) progress_bar = tqdm(range(num_training)) @@ -94,8 +103,13 @@ def train(model: nn.Module, args, num_epoch: int = 3): if __name__ == "__main__": args = get_args() - model = BertForMaskedLM.from_pretrained("bert-base-uncased") colossalai.launch_from_torch(config=args.config) + if args.model == 'bert-base-uncased': + model = BertForMaskedLM.from_pretrained("bert-base-uncased") + elif args.model == 'gpt2': + model = GPT2LMHeadModel.from_pretrained("gpt2") + else: + raise AttributeError("model not supported") shard_config = ShardConfig( rank=int(str(get_current_device()).split(':')[-1]), world_size=int(os.environ['WORLD_SIZE']), From e5bc7e397c0353533294ea8500bc0e64e3f9304b Mon Sep 17 00:00:00 2001 From: FoolPlayer <45593998+FoolPlayer@users.noreply.github.com> Date: Fri, 9 Jun 2023 14:36:54 +0800 Subject: [PATCH 09/49] [shardformer] Align bert value (#3907) * add bert align test, fix dist loss bug * forward and backward align * add ignore index * add shardformer CI * add gather_output optional for user in shardconfig * update readme with optional gather_ouput * add dist crossentropy loss test, remove unused files * remove unused file * remove unused file * rename the file * polish code --- colossalai/shardformer/README.md | 13 +- colossalai/shardformer/__init__.py | 1 + .../shardformer/layer/dist_crossentropy.py | 10 +- colossalai/shardformer/policies/bert.py | 5 +- colossalai/shardformer/shard/shard_config.py | 18 ++- colossalai/shardformer/shard/sharder.py | 4 +- colossalai/shardformer/test/config.py | 1 - colossalai/shardformer/test/module_test.py | 50 ------- colossalai/shardformer/test/test.py | 124 ------------------ .../test_model/test_shard_bert.py | 103 +++++++++++++++ .../test_module/test_distcrossentropy.py | 42 ++++++ 11 files changed, 174 insertions(+), 197 deletions(-) delete mode 100644 colossalai/shardformer/test/config.py delete mode 100644 colossalai/shardformer/test/module_test.py delete mode 100644 colossalai/shardformer/test/test.py create mode 100644 tests/test_shardformer/test_model/test_shard_bert.py create mode 100644 tests/test_shardformer/test_module/test_distcrossentropy.py diff --git a/colossalai/shardformer/README.md b/colossalai/shardformer/README.md index 93a4f1e578e4..222626db3e9d 100644 --- a/colossalai/shardformer/README.md +++ b/colossalai/shardformer/README.md @@ -20,7 +20,7 @@ The sample API usage is given below: ``` python -from colossalai.shardformer import shard_model +from colossalai.shardformer import ShardConfig, shard_model from transformers import BertForMaskedLM # create huggingface model as normal @@ -28,7 +28,12 @@ model = BertForMaskedLM.from_pretrained("bert-base-uncased") # make the huggingface model paralleled to ShardModel # auto policy: -sharded_model = shard_model(model) +shardconfig = ShardConfig( + rank=rank, + world_size=world_size, + gather_output=True, +) +sharded_model = shard_model(model, config=shardconfig) # custom policy: from xxx import @@ -72,7 +77,7 @@ More details can be found in shardformer/policies/basepolicy.py ``` python from colossalai.shardformer.policies.basepolicy import Policy, Layer, Col_Layer, Row_Layer, Argument -CustomPolicy(Policy): +class CustomPolicy(Policy): @staticmethod def argument_policy(model_config, shard_config: int) -> Dict[nn.Module, Argument]: r""" @@ -235,7 +240,7 @@ CustomPolicy(Policy): This class is used to specify the replacement policy for a particular layer. If `replace_layer` is None, only parameter partitioning will be performed without replacing the layer class. CLASS `Col_Layer(Layer)`: - - gather_output (bool): Whether to gather the output of the layer + - gather_output (bool): Whether the output of this layer can be gathered, like the last layer can be gathered, but most of the time, the intermediate layers of the model do not need to be gathered. This class inherited from `Layer`, representing the layer will be sliced along column. diff --git a/colossalai/shardformer/__init__.py b/colossalai/shardformer/__init__.py index e69de29bb2d1..50c92738077a 100644 --- a/colossalai/shardformer/__init__.py +++ b/colossalai/shardformer/__init__.py @@ -0,0 +1 @@ +from .shard import ShardConfig, shard_model diff --git a/colossalai/shardformer/layer/dist_crossentropy.py b/colossalai/shardformer/layer/dist_crossentropy.py index 1869594670ce..05c04bb545c1 100644 --- a/colossalai/shardformer/layer/dist_crossentropy.py +++ b/colossalai/shardformer/layer/dist_crossentropy.py @@ -14,7 +14,7 @@ class DistCrossEntropy(Function): """ @staticmethod - def forward(ctx, vocab_logits: torch.Tensor, target: torch.Tensor): + def forward(ctx, vocab_logits: torch.Tensor, target: torch.Tensor, ignore_index: int): r""" Calculate the cross entropy loss before gather, the origin loss function is as follows: loss = -log(exp(x[class])/sum(exp(x[i])) @@ -75,8 +75,8 @@ def forward(ctx, vocab_logits: torch.Tensor, target: torch.Tensor): # calculate the loss # loss = log(sum(exp(x[i]))) - x[class] - loss = torch.log(sum_exp_logits) - pred_logits - loss = torch.sum(loss).div_(loss.numel()) + loss = torch.where(target == ignore_index, 0.0, torch.log(sum_exp_logits) - pred_logits) + loss = torch.sum(loss).div_(torch.sum(loss != 0.0)) # caculate the softmax exp_logits.div_(sum_exp_logits.unsqueeze(dim=-1)) @@ -101,5 +101,5 @@ def backward(ctx, grad_output): return grad_logits, None, None -def applyDistCrossEntropy(vocab_logits: torch.Tensor, labels: torch.Tensor) -> torch.Tensor: - return DistCrossEntropy.apply(vocab_logits, labels) +def applyDistCrossEntropy(vocab_logits: torch.Tensor, labels: torch.Tensor, ignore_index: int = -100) -> torch.Tensor: + return DistCrossEntropy.apply(vocab_logits, labels, ignore_index) diff --git a/colossalai/shardformer/policies/bert.py b/colossalai/shardformer/policies/bert.py index 89b32f065c27..5d489f41986c 100644 --- a/colossalai/shardformer/policies/bert.py +++ b/colossalai/shardformer/policies/bert.py @@ -141,7 +141,7 @@ def unembedding() -> List: weight="decoder.weight", bias="decoder.bias", replace_layer=col_nn.Linear1D_Col, - # gather_output=True, + gather_output=True, ) ] @@ -155,7 +155,8 @@ class BertForMaskedLMPolicy(BertPolicy): @staticmethod def inject_policy() -> Tuple[nn.Module, nn.Module]: - return (BertForMaskedLM, BertForMaskedLM_) + # return (BertForMaskedLM, BertForMaskedLM_) + return None class BertForSequenceClassificationPolicy(BertPolicy): diff --git a/colossalai/shardformer/shard/shard_config.py b/colossalai/shardformer/shard/shard_config.py index 4cf9162b9548..e8d6f3408c76 100644 --- a/colossalai/shardformer/shard/shard_config.py +++ b/colossalai/shardformer/shard/shard_config.py @@ -5,16 +5,14 @@ @dataclass class ShardConfig: - """ - The config for sharding the huggingface model for test + r""" + The config for sharding the huggingface model + + Args: + rank (int): The rank of local process + world_size (int): The world size of the distributed process + gather_output (bool): Whether to gather the output of the model of the last layer """ rank: int - fp16: bool = True - num_gpus: int = 2 world_size: int = 2 - backend = "nccl" - verbose: str = 'simple' - seed: int = None - require_grad: bool = False - master_addr: str = "127.0.0.1" - master_port: int = 29500 + gather_output: bool = True diff --git a/colossalai/shardformer/shard/sharder.py b/colossalai/shardformer/shard/sharder.py index 1ada75e06b67..159bebccd02d 100644 --- a/colossalai/shardformer/shard/sharder.py +++ b/colossalai/shardformer/shard/sharder.py @@ -65,6 +65,8 @@ def inject_model( BertForMaskedLM.forward -> BertForMaskedLM_.forward """ inject_policy = self.policy.inject_policy() + if inject_policy is None: + return if inject_policy is None: return @@ -148,7 +150,7 @@ def shard_one_layer( n_cast = policy_layer.n_cast reversed = policy_layer.reversed if policy_layer.__class__.__name__ == "Col_Layer": - gather_output = policy_layer.gather_output + gather_output = policy_layer.gather_output and self.shard_config.gather_output if weight_attr is not None: if hasattr_(org_layer, weight_attr): diff --git a/colossalai/shardformer/test/config.py b/colossalai/shardformer/test/config.py deleted file mode 100644 index 2b80d8b3ca12..000000000000 --- a/colossalai/shardformer/test/config.py +++ /dev/null @@ -1 +0,0 @@ -parallel = dict(data=1, pipeline=1, tensor=dict(size=2, mode='1d')) diff --git a/colossalai/shardformer/test/module_test.py b/colossalai/shardformer/test/module_test.py deleted file mode 100644 index 83dc7ec6cf4a..000000000000 --- a/colossalai/shardformer/test/module_test.py +++ /dev/null @@ -1,50 +0,0 @@ -import os - -import torch -import torch.nn as nn -import torch.nn.functional as F - -import colossalai -from colossalai.shardformer.layer.dist_crossentropy import applyDistCrossEntropy -from colossalai.shardformer.layer.dropout import Dropout1D - - -def get_args(): - parser = colossalai.get_default_parser() - parser.add_argument("--module", type=str, default='distloss') - return parser.parse_args() - - -def test_dist_crossentropy(): - pred = torch.randn(2, 4, 8, requires_grad=True) - labels = torch.randint(8, (1, 4)).repeat(2, 1) - - pred_ = pred.view(-1, 8) - labels_ = labels.view(-1) - loss = F.cross_entropy(pred_, labels_) - loss.backward() - print(f"normal loss:{loss}") - - pred = pred.chunk(int(os.environ['WORLD_SIZE']), -1)[int(os.environ['RANK'])] - loss = applyDistCrossEntropy(pred.to('cuda'), labels.to('cuda')) - loss.backward() - print(f"dist loss:{loss}") - - -def test_dropout(): - input = torch.randn(5, 4).to("cuda") - m = Dropout1D(p=0.2).to("cuda") - for i in range(2): - print(f"Output: {m(input)}") - print(torch.randn(1)) - - -if __name__ == '__main__': - args = get_args() - colossalai.launch_from_torch(config={}) - if args.module == 'distloss': - test_dist_crossentropy() - elif args.module == 'dropout': - test_dropout() - else: - print("not implemented yet") diff --git a/colossalai/shardformer/test/test.py b/colossalai/shardformer/test/test.py deleted file mode 100644 index e2d5a94c782a..000000000000 --- a/colossalai/shardformer/test/test.py +++ /dev/null @@ -1,124 +0,0 @@ -import os -import random - -import torch -import torch.nn as nn -from datasets import load_dataset -from torch.utils.data import DataLoader -from tqdm.auto import tqdm -from transformers import AutoTokenizer, BertForMaskedLM, DataCollatorForLanguageModeling, GPT2LMHeadModel, get_scheduler - -import colossalai -from colossalai.shardformer.shard import ShardConfig, shard_model -from colossalai.utils import get_current_device, print_rank_0 - -os.environ['TRANSFORMERS_NO_ADVISORY_WARNINGS'] = 'true' - - -def get_args(): - parser = colossalai.get_default_parser() - parser.add_argument("--mode", type=str, default='inference') - parser.add_argument("--save_model", action='store_true') - parser.add_argument("--model", type=str, default='bert-base-uncased') - return parser.parse_args() - - -def load_data(args): - tokenizer = AutoTokenizer.from_pretrained(args.model) - if tokenizer.pad_token is None: - tokenizer.add_special_tokens({"pad_token": "[PAD]"}) - # tokenizer.pad_token_id = 0 - datasets = load_dataset('wikitext', 'wikitext-2-raw-v1') - # datasets=load_dataset("yelp_review_full") - tokenized_datasets = datasets.map( - lambda examples: tokenizer(examples["text"], truncation=True, padding="max_length"), batched=True) - tokenized_datasets = tokenized_datasets.remove_columns(["text"]) - # tokenized_datasets=tokenized_datasets.rename_column("label","labels") - tokenized_datasets.set_format("torch") - - train_dataset = tokenized_datasets["train"] - test_dataset = tokenized_datasets["test"] - - datacollector = DataCollatorForLanguageModeling(tokenizer, mlm=True, mlm_probability=0.15, return_tensors="pt") - train_dataloader = DataLoader(train_dataset, batch_size=16, shuffle=True, collate_fn=datacollector) - eval_dataloader = DataLoader(test_dataset, batch_size=16, shuffle=True, collate_fn=datacollector) - return train_dataloader, eval_dataloader - - -def inference(model: nn.Module, args): - print(model) - # print(model.wte.weight.shape) - tokenizer = AutoTokenizer.from_pretrained(args.model) - if tokenizer.pad_token is None: - tokenizer.add_special_tokens({"pad_token": "[PAD]"}) - tokenizer.pad_token_id = 0 - token = "Hello, my dog is cute" - inputs = tokenizer(token, return_tensors="pt") - inputs.to("cuda") - model.eval() - model.to("cuda") - outputs = model(**inputs) - print(outputs[0]) - - -def train(model: nn.Module, args, num_epoch: int = 3): - train_dataloader, eval_dataloader = load_data(args) - optimizer = torch.optim.AdamW(model.parameters(), lr=5e-5) - num_training = num_epoch * len(train_dataloader) - progress_bar = tqdm(range(num_training)) - lr_scheduler = get_scheduler(name="linear", - optimizer=optimizer, - num_warmup_steps=0, - num_training_steps=num_training) - best_test_loss = float("inf") - model.to("cuda") - model.train() - for epoch in range(num_epoch): - progress_bar.set_description(f"Rank {get_current_device()} epoch {epoch}") - for batch in train_dataloader: - optimizer.zero_grad() - batch = {k: v.to('cuda') for k, v in batch.items()} - outputs = model(**batch) - loss = outputs.loss - loss.backward() - optimizer.step() - lr_scheduler.step() - progress_bar.update(1) - train_loss = loss - - loss = 0.0 - for batch in eval_dataloader: - batch = {k: v.to('cuda') for k, v in batch.items()} - outputs = model(**batch) - # loss = outputs.loss - assert not torch.isnan(outputs.loss), f"{batch}" - loss += outputs.loss.item() - # loss = criterion(outputs.logits, batch["input_ids"]) - test_loss = loss / len(eval_dataloader) - print_rank_0(f"Train Loss: {train_loss:.4f} Test Loss:{test_loss:.4f}") - if args.save_model and test_loss < best_test_loss: - best_test_loss = test_loss - torch.save(model.state_dict(), "./checkpoints/best_model.pth") - - -if __name__ == "__main__": - args = get_args() - colossalai.launch_from_torch(config=args.config) - if args.model == 'bert-base-uncased': - model = BertForMaskedLM.from_pretrained("bert-base-uncased") - elif args.model == 'gpt2': - model = GPT2LMHeadModel.from_pretrained("gpt2") - else: - raise AttributeError("model not supported") - shard_config = ShardConfig( - rank=int(str(get_current_device()).split(':')[-1]), - world_size=int(os.environ['WORLD_SIZE']), - ) - sharded_model = shard_model(model, shard_config) - - if args.mode == "train": - train(sharded_model, args) - elif args.mode == "inference": - inference(sharded_model, args) - else: - raise NotImplementedError diff --git a/tests/test_shardformer/test_model/test_shard_bert.py b/tests/test_shardformer/test_model/test_shard_bert.py new file mode 100644 index 000000000000..55b78d040505 --- /dev/null +++ b/tests/test_shardformer/test_model/test_shard_bert.py @@ -0,0 +1,103 @@ +import os +import random + +import pytest +import torch +from transformers import AutoTokenizer, BertConfig, BertForMaskedLM + +import colossalai +from colossalai.logging import disable_existing_loggers +from colossalai.shardformer.shard import ShardConfig, shard_model +from colossalai.testing import rerun_if_address_is_in_use, spawn + +os.environ['TRANSFORMERS_NO_ADVISORY_WARNINGS'] = 'true' +CONFIG = dict(parallel=dict(data=1, pipeline=1, tensor=dict(size=2, mode='1d')),) +tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased") + + +def build_model(rank, world_size): + config = BertConfig.from_pretrained('bert-base-uncased') + config.hidden_dropout_prob = 0 + config.attention_probs_dropout_prob = 0 + + org_model = BertForMaskedLM.from_pretrained('bert-base-uncased', config=config).to('cuda') + + shardconfig = ShardConfig( + rank=rank, + world_size=world_size, + gather_output=True, + ) + sharded_model = shard_model(BertForMaskedLM.from_pretrained('bert-base-uncased', config=config), + shardconfig).to('cuda') + + return org_model, sharded_model + + +def check_forward(org_model, sharded_model): + input = 'Hello, my dog is cute' + tokenized_input = tokenizer(input, return_tensors='pt').to('cuda') + + #orgin model + org_model.eval() + org_out = org_model(**tokenized_input) + + #shard model + sharded_model.eval() + shard_out = sharded_model(**tokenized_input) + + assert torch.allclose( + org_out[0], shard_out[0], + atol=1e-5), f"shard model output is not equal to orgin model output\n{org_out[0]}\n{shard_out[0]}" + + +def check_backward(org_model, sharded_model): + # prepare input + input = 'Hello, my dog is cute' + tokenized_input = tokenizer(input, return_tensors='pt').to('cuda') + labels = tokenized_input['input_ids'].clone() + labels[labels == tokenizer.pad_token_id] = -100 + tokenized_input['labels'] = labels + + #orgin model + org_model.train() + org_out = org_model(**tokenized_input) + org_loss = org_out.loss + org_loss.backward() + org_grad = org_model.bert.encoder.layer[0].attention.self.query.weight.grad + + #shard model + sharded_model.train() + shard_out = sharded_model(**tokenized_input) + shard_loss = shard_out.loss + shard_loss.backward() + shard_grad = sharded_model.bert.encoder.layer[0].attention.self.query.weight.grad + + shard_grad_list = [torch.zeros([*shard_grad.shape]).to('cuda') for _ in range(2)] + shard_grad = torch.distributed.all_gather(shard_grad_list, shard_grad) + all_shard_grad = torch.cat(shard_grad_list, dim=0) + + assert torch.allclose(org_loss, shard_loss, + atol=1e-5), f"shard model loss is not equal to orgin model loss\n{org_loss}\n{shard_loss}" + assert torch.allclose(org_grad, all_shard_grad, + atol=1e-5), f"shard model grad is not equal to orgin model grad\n{org_grad}\n{shard_grad}" + + +def check_bert(rank, world_size, port): + disable_existing_loggers() + colossalai.launch(config=CONFIG, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') + + org_model, sharded_model = build_model(rank, world_size) + check_forward(org_model, sharded_model) + check_backward(org_model, sharded_model) + + torch.cuda.empty_cache() + + +@pytest.mark.dist +@rerun_if_address_is_in_use() +def test_bert(): + spawn(check_bert, 2) + + +if __name__ == "__main__": + test_bert() diff --git a/tests/test_shardformer/test_module/test_distcrossentropy.py b/tests/test_shardformer/test_module/test_distcrossentropy.py new file mode 100644 index 000000000000..9a19ec57821d --- /dev/null +++ b/tests/test_shardformer/test_module/test_distcrossentropy.py @@ -0,0 +1,42 @@ +import pytest +import torch +import torch.nn.functional as F + +import colossalai +from colossalai.logging import disable_existing_loggers +from colossalai.shardformer.layer.dist_crossentropy import applyDistCrossEntropy +from colossalai.testing import rerun_if_address_is_in_use, spawn + +CONFIG = dict(parallel=dict(data=1, pipeline=1, tensor=dict(size=2, mode='1d')),) + + +def check_dist_crossentropy(rank, world_size, port, ignore_index): + disable_existing_loggers() + colossalai.launch(config=CONFIG, rank=rank, world_size=world_size, port=port, host='localhost', backend='nccl') + + # prepare data + pred = torch.randn(2, 4, 8, requires_grad=True) + labels = torch.randint(8, (2, 4)) + # set some label to -100 to test the ignore index + labels[0, -1] = ignore_index + + org_pred = pred.view(-1, 8) + org_labels = labels.view(-1) + org_loss = F.cross_entropy(org_pred, org_labels) + + dist_pred = pred.chunk(world_size, -1)[rank] + dist_loss = applyDistCrossEntropy(dist_pred.to('cuda'), labels.to('cuda'), ignore_index=ignore_index) + + assert torch.allclose(org_loss, dist_loss, + atol=1e-5), f"dist cross entropy loss is not equal to orgin loss\n{org_loss}\n{dist_loss}" + + +@pytest.mark.dist +@rerun_if_address_is_in_use() +def test_dist_crossentropy(): + ignore_index = -100 + spawn(check_dist_crossentropy, 2, ignore_index=ignore_index) + + +if __name__ == '__main__': + test_dist_crossentropy() From 661dc3b85eb442d432e624c6d458953cc5f9e557 Mon Sep 17 00:00:00 2001 From: FoolPlayer <45593998+FoolPlayer@users.noreply.github.com> Date: Mon, 12 Jun 2023 13:56:09 +0800 Subject: [PATCH 10/49] [shardformer] Unit test (#3928) * fix bug in slicer, add slicer unit test * add dropout test * use pid as dropout seed * updata dropout test with local pattern * ad todo --- colossalai/shardformer/layer/dropout.py | 4 +- colossalai/shardformer/shard/slicer.py | 16 ++-- .../test_module/test_dropout.py | 51 ++++++++++++ .../test_module/test_slicer.py | 78 +++++++++++++++++++ 4 files changed, 139 insertions(+), 10 deletions(-) create mode 100644 tests/test_shardformer/test_module/test_dropout.py create mode 100644 tests/test_shardformer/test_module/test_slicer.py diff --git a/colossalai/shardformer/layer/dropout.py b/colossalai/shardformer/layer/dropout.py index acc114029ac1..0f653a9be780 100644 --- a/colossalai/shardformer/layer/dropout.py +++ b/colossalai/shardformer/layer/dropout.py @@ -1,5 +1,4 @@ import os -import time from contextlib import contextmanager import torch @@ -14,7 +13,8 @@ class SeedManager: def __init__(self): original_state = torch.cuda.get_rng_state() - seed = int(f"{int(time.time())}{os.environ['RANK']}") + # TODO: unify this seed manager with the colossalai.context.random + seed = os.getpid() torch.cuda.manual_seed(int(seed)) self.dropout_state = torch.cuda.get_rng_state() torch.cuda.set_rng_state(original_state) diff --git a/colossalai/shardformer/shard/slicer.py b/colossalai/shardformer/shard/slicer.py index 6d35bd193fed..09e3219f87a2 100644 --- a/colossalai/shardformer/shard/slicer.py +++ b/colossalai/shardformer/shard/slicer.py @@ -3,7 +3,7 @@ from ..policies.basepolicy import Col_Layer, Layer, Row_Layer from .shard_config import ShardConfig -dim_mapping = {Col_Layer: 1, Row_Layer: 0} +dim_mapping = {Col_Layer: 0, Row_Layer: 1} class Slicer(): @@ -40,7 +40,7 @@ def slice_weight_bias( # print(weight.shape, dim) if policy_layer_cls == Col_Layer: weight = self.slice_tensor(weight, dim, False, n_cast) - bias = self.slice_tensor(bias, 0, True) + bias = self.slice_tensor(bias, 0, True, n_cast) elif policy_layer_cls == Row_Layer: weight = self.slice_tensor(weight, dim, False, n_cast) else: @@ -129,13 +129,13 @@ def slice_col( """ if n_cast is None: - return tensor.chunk(self.shardconfig.world_size, dim=0)[self.shardconfig.rank].contiguous() + return tensor.chunk(self.shardconfig.world_size, dim=1)[self.shardconfig.rank].contiguous() else: - tensor_chunks = tensor.chunk(self.shardconfig.world_size * n_cast, dim=0) + tensor_chunks = tensor.chunk(self.shardconfig.world_size * n_cast, dim=1) chunk_list = [ tensor_chunks[i] for i in range(self.shardconfig.rank, len(tensor_chunks), self.shardconfig.world_size) ] - return torch.cat(chunk_list, dim=0).contiguous() + return torch.cat(chunk_list, dim=1).contiguous() def slice_row( self, @@ -152,10 +152,10 @@ def slice_row( :class:`torch.Tensor`: The sliced tensor """ if n_cast is None: - return tensor.chunk(self.shardconfig.world_size, dim=1)[self.shardconfig.rank].contiguous() + return tensor.chunk(self.shardconfig.world_size, dim=0)[self.shardconfig.rank].contiguous() else: - tensor_chunks = tensor.chunk(self.shardconfig.world_size * n_cast, dim=1) + tensor_chunks = tensor.chunk(self.shardconfig.world_size * n_cast, dim=0) chunk_list = [ tensor_chunks[i] for i in range(self.shardconfig.rank, len(tensor_chunks), self.shardconfig.world_size) ] - return torch.cat(chunk_list, dim=1).contiguous() + return torch.cat(chunk_list, dim=0).contiguous() diff --git a/tests/test_shardformer/test_module/test_dropout.py b/tests/test_shardformer/test_module/test_dropout.py new file mode 100644 index 000000000000..4a13eb61c1fc --- /dev/null +++ b/tests/test_shardformer/test_module/test_dropout.py @@ -0,0 +1,51 @@ +import pytest +import torch +import torch.nn.functional as F + +import colossalai +from colossalai.logging import disable_existing_loggers +from colossalai.shardformer.layer.dropout import Dropout1D +from colossalai.testing import rerun_if_address_is_in_use, spawn + +CONFIG = dict(parallel=dict(data=1, pipeline=1, tensor=dict(size=2, mode='1d')),) + + +def check_dropout(rank, world_size, port): + disable_existing_loggers() + colossalai.launch(config=CONFIG, rank=rank, world_size=world_size, port=port, host='localhost', backend='nccl') + + # prepare data + input = torch.randn(5, 4).to('cuda') + dropout = Dropout1D(p=0.4).to('cuda') + output_list = [] + # compare the dropout pattern in each device + for i in range(2): + output = dropout(input) + output_list.append(output) + dist_output_list = [torch.zeros(*output.shape).to('cuda') for _ in range(world_size)] + torch.distributed.all_gather(dist_output_list, output) + for j in range(world_size): + for k in range(world_size): + if j != k: + mask = torch.eq(dist_output_list[j], 0.0) == torch.eq(dist_output_list[k], 0.0) + assert torch.all( + mask + ) == False, f"The dropout pattern in each device is not unique\n{dist_output_list[j]}\n{dist_output_list[k]}" + # compare the dropout pattern in loacl device + for i in range(len(output_list)): + for j in range(len(output_list)): + if i != j: + mask = torch.eq(output_list[i], 0.0) == torch.eq(output_list[j], 0.0) + assert torch.all( + mask + ) == False, f"The dropout pattern in one device is not unique\n{output_list[i]}\n{output_list[j]}" + + +@pytest.mark.dist +@rerun_if_address_is_in_use() +def test_dropout(): + spawn(check_dropout, 2) + + +if __name__ == '__main__': + test_dropout() diff --git a/tests/test_shardformer/test_module/test_slicer.py b/tests/test_shardformer/test_module/test_slicer.py new file mode 100644 index 000000000000..c72a0357573b --- /dev/null +++ b/tests/test_shardformer/test_module/test_slicer.py @@ -0,0 +1,78 @@ +import pytest +import torch +import torch.nn.functional as F + +import colossalai +from colossalai.logging import disable_existing_loggers +from colossalai.shardformer.policies.basepolicy import Col_Layer, Layer, Row_Layer +from colossalai.shardformer.shard.shard_config import ShardConfig +from colossalai.shardformer.shard.slicer import Slicer +from colossalai.testing import rerun_if_address_is_in_use, spawn + +CONFIG = dict(parallel=dict(data=1, pipeline=1, tensor=dict(size=2, mode='1d')),) + + +def check_slicer(rank, world_size, port, in_feature, out_feature): + disable_existing_loggers() + colossalai.launch(config=CONFIG, rank=rank, world_size=world_size, port=port, host='localhost', backend='nccl') + # initialize slicer + shardconfig = ShardConfig(rank=rank, world_size=world_size) + slicer = Slicer(shardconfig) + # initialize test data + weight = torch.randn(in_feature, out_feature) + bias = torch.randn(out_feature) + policy_layer_cls_list = [Layer, Col_Layer, Row_Layer] + n_cast_list = [None, 2, 3, 4] + # weight and bias + for n_cast in n_cast_list: + sliced_weight, sliced_bias = slicer.slice_weight_bias(weight, bias, policy_layer_cls=Layer, n_cast=n_cast) + expected_sliced_weight = weight + expected_sliced_bias = bias + assert torch.equal( + sliced_weight, expected_sliced_weight + ), f"In Layer case, weight: sliced_weight is not equal to expected_sliced_weight\norg:{weight}\nsliced:{sliced_weight}\nexpected:{expected_sliced_weight}" + assert torch.equal( + sliced_bias, expected_sliced_bias + ), f"In Layer case, bias: sliced_bias is not equal to expected_sliced_bias\norg:{bias}\nsliced:{sliced_weight}\nexpected:{expected_sliced_weight}" + + sliced_weight, sliced_bias = slicer.slice_weight_bias(weight, bias, policy_layer_cls=Col_Layer, n_cast=n_cast) + if (n_cast is None): + expected_sliced_weight = weight.chunk(world_size, dim=0)[rank] + expected_sliced_bias = bias.chunk(world_size)[rank] + else: + chunks = weight.chunk(world_size * n_cast, dim=0) + expected_sliced_weight = torch.cat([chunks[i] for i in range(rank, n_cast * world_size, world_size)], dim=0) + chunks = bias.chunk(world_size * n_cast, dim=0) + expected_sliced_bias = torch.cat([chunks[i] for i in range(rank, n_cast * world_size, world_size)]) + assert torch.equal( + sliced_weight, expected_sliced_weight + ), f"In Col_Layer {n_cast} cast case, weight: sliced_weight is not equal to expected_sliced_weight\norg:{weight}\nsliced:{sliced_weight}\nexpected:{expected_sliced_weight}" + assert torch.equal( + sliced_bias, expected_sliced_bias + ), f"In Col_Layer {n_cast} cast case, bias: sliced_bias is not equal to expected_sliced_bias\norg:{bias}\nsliced:{sliced_bias}\nexpected:{expected_sliced_bias}" + + sliced_weight, sliced_bias = slicer.slice_weight_bias(weight, bias, policy_layer_cls=Row_Layer, n_cast=n_cast) + if (n_cast is None): + expected_sliced_weight = weight.chunk(world_size, dim=1)[rank] + expected_sliced_bias = bias + else: + chunks = weight.chunk(world_size * n_cast, dim=1) + expected_sliced_weight = torch.cat([chunks[i] for i in range(rank, n_cast * world_size, world_size)], dim=1) + expected_sliced_bias = bias + assert torch.equal( + sliced_weight, expected_sliced_weight + ), f"In Row_Layer {n_cast} cast case, weight: sliced_weight is not equal to expected_sliced_weight\norg:{weight}\nsliced:{sliced_weight}\nexpected:{expected_sliced_weight}" + assert torch.equal( + sliced_bias, expected_sliced_bias + ), f"In Row_Layer {n_cast} cast case, bias: sliced_bias is not equal to expected_sliced_bias\norg:{bias}\nsliced:{sliced_weight}\nexpected:{expected_sliced_weight}" + + +@pytest.mark.dist +@rerun_if_address_is_in_use() +def test_slicer(): + args = dict(in_feature=24, out_feature=48) + spawn(check_slicer, nprocs=2, in_feature=args['in_feature'], out_feature=args['out_feature']) + + +if __name__ == '__main__': + test_slicer() From 702513a17e6028369a2adf6ac5c6818e03838699 Mon Sep 17 00:00:00 2001 From: FoolPlayer <45593998+FoolPlayer@users.noreply.github.com> Date: Mon, 12 Jun 2023 16:52:18 +0800 Subject: [PATCH 11/49] [shardformer] Add dropout layer in shard model and refactor policy api (#3949) * add dist dropout in model * update docstring and bert policy with dropout * refactor basepolicy and sharded, update bert * update format * update gpt2 policy * update bert policy * remove unused code * update readme for new policy usage --- colossalai/shardformer/README.md | 80 +++++----- colossalai/shardformer/policies/basepolicy.py | 68 +++++---- colossalai/shardformer/policies/bert.py | 139 ++++++++++-------- colossalai/shardformer/policies/gpt2.py | 40 +++-- colossalai/shardformer/shard/sharder.py | 108 +++++++------- colossalai/shardformer/shard/slicer.py | 4 +- colossalai/shardformer/utils/utils.py | 2 +- 7 files changed, 255 insertions(+), 186 deletions(-) diff --git a/colossalai/shardformer/README.md b/colossalai/shardformer/README.md index 222626db3e9d..b8357c203939 100644 --- a/colossalai/shardformer/README.md +++ b/colossalai/shardformer/README.md @@ -55,7 +55,7 @@ colossalai run --nproc_per_node 2 --master_port 29500 test.py --config config.py ## 💡 Policy -If you wanna parallel the model in a custom way, just overwrite the policy class for the Hugging Face model. +If you wanna parallel the model in a custom way, just overwrite the policy class for the Hugging Face model. Please refer to any policy that we have pre-established, like [bert policy](./policies/bert.py) or [gpt2 policy](./policies/gpt2.py). You should do: @@ -68,7 +68,7 @@ You should do: - Shardformer will inject the model according to this method. If you need to modify the forward or backward progress (like distributed corssentropy loss in Bert) you need to overwrite this method. 4. Overwrite or add the param functions - These functions use a suffix to record the path of weight or bias for the layer. - - The return is a list contains some `Col_Layer` or `Row_Layer` objects, which means slice along col and row respectively. + - The return is a list contains some `Col_Layer`, `Row_Layer` or `Dropout_Layer` objects, which means slice along col and row respectively or as dropout layer, refer to CLASS `Layer` for more details. 5. Overwrite `binding_policy` (Optional) - Overwrite to specify Shardformer will bind some weight between layers, like embedding and unembedding layers. - This function will return a dict, the key and value are the suffix of weight need to be binded. @@ -123,7 +123,7 @@ class CustomPolicy(Policy): raise NotImplementedError @staticmethod - def inject_policy() -> Tuple[nn.Module, nn.Module]: + def inject_policy() -> Union[Tuple[nn.Module, nn.Module], None]: r""" Return the dict for the inject model @@ -133,12 +133,12 @@ class CustomPolicy(Policy): (OrignModel, CustomModel) in `CustomModel`, we can overwrite the forward and backward process """ - return () + return None @staticmethod - def binding_policy() -> Dict: + def binding_policy() -> Union[Dict[str, str], None]: r""" - Return the dict for the binding model + Return the dict for the binding model, None means no need to bind Return: This method should return the binding relationship for some layers share the weight or bias, @@ -148,69 +148,70 @@ class CustomPolicy(Policy): "bert.embeddings.word_embeddings.weight": "cls.predictions.decoder.weight", } """ - return NotImplementedError + return None @staticmethod - def attn_in() -> List: - """ + def attn_in() -> Union[List, None]: + r""" Attention qkv layer + In this kind of method, we should return the list of ``Layer`` object, each ``Layer`` object should be + ``Layer`` for no slicing, ``Col_Layer`` for col slicing, ``Row_Layer`` for row slicing. And the parameters + in ``Layer`` object can refer to the ``Layer`` class. Returns: List[Layer]: List of layer object, each layer is the new """ - return NotImplementedError + return None @staticmethod - def attn_out() -> List: - """ + def attn_out() -> Union[List, None]: + r""" Attention output projection layer Returns: List[Layer]: List of layer object """ - return NotImplementedError + return None @staticmethod - def mlp_in() -> List: - """ + def mlp_in() -> Union[List, None]: + r""" h -> 4h mlp layer Returns: List[Layer]: List of layer object """ - return NotImplementedError + return None @staticmethod - def mlp_out() -> List: - """ + def mlp_out() -> Union[List, None]: + r""" 4h -> h mlp layer Returns: List[Layer]: List of layer object """ - return NotImplementedError + return None @staticmethod - def embedding() -> List: - """ + def embedding() -> Union[List, None]: + r""" Partially slice the embedding layer - vocab_size->vocab_size//gpu_nums Return: List[Layer]: List of layer object """ - return NotImplementedError + return None @staticmethod - def unembedding() -> List: - """ - Partially slice the embedding layer - vocab_size->vocab_size//gpu_nums + def unembedding() -> Union[List, None]: + r""" + Partially slice the embedding layer, None means there is no unembedding layer Return: List[Layer]: List of layer object """ - return NotImplementedError + return None ``` @@ -232,21 +233,26 @@ class CustomPolicy(Policy): - CLASS `Layer`: Parameters: - - weight (str): The weight suffix of the layer - - bias (str): The bias suffix of the layer + - suffix: (str): the suffix of the layer to indicate the attribute of the layer. - replace_layer (:class:`colosalai.nn`): The layer to replace the original layer - ignore (bool): Whether to ignore this layer if it is not in the model + - reversed (bool): Whether the weight in layer is reversed, commonly the weight in `torch.nn.Linear` is [out, in], but in GPT2 `Conv1D` layer is [in, out] which is reversed. + - n_cast (int): The number of weight will cast to, like q, k, v in attention layer, n_cast should be 3. commonly in TP, we just chunk the weight with the number of devices, but in multi-head attention, we need to chunk the weight with the number of $ devices * n\_head $, and each device should have a part of Q, K and V weight. - This class is used to specify the replacement policy for a particular layer. If `replace_layer` is None, only parameter partitioning will be performed without replacing the layer class. + This class is a base class used to specify the replacement policy and the suffix the layer for a particular layer. CLASS `Col_Layer(Layer)`: + - weight (str): The weight suffix of the layer + - bias (str): The bias suffix of the layer - gather_output (bool): Whether the output of this layer can be gathered, like the last layer can be gathered, but most of the time, the intermediate layers of the model do not need to be gathered. - This class inherited from `Layer`, representing the layer will be sliced along column. + This class inherited from `Layer`, representing the layer will be sliced along colum and indicate the attributes of weight and bias. Setting `bias` to `None` means ignoring bias, regardless of whether or not it originally exists. CLASS `Row_Layer(Layer)`: + - weight (str): The weight suffix of the layer + - bias (str): The bias suffix of the layer - This class inherited from `Layer`, representing the layer will be sliced along row. + This class inherited from `Layer`, representing the layer will be sliced along row. Just like `Col_Layer` but in tensor parrallel, there is no need to gather the output of layer sliced by row. - CLASS `Policy`: @@ -254,29 +260,37 @@ class CustomPolicy(Policy): - `Policy.attn_in()/attn_out()/mlp_in()/mlp_out()/embedding()/unembedding()`...... These functions define the partitioning methods of the parameters at different locations in the model. Each function returns a list of objects of Layer class that specify the replacement approach for these parameters. Shardformer also supports user-defined functions for modifying their models, in addition to the listed functions. + - `Policy.argument_policy()` In this function, the user should use multiple dict to define which class of layers will require replacement. This includes the attributes and parameters that need to be modified or replaced. Attributes are stored in the form of a "suffix-string: value" dict, while parameters are stored via multiple static methods that return the replacement approach. + - `Policy.inject_policy()` This function will return the injected model to replace the original model. The new model should be a nn.Module class which includes modified forward or backward functions or anything else. + - `Policy.binding_policy()` This function will return the weight sharing information in the model in some dict. The key and value are both the suffixes of the shared parameters. + - CLASS `ModelSharder(model, policy)`: This class helps shard the model, the parameter is the created transformers model and the custom policy. If custom policy is None, shardformer will automatically get already defined policy for the model. + - `ModelShard.inject_model()` This function is used to inject the model to modify the forward and backward progress. + - `ModelShard.replace_layer()` This function is used to replace the original layers with colossalai layer to make them paralleled and can do distributed communication. + - `ModelShard.bind_layer()` This function is used to help different layers share weight or bias. + - CLASS `Slicer`: This class is used to slice tensor according to policy. diff --git a/colossalai/shardformer/policies/basepolicy.py b/colossalai/shardformer/policies/basepolicy.py index 644d115a270e..d55df59fdc8b 100644 --- a/colossalai/shardformer/policies/basepolicy.py +++ b/colossalai/shardformer/policies/basepolicy.py @@ -1,7 +1,7 @@ # part of code modified from https://github.com/tunib-ai/parallelformers from dataclasses import dataclass -from typing import Any, Callable, Dict, List, Tuple, Type +from typing import Any, Callable, Dict, List, Tuple, Union import torch.nn as nn @@ -25,8 +25,7 @@ class Layer: The layer object for the policy Args: - weight (str): The weight suffix of the layer - bias (str): The bias suffix of the layer + suffix: (str): the suffix of the layer. replace_layer (:class:`colosalai.nn`): The layer to replace the original layer ignore (bool): Whether to ignore this layer if it is not in the model reversed (bool): Whether the weight in layer is reversed, commonly the weight in `torch.nn.Linear` is [out, in], @@ -35,8 +34,7 @@ class Layer: but in multi-head attention, we need to chunk the weight with the number of devices * n_head, and each device should have a part of Q, K and V weight. """ - weight: str = None - bias: str = None + suffix: str = None replace_layer: Any = None ignore: bool = False reversed: bool = False @@ -46,20 +44,40 @@ class Layer: @dataclass class Col_Layer(Layer): r""" - Class for col shard layer in MegatronLM + Class for col shard layer in tensor parrallel Args: + weight (str): The weight suffix of the layer + bias (str): The bias suffix of the layer gather_output (bool): Whether to gather the output of the layer """ + weight: str = None + bias: str = None gather_output: bool = False @dataclass class Row_Layer(Layer): r""" - Class for col shard layer in MegatronLM + Class for col shard layer in tensor parrallel + + Args: + weight (str): The weight suffix of the layer + bias (str): The bias suffix of the layer """ - pass + weight: str = None + bias: str = None + + +@dataclass +class Dropout_Layer(Layer): + r""" + Class for dropout layer in tensor parrallel + + Args: + p (str): The dropout rate suffix of the layer + """ + p: str = None class Policy(): @@ -82,14 +100,14 @@ class for the example. """ @staticmethod - def argument_policy(model_config, shard_config: int) -> Dict[nn.Module, Argument]: + def argument_policy(model_config, world_size: int) -> Dict[nn.Module, Argument]: r""" Return the dict for the modify policy, the key is the original layer class and the value is the argument for the modify layer Args: model_config (:class:`tansformer.Config`): The config of transformer model - shard_config (:class:`ShardConfig`): The config for sharding model + world_size (int)): The world size of sharding model Return: Dict for the modify policy, @@ -126,7 +144,7 @@ def argument_policy(model_config, shard_config: int) -> Dict[nn.Module, Argument raise NotImplementedError @staticmethod - def inject_policy() -> Tuple[nn.Module, nn.Module]: + def inject_policy() -> Union[Tuple[nn.Module, nn.Module], None]: r""" Return the dict for the inject model @@ -139,9 +157,9 @@ def inject_policy() -> Tuple[nn.Module, nn.Module]: return None @staticmethod - def binding_policy() -> Dict: + def binding_policy() -> Union[Dict[str, str], None]: r""" - Return the dict for the binding model + Return the dict for the binding model, None means no need to bind Return: This method should return the binding relationship for some layers share the weight or bias, @@ -154,7 +172,7 @@ def binding_policy() -> Dict: return None @staticmethod - def attn_in() -> List: + def attn_in() -> Union[List, None]: r""" Attention qkv layer In this kind of method, we should return the list of ``Layer`` object, each ``Layer`` object should be @@ -164,52 +182,52 @@ def attn_in() -> List: Returns: List[Layer]: List of layer object, each layer is the new """ - return NotImplementedError + return None @staticmethod - def attn_out() -> List: + def attn_out() -> Union[List, None]: r""" Attention output projection layer Returns: List[Layer]: List of layer object """ - return NotImplementedError + return None @staticmethod - def mlp_in() -> List: + def mlp_in() -> Union[List, None]: r""" h -> 4h mlp layer Returns: List[Layer]: List of layer object """ - return NotImplementedError + return None @staticmethod - def mlp_out() -> List: + def mlp_out() -> Union[List, None]: r""" 4h -> h mlp layer Returns: List[Layer]: List of layer object """ - return NotImplementedError + return None @staticmethod - def embedding() -> List: + def embedding() -> Union[List, None]: r""" Partially slice the embedding layer Return: List[Layer]: List of layer object """ - return NotImplementedError + return None @staticmethod - def unembedding() -> List: + def unembedding() -> Union[List, None]: r""" - Partially slice the embedding layer + Partially slice the embedding layer, None means there is no unembedding layer Return: List[Layer]: List of layer object diff --git a/colossalai/shardformer/policies/bert.py b/colossalai/shardformer/policies/bert.py index 5d489f41986c..67e910d521e9 100644 --- a/colossalai/shardformer/policies/bert.py +++ b/colossalai/shardformer/policies/bert.py @@ -5,7 +5,7 @@ import colossalai.shardformer.layer.layers as col_nn -from .basepolicy import Argument, Col_Layer, Layer, Policy, Row_Layer +from .basepolicy import Argument, Col_Layer, Dropout_Layer, Policy, Row_Layer class BertPolicy(Policy): @@ -28,123 +28,126 @@ def argument_policy(config, world_size: int) -> Dict[nn.Module, Argument]: Argument( attr_dict={ # 1. shard vocab size - # "word_embeddings.num_embeddings": config.vocab_size // world_size, - # 2. add the size of the sliced embedding layer excluding the last slice "word_embeddings.dim_size": (config.vocab_size + world_size - 1) // world_size, }, param_funcs=[ BertPolicy.embedding, ]), - BertLMPredictionHead: - Argument( - attr_dict={ - # 1. shard vocab size - # "word_embeddings.num_embeddings": config.vocab_size // world_size, - # 2. add the size of the sliced embedding layer excluding the last slice - }, - param_funcs=[ - BertPolicy.unembedding, - ]) } @staticmethod - def binding_policy() -> Dict: + def binding_policy(): return { "bert.embeddings.word_embeddings.weight": "cls.predictions.decoder.weight", } @staticmethod - def attn_in() -> List: + def attn_in(): return [ Col_Layer( - weight="attention.self.query.weight", - bias="attention.self.query.bias", + suffix="attention.self.query", + weight="weight", + bias="bias", replace_layer=col_nn.Linear1D_Col, ), Col_Layer( - weight="attention.self.key.weight", - bias="attention.self.key.bias", + suffix="attention.self.key", + weight="weight", + bias="bias", replace_layer=col_nn.Linear1D_Col, ), Col_Layer( - weight="attention.self.value.weight", - bias="attention.self.value.bias", + suffix="attention.self.value", + weight="weight", + bias="bias", replace_layer=col_nn.Linear1D_Col, ), + Dropout_Layer( + suffix="attention.self.dropout", + p="p", + replace_layer=col_nn.Dropout1D, + ), Col_Layer( - weight="crossattention.self.query.weight", - bias="crossattention.self.query.bias", + suffix="crossattention.self.query", + weight="weight", + bias="bias", replace_layer=col_nn.Linear1D_Col, ignore=True, ), Col_Layer( - weight="crossattention.self.key.weight", - bias="crossattention.self.key.bias", + suffix="crossattention.self.key", + weight="weight", + bias="bias", replace_layer=col_nn.Linear1D_Col, ignore=True, ), Col_Layer( - weight="crossattention.self.value.weight", - bias="crossattention.self.value.bias", + suffix="crossattention.self.value", + weight="weight", + bias="bias", replace_layer=col_nn.Linear1D_Col, ignore=True, ), ] @staticmethod - def attn_out() -> List: + def attn_out(): return [ Row_Layer( - weight="attention.output.dense.weight", - bias="attention.output.dense.bias", + suffix="attention.output.dense", + weight="weight", + bias="bias", replace_layer=col_nn.Linear1D_Row, ), + Dropout_Layer( + suffix="attention.output.dropout", + p="p", + replace_layer=col_nn.Dropout1D, + ), Row_Layer( - weight="crossattention.output.dense.weight", - bias="crossattention.output.dense.bias", + suffix="crossattention.output.dense", + weight="weight", + bias="bias", replace_layer=col_nn.Linear1D_Row, ignore=True, ), ] @staticmethod - def mlp_in() -> List: + def mlp_in(): return [ Col_Layer( - weight="intermediate.dense.weight", - bias="intermediate.dense.bias", + suffix="intermediate.dense", + weight="weight", + bias="bias", replace_layer=col_nn.Linear1D_Col, ), ] @staticmethod - def mlp_out() -> List: + def mlp_out(): return [ Row_Layer( - weight="output.dense.weight", - bias="output.dense.bias", + suffix="output.dense", + weight="weight", + bias="bias", replace_layer=col_nn.Linear1D_Row, ), + Dropout_Layer( + suffix="output.dropout", + p="p", + replace_layer=col_nn.Dropout1D, + ) ] @staticmethod - def embedding() -> List: + def embedding(): return [Col_Layer( - weight="word_embeddings.weight", + suffix="word_embeddings", + weight="weight", replace_layer=col_nn.VocabParallelEmbedding1D, )] - @staticmethod - def unembedding() -> List: - return [ - Col_Layer( - weight="decoder.weight", - bias="decoder.bias", - replace_layer=col_nn.Linear1D_Col, - gather_output=True, - ) - ] - from transformers import BertForMaskedLM @@ -154,18 +157,36 @@ def unembedding() -> List: class BertForMaskedLMPolicy(BertPolicy): @staticmethod - def inject_policy() -> Tuple[nn.Module, nn.Module]: + def argument_policy(config, world_size): + base_argument = BertPolicy.argument_policy(config, world_size) + argument = { + BertLMPredictionHead: Argument(attr_dict={}, param_funcs=[ + BertForMaskedLMPolicy.unembedding, + ]), + } + argument.update(base_argument) + return argument + + @staticmethod + def inject_policy(): # return (BertForMaskedLM, BertForMaskedLM_) return None + @staticmethod + def unembedding(): + return [ + Col_Layer( + suffix="decoder", + weight="weight", + bias="bias", + replace_layer=col_nn.Linear1D_Col, + gather_output=True, + ) + ] + class BertForSequenceClassificationPolicy(BertPolicy): @staticmethod - def inject_policy() -> Dict: - return {} - - -# model = BertForMaskedLM.from_pretrained("bert-base-uncased") -# _ = BertForMaskedLMPolicy(model) -# print(isinstance(model,list(_.inject_policy().keys())[0])) + def inject_policy(): + return None diff --git a/colossalai/shardformer/policies/gpt2.py b/colossalai/shardformer/policies/gpt2.py index 44dc9c72f986..0d4342e75783 100644 --- a/colossalai/shardformer/policies/gpt2.py +++ b/colossalai/shardformer/policies/gpt2.py @@ -40,19 +40,22 @@ def argument_policy(config, world_size): @staticmethod def attn_in() -> List: return [ - Col_Layer(weight="attn.c_attn.weight", - bias="attn.c_attn.bias", + Col_Layer(suffix="attn.c_attn", + weight="weight", + bias="bias", n_cast=3, reversed=True, replace_layer=col_nn.Linear1D_Col), - Col_Layer(weight="crossattention.c_attn.weight", - bias="crossattention.c_attn.bias", + Col_Layer(suffix="crossattention.c_attn", + weight="weight", + bias="bias", n_cast=2, reversed=True, ignore=True, replace_layer=col_nn.Linear1D_Col), - Col_Layer(weight="crossattention.q_attn.weight", - bias="crossattention.q_attn.bias", + Col_Layer(suffix="crossattention.q_attn", + weight="weight", + bias="bias", reversed=True, ignore=True, replace_layer=col_nn.Linear1D_Col) @@ -61,12 +64,14 @@ def attn_in() -> List: @staticmethod def attn_out() -> List: return [ - Row_Layer(weight="attn.c_proj.weight", - bias="attn.c_proj.bias", + Row_Layer(suffix="attn.c_proj", + weight="weight", + bias="bias", reversed=True, replace_layer=col_nn.Linear1D_Row), - Row_Layer(weight="crossattention.c_proj.weight", - bias="crossattention.c_proj.bias", + Row_Layer(suffix="crossattention.c_proj", + weight="weight", + bias="bias", reversed=True, ignore=True, replace_layer=col_nn.Linear1D_Row) @@ -75,21 +80,23 @@ def attn_out() -> List: @staticmethod def mlp_in() -> List: return [ - Col_Layer(weight="mlp.c_fc.weight", bias="mlp.c_fc.bias", reversed=True, replace_layer=col_nn.Linear1D_Col), + Col_Layer(suffix="mlp.c_fc", weight="weight", bias="bias", reversed=True, + replace_layer=col_nn.Linear1D_Col), ] @staticmethod def mlp_out() -> List: return [ - Row_Layer(weight="mlp.c_proj.weight", - bias="mlp.c_proj.bias", + Row_Layer(suffix="mlp.c_proj", + weight="weight", + bias="bias", reversed=True, replace_layer=col_nn.Linear1D_Row) ] @staticmethod def embedding() -> List: - return [Col_Layer(weight="wte.weight", replace_layer=col_nn.VocabParallelEmbedding1D)] + return [Col_Layer(suffix="wte", weight="weight", replace_layer=col_nn.VocabParallelEmbedding1D)] from transformers import GPT2LMHeadModel @@ -111,8 +118,9 @@ def argument_policy(config, world_size): @staticmethod def unembedding() -> List: return [ - Col_Layer(weight="lm_head.weight", - bias="lm_head.bias", + Col_Layer(suffix="lm_head", + weight="weight", + bias="bias", replace_layer=col_nn.Linear1D_Col, gather_output=True) ] diff --git a/colossalai/shardformer/shard/sharder.py b/colossalai/shardformer/shard/sharder.py index 159bebccd02d..95184cfe6929 100644 --- a/colossalai/shardformer/shard/sharder.py +++ b/colossalai/shardformer/shard/sharder.py @@ -5,7 +5,7 @@ from transformers.pytorch_utils import Conv1D from ..policies.autopolicy import get_autopolicy -from ..policies.basepolicy import Policy +from ..policies.basepolicy import Col_Layer, Dropout_Layer, Policy, Row_Layer from ..utils.utils import getattr_, hasattr_, setattr_ from .shard_config import ShardConfig from .slicer import Slicer @@ -141,65 +141,73 @@ def shard_one_layer( for func in param_funcs: policy_layers = func() for policy_layer in policy_layers: - weight = None - bias = None - weight_attr = policy_layer.weight - bias_attr = policy_layer.bias + suffix = policy_layer.suffix replace_layer_cls = policy_layer.replace_layer ignore = policy_layer.ignore - n_cast = policy_layer.n_cast reversed = policy_layer.reversed - if policy_layer.__class__.__name__ == "Col_Layer": - gather_output = policy_layer.gather_output and self.shard_config.gather_output - - if weight_attr is not None: - if hasattr_(org_layer, weight_attr): - weight = getattr_(org_layer, weight_attr) - elif not ignore: - raise ValueError(f"Layer {org_layer.__class__.__qualname__} has no attribute {weight_attr}") - - if bias_attr is not None: - if hasattr_(org_layer, bias_attr): - bias = getattr_(org_layer, bias_attr) - elif not ignore: - raise ValueError(f"Layer {org_layer.__class__.__qualname__} has no attribute {bias_attr}") - - # dont have the attribute in policy, and ignore is true - if weight is None and bias is None and ignore: - continue - - # set the sliced weight and bias to the new nn_col layer - assert weight is not None or bias is not None - layer_attr = (lambda x: x[:x.rfind(".")])(weight_attr or bias_attr) + n_cast = policy_layer.n_cast - # slice weight and bias - weight, bias = self.slicer.slice_weight_bias(weight, bias, policy_layer.__class__, n_cast, reversed) + assert replace_layer_cls is not None, 'replace_layer should not be None' # create new object to replace the origin layer - if replace_layer_cls is not None: - if isinstance(getattr_(org_layer, layer_attr), (nn.Linear, Conv1D)): - if replace_layer_cls.__name__ == "Linear1D_Row": - replace_layer = replace_layer_cls(weight.shape[1], - weight.shape[0], - bias=False if bias is None else True) - elif replace_layer_cls.__name__ == "Linear1D_Col": - replace_layer = replace_layer_cls(weight.shape[0], - weight.shape[1], - bias=False if bias is None else True, - gather_output=gather_output) - setattr_(org_layer, layer_attr, replace_layer, ignore=ignore) - self.set_param(replace_layer, weight, bias) - elif isinstance(getattr_(org_layer, layer_attr), nn.Embedding): + # Linear + suffix_layer = getattr_(org_layer, suffix, ignore=True) + assert suffix_layer is not None or ignore, f"Layer {org_layer.__class__.__qualname__} has no attribute {suffix}" + if suffix_layer is None and ignore: + continue + if isinstance(policy_layer, (Col_Layer, Row_Layer)): + weight = None + bias = None + weight_attr = suffix + '.' + policy_layer.weight if policy_layer.weight is not None else None + bias_attr = suffix + '.' + policy_layer.bias if policy_layer.bias is not None else None + + if weight_attr is not None: + if hasattr_(org_layer, weight_attr): + weight = getattr_(org_layer, weight_attr) + else: + raise ValueError(f"Layer {org_layer.__class__.__qualname__} has no attribute {weight_attr}") + + if bias_attr is not None: + if hasattr_(org_layer, bias_attr): + bias = getattr_(org_layer, bias_attr) + else: + raise ValueError(f"Layer {org_layer.__class__.__qualname__} has no attribute {bias_attr}") + + # set the sliced weight and bias to the new nn_col layer + assert weight is not None or bias is not None + + # slice weight and bias + weight, bias = self.slicer.slice_weight_bias(weight, bias, policy_layer.__class__, n_cast, reversed) + + if replace_layer_cls.__name__ == "Linear1D_Row": + replace_layer = replace_layer_cls(weight.shape[1], + weight.shape[0], + bias=False if bias is None else True) + elif replace_layer_cls.__name__ == "Linear1D_Col": + gather_output = policy_layer.gather_output and self.shard_config.gather_output + replace_layer = replace_layer_cls(weight.shape[0], + weight.shape[1], + bias=False if bias is None else True, + gather_output=gather_output) + elif replace_layer_cls.__name__ == "VocabParallelEmbedding1D": replace_layer = replace_layer_cls(weight.shape[0], weight.shape[1], - getattr_(org_layer, f"{layer_attr}.padding_idx", ignore=True)) - setattr_(org_layer, layer_attr, replace_layer, ignore=ignore) - self.set_param(replace_layer, weight, bias) + getattr_(org_layer, f"{suffix}.padding_idx", ignore=True)) + # setattr_(org_layer, suffix, replace_layer, ignore=ignore) + # self.set_param(replace_layer, weight, bias) else: raise NotImplementedError( - f"Replacing {getattr_(org_layer, layer_attr).__class__} is not implemented so far") - # do not replace the layer object, just replace the weight and bias + f"Replacing to {replace_layer_cls.__name__} is not implemented so far") + setattr_(org_layer, suffix, replace_layer, ignore=ignore) + self.set_param(replace_layer, weight, bias) + # dropout + elif isinstance(policy_layer, Dropout_Layer): + p_attr = suffix + '.' + policy_layer.p + p = getattr_(org_layer, p_attr, ignore=True) + replace_layer = replace_layer_cls(p) + setattr_(org_layer, suffix, replace_layer, ignore=ignore) else: - self.set_param(org_layer, layer_attr, weight, bias) + raise NotImplementedError( + f"Replacing {getattr_(org_layer, suffix).__class__} is not implemented so far") def set_param(self, layer: Any, diff --git a/colossalai/shardformer/shard/slicer.py b/colossalai/shardformer/shard/slicer.py index 09e3219f87a2..0bf8f58b8544 100644 --- a/colossalai/shardformer/shard/slicer.py +++ b/colossalai/shardformer/shard/slicer.py @@ -1,6 +1,6 @@ import torch -from ..policies.basepolicy import Col_Layer, Layer, Row_Layer +from ..policies.basepolicy import Col_Layer, Dropout_Layer, Layer, Row_Layer from .shard_config import ShardConfig dim_mapping = {Col_Layer: 0, Row_Layer: 1} @@ -33,7 +33,7 @@ def slice_weight_bias( bias: (:class:`torch.nn.Module`): The bias of the layer policy_layer_class (:class:`Policy`): The class represent how to slice the tensor """ - if policy_layer_cls == Layer: + if policy_layer_cls in [Layer, Dropout_Layer]: return weight, bias dim = dim_mapping[policy_layer_cls] if not reversed else (1 - dim_mapping[policy_layer_cls]) diff --git a/colossalai/shardformer/utils/utils.py b/colossalai/shardformer/utils/utils.py index eb84edd88404..2c02b6f69a3e 100644 --- a/colossalai/shardformer/utils/utils.py +++ b/colossalai/shardformer/utils/utils.py @@ -37,7 +37,7 @@ def setattr_(obj, attr: str, value, ignore: bool = False): setattr(obj, attrs[-1], value) -def getattr_(obj, attr: str, ignore: bool = None): +def getattr_(obj, attr: str, ignore: bool = False): r""" Get the object's multi sublevel attr From 17d160739edb4187c001ae38dd0b21fa95a9a853 Mon Sep 17 00:00:00 2001 From: wukong1992 Date: Tue, 13 Jun 2023 14:44:40 +0800 Subject: [PATCH 12/49] [shardformer] support llama model using shardformer (#3969) adjust layer attr --- .../shardformer/layer/dist_crossentropy.py | 2 +- colossalai/shardformer/policies/autopolicy.py | 14 ++ colossalai/shardformer/policies/llama.py | 122 ++++++++++++++++++ .../test_model/test_shard_llama.py | 106 +++++++++++++++ 4 files changed, 243 insertions(+), 1 deletion(-) create mode 100644 colossalai/shardformer/policies/llama.py create mode 100644 tests/test_shardformer/test_model/test_shard_llama.py diff --git a/colossalai/shardformer/layer/dist_crossentropy.py b/colossalai/shardformer/layer/dist_crossentropy.py index 05c04bb545c1..ff05209fefe8 100644 --- a/colossalai/shardformer/layer/dist_crossentropy.py +++ b/colossalai/shardformer/layer/dist_crossentropy.py @@ -21,7 +21,7 @@ def forward(ctx, vocab_logits: torch.Tensor, target: torch.Tensor, ignore_index: and can be rewrite as: loss = log(sum(exp(x[i])) - x[class] - To avoid the `nan` of log(sim(exp(x[i]))), we minus the max of x[i] + To avoid the `nan` of log(sum(exp(x[i]))), we minus the max of x[i] Args: vocab_logits (:class:`torch.Tensor`): The logits of the vocabulary, shape is diff --git a/colossalai/shardformer/policies/autopolicy.py b/colossalai/shardformer/policies/autopolicy.py index 54cc63ba124f..27fd09b4561b 100644 --- a/colossalai/shardformer/policies/autopolicy.py +++ b/colossalai/shardformer/policies/autopolicy.py @@ -19,6 +19,20 @@ def build_policies(): from .bert import BertForSequenceClassificationPolicy auto_policy_dict[BertForSequenceClassification] = BertForSequenceClassificationPolicy + from transformers.models.llama.modeling_llama import LlamaModel + + from .llama import LlamaPolicy + auto_policy_dict[LlamaModel] = LlamaPolicy + + from transformers import LlamaForSequenceClassification + + from .llama import LlamaForSequenceClassificationPolicy + auto_policy_dict[LlamaForSequenceClassification] = LlamaForSequenceClassificationPolicy + + from transformers import LlamaForCausalLM + + from .llama import LlamaForCausalLMPolicy + auto_policy_dict[LlamaForCausalLM] = LlamaForCausalLMPolicy from transformers import GPT2Model diff --git a/colossalai/shardformer/policies/llama.py b/colossalai/shardformer/policies/llama.py new file mode 100644 index 000000000000..fac6765cdcb5 --- /dev/null +++ b/colossalai/shardformer/policies/llama.py @@ -0,0 +1,122 @@ +from dataclasses import dataclass +from typing import Any, Callable, Dict, List, Tuple, Type + +import torch.nn as nn +from transformers.models.llama.modeling_llama import LlamaDecoderLayer, LlamaModel + +import colossalai.shardformer.layer.layers as col_nn + +from .basepolicy import Argument, Col_Layer, Policy, Row_Layer + + +class LlamaPolicy(Policy): + + @staticmethod + def argument_policy(config, world_size: int) -> Dict[nn.Module, Argument]: + return { + LlamaDecoderLayer: + Argument(attr_dict={ + "self_attn.hidden_size": config.hidden_size // world_size, + "self_attn.num_heads": config.num_attention_heads // world_size, + }, + param_funcs=[LlamaPolicy.attn_layer, LlamaPolicy.mlp_layer]), + LlamaModel: + Argument(attr_dict={}, param_funcs=[LlamaPolicy.embeddings]) + } + + @staticmethod + def attn_layer() -> List: + return [ + Col_Layer( + suffix="self_attn.q_proj", + weight="weight", + bias="bias", + replace_layer=col_nn.Linear1D_Col, + ), + Col_Layer( + suffix="self_attn.k_proj", + weight="weight", + bias="bias", + replace_layer=col_nn.Linear1D_Col, + ), + Col_Layer( + suffix="self_attn.v_proj", + weight="weight", + bias="bias", + replace_layer=col_nn.Linear1D_Col, + ), + Row_Layer( + suffix="self_attn.o_proj", + weight="weight", + bias="bias", + replace_layer=col_nn.Linear1D_Row, + ) + ] + + @staticmethod + def mlp_layer() -> List: + return [ + Col_Layer( + suffix="mlp.gate_proj", + weight="weight", + bias="bias", + replace_layer=col_nn.Linear1D_Col, + gather_output=True, + ), + Col_Layer( + suffix="mlp.up_proj", + weight="weight", + bias="bias", + replace_layer=col_nn.Linear1D_Row, + gather_output=True, + ), + Col_Layer( + suffix="mlp.down_proj", + weight="weight", + bias="bias", + replace_layer=col_nn.Linear1D_Col, + gather_output=True, + ), + ] + + @staticmethod + def embeddings() -> List: + return [Col_Layer( + suffix="embed_tokens", + weight="weight", + replace_layer=col_nn.VocabParallelEmbedding1D, + )] + +from transformers import LlamaForCausalLM + + +class LlamaForCausalLMPolicy(LlamaPolicy): + + @staticmethod + def argument(config, world_size): + llamapolicy = LlamaPolicy.argument_policy(config, world_size) + argument = {LlamaForCausalLM: Argument(attr_dict={}, param_funcs=[LlamaForCausalLMPolicy.lm_head])} + argument.update(llamapolicy) + + @staticmethod + def lm_head() -> List: + return [Col_Layer(suffix="lm_head", weight="weight", replace_layer=col_nn.Linear1D_Col, gather_output=True)] + + +from transformers import LlamaForSequenceClassification + + +class LlamaForSequenceClassificationPolicy(LlamaPolicy): + + @staticmethod + def argument(config, world_size): + llamapolicy = LlamaPolicy.argument_policy(config, world_size) + argument = { + LlamaForSequenceClassification: + Argument(attr_dict={}, param_funcs=[LlamaForSequenceClassificationPolicy.score]) + } + argument.update(llamapolicy) + + @staticmethod + def score() -> List: + return [Col_Layer(suffix="score", weight="weight", replace_layer=col_nn.Linear1D_Col, gather_output=True)] diff --git a/tests/test_shardformer/test_model/test_shard_llama.py b/tests/test_shardformer/test_model/test_shard_llama.py new file mode 100644 index 000000000000..689898bbbad2 --- /dev/null +++ b/tests/test_shardformer/test_model/test_shard_llama.py @@ -0,0 +1,106 @@ +import copy +import os +import random + +import pytest +import torch +from transformers import AutoTokenizer, LlamaConfig, LlamaForCausalLM, LlamaModel, LlamaTokenizerFast + +import colossalai +from colossalai.logging import disable_existing_loggers +from colossalai.shardformer.shard import ShardConfig, shard_model +from colossalai.testing import rerun_if_address_is_in_use, spawn + +os.environ['TRANSFORMERS_NO_ADVISORY_WARNINGS'] = 'true' +CONFIG = dict(parallel=dict(data=1, pipeline=1, tensor=dict(size=4, mode='1d')),) +tokenizer = LlamaTokenizerFast.from_pretrained("hf-internal-testing/llama-tokenizer") + + +def build_model(rank, world_size): + cfg = LlamaConfig(num_hidden_layers=16) + org_model = LlamaForCausalLM(cfg) + + shardconfig = ShardConfig( + rank=rank, + world_size=world_size, + gather_output=True, + ) + org_model = org_model.to('cuda') + + org_model_forshard = copy.deepcopy(org_model) + sharded_model = shard_model(org_model_forshard, shardconfig).to('cuda') + + return org_model, sharded_model + + +def check_forward(org_model, sharded_model): + input = 'Hello, my dog is cute' + inputs = tokenizer(input, return_tensors='pt').to('cuda') + del inputs["token_type_ids"] + del inputs["attention_mask"] + #orgin model + org_model.eval() + org_out = org_model(**inputs) + + #shard model + sharded_model.eval() + shard_out = sharded_model(**inputs) + + assert torch.allclose( + org_out[0], shard_out[0], + atol=1e-4), f"shard model output is not equal to orgin model output\n{org_out[0]}\n{shard_out[0]}" + + +def check_backward(org_model, sharded_model): + # prepare input + input = 'Hello, my dog is cute' + tokenized_input = tokenizer(input, return_tensors='pt').to('cuda') + del tokenized_input["token_type_ids"] + del tokenized_input["attention_mask"] + labels = tokenized_input['input_ids'].clone() + labels[labels == tokenizer.pad_token_id] = -100 + tokenized_input['labels'] = labels + + #orgin model + org_model.train() + org_out = org_model(**tokenized_input) + org_loss = org_out.loss + org_loss.backward() + org_grad = org_model.model.layers[0].self_attn.q_proj.weight.grad + + torch.cuda.empty_cache() + #shard model + sharded_model.train() + shard_out = sharded_model(**tokenized_input) + shard_loss = shard_out.loss + shard_loss.backward() + shard_grad = sharded_model.model.layers[0].self_attn.q_proj.weight.grad + shard_grad_list = [torch.zeros([*shard_grad.shape]).to('cuda') for _ in range(4)] + shard_grad = torch.distributed.all_gather(shard_grad_list, shard_grad) + all_shard_grad = torch.cat(shard_grad_list, dim=0) + + assert torch.allclose(org_loss, shard_loss, + atol=1e-5), f"shard model loss is not equal to orgin model loss\n{org_loss}\n{shard_loss}" + assert torch.allclose(org_grad, all_shard_grad, + atol=1e-5), f"shard model grad is not equal to orgin model grad\n{org_grad}\n{shard_grad}" + + +def check_llama(rank, world_size, port): + disable_existing_loggers() + colossalai.launch(config=CONFIG, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') + + org_model, sharded_model = build_model(rank, world_size) + check_forward(org_model, sharded_model) + check_backward(org_model, sharded_model) + + torch.cuda.empty_cache() + + +@pytest.mark.dist +@rerun_if_address_is_in_use() +def test_llama(): + spawn(check_llama, 4) + + +if __name__ == "__main__": + test_llama() From e849d1b2bf05a9cb81a3b91aa58be69631f33eaf Mon Sep 17 00:00:00 2001 From: wukong1992 Date: Thu, 15 Jun 2023 16:50:08 +0800 Subject: [PATCH 13/49] [shardformer] shardformer support t5 model (#3994) test t5 --- applications/Chat/coati/trainer/.sft.py.swp | Bin 0 -> 20480 bytes colossalai/shardformer/layer/layers.py | 8 +- colossalai/shardformer/policies/autopolicy.py | 9 + colossalai/shardformer/policies/basepolicy.py | 12 ++ colossalai/shardformer/policies/t5.py | 159 ++++++++++++++++++ colossalai/shardformer/shard/sharder.py | 11 +- colossalai/shardformer/shard/slicer.py | 6 +- colossalai/shardformer/utils/utils.py | 25 ++- requirements/requirements-test.txt | 1 + .../test_model/test_shard_t5.py | 99 +++++++++++ 10 files changed, 320 insertions(+), 10 deletions(-) create mode 100644 applications/Chat/coati/trainer/.sft.py.swp create mode 100644 colossalai/shardformer/policies/t5.py create mode 100644 tests/test_shardformer/test_model/test_shard_t5.py diff --git a/applications/Chat/coati/trainer/.sft.py.swp b/applications/Chat/coati/trainer/.sft.py.swp new file mode 100644 index 0000000000000000000000000000000000000000..302cf2a775338fb4fcd6b9b12c1a8e80f3969a01 GIT binary patch literal 20480 zcmeHOU5q4E6)psY74VP5@FaJ89H$z3t7mp0aohBo-E~(9yMp_JSthfc>h9aq#Z*@{ zRn;uJ!vHZRh#HOZ#2YFJFB0QVSObsxWI{~Lk{BNh;)6b@tf-4W;QF2WQ+2DldsbdF zA-9unZdKiL?>+as=box_s;0Yq@0~~4UT@jK?*oo=(;FxI-#GnI=k?bd=l&oa`tc}; zsBW8IF4yNr58UI0{+tGnwG2>WBO9Z2IGI81(!sh@-T5aOh{vUW~J5 zmiR%$lV&=|yz$v>bg)&H0n5OZ7`V~dvwKBs^@Z9_cGHJWUa5htw+vVYECZGS%YbFT zGGH073|I!UFsSOzQumI2Ga|A7I&?>KiO*ng4|0C@keHBLF9S~i4+E<} z3%D7$h(y4zfBv-t%5*a`ZrO?|Z0V174>laU{aAzEycM%t>?or)eT z7)T}dwb^d(4(3OF7NniTk2X5XZoM_Fg3>&mWaCMedUQH1q-rrkF||?L=bgLq$?QqGObFGlw zsur=BXQ9tuXN* zT09OpAri(tKNK?cdNBm{(In*^F^V*UJ|D0FMx$9b)UIt9LaZCdA#DAyAEq3p115z* z{>`AYG)r^FmE&kK>WKXTk05G(MX;eL>PT*BWlX2cfHuQQt2;ICh7{ju+ zyrA2V(pmqIh0}G%TZ!m0! zuzJ@o<|mwylTIh&ahzlv_n``0Jis|OV;YhZYFtuXv)eOs7m2C0u1LIu6La(7W#S>G1>6z+*f z#61zckXD6kb!|=ILP}U%i`_kAn<|c4OtV$A62@4Zk{!!4@_f0?{%%}z zh}c@KCQ3RMT6H$ylGan^_JdwVH;sWm3A2=CG22`Ulcm-Q-*x*|y~VO2XG#`z@Wt%R zFbFxg4YI%wgL80=vP*JA`5^hSF}J*I<*`=R)DVR{MxhhdpCr_^V44oplT1EBv6jrm zLYMp9t-6In>cS1e&XpfS@1n=N8~JRPLiriptf+2;5EDXj_FAjFn-5K6S(muQS)X-u*-Q38O%Hk?52`wot8x1|xR_{*cYCBZT^NQ_SOZVw! zanc)lQRE$s>8_)Ckd-CA_rv0hN8c0j>twB9nOT-_SZH6mW=}9hM%)X1JmDWdd_pyL zRcYMl|6v-xGf3hQ^CnpkrYsnd3oM>j@#GPA`uxG5$CaF#cEEGtX}{LU_pEAE*sdDj zpz=|wzUza2Ri)H*WL1Xe-4qKm;z{b^gg(~RlBFB>Eb(}QN5MH6*NFMy$*>sco`lP> z`X{6yPY$3ucN_mSSki9eU&b}T8^sG+A39cQi1ijjdSs(oqVSidoG?T!l3S^x_z(|v zxl(UJ0sR=R1f+4{IX_E+?u71NeoE0trh5wgVA3!DOO z27ZE={t@69@Ht>7@G5lu4e%WB7*GRt0yKvuD(%lQU>UFsSOzQumVs>I)s>jiJTo%Io9T8B?Yr36Urjv+!=8{+FB~71M~<8-s$A*#%|0$B8Agnj(`}mV2&{9PSOKJJ^azp3@v(PnE z!oFvzi?B{4y2|MzYJiQ2?KN%6rMDy7Y}o~s2*;LnH)}M%TFW%rt+IY{$nyRcq1em~ zF7pyw>;PriS@Eig9mCdTo0XkxIsvt1$PMT)=8rgVmBZ54{$zwJq?-9#ATos!(2Ju* zOlloD$&4rDj-=CYOnH5*UxO5+=^C@#=FNgBHiJI1I)4ub(zU9NDNsGG2}&U?%K2M~ z+SS4Yg2J5gB)M>$NHb)e5Lw#Q_QrZG26NCeC8`7U!w>MbEwzq59?&SKNOkO z%6pM&i^`;E$toL3w-AzfzQ{w)NN-I?iWtKD+hn9b%7 literal 0 HcmV?d00001 diff --git a/colossalai/shardformer/layer/layers.py b/colossalai/shardformer/layer/layers.py index f5123885bbe4..a9f3cf5ad14c 100644 --- a/colossalai/shardformer/layer/layers.py +++ b/colossalai/shardformer/layer/layers.py @@ -770,6 +770,7 @@ def __init__(self, embedding_dim: int, padding_idx: int = None, dtype: torch.dtype = None, + gather_output: bool = True, weight_initializer: Callable = init.normal_(), *args, **kwargs): @@ -782,6 +783,7 @@ def __init__(self, self.padding_idx = padding_idx self.embed_args = args self.embed_kwargs = kwargs + self.gather_output = gather_output self.weight = Parameter( torch.empty((num_embeddings, embed_dim_per_partition), device=get_current_device(), dtype=dtype)) @@ -832,8 +834,10 @@ def _save_to_global_state_dict(self, destination, prefix, keep_vars): def forward(self, input_: Tensor) -> Tensor: output_parallel = F.embedding(input_, self.weight, self.padding_idx, *self.embed_args, **self.embed_kwargs) - - output = gather_forward_split_backward(output_parallel, ParallelMode.PARALLEL_1D, dim=-1) + if self.gather_output: + output = gather_forward_split_backward(output_parallel, ParallelMode.PARALLEL_1D, dim=-1) + else: + output = output_parallel return output diff --git a/colossalai/shardformer/policies/autopolicy.py b/colossalai/shardformer/policies/autopolicy.py index 27fd09b4561b..d4425497bd8e 100644 --- a/colossalai/shardformer/policies/autopolicy.py +++ b/colossalai/shardformer/policies/autopolicy.py @@ -43,6 +43,15 @@ def build_policies(): from .gpt2 import GPT2LMHeadModelPolicy auto_policy_dict[GPT2LMHeadModel] = GPT2LMHeadModelPolicy + + from .t5 import T5ForConditionalGenerationPolicy, T5EncoderModelPolicy, T5ModelPolicy + from transformers import T5ForConditionalGeneration, T5EncoderModel, T5Model + t5 = { + T5ForConditionalGeneration: T5ForConditionalGenerationPolicy, + T5EncoderModel: T5EncoderModelPolicy, + T5Model: T5ModelPolicy, + } + auto_policy_dict.update(t5) return auto_policy_dict diff --git a/colossalai/shardformer/policies/basepolicy.py b/colossalai/shardformer/policies/basepolicy.py index d55df59fdc8b..ba3a97f1bbcd 100644 --- a/colossalai/shardformer/policies/basepolicy.py +++ b/colossalai/shardformer/policies/basepolicy.py @@ -80,6 +80,18 @@ class Dropout_Layer(Layer): p: str = None +@dataclass +class Embedding_Layer(Layer): + r""" + Class for col shard layer in tensor parrallel + + Args: + weight (str): The weight suffix of the layer + """ + weight: str = None + gather_output: bool = True + + class Policy(): r""" The base class for all the policies diff --git a/colossalai/shardformer/policies/t5.py b/colossalai/shardformer/policies/t5.py new file mode 100644 index 000000000000..7b013a37845a --- /dev/null +++ b/colossalai/shardformer/policies/t5.py @@ -0,0 +1,159 @@ +from typing import Dict + +import torch.nn as nn +from torch.nn import Embedding +from transformers.models.t5.modeling_t5 import ( + T5Attention, + T5Block, + T5DenseActDense, + T5DenseGatedActDense, + T5LayerCrossAttention, + T5LayerFF, + T5LayerSelfAttention, + T5Model, + T5Stack, +) + +import colossalai.shardformer.layer.layers as col_nn + +from .basepolicy import Argument, Col_Layer, Dropout_Layer, Embedding_Layer, Policy, Row_Layer + + +class T5ModelPolicy(Policy): + + @staticmethod + def argument_policy(config, world_size: int) -> Dict[nn.Module, Argument]: + print('config heads', config.num_heads) + return { + T5Stack: + Argument(attr_dict={}, param_funcs=[T5ModelPolicy.dropout, T5ModelPolicy.embedding]), + T5Block: + Argument(attr_dict={}, param_funcs=[]), + T5LayerSelfAttention: + Argument(attr_dict={}, param_funcs=[T5ModelPolicy.dropout]), + T5LayerCrossAttention: + Argument(attr_dict={}, param_funcs=[T5ModelPolicy.dropout]), + T5Attention: + Argument(attr_dict={ + "d_model": config.d_model // world_size, + "n_heads": config.num_heads // world_size, + "inner_dim": config.num_heads * config.d_kv // world_size, + }, + param_funcs=[T5ModelPolicy.attn_layer]), + T5LayerFF: + Argument(attr_dict={}, param_funcs=[T5ModelPolicy.dropout]), + T5DenseGatedActDense: + Argument(attr_dict={}, param_funcs=[T5ModelPolicy.dropout, T5ModelPolicy.dense_gated_layer]), + T5DenseActDense: + Argument(attr_dict={}, param_funcs=[T5ModelPolicy.dropout, T5ModelPolicy.dense_act_layer]), + } + + @staticmethod + def dense_gated_layer(): + return [ + Col_Layer( + suffix="wi_0", + weight="weight", + replace_layer=col_nn.Linear1D_Col, + ), + Row_Layer( + suffix="wi_1", + weight="weight", + replace_layer=col_nn.Linear1D_Row, + ), + Col_Layer(suffix="wo", weight="weight", replace_layer=col_nn.Linear1D_Col, gather_output=True) + ] + + @staticmethod + def dense_act_layer(): + return [ + Col_Layer( + suffix="wi", + weight="weight", + replace_layer=col_nn.Linear1D_Col, + ), + Row_Layer( + suffix="wo", + weight="weight", + replace_layer=col_nn.Linear1D_Row, + ) + ] + + @staticmethod + def attn_layer(): + return [ + Col_Layer( + suffix="q", + weight="weight", + bias="bias", + replace_layer=col_nn.Linear1D_Col, + ), + Col_Layer( + suffix="k", + weight="weight", + bias="bias", + replace_layer=col_nn.Linear1D_Col, + ), + Col_Layer( + suffix="v", + weight="weight", + bias="bias", + replace_layer=col_nn.Linear1D_Col, + ), + Row_Layer( + suffix="o", + weight="weight", + bias="bias", + replace_layer=col_nn.Linear1D_Row, + ), + ] + + @staticmethod + def dropout(): + return [Dropout_Layer( + suffix="dropout", + p="p", + replace_layer=col_nn.Dropout1D, + )] + + @staticmethod + def embedding(): + return [ + Embedding_Layer( + suffix="block[0].layer[0].SelfAttention.relative_attention_bias", + weight="weight", + replace_layer=col_nn.Embedding1D, + gather_output=False, + ) + ] + + +from transformers import T5ForConditionalGeneration + + +class T5ForConditionalGenerationPolicy(T5ModelPolicy): + + @staticmethod + def argument_policy(config, world_size): + base_argument = T5ModelPolicy.argument_policy(config, world_size) + argument = { + T5ForConditionalGeneration: Argument(attr_dict={}, param_funcs=[T5ForConditionalGenerationPolicy.lm_head]) + } + argument.update(base_argument) + return argument + + @staticmethod + def lm_head(): + return [Col_Layer( + suffix="lm_head", + weight="weight", + replace_layer=col_nn.Linear1D_Col, + gather_output=True, + )] + + +from transformers import T5EncoderModel + + +class T5EncoderModelPolicy(T5ModelPolicy): + pass diff --git a/colossalai/shardformer/shard/sharder.py b/colossalai/shardformer/shard/sharder.py index 95184cfe6929..8f6514cb4f5f 100644 --- a/colossalai/shardformer/shard/sharder.py +++ b/colossalai/shardformer/shard/sharder.py @@ -5,7 +5,7 @@ from transformers.pytorch_utils import Conv1D from ..policies.autopolicy import get_autopolicy -from ..policies.basepolicy import Col_Layer, Dropout_Layer, Policy, Row_Layer +from ..policies.basepolicy import Col_Layer, Dropout_Layer, Policy, Row_Layer, Embedding_Layer from ..utils.utils import getattr_, hasattr_, setattr_ from .shard_config import ShardConfig from .slicer import Slicer @@ -155,11 +155,11 @@ def shard_one_layer( assert suffix_layer is not None or ignore, f"Layer {org_layer.__class__.__qualname__} has no attribute {suffix}" if suffix_layer is None and ignore: continue - if isinstance(policy_layer, (Col_Layer, Row_Layer)): + if isinstance(policy_layer, (Col_Layer, Row_Layer, Embedding_Layer)): weight = None bias = None weight_attr = suffix + '.' + policy_layer.weight if policy_layer.weight is not None else None - bias_attr = suffix + '.' + policy_layer.bias if policy_layer.bias is not None else None + bias_attr = suffix + '.' + policy_layer.bias if hasattr(policy_layer, 'bias') and policy_layer.bias is not None else None if weight_attr is not None: if hasattr_(org_layer, weight_attr): @@ -189,6 +189,11 @@ def shard_one_layer( weight.shape[1], bias=False if bias is None else True, gather_output=gather_output) + elif replace_layer_cls.__name__ == "Embedding1D": + gather_output = policy_layer.gather_output + replace_layer = replace_layer_cls(weight.shape[0], + weight.shape[1], + gather_output=gather_output) elif replace_layer_cls.__name__ == "VocabParallelEmbedding1D": replace_layer = replace_layer_cls(weight.shape[0], weight.shape[1], getattr_(org_layer, f"{suffix}.padding_idx", ignore=True)) diff --git a/colossalai/shardformer/shard/slicer.py b/colossalai/shardformer/shard/slicer.py index 0bf8f58b8544..860533dca50d 100644 --- a/colossalai/shardformer/shard/slicer.py +++ b/colossalai/shardformer/shard/slicer.py @@ -1,9 +1,9 @@ import torch -from ..policies.basepolicy import Col_Layer, Dropout_Layer, Layer, Row_Layer +from ..policies.basepolicy import Col_Layer, Dropout_Layer, Layer, Row_Layer, Embedding_Layer from .shard_config import ShardConfig -dim_mapping = {Col_Layer: 0, Row_Layer: 1} +dim_mapping = {Col_Layer: 0, Row_Layer: 1, Embedding_Layer: 1} class Slicer(): @@ -43,6 +43,8 @@ def slice_weight_bias( bias = self.slice_tensor(bias, 0, True, n_cast) elif policy_layer_cls == Row_Layer: weight = self.slice_tensor(weight, dim, False, n_cast) + elif policy_layer_cls == Embedding_Layer: + weight = self.slice_tensor(weight, dim, False, n_cast) else: raise NotImplementedError(f"The policy layer class {policy_layer_cls} is not supported") if reversed: diff --git a/colossalai/shardformer/utils/utils.py b/colossalai/shardformer/utils/utils.py index 2c02b6f69a3e..05a6a3ae6c30 100644 --- a/colossalai/shardformer/utils/utils.py +++ b/colossalai/shardformer/utils/utils.py @@ -1,3 +1,22 @@ +import re + + +def get_obj_list_element(obj, a): + re_pattern = r'\[\d+\]' + prog = re.compile(re_pattern) + result = prog.search(a) + if result: + matched_brackets = result.group() + matched_index = matched_brackets.replace('[', '') + matched_index = matched_index.replace(']', '') + a_ = a.replace(matched_brackets, '') + container_obj = getattr(obj, a_) + obj = container_obj[int(matched_index)] + else: + obj = getattr(obj, a) + return obj + + def hasattr_(obj, attr: str): r""" Check whether the object has the multi sublevel attr @@ -9,7 +28,7 @@ def hasattr_(obj, attr: str): attrs = attr.split('.') for a in attrs: try: - obj = getattr(obj, a) + obj = get_obj_list_element(obj, a) except AttributeError: return False return True @@ -29,7 +48,7 @@ def setattr_(obj, attr: str, value, ignore: bool = False): attrs = attr.split('.') for a in attrs[:-1]: try: - obj = getattr(obj, a) + obj = get_obj_list_element(obj, a) except AttributeError: if ignore: return @@ -50,7 +69,7 @@ def getattr_(obj, attr: str, ignore: bool = False): attrs = attr.split('.') for a in attrs: try: - obj = getattr(obj, a) + obj = get_obj_list_element(obj, a) except AttributeError: if ignore: return None diff --git a/requirements/requirements-test.txt b/requirements/requirements-test.txt index 6895113bc637..50121a9283f2 100644 --- a/requirements/requirements-test.txt +++ b/requirements/requirements-test.txt @@ -15,3 +15,4 @@ einops triton==2.0.0.dev20221202 git+https://github.com/HazyResearch/flash-attention.git@c422fee3776eb3ea24e011ef641fd5fbeb212623#egg=flash_attn requests==2.27.1 # downgrade to avoid huggingface error https://github.com/huggingface/transformers/issues/17611 +SentencePiece diff --git a/tests/test_shardformer/test_model/test_shard_t5.py b/tests/test_shardformer/test_model/test_shard_t5.py new file mode 100644 index 000000000000..ca44f0b00a74 --- /dev/null +++ b/tests/test_shardformer/test_model/test_shard_t5.py @@ -0,0 +1,99 @@ +import copy +import os +import random + +import pytest +import torch +from transformers import AutoTokenizer, BertConfig, BertForMaskedLM, T5Config, T5ForConditionalGeneration, T5Tokenizer + +import colossalai +from colossalai.logging import disable_existing_loggers +from colossalai.shardformer.shard import ShardConfig, shard_model +from colossalai.testing import rerun_if_address_is_in_use, spawn + +os.environ['TRANSFORMERS_NO_ADVISORY_WARNINGS'] = 'true' +CONFIG = dict(parallel=dict(data=1, pipeline=1, tensor=dict(size=2, mode='1d')),) +tokenizer = T5Tokenizer.from_pretrained("t5-small") + + +def build_model(rank, world_size): + config = T5Config.from_pretrained("t5-small") + config.dropout_rate = 0 + org_model = T5ForConditionalGeneration.from_pretrained("t5-small", config=config).to('cuda') + + shardconfig = ShardConfig( + rank=rank, + world_size=world_size, + gather_output=True, + ) + + org_model_for_shard = copy.deepcopy(org_model) + + sharded_model = shard_model(org_model_for_shard, shardconfig).to('cuda') + + return org_model, sharded_model + + +def check_forward(org_model, sharded_model): + + input_ids = tokenizer("translate English to German: The house is wonderful.", + return_tensors="pt").input_ids.to('cuda') + #orgin model + org_model.eval() + org_output = org_model.generate(input_ids) + + #shard model + sharded_model.eval() + shard_output = sharded_model.generate(input_ids) + assert torch.allclose( + org_output[0], shard_output[0], + atol=1e-5), f"shard model output is not equal to orgin model output\n{org_out[0]}\n{shard_out[0]}" + + +def check_backward(org_model, sharded_model): + # prepare input + input_ids = tokenizer("translate English to German: The house is wonderful.", + return_tensors="pt").input_ids.to('cuda') + labels = tokenizer("Das Haus ist wunderbar.", return_tensors="pt").input_ids.to('cuda') + + #orgin model + org_model.train() + org_loss = org_model(input_ids=input_ids, labels=labels).loss + org_loss.backward() + org_grad = org_model.encoder.block[0].layer[0].SelfAttention.q.weight.grad + + #shard model + sharded_model.train() + shard_loss = sharded_model(input_ids=input_ids, labels=labels).loss + shard_loss.backward() + shard_grad = sharded_model.encoder.block[0].layer[0].SelfAttention.q.weight.grad + + shard_grad_list = [torch.zeros([*shard_grad.shape]).to('cuda') for _ in range(2)] + shard_grad = torch.distributed.all_gather(shard_grad_list, shard_grad) + all_shard_grad = torch.cat(shard_grad_list, dim=0) + + assert torch.allclose(org_loss, shard_loss, + atol=1e-5), f"shard model loss is not equal to orgin model loss\n{org_loss}\n{shard_loss}" + assert torch.allclose(org_grad, all_shard_grad, + atol=1e-5), f"shard model grad is not equal to orgin model grad\n{org_grad}\n{shard_grad}" + + +def check_t5(rank, world_size, port): + disable_existing_loggers() + colossalai.launch(config=CONFIG, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') + + org_model, sharded_model = build_model(rank, world_size) + check_forward(org_model, sharded_model) + check_backward(org_model, sharded_model) + + torch.cuda.empty_cache() + + +@pytest.mark.dist +@rerun_if_address_is_in_use() +def test_t5(): + spawn(check_t5, 2) + + +if __name__ == "__main__": + test_t5() From 735e44b4bdfc065db024d5b731d92f32e1a3d18b Mon Sep 17 00:00:00 2001 From: FoolPlayer <45593998+FoolPlayer@users.noreply.github.com> Date: Thu, 15 Jun 2023 17:56:51 +0800 Subject: [PATCH 14/49] [Shardformer] Downstream bert (#3979) * add dist dropout in model * update docstring and bert policy with dropout * refactor basepolicy and sharded, update bert * update format * update gpt2 policy * update bert policy * remove unused code * update readme for new policy usage * add downstream model of bert * remove unused code --- colossalai/shardformer/policies/autopolicy.py | 25 ++++ colossalai/shardformer/policies/bert.py | 112 +++++++++++++++--- colossalai/shardformer/shard/shard_config.py | 4 +- colossalai/shardformer/shard/sharder.py | 1 + .../test_model/test_shard_bert.py | 42 +++++-- 5 files changed, 151 insertions(+), 33 deletions(-) diff --git a/colossalai/shardformer/policies/autopolicy.py b/colossalai/shardformer/policies/autopolicy.py index d4425497bd8e..e864719ac1ff 100644 --- a/colossalai/shardformer/policies/autopolicy.py +++ b/colossalai/shardformer/policies/autopolicy.py @@ -10,11 +10,31 @@ def build_policies(): """ auto_policy_dict = {} + from transformers import BertModel + + from .bert import BertModelPolicy + auto_policy_dict[BertModel] = BertModelPolicy + + from transformers import BertForPreTraining + + from .bert import BertForPretrainingPolicy + auto_policy_dict[BertForPreTraining] = BertForPretrainingPolicy + + from transformers import BertLMHeadModel + + from .bert import BertLMHeadModelPolicy + auto_policy_dict[BertLMHeadModel] = BertLMHeadModelPolicy + from transformers import BertForMaskedLM from .bert import BertForMaskedLMPolicy auto_policy_dict[BertForMaskedLM] = BertForMaskedLMPolicy + from transformers import BertForNextSentencePrediction + + from .bert import BertForNextSentencePredictionPolicy + auto_policy_dict[BertForNextSentencePrediction] = BertForNextSentencePredictionPolicy + from transformers import BertForSequenceClassification from .bert import BertForSequenceClassificationPolicy @@ -34,6 +54,11 @@ def build_policies(): from .llama import LlamaForCausalLMPolicy auto_policy_dict[LlamaForCausalLM] = LlamaForCausalLMPolicy + from transformers import BertForMultipleChoice + + from .bert import BertForMultipleChoicePolicy + auto_policy_dict[BertForMultipleChoice] = BertForMultipleChoicePolicy + from transformers import GPT2Model from .gpt2 import GPT2Policy diff --git a/colossalai/shardformer/policies/bert.py b/colossalai/shardformer/policies/bert.py index 67e910d521e9..ba2266353e3e 100644 --- a/colossalai/shardformer/policies/bert.py +++ b/colossalai/shardformer/policies/bert.py @@ -35,12 +35,6 @@ def argument_policy(config, world_size: int) -> Dict[nn.Module, Argument]: ]), } - @staticmethod - def binding_policy(): - return { - "bert.embeddings.word_embeddings.weight": "cls.predictions.decoder.weight", - } - @staticmethod def attn_in(): return [ @@ -148,9 +142,53 @@ def embedding(): replace_layer=col_nn.VocabParallelEmbedding1D, )] + @staticmethod + def unembedding(): + return [ + Col_Layer( + suffix="decoder", + weight="weight", + bias="bias", + replace_layer=col_nn.Linear1D_Col, + gather_output=True, + ) + ] + + +# BertModel +class BertModelPolicy(BertPolicy): + + @staticmethod + def argument_policy(config, world_size): + return BertPolicy.argument_policy(config, world_size) + -from transformers import BertForMaskedLM +# BertForPretraining +class BertForPretrainingPolicy(BertPolicy): + @staticmethod + def argument_policy(config, world_size): + base_argument = BertPolicy.argument_policy(config, world_size) + argument = { + BertLMPredictionHead: Argument(attr_dict={}, param_funcs=[ + BertPolicy.unembedding, + ]), + } + argument.update(base_argument) + return argument + + @staticmethod + def inject_policy(): + return None + + @staticmethod + def binding_policy(): + return { + "bert.embeddings.word_embeddings.weight": "cls.predictions.decoder.weight", + } + + +# BertForMaskedLM from colossalai.shardformer.model.modeling_bert import BertForMaskedLM_ @@ -161,7 +199,7 @@ def argument_policy(config, world_size): base_argument = BertPolicy.argument_policy(config, world_size) argument = { BertLMPredictionHead: Argument(attr_dict={}, param_funcs=[ - BertForMaskedLMPolicy.unembedding, + BertPolicy.unembedding, ]), } argument.update(base_argument) @@ -173,20 +211,56 @@ def inject_policy(): return None @staticmethod - def unembedding(): - return [ - Col_Layer( - suffix="decoder", - weight="weight", - bias="bias", - replace_layer=col_nn.Linear1D_Col, - gather_output=True, - ) - ] + def binding_policy(): + return { + "bert.embeddings.word_embeddings.weight": "cls.predictions.decoder.weight", + } -class BertForSequenceClassificationPolicy(BertPolicy): +# BertLMHeadModel +class BertLMHeadModelPolicy(BertPolicy): + + @staticmethod + def argument_policy(config, world_size): + base_argument = BertPolicy.argument_policy(config, world_size) + argument = { + BertLMPredictionHead: Argument(attr_dict={}, param_funcs=[ + BertPolicy.unembedding, + ]), + } + argument.update(base_argument) + return argument @staticmethod def inject_policy(): return None + + @staticmethod + def binding_policy(): + return { + "bert.embeddings.word_embeddings.weight": "cls.predictions.decoder.weight", + } + + +# BertForNextSentencePrediction +class BertForNextSentencePredictionPolicy(BertPolicy): + + @staticmethod + def argument_policy(config, world_size): + return BertPolicy.argument_policy(config, world_size) + + +# BertForSequenceClassification +class BertForSequenceClassificationPolicy(BertPolicy): + + @staticmethod + def argument_policy(config, world_size): + return BertPolicy.argument_policy(config, world_size) + + +# BertForMultipleChoice +class BertForMultipleChoicePolicy(BertPolicy): + + @staticmethod + def argument_policy(config, world_size): + return BertPolicy.argument_policy(config, world_size) diff --git a/colossalai/shardformer/shard/shard_config.py b/colossalai/shardformer/shard/shard_config.py index e8d6f3408c76..96c287577ddc 100644 --- a/colossalai/shardformer/shard/shard_config.py +++ b/colossalai/shardformer/shard/shard_config.py @@ -13,6 +13,6 @@ class ShardConfig: world_size (int): The world size of the distributed process gather_output (bool): Whether to gather the output of the model of the last layer """ - rank: int - world_size: int = 2 + rank: int = None + world_size: int = None gather_output: bool = True diff --git a/colossalai/shardformer/shard/sharder.py b/colossalai/shardformer/shard/sharder.py index 8f6514cb4f5f..7ef0c37a4040 100644 --- a/colossalai/shardformer/shard/sharder.py +++ b/colossalai/shardformer/shard/sharder.py @@ -276,6 +276,7 @@ def shard_model(model: nn.Module, shard_config: ShardConfig = None, policy: Poli shard_config (`ShardConfig`): the config for distribute information policy (`Policy`): the custom policy for sharding """ + # TODO: init shard_config automatically sharder = ModelSharder(model=model, shard_config=shard_config, policy=policy) sharder.shard() return model diff --git a/tests/test_shardformer/test_model/test_shard_bert.py b/tests/test_shardformer/test_model/test_shard_bert.py index 55b78d040505..9b29111eadb2 100644 --- a/tests/test_shardformer/test_model/test_shard_bert.py +++ b/tests/test_shardformer/test_model/test_shard_bert.py @@ -1,9 +1,19 @@ +import copy import os -import random import pytest import torch -from transformers import AutoTokenizer, BertConfig, BertForMaskedLM +from transformers import ( + AutoTokenizer, + BertConfig, + BertForMaskedLM, + BertForMultipleChoice, + BertForNextSentencePrediction, + BertForPreTraining, + BertForSequenceClassification, + BertLMHeadModel, + BertModel, +) import colossalai from colossalai.logging import disable_existing_loggers @@ -15,20 +25,21 @@ tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased") -def build_model(rank, world_size): +def build_model(rank, world_size, model): config = BertConfig.from_pretrained('bert-base-uncased') config.hidden_dropout_prob = 0 config.attention_probs_dropout_prob = 0 - org_model = BertForMaskedLM.from_pretrained('bert-base-uncased', config=config).to('cuda') + org_model = model(config=config) + org_model_forshard = copy.deepcopy(org_model) + org_model = org_model.to('cuda') shardconfig = ShardConfig( rank=rank, world_size=world_size, gather_output=True, ) - sharded_model = shard_model(BertForMaskedLM.from_pretrained('bert-base-uncased', config=config), - shardconfig).to('cuda') + sharded_model = shard_model(org_model_forshard, shardconfig).to('cuda') return org_model, sharded_model @@ -85,12 +96,19 @@ def check_backward(org_model, sharded_model): def check_bert(rank, world_size, port): disable_existing_loggers() colossalai.launch(config=CONFIG, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') - - org_model, sharded_model = build_model(rank, world_size) - check_forward(org_model, sharded_model) - check_backward(org_model, sharded_model) - - torch.cuda.empty_cache() + forward_list = [ + BertModel, BertForPreTraining, BertForMaskedLM, BertLMHeadModel, BertForNextSentencePrediction, + BertForSequenceClassification + ] + backward_lsit = [BertForMaskedLM, BertLMHeadModel] + + for model in forward_list: + org_model, sharded_model = build_model(rank, world_size, model) + check_forward(org_model, sharded_model) + if model in backward_lsit: + check_backward(org_model, sharded_model) + + torch.cuda.empty_cache() @pytest.mark.dist From 73cacb7c2a831cc3391d5a6011671f77059051fd Mon Sep 17 00:00:00 2001 From: FoolPlayer <45593998+FoolPlayer@users.noreply.github.com> Date: Thu, 15 Jun 2023 10:14:08 +0800 Subject: [PATCH 15/49] [shardformer] fix an error in readme (#3988) * fix an error in readme * simplify code --- colossalai/tensor/d_tensor/RAEDME.md | 103 +++++++++++++++++++++++++++ 1 file changed, 103 insertions(+) create mode 100644 colossalai/tensor/d_tensor/RAEDME.md diff --git a/colossalai/tensor/d_tensor/RAEDME.md b/colossalai/tensor/d_tensor/RAEDME.md new file mode 100644 index 000000000000..3d862dddbf20 --- /dev/null +++ b/colossalai/tensor/d_tensor/RAEDME.md @@ -0,0 +1,103 @@ +# 🔢 Distributed Tensor + +## 📚 Table of Contents + +- [🔢 Distributed Tensor](#-distributed-tensor) + - [📚 Table of Contents](#-table-of-contents) + - [🔗 Introduction](#-introduction) + - [📝 Design](#-design) + - [🔨 Usage](#-usage) + - [🎈 Progress Log](#-progress-log) + +## 🔗 Introduction + +Distributed tensor is a type of tensor that is distributed across multiple devices. It is a wrapper of PyTorch tensor, and it is used to support distributed training. +It can represent the device topology and tensor placement over the devices in the topology. It also provides a set of APIs to manipulate the distributed tensor. + +## 📝 Design + +Our implementation is inspired by the work [Alpa](https://arxiv.org/abs/2201.12023), which unifies data parallelism and tensor parallelism as intra-op parallelism. It uses notations `S` to represent the sharded dimension and `R` to represent the replicated dimension. For example, given a 2D matrix, `[S, R]` represents the tensor is sharded over the first dimension. + +Each sharded dimension will have a subscript to represent its placement over the devices. Assuming we have 4 GPUs and the GPUs are arranged in a 2 x 2 manner. Let's say we have a 2D matrix like below: + + +```text + [1, 2, 3, 4 ] +A = [4, 5, 6, 7 ] + [8, 9, 10, 11] + [12, 13, 14, 15] +``` + +`[S0, R]` would mean that the first dimension is sharded over the rows in the device topology. + +```text +| --------------------—————————————————————-| +| | | +| [1, 2, 3, 4 ] | [1, 2, 3, 4 ] | +| [4, 5, 6, 7 ] | [4, 5, 6, 7 ] | +| | | +| --------------------——————————————————----- +| | | +| [8, 9, 10, 11] | [8, 9, 10, 11] | +| [12, 13, 14, 15] | [12, 13, 14, 15] | +| | | +| --------------------——————————————————----- +``` + +`[S01, R]` would mean that the first dimension is sharded over both the row and column in the device topology. + +```text +| --------------------—————————————————————-| +| | | +| [1, 2, 3, 4 ] | [4, 5, 6, 7 ] | +| | | +| --------------------——————————————————----- +| | | +| [8, 9, 10, 11] | [12, 13, 14, 15] | +| | | +| --------------------——————————————————----- +``` + +## 🔨 Usage + +A sample API usage is given below. + +```python +import torch + +import colossalai +from colossalai.device.device_mesh import DeviceMesh +from colossalai.tensor.d_tensor import DTensor, ShardingSpec + +colossalai.launch_from_torch(config={}) + +# define your device mesh +# assume you have 4 GPUs +physical_mesh_id = torch.arange(0, 4) +mesh_shape = (2, 2) +device_mesh = DeviceMesh(physical_mesh_id, mesh_shape, init_process_group=True) + +# define a tensor +a = torch.rand(16, 32).cuda() + +# create sharding spec for the tensor +# assume the sharding spec is [S0, R] +dim_partition_dict = {0: [0]} +sharding_spec = ShardingSpec(a.dim(), dim_partition_dict) + +# create a distributed tensor +d_tensor = DTensor(a, device_mesh, sharding_spec) +print(d_tensor) + +global_tensor = d_tensor.to_global() +print(global_tensor) +``` + + +## 🎈 Progress Log + +- [x] Support layout conversion +- [x] Support sharding on 2D device mesh +- [ ] Support sharding on 3D device mesh +- [ ] Support sharding 4D device mesh +- [ ] Support sharding info saving and offline tensor merge (we can save tensor as dtensor and gather the tensors back to the global tensor based on the sharding info in a single process in CPU, useful for distributed training checkpoint load and save.) From 45a31104a1c88a1d375454fc81df3f2b46dbd2e5 Mon Sep 17 00:00:00 2001 From: Frank Lee Date: Thu, 15 Jun 2023 11:49:16 +0800 Subject: [PATCH 16/49] [device] support init device mesh from process group (#3990) --- colossalai/device/device_mesh.py | 553 +++++++++++++++++++------- tests/test_device/test_device_mesh.py | 69 ++++ 2 files changed, 474 insertions(+), 148 deletions(-) diff --git a/colossalai/device/device_mesh.py b/colossalai/device/device_mesh.py index 2a5f747fbc23..3e96310e1890 100644 --- a/colossalai/device/device_mesh.py +++ b/colossalai/device/device_mesh.py @@ -3,11 +3,19 @@ with some changes. """ import operator +from dataclasses import dataclass from functools import reduce -from typing import List, Tuple +from typing import Dict, List, Union import torch import torch.distributed as dist +from torch.distributed import ProcessGroup + + +@dataclass +class ProcessGroupContainer: + process_group: ProcessGroup + ranks: List[int] # modified from alpa LogicalDeviceMesh(https://github.com/alpa-projects/alpa/blob/main/alpa/shard_parallel/auto_sharding.py) @@ -27,9 +35,11 @@ class DeviceMesh: during initializing the DeviceMesh instance if the init_process_group set to True. Otherwise, users need to call create_process_groups_for_logical_mesh manually to init logical process group. (default: False) - need_flatten(bool, optional): initialize flatten_device_mesh during initializing the DeviceMesh instance if the need_flatten set to True. + device (str): the device for the process groups used by the DeviceMesh instance. (default: 'cuda') """ + _DIST_BACKEND = {"cuda": "nccl", "cpu": "gloo"} + def __init__(self, physical_mesh_id: torch.Tensor, mesh_shape: torch.Size = None, @@ -37,160 +47,442 @@ def __init__(self, mesh_alpha: List[float] = None, mesh_beta: List[float] = None, init_process_group: bool = False, - need_flatten: bool = True): - self.physical_mesh_id = physical_mesh_id + device: str = 'cuda'): + # ============================ + # Physical & Logical Mesh IDs + # ============================ + self._physical_mesh_id = physical_mesh_id + assert physical_mesh_id.dim() == 1, "physical_mesh_id should be a 1D tensor." + + # logical mesh ids can be obtained via two ways + # 1. provide physical mesh id and provide mesh shape + # 2. directly supply the logical mesh id + assert mesh_shape is None or logical_mesh_id is None, \ + "Only one of mesh_shape and logical_mesh_id can be specified." \ + "Logical mesh IDs are obtained from either mesh_shape + phyiscal_mesh_id or directly from the user-supplied logical_mesh_id" + if logical_mesh_id is None: - self.mesh_shape = mesh_shape - self._logical_mesh_id = self.physical_mesh_id.reshape(self.mesh_shape) + self._mesh_shape = mesh_shape + self._logical_mesh_id = self._physical_mesh_id.reshape(self._mesh_shape) else: self._logical_mesh_id = logical_mesh_id - self.mesh_shape = self._logical_mesh_id.shape + self._mesh_shape = self._logical_mesh_id.shape + + # ensure two things: + # 1. logical and physical mesh IDs should contain the same elements + # 2. there is no duplicate IDs in each mesh, e.g. [2, 2] is not allowed + assert torch.equal(torch.unique(self._physical_mesh_id), torch.unique(self.logical_mesh_id)), \ + "physical and logical mesh IDs should contain the same elements, please check if you have consistent physical_mesh_id and logical_mesh_id." + assert torch.unique(self._physical_mesh_id).numel() == self._physical_mesh_id.numel(), \ + "Found duplicate IDs in the phyiscal_mesh_id and this is not allowed, please check your physical_mesh_id again." + assert torch.unique(self.logical_mesh_id).numel() == self.logical_mesh_id.numel(), \ + "Found duplicate IDs in the logical_mesh_id and this is not allowed, please check your logical_mesh_id again." - # map global rank into logical rank - self.convert_map = {} - self._global_rank_to_logical_rank_map(self._logical_mesh_id, []) + # =============================================== # coefficient for alpha-beta communication model + # alpha is latency and beta is bandwidth + # =============================================== + # if the values are not provided, we assume they are 1 for simplicity if mesh_alpha is None: - mesh_alpha = [1] * len(self.mesh_shape) + mesh_alpha = [1] * len(self._mesh_shape) if mesh_beta is None: - mesh_beta = [1] * len(self.mesh_shape) + mesh_beta = [1] * len(self._mesh_shape) + self.mesh_alpha = tuple(mesh_alpha) self.mesh_beta = tuple(mesh_beta) - self.init_process_group = init_process_group - self.need_flatten = need_flatten - if self.init_process_group: - self.process_groups_dict = self.create_process_groups_for_logical_mesh() - if self.need_flatten and self._logical_mesh_id.dim() > 1: - self.flatten_device_mesh = self.flatten() - # Create a new member `flatten_device_meshes` to distinguish from original flatten methods (Because I'm not sure if there are functions that rely on the self.flatten()) - # self.flatten_device_meshes = FlattenDeviceMesh(self.physical_mesh_id, self.mesh_shape, self.mesh_alpha, - # self.mesh_beta) + + # ensure the alpha and beta have the same shape + assert len(self.mesh_alpha) == len(self.mesh_beta), \ + "mesh_alpha and mesh_beta should have the same length, please check your mesh_alpha and mesh_beta again." + + # ========================= + # Device for Process Group + # ========================= + self._device = device + self._dist_backend = self._DIST_BACKEND[device] + + # ========================= + # Process Group Management + # ========================= + # the _global_to_local_rank_mapping is structured as follows + # { + # : [ , , , ...] + # } + self._global_to_local_rank_mapping = dict() + self._init_global_to_logical_rank_mapping(mapping=self._global_to_local_rank_mapping, + tensor=self.logical_mesh_id) + + # create process group + self._process_group_dict = {} + self._ranks_in_the_process_group = {} + self._global_rank_of_current_process = None + self._is_initialized = False + + # attribute used to inidicate whether this objectd + # is created using DeviceMesh.from_process_group + # this attribute can be used to do some check in methods + # such get_process_group as no global rank information + # is known if created with from_process_group + self._is_init_from_process_group = False + + # initialize process group if specified + self._init_ranks_in_the_same_group() + self._init_process_group = init_process_group + if init_process_group: + self.init_logical_process_group() @property - def shape(self): - return self.mesh_shape + def shape(self) -> torch.Size: + """ + Return the shape of the logical mesh. + """ + return self._mesh_shape @property - def num_devices(self): - return reduce(operator.mul, self.physical_mesh_id.shape, 1) + def num_devices(self) -> int: + """ + Return the number of devices contained in the device mesh. + """ + return reduce(operator.mul, self._physical_mesh_id.shape, 1) @property - def logical_mesh_id(self): + def logical_mesh_id(self) -> torch.Tensor: + """ + Return the logical mesh id. + """ return self._logical_mesh_id - def __deepcopy__(self, memo): + @property + def is_initialized(self) -> bool: + """ + Return whether the process group is initialized. + """ + return self._is_initialized + + @staticmethod + def from_process_group(process_group: Union[ProcessGroup, List[ProcessGroup]]) -> "DeviceMesh": + """ + Create a DeviceMesh instance from the current process group. Please note that the DeviceMesh object created with this method + will not have information about the physical mesh id, and thus will not be able to query for other ranks and perform alpha-beta communication. + + Args: + process_group (Union[ProcessGroup, List[ProcessGroup]]): the process group or a list of process groups for the device mesh. + If the input is a ProcessGroup object, a 1D DeviceMesh object will be created. If the input is a list of ProcessGroup objects, + the ProcessGroup at the ith index will correspond to the process group in the ith axis of the device mesh. + + Returns: + DeviceMesh: the device mesh instance. + """ + + def _get_device_by_backend(process_group): + """ + Get the device type given a process group's backend. + """ + backend = dist.get_backend(process_group) + for _device, _backend in DeviceMesh._DIST_BACKEND.items(): + if _backend == backend: + return _device + return None + + if isinstance(process_group, ProcessGroup): + process_group = [process_group] + + # get mesh shape + mesh_shape = [dist.get_world_size(pg) for pg in process_group] + + # get device + device_list = [_get_device_by_backend(pg) for pg in process_group] + + # make sure all devices are the same + assert all([device == device_list[0] for device in device_list]), \ + "All devices should be the same, please check your input process groups are created with the same distributed backend." + + # create a fake physical mesh id + # as we only get the process group associated with the current process, + # we cannot get the global ranks for all processes in the mesh + # therefore, we only use this fake physical mesh id to create the device mesh + # and will remove this fake physical mesh id later + fake_physical_mesh_id = torch.arange(reduce(operator.mul, mesh_shape, 1)) + + # create the device mesh + device_mesh = DeviceMesh(physical_mesh_id=fake_physical_mesh_id, mesh_shape=mesh_shape, device=device_list[0]) + + # hack the device attribute + device_mesh._physical_mesh_id = None + device_mesh._logical_mesh_id = None + device_mesh._global_rank_of_current_process = dist.get_rank() + device_mesh._is_initialized = False + device_mesh._process_group_dict = { + device_mesh._global_rank_of_current_process: {axis: pg for axis, pg in enumerate(process_group)} + } + + return device_mesh + + def get_process_group(self, axis: int, global_rank: int = None) -> ProcessGroup: + """ + Return the process group on the specified axis. + + Args: + axis (int): the axis of the process group. + global_rank (int, optional): the global rank of the process group. If not specified, the current process is used. (default: None) + """ + if global_rank is None: + global_rank = self._global_rank_of_current_process + elif self._is_init_from_process_group: + raise RuntimeError( + "The logical device mesh is create with DeviceMesh.from_process_group, this method is not supported for this creation method as no global rank information is known." + ) + return self._process_group_dict[global_rank][axis] + + def get_process_group_for_all_axes(self, global_rank: int = None) -> Dict[int, ProcessGroup]: + """ + Return the process groups for all axes. + + Args: + global_rank (int, optional): the global rank of the process + """ + if global_rank is None: + global_rank = self._global_rank_of_current_process + elif self._is_init_from_process_group: + raise RuntimeError( + "The logical device mesh is create with DeviceMesh.from_process_group, this method is not supported for this creation method as no global rank information is known." + ) + return self._process_group_dict[global_rank] + + def get_ranks_in_process_group(self, axis: int, global_rank: int = None) -> List[int]: + """ + Return the ranks in the process group on the specified axis. + + Args: + axis (int): the axis of the process group. + global_rank (int, optional): the global rank of the process + """ + if global_rank is None: + global_rank = self._global_rank_of_current_process + elif self._is_init_from_process_group: + raise RuntimeError( + "The logical device mesh is create with DeviceMesh.from_process_group, this method is not supported for this creation method as no global rank information is known." + ) + return self._ranks_in_the_process_group[global_rank][axis] + + def __deepcopy__(self, memo) -> "DeviceMesh": cls = self.__class__ result = cls.__new__(cls) memo[id(self)] = result for k, v in self.__dict__.items(): - if k != 'process_groups_dict': + if k != '_process_group_dict': setattr(result, k, __import__("copy").deepcopy(v, memo)) else: + # process group cannot be copied + # thus, we share them directly setattr(result, k, v) - return result - def flatten(self): - """ - Flatten the logical mesh into an effective 1d logical mesh, + def _init_global_to_logical_rank_mapping(self, + mapping: Dict, + tensor: torch.Tensor, + index_list: List[int] = []) -> Dict[int, List[int]]: """ - flatten_mesh_shape_size = len(self.mesh_shape) - flatten_mesh_shape = [self.num_devices] - return DeviceMesh(self.physical_mesh_id, - tuple(flatten_mesh_shape), - mesh_alpha=[max(self.mesh_alpha)] * (flatten_mesh_shape_size - 1), - mesh_beta=[max(self.mesh_beta)] * (flatten_mesh_shape_size - 1), - init_process_group=self.init_process_group, - need_flatten=False) + Build a global rank to local rank mapping for each process group in different axis in the logical device mesh. - def _global_rank_to_logical_rank_map(self, tensor, index_list): - ''' - This method is a helper function to build convert_map recursively. - ''' + Args: + mapping (Dict): a dictionary that maps the global rank to the local rank in the logical device mesh. + tensor (torch.Tensor): the tensor that contains the logical mesh ids. + index_list (List[int]) + + Returns: + mapping (Dict): a dictionary that maps the global rank to the local rank in the logical device mesh. + The value is a list of integers and each integer represents the local rank in the indexed axis. + """ for index, inner_tensor in enumerate(tensor): + # index means the local rank in the current axis + # inner_tensor refers to the processes with the same local rank + if inner_tensor.numel() == 1: - self.convert_map[int(inner_tensor)] = index_list + [index] + # if the inner_tensor only has one element, it means that + # it already reaches the last axis + # we append its local_rank in the last axis to the index_list + # and assign to the mapping + # the value of the mapping is the the local rank at the indexed axis of the device mesh + mapping[int(inner_tensor)] = index_list + [index] else: - self._global_rank_to_logical_rank_map(inner_tensor, index_list + [index]) + # we recursively go into the function until we reach the last axis + # meanwhile, we should add the local rank in the current axis in the index_list + self._init_global_to_logical_rank_mapping(mapping, inner_tensor, index_list + [index]) - def create_process_groups_for_logical_mesh(self): + def init_logical_process_group(self): ''' This method is used to initialize the logical process groups which will be used in communications among logical device mesh. Note: if init_process_group set to False, you have to call this method manually. Otherwise, the communication related function, such as ShapeConsistencyManager.apply will raise errors. ''' - process_groups_dict = {} - check_duplicate_list = [] - global_rank_flatten_list = self.physical_mesh_id.view(-1).tolist() + # sanity check + assert dist.is_initialized, "The torch.distributed should be initialized before calling init_logical_process_group" + assert not self._is_initialized, "The logical process group has been initialized, do not call init_logical_process_group twice" + + # update the global rank of the current process + self._global_rank_of_current_process = dist.get_rank() + duplicate_check_list = [] + + # flatten the global ranks to 1D list + global_rank_flatten_list = self._physical_mesh_id.view(-1).tolist() + for global_rank in global_rank_flatten_list: - process_groups = self.global_rank_to_process_groups_with_global_rank(global_rank) - for axis, process_group in process_groups.items(): - if axis not in process_groups_dict: - process_groups_dict[axis] = [] - if process_group not in check_duplicate_list: - check_duplicate_list.append(process_group) - process_group_handler = dist.new_group(process_group) - process_groups_dict[axis].append((process_group, process_group_handler)) + # find the other ranks which are in the same process group as global_rank + ranks_in_same_group_by_axis = self._collate_global_ranks_in_same_process_group(global_rank) - return process_groups_dict + for axis, ranks_in_same_group in ranks_in_same_group_by_axis.items(): + # skip duplicated process group creation + if ranks_in_same_group in duplicate_check_list: + continue - def global_rank_to_logical_rank(self, rank): - return self.convert_map[rank] + # create the process group + pg_handler = dist.new_group(ranks=ranks_in_same_group, backend=self._dist_backend) - def global_rank_to_process_groups_with_logical_rank(self, rank): - ''' - Give a global rank and return all logical process groups of this rank. - for example: - physical_mesh_id = torch.arange(0, 16).reshape(2, 8) - mesh_shape = (4, 4) - # [[0, 1, 2, 3], - # [4, 5, 6, 7], - # [8, 9, 10,11], - # [12,13,14,15]] - device_mesh = DeviceMesh(physical_mesh_id, mesh_shape) - print(device_mesh.global_rank_to_process_groups_with_logical_rank(0)) - output: - # key is axis name - # value is a list of logical ranks in same axis with rank 0 - {0: [[0, 0], [1, 0], [2, 0], [3, 0]], 1: [[0, 0], [0, 1], [0, 2], [0, 3]]} - ''' - process_groups = {} - for d in range(self.logical_mesh_id.dim()): - for replacer in range(self.logical_mesh_id.shape[d]): - if d not in process_groups: - process_groups[d] = [] - process_group_member = self.convert_map[rank].copy() - process_group_member[d] = replacer - process_groups[d].append(process_group_member) - return process_groups - - def global_rank_to_process_groups_with_global_rank(self, rank): + # keep this process group in the process_groups_dict + for rank in ranks_in_same_group: + if rank not in self._process_group_dict: + self._process_group_dict[rank] = dict() + self._process_group_dict[rank][axis] = pg_handler + + # update the init flag + # we only allow init for once + self._is_initialized = True + + def _init_ranks_in_the_same_group(self): + """ + This method is used to initialize the ranks_in_the_same_group dictionary. + """ + # flatten the global ranks to 1D list + global_rank_flatten_list = self._physical_mesh_id.view(-1).tolist() + + for global_rank in global_rank_flatten_list: + # find the other ranks which are in the same process group as global_rank + ranks_in_same_group_by_axis = self._collate_global_ranks_in_same_process_group(global_rank) + + for axis, ranks_in_same_group in ranks_in_same_group_by_axis.items(): + # create dict for each rank + if global_rank not in self._process_group_dict: + self._ranks_in_the_process_group[global_rank] = dict() + + # keep this process group in the process_groups_dict + self._ranks_in_the_process_group[global_rank][axis] = ranks_in_same_group + + def global_rank_to_local_rank(self, rank: int, axis: int = None) -> Union[List[int], int]: + """ + Return the local rank of the given global rank in the logical device mesh. + + Args: + rank (int): the global rank in the logical device mesh. + axis (int): the axis of the logical device mesh. + """ + if self._is_init_from_process_group: + raise RuntimeError( + "The logical device mesh is create with DeviceMesh.from_process_group, this method is not supported for this creation method as no global rank information is known." + ) + + local_ranks = self._global_to_local_rank_mapping[rank] + if axis: + return local_ranks[axis] + else: + return local_ranks + + def _collate_global_ranks_in_same_process_group(self, global_rank): ''' - Give a global rank and return all process groups of this rank. - for example: - physical_mesh_id = torch.arange(0, 16).reshape(2, 8) - mesh_shape = (4, 4) - # [[0, 1, 2, 3], - # [4, 5, 6, 7], - # [8, 9, 10,11], - # [12,13,14,15]] - device_mesh = DeviceMesh(physical_mesh_id, mesh_shape) - print(device_mesh.global_rank_to_process_groups_with_global_rank(0)) - output: - # key is axis name - # value is a list of global ranks in same axis with rank 0 - {0: [0, 4, 8, 12], 1: [0, 1, 2, 3]} + Give a global rank and return all global ranks involved in its associated process group in each axis. + + Example: + + ```python + sphysical_mesh_id = torch.arange(0, 16) + mesh_shape = (4, 4) + + # logical mesh will look like + # [[0, 1, 2, 3], + # [4, 5, 6, 7], + # [8, 9, 10,11], + # [12,13,14,15]] + + device_mesh = DeviceMesh(physical_mesh_id, mesh_shape) + print(device_mesh.collate_global_ranks_in_same_process_group(0)) + + # key is axis name + # value is a list of global ranks in same axis with rank 0 + # output will look like + # { + 0: [0, 4, 8, 12], + 1: [0, 1, 2, 3] + # } ''' - logical_process_groups = self.global_rank_to_process_groups_with_logical_rank(rank) - process_groups = {} - for dim, logical_ranks in logical_process_groups.items(): - process_groups[dim] = [] - for logical_rank in logical_ranks: - for g_rank, l_rank in self.convert_map.items(): - if l_rank == logical_rank: - process_groups[dim].append(g_rank) - return process_groups + # We have init the global rank to local rank by calling _init_global_to_logical_rank_mapping + # for self._global_to_local_rank_mapping + # the key is the global rank + # the value is the list of local ranks corresponding to the global rank with respect of different axes + # we can see the list of local ranks as the process coordinates for simplicity + # the key and value are all unique, therefore, + # we can also to use the coordinates to find the global rank + + # ========================================================================= + # Step 1 + # find all the process_coordinates for processes in the same process group + # as the given global rank + # ========================================================================= + + # each + processes_in_the_same_process_group = {} + + for dim in range(self.logical_mesh_id.dim()): + # iterate over the dimension size so that we can include all processes + # in the same process group in the given axis + # the _local_rank refers to the local rank of the current process + for _local_rank in range(self.logical_mesh_id.shape[dim]): + + # if this dimension is not initailized yet, + # initialize it with an empty array + if dim not in processes_in_the_same_process_group: + processes_in_the_same_process_group[dim] = [] + + # get the local rank corresponding to the global rank + process_coordinates = self._global_to_local_rank_mapping[global_rank].copy() + + # replace the local rank in the given dimension with the + # lcoal rank of the current process iterated + process_coordinates[dim] = _local_rank + processes_in_the_same_process_group[dim].append(process_coordinates) + + # ================================================================= + # Step 2 + # Use local rank combination to find its corresponding global rank + # ================================================================= + # the key of the dict is the axis + # the value is the list of global ranks which are in the same process group as the given global rank + global_pg_ranks = {} + for dim, coordinates_of_all_processes in processes_in_the_same_process_group.items(): + global_pg_ranks[dim] = [] + for process_coordinates in coordinates_of_all_processes: + # find the global rank by local rank combination + for _global_rank, _process_coordinates in self._global_to_local_rank_mapping.items(): + if process_coordinates == _process_coordinates: + global_pg_ranks[dim].append(_global_rank) + return global_pg_ranks + + def flatten(self): + """ + Flatten the logical mesh into an effective 1d logical mesh, + """ + if self._is_init_from_process_group: + raise RuntimeError( + "The logical device mesh is create with DeviceMesh.from_process_group, this method is not supported for this creation method as no global rank information is known." + ) + + flatten_mesh_shape_size = len(self._mesh_shape) + flatten_mesh_shape = [self.num_devices] + return DeviceMesh(self._physical_mesh_id, + tuple(flatten_mesh_shape), + mesh_alpha=[max(self.mesh_alpha)] * (flatten_mesh_shape_size - 1), + mesh_beta=[max(self.mesh_beta)] * (flatten_mesh_shape_size - 1), + init_process_group=self._init_process_group) def all_gather_cost(self, num_bytes, mesh_dim): num_devices = self.logical_mesh_id.shape[mesh_dim] @@ -211,39 +503,4 @@ def all_to_all_cost(self, num_bytes, mesh_dim): num_devices = self.logical_mesh_id.shape[mesh_dim] penalty_factor = num_devices / 2.0 return (self.mesh_alpha[mesh_dim] + self.mesh_beta[mesh_dim] * - (num_devices - 1) / num_devices / num_devices * num_bytes * penalty_factor + 0.001) - - -class FlattenDeviceMesh(DeviceMesh): - - def __init__(self, physical_mesh_id, mesh_shape, mesh_alpha=None, mesh_beta=None): - super().__init__(physical_mesh_id, - mesh_shape, - mesh_alpha, - mesh_beta, - init_process_group=False, - need_flatten=False) - # Different from flatten(), mesh_shape leaves unchanged, mesh_alpha and mesh_beta are scalars - self.mesh_alpha = max(self.mesh_alpha) - self.mesh_beta = min(self.mesh_beta) - # Different from original process_groups_dict, rank_list is not stored - self.process_number_dict = self.create_process_numbers_for_logical_mesh() - - def create_process_numbers_for_logical_mesh(self): - ''' - Build 1d DeviceMesh in column-major(0) and row-major(1) - for example: - mesh_shape = (2,4) - # [[0, 1, 2, 3], - # [4, 5, 6, 7]] - # return {0: [0, 4, 1, 5, 2, 6, 3, 7], 1: [0, 1, 2, 3, 4, 5, 6, 7]} - ''' - num_devices = reduce(operator.mul, self.mesh_shape, 1) - process_numbers_dict = {} - process_numbers_dict[0] = torch.arange(num_devices).reshape(self.mesh_shape).transpose(1, 0).flatten().tolist() - process_numbers_dict[1] = torch.arange(num_devices).reshape(self.mesh_shape).flatten().tolist() - return process_numbers_dict - - def mix_gather_cost(self, num_bytes): - num_devices = reduce(operator.mul, self.mesh_shape, 1) - return (self.mesh_alpha + self.mesh_beta * (num_devices - 1) / num_devices * num_bytes + 0.1) + (num_devices - 1) / num_devices / num_devices * num_bytes * penalty_factor + 0.001) \ No newline at end of file diff --git a/tests/test_device/test_device_mesh.py b/tests/test_device/test_device_mesh.py index 789ce8ab35b8..e9f0f9477e4a 100644 --- a/tests/test_device/test_device_mesh.py +++ b/tests/test_device/test_device_mesh.py @@ -1,6 +1,10 @@ +import pytest import torch +import torch.distributed as dist +import colossalai from colossalai.device.device_mesh import DeviceMesh +from colossalai.testing import rerun_if_address_is_in_use, spawn def test_device_mesh(): @@ -18,5 +22,70 @@ def test_device_mesh(): assert device_mesh.global_rank_to_process_groups_with_global_rank(2)[1] == [0, 1, 2, 3] +def check_1d_device_mesh(): + # check for 1D device mesh + process_group = dist.GroupMember.WORLD + device_mesh = DeviceMesh.from_process_group(process_group) + + # checks + assert device_mesh.shape == [4] + assert len(device_mesh.get_process_group_for_all_axes().keys()) == 1, 'Expected 1 axis for the process group dict' + assert device_mesh.get_process_group(axis=0) == process_group, 'Expected world process group' + assert device_mesh.is_initialized + assert device_mesh.num_devices == 4 + assert device_mesh.is_initialized + assert device_mesh.logical_mesh_id is None + assert device_mesh._is_init_from_process_group + + +def check_2d_device_mesh(): + # create process group for 2D device mesh + first_row_ranks = [0, 1] + second_row_ranks = [2, 3] + first_col_ranks = [0, 2] + second_col_ranks = [1, 3] + + first_row_pg = dist.new_group(first_row_ranks, backend='nccl') + second_row_pg = dist.new_group(second_row_ranks, backend='nccl') + first_col_pg = dist.new_group(first_col_ranks, backend='nccl') + second_col_pg = dist.new_group(second_col_ranks, backend='nccl') + + # check for + current_rank = dist.get_rank() + + if current_rank in first_row_ranks: + row_pg = first_row_pg + else: + row_pg = second_row_pg + + if current_rank in first_col_ranks: + col_pg = first_col_pg + else: + col_pg = second_col_pg + + device_mesh = DeviceMesh.from_process_group([col_pg, row_pg]) + + # checks + assert device_mesh.shape == [2, 2] + assert len(device_mesh.get_process_group_for_all_axes().keys()) == 2, 'Expected 2 axes for the process group dict' + assert device_mesh.get_process_group(axis=0) == col_pg, 'Expected column process group' + assert device_mesh.get_process_group(axis=1) == row_pg, 'Expected row process group' + assert device_mesh.num_devices == 4 + assert device_mesh.is_initialized + assert device_mesh.logical_mesh_id is None + assert device_mesh._is_init_from_process_group + + +def check_init_from_process_group(rank, world_size, port): + colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') + + +@pytest.mark.dist +@rerun_if_address_is_in_use() +def test_device_mesh_from_process_group(): + spawn(check_init_from_process_group, 4) + + if __name__ == '__main__': test_device_mesh() + test_device_mesh_from_process_group() From 18396e7666fd4cc9be144a6b90aed82764a8c94b Mon Sep 17 00:00:00 2001 From: FoolPlayer <45593998+FoolPlayer@users.noreply.github.com> Date: Thu, 15 Jun 2023 17:55:42 +0800 Subject: [PATCH 17/49] [shardformer] Refactor shardformer api (#4001) * fix an error in readme * simplify code * refactor shardformer * add todo * remove slicer * resolve code review --- colossalai/shardformer/__init__.py | 2 +- colossalai/shardformer/policies/autopolicy.py | 56 ++-- colossalai/shardformer/policies/basepolicy.py | 257 +++++----------- colossalai/shardformer/policies/bert.py | 282 ++++------------- colossalai/shardformer/shard/__init__.py | 6 +- colossalai/shardformer/shard/shard_config.py | 17 +- colossalai/shardformer/shard/sharder.py | 291 ++++++------------ colossalai/shardformer/shard/shardformer.py | 77 +++++ colossalai/shardformer/shard/slicer.py | 163 ---------- colossalai/shardformer/utils/__init__.py | 1 + 10 files changed, 342 insertions(+), 810 deletions(-) create mode 100644 colossalai/shardformer/shard/shardformer.py delete mode 100644 colossalai/shardformer/shard/slicer.py diff --git a/colossalai/shardformer/__init__.py b/colossalai/shardformer/__init__.py index 50c92738077a..77c2af8d18f7 100644 --- a/colossalai/shardformer/__init__.py +++ b/colossalai/shardformer/__init__.py @@ -1 +1 @@ -from .shard import ShardConfig, shard_model +from .shard import ShardConfig, ShardFormer diff --git a/colossalai/shardformer/policies/autopolicy.py b/colossalai/shardformer/policies/autopolicy.py index e864719ac1ff..6239397b7cbe 100644 --- a/colossalai/shardformer/policies/autopolicy.py +++ b/colossalai/shardformer/policies/autopolicy.py @@ -1,5 +1,7 @@ import torch.nn as nn +from .basepolicy import Policy + def build_policies(): r""" @@ -41,47 +43,25 @@ def build_policies(): auto_policy_dict[BertForSequenceClassification] = BertForSequenceClassificationPolicy from transformers.models.llama.modeling_llama import LlamaModel - from .llama import LlamaPolicy - auto_policy_dict[LlamaModel] = LlamaPolicy - - from transformers import LlamaForSequenceClassification - - from .llama import LlamaForSequenceClassificationPolicy - auto_policy_dict[LlamaForSequenceClassification] = LlamaForSequenceClassificationPolicy - - from transformers import LlamaForCausalLM - - from .llama import LlamaForCausalLMPolicy - auto_policy_dict[LlamaForCausalLM] = LlamaForCausalLMPolicy - - from transformers import BertForMultipleChoice - - from .bert import BertForMultipleChoicePolicy - auto_policy_dict[BertForMultipleChoice] = BertForMultipleChoicePolicy - - from transformers import GPT2Model - - from .gpt2 import GPT2Policy - auto_policy_dict[GPT2Model] = GPT2Policy - - from transformers import GPT2LMHeadModel - - from .gpt2 import GPT2LMHeadModelPolicy - auto_policy_dict[GPT2LMHeadModel] = GPT2LMHeadModelPolicy - - from .t5 import T5ForConditionalGenerationPolicy, T5EncoderModelPolicy, T5ModelPolicy - from transformers import T5ForConditionalGeneration, T5EncoderModel, T5Model - t5 = { - T5ForConditionalGeneration: T5ForConditionalGenerationPolicy, - T5EncoderModel: T5EncoderModelPolicy, - T5Model: T5ModelPolicy, - } - auto_policy_dict.update(t5) + # from .llama import LlamaPolicy + # auto_policy_dict[LlamaModel] = LlamaPolicy + # from transformers import LlamaForSequenceClassification + # from .llama import LlamaForSequenceClassificationPolicy + # auto_policy_dict[LlamaForSequenceClassification] = LlamaForSequenceClassificationPolicy + # from transformers import LlamaForCausalLM + # from .llama import LlamaForCausalLMPolicy + # auto_policy_dict[LlamaForCausalLM] = LlamaForCausalLMPolicy + # from transformers import GPT2Model + # from .gpt2 import GPT2Policy + # auto_policy_dict[GPT2Model] = GPT2Policy + # from transformers import GPT2LMHeadModel + # from .gpt2 import GPT2LMHeadModelPolicy + # auto_policy_dict[GPT2LMHeadModel] = GPT2LMHeadModelPolicy return auto_policy_dict -def get_autopolicy(model: nn.Module): +def get_autopolicy(model: nn.Module) -> Policy: r""" Return the auto policy for the model @@ -97,7 +77,7 @@ def get_autopolicy(model: nn.Module): raise NotImplementedError( f"Auto policy for {model.__class__.__qualname__} is not implemented\n Supported models are {[i.__qualname__ for i in auto_policy_dict.keys()]}" ) - return policy + return policy() # from transformers.models.bert.modeling_bert import BertForMaskedLM, BertForPreTraining diff --git a/colossalai/shardformer/policies/basepolicy.py b/colossalai/shardformer/policies/basepolicy.py index ba3a97f1bbcd..80ea7a252131 100644 --- a/colossalai/shardformer/policies/basepolicy.py +++ b/colossalai/shardformer/policies/basepolicy.py @@ -1,102 +1,65 @@ # part of code modified from https://github.com/tunib-ai/parallelformers +from abc import ABC, abstractmethod from dataclasses import dataclass -from typing import Any, Callable, Dict, List, Tuple, Union +from typing import Any, Callable, Dict, List, Tuple, Type, Union import torch.nn as nn +from ..shard.shard_config import ShardConfig -@dataclass -class Argument: - r""" - The argument class for the policy - Args: - attr_dict (Dict[str, Any]): The dict for the param setting - param_funcs (:class:`List[Callable]`): The list for the param functions - """ - attr_dict: Dict[str, Any] - param_funcs: List[Callable] +class ParallelModule(): - -@dataclass -class Layer: - r""" - The layer object for the policy - - Args: - suffix: (str): the suffix of the layer. - replace_layer (:class:`colosalai.nn`): The layer to replace the original layer - ignore (bool): Whether to ignore this layer if it is not in the model - reversed (bool): Whether the weight in layer is reversed, commonly the weight in `torch.nn.Linear` is [out, in], - but in GPT2 `Conv1D` layer is [in, out] which is reversed. - n_cast (int): The number of weight will cast to, like q, k, v in attention layer, n_cast should be 3. commonly in TP, we just chunk the weight with the number of devices, - but in multi-head attention, we need to chunk the weight with the number of devices * n_head, and - each device should have a part of Q, K and V weight. - """ - suffix: str = None - replace_layer: Any = None - ignore: bool = False - reversed: bool = False - n_cast: int = None + def __init__(self): + pass @dataclass -class Col_Layer(Layer): +class SubModuleReplacementDescription: r""" - Class for col shard layer in tensor parrallel + Describe how a submodule will be replaced - Args: - weight (str): The weight suffix of the layer - bias (str): The bias suffix of the layer - gather_output (bool): Whether to gather the output of the layer + suffix (str): used to get the submodule object + target_module (ParallelModule): specifies the module class used to replace to submodule + kwargs (Dict[str, Any]): the dictionary used to pass extra arguments to the `ParallelModule.from_native_module` method. """ - weight: str = None - bias: str = None - gather_output: bool = False + suffix: str + target_module: ParallelModule + kwargs: Dict[str, Any] = None @dataclass -class Row_Layer(Layer): +class ModulePolicyDescription: r""" - Class for col shard layer in tensor parrallel - - Args: - weight (str): The weight suffix of the layer - bias (str): The bias suffix of the layer - """ - weight: str = None - bias: str = None + Describe how the attributes and parameters will be transformed in a policy + attribute_replacement (Dict[str, Any]): key is the attribute name, value is the attribute value after sharding + param_replacement (List[Callable]): a list of functions to perform in-place param replacement. The function + must receive two arguments: module, process_group. One example is -@dataclass -class Dropout_Layer(Layer): - r""" - Class for dropout layer in tensor parrallel - - Args: - p (str): The dropout rate suffix of the layer - """ - p: str = None - - -@dataclass -class Embedding_Layer(Layer): - r""" - Class for col shard layer in tensor parrallel + ```python + def example_replace_weight(module: torch.nn.Module, process_group): + weight = module.weight + new_weight = shard_rowwise(weight, process_group) + module.weight = torch.nn.Parameter(new_weight) + ``` - Args: - weight (str): The weight suffix of the layer + sub_module_replacement: each element in the list is a ParamReplacementDescription object which specifies + the module to be replaced and the target module used to replacement """ - weight: str = None - gather_output: bool = True + attribute_replacement: Dict[str, Any] + param_replacement: List[Callable] + sub_module_replacement: List[SubModuleReplacementDescription] -class Policy(): +class Policy(ABC): r""" The base class for all the policies + For each different model, it should have a different policy class, like BertPolicy for Bert Model or OPTPolicy for OPT model. + AutoPolicy: Shardformer already defined some policies for huggingface model, just set ``custom_policy`` = None to use the auto policy. In shardformer autopolicy, we define a base policy for one type model, @@ -111,137 +74,75 @@ class for the example. """ - @staticmethod - def argument_policy(model_config, world_size: int) -> Dict[nn.Module, Argument]: + def __init__(self) -> None: + self.model = None + + def set_model(self, model: nn.Module) -> None: r""" - Return the dict for the modify policy, the key is the original layer class and the value is the - argument for the modify layer + Set model as an attribute of the Policy object so that we can access the model's attributes. Args: - model_config (:class:`tansformer.Config`): The config of transformer model - world_size (int)): The world size of sharding model + model (:class:`nn.Module`): The model to be perform + """ + self.model = model + + @abstractmethod + def preprocess(self, shard_config: ShardConfig = None) -> nn.Module: + r""" + Perform some preprocessing of the model, like reshaping the embedding layer + """ + + @abstractmethod + def module_policy(self, shard_config: ShardConfig = None) -> Dict[Union[str, nn.Module], ModulePolicyDescription]: + r""" + Return the dict for the modify policy, the key is the original layer class and the value is the + argument for the modify layer Return: Dict for the modify policy, :: { - origin layer class1 (nn.Module): Argument( - attr_dict = { - argument1: value1, - argument2: value2, + origin layer class1 (nn.Module): ModulePolicyDescription( + attribute_replacement = { + "attribute1": value1, + "attribute2": value2, ... }, - param_funcs = [ - staticmethod1, - staticmethod2, + param_replacement = [ + function1, + function2, ... - ] - ), - origin layer class2 (nn.Module): Argument( - attr_dict = { - argument1: value1, - argument2: value2, - ... - }, - param_funcs = [ - staticmethod1, - staticmethod2, + ], + sub_module_replacement = [ + `SubModuleReplacementDescription` description1, + `SubModuleReplacementDescription` description2, ... ] ), + origin layer class2 (nn.Module): ModulePolicyDescription( + ... + ), ... } - - """ - raise NotImplementedError - - @staticmethod - def inject_policy() -> Union[Tuple[nn.Module, nn.Module], None]: - r""" - Return the dict for the inject model - - Return: - The injected model, key is the original model and value is the new shardmodel - :: - (OrignModel, CustomModel) - in `CustomModel`, we can overwrite the forward and backward process """ - return None - @staticmethod - def binding_policy() -> Union[Dict[str, str], None]: + @abstractmethod + def new_model_class(self) -> Union[Type[nn.Module], None]: r""" - Return the dict for the binding model, None means no need to bind + Return the new model class for the new model, None means no need to modify the model class Return: - This method should return the binding relationship for some layers share the weight or bias, - the key and value is the suffix of the weight or bias of the model - :: - return { - "bert.embeddings.word_embeddings.weight": "cls.predictions.decoder.weight", - } - """ - return None - - @staticmethod - def attn_in() -> Union[List, None]: - r""" - Attention qkv layer - In this kind of method, we should return the list of ``Layer`` object, each ``Layer`` object should be - ``Layer`` for no slicing, ``Col_Layer`` for col slicing, ``Row_Layer`` for row slicing. And the parameters - in ``Layer`` object can refer to the ``Layer`` class. - - Returns: - List[Layer]: List of layer object, each layer is the new - """ - return None - - @staticmethod - def attn_out() -> Union[List, None]: - r""" - Attention output projection layer - - Returns: - List[Layer]: List of layer object - """ - return None - - @staticmethod - def mlp_in() -> Union[List, None]: - r""" - h -> 4h mlp layer - - Returns: - List[Layer]: List of layer object - """ - return None - - @staticmethod - def mlp_out() -> Union[List, None]: - r""" - 4h -> h mlp layer - - Returns: - List[Layer]: List of layer object - """ - return None - - @staticmethod - def embedding() -> Union[List, None]: - r""" - Partially slice the embedding layer + New model class - Return: - List[Layer]: List of layer object + E.g. + ``` + return BertModel_ + ``` """ - return None - @staticmethod - def unembedding() -> Union[List, None]: + @abstractmethod + def postprocess(self) -> nn.Module: r""" - Partially slice the embedding layer, None means there is no unembedding layer - - Return: - List[Layer]: List of layer object + Perform some postprocessing of the model, like binding the weight of embedding layer with + the classifier layer """ - return None diff --git a/colossalai/shardformer/policies/bert.py b/colossalai/shardformer/policies/bert.py index ba2266353e3e..f3431c386fe4 100644 --- a/colossalai/shardformer/policies/bert.py +++ b/colossalai/shardformer/policies/bert.py @@ -1,220 +1,77 @@ -from typing import Any, Callable, Dict, List, Tuple, Type - import torch.nn as nn from transformers.models.bert.modeling_bert import BertEmbeddings, BertLayer, BertLMPredictionHead import colossalai.shardformer.layer.layers as col_nn -from .basepolicy import Argument, Col_Layer, Dropout_Layer, Policy, Row_Layer +from ..shard.shard_config import ShardConfig +from ..utils import getattr_, setattr_ +from .basepolicy import ModulePolicyDescription, Policy, SubModuleReplacementDescription + + +class ParallelModule(): + + def __init__(self): + pass class BertPolicy(Policy): - @staticmethod - def argument_policy(config, world_size: int) -> Dict[nn.Module, Argument]: + def preprocess(self, shard_config: ShardConfig = None): + # reshape the embedding layer + r""" + Reshape the Embedding layer to make the embedding dimension divisible by world_size + """ + # TODO: + vocab_size = self.model.config.vocab_size + world_size = shard_config.tensor_parallel_size + if vocab_size % world_size != 0: + new_vocab_size = vocab_size + world_size - vocab_size % world_size + self.model.resize_token_embeddings(new_vocab_size) + return self.model + + def module_policy(self, shard_config: ShardConfig = None): return { BertLayer: - Argument( - attr_dict={ + ModulePolicyDescription( + attribute_replacement={ # 1. shard hidden size - "attention.self.all_head_size": config.hidden_size // world_size, - "crossattention.self.all_head_size": config.hidden_size // world_size, + "attention.self.all_head_size": + self.model.config.hidden_size // shard_config.tensor_parallel_size, + "crossattention.self.all_head_size": + self.model.config.hidden_size // shard_config.tensor_parallel_size, # 2. shard number of heads - "attention.self.num_attention_heads": config.num_attention_heads // world_size, - "crossattention.self.num_attention_heads": config.num_attention_heads // world_size, + "attention.self.num_attention_heads": + self.model.config.num_attention_heads // shard_config.tensor_parallel_size, + "crossattention.self.num_attention_heads": + self.model.config.num_attention_heads // shard_config.tensor_parallel_size, }, - param_funcs=[BertPolicy.attn_in, BertPolicy.attn_out, BertPolicy.mlp_in, BertPolicy.mlp_out]), - BertEmbeddings: - Argument( - attr_dict={ - # 1. shard vocab size - "word_embeddings.dim_size": (config.vocab_size + world_size - 1) // world_size, - }, - param_funcs=[ - BertPolicy.embedding, - ]), - } - - @staticmethod - def attn_in(): - return [ - Col_Layer( - suffix="attention.self.query", - weight="weight", - bias="bias", - replace_layer=col_nn.Linear1D_Col, - ), - Col_Layer( - suffix="attention.self.key", - weight="weight", - bias="bias", - replace_layer=col_nn.Linear1D_Col, - ), - Col_Layer( - suffix="attention.self.value", - weight="weight", - bias="bias", - replace_layer=col_nn.Linear1D_Col, - ), - Dropout_Layer( - suffix="attention.self.dropout", - p="p", - replace_layer=col_nn.Dropout1D, - ), - Col_Layer( - suffix="crossattention.self.query", - weight="weight", - bias="bias", - replace_layer=col_nn.Linear1D_Col, - ignore=True, - ), - Col_Layer( - suffix="crossattention.self.key", - weight="weight", - bias="bias", - replace_layer=col_nn.Linear1D_Col, - ignore=True, - ), - Col_Layer( - suffix="crossattention.self.value", - weight="weight", - bias="bias", - replace_layer=col_nn.Linear1D_Col, - ignore=True, - ), - ] - - @staticmethod - def attn_out(): - return [ - Row_Layer( - suffix="attention.output.dense", - weight="weight", - bias="bias", - replace_layer=col_nn.Linear1D_Row, - ), - Dropout_Layer( - suffix="attention.output.dropout", - p="p", - replace_layer=col_nn.Dropout1D, - ), - Row_Layer( - suffix="crossattention.output.dense", - weight="weight", - bias="bias", - replace_layer=col_nn.Linear1D_Row, - ignore=True, - ), - ] - - @staticmethod - def mlp_in(): - return [ - Col_Layer( - suffix="intermediate.dense", - weight="weight", - bias="bias", - replace_layer=col_nn.Linear1D_Col, - ), - ] - - @staticmethod - def mlp_out(): - return [ - Row_Layer( - suffix="output.dense", - weight="weight", - bias="bias", - replace_layer=col_nn.Linear1D_Row, - ), - Dropout_Layer( - suffix="output.dropout", - p="p", - replace_layer=col_nn.Dropout1D, - ) - ] - - @staticmethod - def embedding(): - return [Col_Layer( - suffix="word_embeddings", - weight="weight", - replace_layer=col_nn.VocabParallelEmbedding1D, - )] - - @staticmethod - def unembedding(): - return [ - Col_Layer( - suffix="decoder", - weight="weight", - bias="bias", - replace_layer=col_nn.Linear1D_Col, - gather_output=True, - ) - ] - - -# BertModel -class BertModelPolicy(BertPolicy): - - @staticmethod - def argument_policy(config, world_size): - return BertPolicy.argument_policy(config, world_size) - - -# BertForPretraining -class BertForPretrainingPolicy(BertPolicy): - - @staticmethod - def argument_policy(config, world_size): - base_argument = BertPolicy.argument_policy(config, world_size) - argument = { - BertLMPredictionHead: Argument(attr_dict={}, param_funcs=[ - BertPolicy.unembedding, - ]), + param_replacement=[], + sub_module_replacement=[ + SubModuleReplacementDescription( + suffix="attention.self.query", + target_module=ParallelModule, + ), + ]) } - argument.update(base_argument) - return argument - @staticmethod - def inject_policy(): + def new_model_class(self): + # do nothing return None - @staticmethod - def binding_policy(): - return { - "bert.embeddings.word_embeddings.weight": "cls.predictions.decoder.weight", - } - - -# BertForMaskedLM -from colossalai.shardformer.model.modeling_bert import BertForMaskedLM_ + def postprocess(self): + binding_map = {"bert.embeddings.word_embeddings.weight": "cls.predictions.decoder.weight"} + for k, v in binding_map.items(): + param = getattr_(self.model, k) + param = nn.Parameter(param) + setattr_(self.model, k, param) + setattr_(self.model, v, param) + return self.model class BertForMaskedLMPolicy(BertPolicy): - @staticmethod - def argument_policy(config, world_size): - base_argument = BertPolicy.argument_policy(config, world_size) - argument = { - BertLMPredictionHead: Argument(attr_dict={}, param_funcs=[ - BertPolicy.unembedding, - ]), - } - argument.update(base_argument) - return argument - - @staticmethod - def inject_policy(): - # return (BertForMaskedLM, BertForMaskedLM_) - return None - - @staticmethod - def binding_policy(): - return { - "bert.embeddings.word_embeddings.weight": "cls.predictions.decoder.weight", - } + def __init__(self) -> None: + super().__init__() # BertLMHeadModel @@ -231,36 +88,5 @@ def argument_policy(config, world_size): argument.update(base_argument) return argument - @staticmethod - def inject_policy(): - return None - - @staticmethod - def binding_policy(): - return { - "bert.embeddings.word_embeddings.weight": "cls.predictions.decoder.weight", - } - - -# BertForNextSentencePrediction -class BertForNextSentencePredictionPolicy(BertPolicy): - - @staticmethod - def argument_policy(config, world_size): - return BertPolicy.argument_policy(config, world_size) - - -# BertForSequenceClassification -class BertForSequenceClassificationPolicy(BertPolicy): - - @staticmethod - def argument_policy(config, world_size): - return BertPolicy.argument_policy(config, world_size) - - -# BertForMultipleChoice -class BertForMultipleChoicePolicy(BertPolicy): - - @staticmethod - def argument_policy(config, world_size): - return BertPolicy.argument_policy(config, world_size) + def __init__(self) -> None: + super().__init__() diff --git a/colossalai/shardformer/shard/__init__.py b/colossalai/shardformer/shard/__init__.py index d5f70163ad57..7abdd45ec7c5 100644 --- a/colossalai/shardformer/shard/__init__.py +++ b/colossalai/shardformer/shard/__init__.py @@ -1,5 +1,5 @@ from .shard_config import ShardConfig -from .sharder import ModelSharder, shard_model -from .slicer import Slicer +from .sharder import ModelSharder +from .shardformer import ShardFormer -__all__ = ['ShardConfig', 'ModelSharder', 'shard_model', 'Slicer'] +__all__ = ['ShardConfig', 'ModelSharder', 'ShardFormer'] diff --git a/colossalai/shardformer/shard/shard_config.py b/colossalai/shardformer/shard/shard_config.py index 96c287577ddc..53999529d277 100644 --- a/colossalai/shardformer/shard/shard_config.py +++ b/colossalai/shardformer/shard/shard_config.py @@ -1,4 +1,5 @@ from dataclasses import dataclass +from typing import List, Literal __all__ = ['ShardConfig'] @@ -9,10 +10,18 @@ class ShardConfig: The config for sharding the huggingface model Args: - rank (int): The rank of local process - world_size (int): The world size of the distributed process + data_parallel_size (int): The size of data parallel + tensor_parallel_size (int): The size of tensor parallel + pipeline_parallel_size (int): The size of pipeline parallel + tensor_parallel_mode (List): The mode of tensor parallel, choose from `['1d','2d','2.5d','3d'] + inference_only (bool): Whether to use the inference only mode, when setting to `True`, the model + will not calculate the loss and just return the output. gather_output (bool): Whether to gather the output of the model of the last layer """ - rank: int = None - world_size: int = None + data_parallel_size: int + tensor_parallel_size: int + + pipeline_parallel_size: int + tensor_parallel_mode: Literal['1d', '2d', '2.5d', '3d'] + inference_only: bool = True gather_output: bool = True diff --git a/colossalai/shardformer/shard/sharder.py b/colossalai/shardformer/shard/sharder.py index 7ef0c37a4040..8eee3c6a3b7e 100644 --- a/colossalai/shardformer/shard/sharder.py +++ b/colossalai/shardformer/shard/sharder.py @@ -4,11 +4,12 @@ import torch.nn as nn from transformers.pytorch_utils import Conv1D +from colossalai.cluster.process_group_manager import ProcessGroupManager + from ..policies.autopolicy import get_autopolicy -from ..policies.basepolicy import Col_Layer, Dropout_Layer, Policy, Row_Layer, Embedding_Layer -from ..utils.utils import getattr_, hasattr_, setattr_ +from ..policies.basepolicy import Policy +from ..utils.utils import setattr_ from .shard_config import ShardConfig -from .slicer import Slicer __all__ = ['ModelSharder', 'shard_model'] @@ -28,20 +29,23 @@ def __init__( model: nn.Module, policy: Policy, shard_config: ShardConfig = None, # TODO - ) -> None: + pg_manager: ProcessGroupManager = None) -> None: self.model = model self.policy = get_autopolicy(self.model) if policy is None else policy - self.slicer = Slicer(shard_config) self.shard_config = shard_config - self.model_config = self.model.config + self.pg_manager = pg_manager def shard(self) -> None: - self.reshape_embedding() - self.inject_model(self.model) - self.replace_layer(self.model) - self.bind_layer(self.model) + r""" + Shard the model according to the policy + """ + self.policy.set_model(self.model) + self.preprocess() + self.replace_model_class() + self.replace_module() + self.postprocess() - def reshape_embedding(self,) -> None: + def reshape_embedding(self) -> None: r""" Reshape the Embedding layer to make the embedding dimension divisible by world_size """ @@ -52,10 +56,13 @@ def reshape_embedding(self,) -> None: self.model.resize_token_embeddings(new_vocab_size) self.model_config = self.model.config - def inject_model( - self, - model: nn.Module, - ) -> None: + def preprocess(self) -> None: + self.model = self.policy.preprocess(self.shard_config) + + def postprocess(self) -> None: + self.model = self.policy.postprocess() + + def replace_model_class(self,) -> None: r""" Replace the model to policy defined model Mainly modify the forward and backward to fit distributed model @@ -64,49 +71,43 @@ def inject_model( :: BertForMaskedLM.forward -> BertForMaskedLM_.forward """ - inject_policy = self.policy.inject_policy() - if inject_policy is None: - return - - if inject_policy is None: + new_model_class = self.policy.new_model_class() + if new_model_class is None: return - org_model_cls = inject_policy[0] - shard_model_cls = inject_policy[1] - if model.__class__ == org_model_cls: - for key in shard_model_cls.__dict__.keys(): - if hasattr(model.__class__, key): - setattr( - model.__class__, - key, - getattr(shard_model_cls, key), - ) - else: - raise NotImplementedError(f"{model.__class__} is not implemented so far") + for key in new_model_class.__dict__.keys(): + if hasattr(self.model.__class__, key): + setattr( + self.model.__class__, + key, + getattr(new_model_class, key), + ) - def replace_layer( - self, - model: nn.Module, - ) -> None: + def replace_module(self,) -> None: r""" - Replace the layer according to the policy, and replace the layer one by one + Replace the module according to the policy, and replace the module one by one Args: - model (:class:`torch.nn.Module`): The layer to shard + model (:class:`torch.nn.Module`): The model to shard """ - argument_policies = self.policy.argument_policy(self.model_config, self.shard_config.world_size) - for argument_policy in argument_policies.items(): - origin_layer_cls = argument_policy[0] - attr_dict = argument_policy[1].attr_dict - param_funcs = argument_policy[1].param_funcs - self.traverse_replace_layer(model, origin_layer_cls, attr_dict, param_funcs) - - def traverse_replace_layer( + print(self.policy) + module_descriptions = self.policy.module_policy(self.shard_config) + print(f"*******{module_descriptions}") + for module_description in module_descriptions.items(): + origin_layer_cls = module_description[0] + attr_replacement = module_description[1].attribute_replacement + param_replacement = module_description[1].param_replacement + sub_module_replacement = module_description[1].sub_module_replacement + self._recursive_replace_layer(self.model, origin_layer_cls, attr_replacement, param_replacement, + sub_module_replacement) + + def _recursive_replace_layer( self, - layer: nn.Module, + module: nn.Module, origin_cls: nn.Module, - attr_dict: Dict[str, Any], - param_funcs: List[Callable], + attr_replacement: Dict[str, Any], + param_replacement: List[Callable], + sub_module_replacement: List[Callable], ) -> None: r""" Reverse the replace layer operation @@ -114,169 +115,69 @@ def traverse_replace_layer( Args: layer (:class:`torch.nn.Module`): The object of layer to shard origin_cls (:class:`transformers.model`): The origin layer class - attr_dict (Dict): The attribute dict to modify - policy_cls (:class:`Policy`): The policy class + attr_replacement (Dict): The attribute dict to modify + param_replacement (List[Callable]): The function list to get parameter shard information in polic + sub_module_replacement (List[Callable]): The function list to get sub module shard information in policy """ - if layer.__class__ == origin_cls: - for k, v in attr_dict.items(): - setattr_(layer, k, v, ignore=True) - self.shard_one_layer(layer, param_funcs) - for name, child in layer.named_children(): - self.traverse_replace_layer(child, origin_cls, attr_dict, param_funcs) - return layer - - def shard_one_layer( + if module.__class__ == origin_cls: + self._replace_attr(module, attr_replacement) + self._replace_param(module, param_replacement) + self._replace_sub_module(module, sub_module_replacement) + for name, child in module.named_children(): + self._recursive_replace_layer(child, origin_cls, attr_replacement, param_replacement, + sub_module_replacement) + + def _replace_attr( self, - org_layer: nn.Module, - param_funcs: List[Callable], + module: nn.Module, + attr_replacement: Dict[str, Any], ) -> None: r""" - Shard one layer according to the policy, the layer should be the same class as the key in policy's argument_policy return dict + Replace the attribute of the layer Args: - org_layer (:class:`torch.nn.Module`): The origin layer object to shard - param_funcs (:class:`List[typing.Callable]`): The function list to get shard information in policy class - + layer (:class:`torch.nn.Module`): The object of layer to shard + attr_replacement (Dict): The attribute dict to modify """ - for func in param_funcs: - policy_layers = func() - for policy_layer in policy_layers: - suffix = policy_layer.suffix - replace_layer_cls = policy_layer.replace_layer - ignore = policy_layer.ignore - reversed = policy_layer.reversed - n_cast = policy_layer.n_cast - - assert replace_layer_cls is not None, 'replace_layer should not be None' - - # create new object to replace the origin layer - # Linear - suffix_layer = getattr_(org_layer, suffix, ignore=True) - assert suffix_layer is not None or ignore, f"Layer {org_layer.__class__.__qualname__} has no attribute {suffix}" - if suffix_layer is None and ignore: - continue - if isinstance(policy_layer, (Col_Layer, Row_Layer, Embedding_Layer)): - weight = None - bias = None - weight_attr = suffix + '.' + policy_layer.weight if policy_layer.weight is not None else None - bias_attr = suffix + '.' + policy_layer.bias if hasattr(policy_layer, 'bias') and policy_layer.bias is not None else None - - if weight_attr is not None: - if hasattr_(org_layer, weight_attr): - weight = getattr_(org_layer, weight_attr) - else: - raise ValueError(f"Layer {org_layer.__class__.__qualname__} has no attribute {weight_attr}") + for k, v in attr_replacement.items(): + setattr_(module, k, v, ignore=True) - if bias_attr is not None: - if hasattr_(org_layer, bias_attr): - bias = getattr_(org_layer, bias_attr) - else: - raise ValueError(f"Layer {org_layer.__class__.__qualname__} has no attribute {bias_attr}") - - # set the sliced weight and bias to the new nn_col layer - assert weight is not None or bias is not None - - # slice weight and bias - weight, bias = self.slicer.slice_weight_bias(weight, bias, policy_layer.__class__, n_cast, reversed) - - if replace_layer_cls.__name__ == "Linear1D_Row": - replace_layer = replace_layer_cls(weight.shape[1], - weight.shape[0], - bias=False if bias is None else True) - elif replace_layer_cls.__name__ == "Linear1D_Col": - gather_output = policy_layer.gather_output and self.shard_config.gather_output - replace_layer = replace_layer_cls(weight.shape[0], - weight.shape[1], - bias=False if bias is None else True, - gather_output=gather_output) - elif replace_layer_cls.__name__ == "Embedding1D": - gather_output = policy_layer.gather_output - replace_layer = replace_layer_cls(weight.shape[0], - weight.shape[1], - gather_output=gather_output) - elif replace_layer_cls.__name__ == "VocabParallelEmbedding1D": - replace_layer = replace_layer_cls(weight.shape[0], weight.shape[1], - getattr_(org_layer, f"{suffix}.padding_idx", ignore=True)) - # setattr_(org_layer, suffix, replace_layer, ignore=ignore) - # self.set_param(replace_layer, weight, bias) - else: - raise NotImplementedError( - f"Replacing to {replace_layer_cls.__name__} is not implemented so far") - setattr_(org_layer, suffix, replace_layer, ignore=ignore) - self.set_param(replace_layer, weight, bias) - # dropout - elif isinstance(policy_layer, Dropout_Layer): - p_attr = suffix + '.' + policy_layer.p - p = getattr_(org_layer, p_attr, ignore=True) - replace_layer = replace_layer_cls(p) - setattr_(org_layer, suffix, replace_layer, ignore=ignore) - else: - raise NotImplementedError( - f"Replacing {getattr_(org_layer, suffix).__class__} is not implemented so far") - - def set_param(self, - layer: Any, - weight: torch.Tensor = None, - bias: torch.Tensor = None, - layer_attr: str = "") -> None: + def _replace_param( + self, + module: nn.Module, + param_replacement: List[Callable], + ) -> None: r""" - Reset the weight and bias of the layer object + Replace the parameter of the layer Args: - layer (:class:`torch.nn.Module`): The layer object - layer_attr (str): The attribute name of the layer - weight (:class:`torch.Tensor`): The weight of the layer - bias (:class:`torch.Tensor`): The bias of the layer + layer (:class:`torch.nn.Module`): The object of layer to shard + param_replacement (List[Callable]): The function list to get parameter shard information in policy """ - assert weight is not None or bias is not None - if weight is not None: - setattr_(layer, "weight" if layer_attr == "" else layer_attr + ".weight", nn.Parameter(weight.contiguous())) - self.set_layer_size(layer, layer_attr, weight.shape) - if bias is not None: - setattr_(layer, "bias" if layer_attr == "" else layer_attr + ".bias", nn.Parameter(bias.contiguous())) + # TODO: support parameter shard + pass - def set_layer_size(self, layer: nn.Module, layer_attr: str, size: torch.Size) -> None: + def _replace_sub_module( + self, + org_layer: nn.Module, + sub_module_replacement: List[Callable], + ) -> None: r""" - Set the layer attribute + Shard one layer according to the policy, the layer should be the same class as the key in policy's argument_policy return dict Args: - layer (:class:`torch.nn.Module`): The layer object - layer_attr (str): The attribute name of the layer - size (:class:`torch.Size`): The size of the tensor - """ - # Tensor.shape[0] -> out_features, Tensor.shape[1] -> in_features - attrs = ["out_features", "in_features"] - for i, attr in enumerate(attrs): - if hasattr_(layer, f"{layer_attr}.{attr}"): - setattr_(layer, f"{layer_attr}.{attr}", size[i]) - - def bind_layer(self, model: nn.Module) -> None: - r""" - Bind the layer according to the binding policy + org_layer (:class:`torch.nn.Module`): The origin layer object to shard + param_funcs (:class:`List[typing.Callable]`): The function list to get shard information in policy class - Args: - model (:class:`torch.nn.Module`): The shard model """ - binding_map = self.policy.binding_policy() - if binding_map is None: - return - for k, v in binding_map.items(): - param = getattr_(model, k) - param = nn.Parameter(param) - setattr_(model, k, param) - setattr_(model, v, param) - + for description in sub_module_replacement: + suffix = description.suffix + target_module = description.target_module + kwargs = description.kwargs -def shard_model(model: nn.Module, shard_config: ShardConfig = None, policy: Policy = None): - r""" - The function is used to shard the PyTorch model. + assert target_module is not None, 'target_module should not be None' - Args: - model (`torch.nn.Model`): the origin huggingface model - shard_config (`ShardConfig`): the config for distribute information - policy (`Policy`): the custom policy for sharding - """ - # TODO: init shard_config automatically - sharder = ModelSharder(model=model, shard_config=shard_config, policy=policy) - sharder.shard() - return model + # TODO: integrate with new layer + # replace_layer = target_module.from_native_layer(org_layer, self.pg_manager) + replace_layer = None + setattr_(org_layer, suffix, replace_layer) diff --git a/colossalai/shardformer/shard/shardformer.py b/colossalai/shardformer/shard/shardformer.py new file mode 100644 index 000000000000..5313dfecb37e --- /dev/null +++ b/colossalai/shardformer/shard/shardformer.py @@ -0,0 +1,77 @@ +import torch.nn as nn +from torch.utils.data import Dataset + +from colossalai.cluster import DistCoordinator, ProcessGroupManager + +from ..policies.basepolicy import Policy +from .shard_config import ShardConfig +from .sharder import ModelSharder + + +class ShardFormer: + """ + Parallelize model based on the given config and policy + + Example: + + ```python + from colossalai.shardformer import ShardFormer, ShardConfig + from transformers import BertForMaskedLM + import colossalai + import torch + + colossalai.launch_from_torch(config={}) + + org_model = BertForMaskedLM.from_pretrained('bert-base-uncased') + shard_config = ShardConfig( + tensor_parallel_size=2, + data_parallel_size=1, + pipeline_parallel_size=1, + tensor_parallel_mode='1d', + inference_only=True, + gather_output=True + ) + shard_former = ShardFormer(shard_config=shard_config) + shard_former.init_distributed() + model = shard_former.shard_model(org_model) + ``` + """ + + def __init__(self, shard_config: ShardConfig): + """ + Do two things: + 1. Create a colossalai.cluster.process_group_manager to manage process groups for dp, tp and pp + 2. serve as a store for + """ + self.coordinator = DistCoordinator() + self.shard_config = shard_config + self.pg_manager = None + + def init_distributed(self) -> ProcessGroupManager: + """ + Initialize the distributed process group according to the + """ + pg_manager = ProcessGroupManager() + if (self.shard_config.tensor_parallel_mode == '1d'): + pg_manager.create_process_group(name='tp1d', ranks=range(self.coordinator.world_size)) + self.pg_manager = pg_manager + return pg_manager + + def shard_model(self, model: nn.Module, policy: Policy = None): + r""" + The function is used to shard the PyTorch model. + + Args: + model (`torch.nn.Model`): the origin huggingface model + shard_config (`ShardConfig`): the config for distribute information + policy (`Policy`): the custom policy for sharding + """ + sharder = ModelSharder(model=model, shard_config=self.shard_config, policy=policy, pg_manager=self.pg_manager) + sharder.shard() + return model + + def shard_dataset(self, dataset: Dataset): + """ + Shard dataset for DP + """ + pass diff --git a/colossalai/shardformer/shard/slicer.py b/colossalai/shardformer/shard/slicer.py deleted file mode 100644 index 860533dca50d..000000000000 --- a/colossalai/shardformer/shard/slicer.py +++ /dev/null @@ -1,163 +0,0 @@ -import torch - -from ..policies.basepolicy import Col_Layer, Dropout_Layer, Layer, Row_Layer, Embedding_Layer -from .shard_config import ShardConfig - -dim_mapping = {Col_Layer: 0, Row_Layer: 1, Embedding_Layer: 1} - - -class Slicer(): - - def __init__( - self, - shardconfig: ShardConfig #TODO - ) -> None: - self.shardconfig = shardconfig - - def slice_weight_bias( - self, - weight: torch.Tensor, - bias: torch.Tensor, - policy_layer_cls: Layer, - n_cast: int = None, - reversed: bool = False, - ): - r""" - Slice the weight and bias according to policy layer cls - ``Layer`` -> do nothing - ``Col_Layer`` -> slice the weight and bias along dim 1 - ``Row_Layer`` -> slice the weight along dim 0 and do not slice bias - - Args: - weight (:class:`torch.nn.Module`): The weight of the layer - bias: (:class:`torch.nn.Module`): The bias of the layer - policy_layer_class (:class:`Policy`): The class represent how to slice the tensor - """ - if policy_layer_cls in [Layer, Dropout_Layer]: - return weight, bias - - dim = dim_mapping[policy_layer_cls] if not reversed else (1 - dim_mapping[policy_layer_cls]) - # print(weight.shape, dim) - if policy_layer_cls == Col_Layer: - weight = self.slice_tensor(weight, dim, False, n_cast) - bias = self.slice_tensor(bias, 0, True, n_cast) - elif policy_layer_cls == Row_Layer: - weight = self.slice_tensor(weight, dim, False, n_cast) - elif policy_layer_cls == Embedding_Layer: - weight = self.slice_tensor(weight, dim, False, n_cast) - else: - raise NotImplementedError(f"The policy layer class {policy_layer_cls} is not supported") - if reversed: - weight = weight.transpose(0, 1).contiguous() - return weight, bias - - def slice_tensor( - self, - tensor_in: torch.Tensor, - dim: int, - is_bias: bool, - n_cast: int = None, - ) -> torch.Tensor: - r""" - Slice tensor according to the config - - Args: - tensor_in (:class:`torch.Tensor`): The tensor to slice - dim (int): The dimension to slice - is_bias (bool): Whether the tensor is bias - """ - if tensor_in is None: - return None - if not is_bias: - return self.slice_2d(tensor_in, dim, n_cast) - else: - return self.slice_1d(tensor_in, n_cast) - - def slice_2d( - self, - tensor: torch.Tensor, - dim: int, - n_cast: int = None, - ) -> torch.Tensor: - r""" - Slice the 2D tensor - - Args: - tensor (:class:`torch.Tensor`): The tensor to slice - dim (int): The dimension to slice - """ - assert dim in [0, 1], f"Only support 2D tensor, but got {dim}D tensor" - if dim == 0: - return self.slice_row(tensor, n_cast) - elif dim == 1: - return self.slice_col(tensor, n_cast) - - def slice_1d( - self, - tensor: torch.Tensor, - n_cast: int = None, - ) -> torch.Tensor: - r""" - Slice the 1D tensor - - Args: - tensor (:class:`torch.Tensor`): The tensor to slice - - Returns: - :class:`torch.Tensor`: The sliced tensor - """ - if n_cast is None: - return tensor.chunk(self.shardconfig.world_size, dim=0)[self.shardconfig.rank].contiguous() - else: - tensor_chunks = tensor.chunk(self.shardconfig.world_size * n_cast, dim=0) - chunk_list = [ - tensor_chunks[i] for i in range(self.shardconfig.rank, len(tensor_chunks), self.shardconfig.world_size) - ] - return torch.cat(chunk_list, dim=0).contiguous() - - def slice_col( - self, - tensor: torch.Tensor, - n_cast: int = None, - ) -> torch.Tensor: - r""" - Slice the tensor in column - - Args: - tensor (:class:`torch.Tensor`): The tensor to slice - - Returns: - :class:`torch.Tensor`: The sliced tensor - - """ - if n_cast is None: - return tensor.chunk(self.shardconfig.world_size, dim=1)[self.shardconfig.rank].contiguous() - else: - tensor_chunks = tensor.chunk(self.shardconfig.world_size * n_cast, dim=1) - chunk_list = [ - tensor_chunks[i] for i in range(self.shardconfig.rank, len(tensor_chunks), self.shardconfig.world_size) - ] - return torch.cat(chunk_list, dim=1).contiguous() - - def slice_row( - self, - tensor: torch.Tensor, - n_cast: int = None, - ) -> torch.Tensor: - r""" - Slice the tensor in column - - Args: - tensor (:class:`torch.Tensor`): The tensor to slice - - Returns: - :class:`torch.Tensor`: The sliced tensor - """ - if n_cast is None: - return tensor.chunk(self.shardconfig.world_size, dim=0)[self.shardconfig.rank].contiguous() - else: - tensor_chunks = tensor.chunk(self.shardconfig.world_size * n_cast, dim=0) - chunk_list = [ - tensor_chunks[i] for i in range(self.shardconfig.rank, len(tensor_chunks), self.shardconfig.world_size) - ] - return torch.cat(chunk_list, dim=0).contiguous() diff --git a/colossalai/shardformer/utils/__init__.py b/colossalai/shardformer/utils/__init__.py index e69de29bb2d1..b50e7b2f6d80 100644 --- a/colossalai/shardformer/utils/__init__.py +++ b/colossalai/shardformer/utils/__init__.py @@ -0,0 +1 @@ +from .utils import getattr_, hasattr_, setattr_ From 579b617dc99cef945d143ffb2f9718c68423b68b Mon Sep 17 00:00:00 2001 From: Frank Lee Date: Thu, 15 Jun 2023 18:03:38 +0800 Subject: [PATCH 18/49] [shardformer] integrated linear 1D with dtensor (#3996) * [shardformer] integrated linear 1D with dtensor * polish code --- colossalai/nn/layer/base_layer.py | 1 + colossalai/shardformer/layer/_operation.py | 133 +++- colossalai/shardformer/layer/dropout.py | 54 +- colossalai/shardformer/layer/layers.py | 653 +++++++++--------- colossalai/shardformer/layer/utils.py | 138 ++++ colossalai/tensor/d_tensor/api.py | 44 ++ colossalai/tensor/d_tensor/layout.py | 21 +- .../tensor/d_tensor/layout_converter.py | 2 +- .../test_layer/test_linear_1d.py | 67 ++ 9 files changed, 706 insertions(+), 407 deletions(-) create mode 100644 colossalai/shardformer/layer/utils.py create mode 100644 colossalai/tensor/d_tensor/api.py create mode 100644 tests/test_shardformer/test_layer/test_linear_1d.py diff --git a/colossalai/nn/layer/base_layer.py b/colossalai/nn/layer/base_layer.py index 5234b6b1a1b5..4a06bdcb7629 100644 --- a/colossalai/nn/layer/base_layer.py +++ b/colossalai/nn/layer/base_layer.py @@ -10,6 +10,7 @@ class ParallelLayer(nn.Module): + global_state_dict: bool = True def __init__(self): diff --git a/colossalai/shardformer/layer/_operation.py b/colossalai/shardformer/layer/_operation.py index e817ea3ebbee..208a391c33e2 100644 --- a/colossalai/shardformer/layer/_operation.py +++ b/colossalai/shardformer/layer/_operation.py @@ -54,10 +54,10 @@ class LinearWithAsyncCommunication(torch.autograd.Function): """ @staticmethod - def forward(ctx, input_, weight, bias, parallel_mode, async_grad_allreduce): + def forward(ctx, input_, weight, bias, process_group, async_grad_allreduce): ctx.save_for_backward(input_, weight) ctx.use_bias = bias is not None - ctx.parallel_mode = parallel_mode + ctx.process_group = process_group ctx.async_grad_allreduce = async_grad_allreduce output = torch.matmul(input_, weight.t()) @@ -74,12 +74,13 @@ def backward(ctx, grad_output): grad_input = grad_output.matmul(weight) grad_output = grad_output.contiguous() # Convert the tensor shapes to 2D for execution compatibility - grad_output = grad_output.view(grad_output.shape[0] * grad_output.shape[1], grad_output.shape[2]) - total_input = total_input.view(total_input.shape[0] * total_input.shape[1], total_input.shape[2]) + if len(grad_output.shape) > 2: + grad_output = grad_output.view(-1, grad_output.shape[-1]) + total_input = total_input.view(-1, total_input.shape[-1]) if ctx.async_grad_allreduce: # Asynchronous all-reduce - handle = dist.all_reduce(grad_input, group=gpc.get_group(ctx.parallel_mode), async_op=True) + handle = dist.all_reduce(grad_input, group=ctx.process_group, async_op=True) # Delay the start of weight gradient computation shortly (3us) to have # all-reduce scheduled first and have GPU resources allocated _ = torch.empty(1, device=grad_output.device) + 1 @@ -93,5 +94,123 @@ def backward(ctx, grad_output): return grad_input, grad_weight, grad_bias, None, None, None -def linear_with_async_comm(input_, weight, bias, parallel_mode, async_grad_allreduce): - return LinearWithAsyncCommunication.apply(input_, weight, bias, parallel_mode, async_grad_allreduce) +class _SplitForwardGatherBackward(torch.autograd.Function): + """ + Split the input and keep only the corresponding chuck to the rank. + + Args: + input_ (`torch.Tensor`): input matrix. + dim (int): the dimension to perform split and gather + process_group (`torch.distributed.ProcessGroup`): the process group used for collective communication + + """ + + @staticmethod + def forward(ctx, input_, dim, process_group): + ctx.process_group = process_group + ctx.dim = dim + return _split(input_, dim, process_group) + + @staticmethod + def backward(ctx, grad_output): + return _gather(grad_output, ctx.dim, ctx.process_group), None, None + + +class _ReduceInput(torch.autograd.Function): + """ + All-reduce the input from the model parallel region. + + Args: + input_: input matrix. + parallel_mode: parallel mode. + """ + + @staticmethod + def forward(ctx, input_, process_group): + return _reduce(input_, process_group) + + @staticmethod + def backward(ctx, grad_output): + return grad_output, None + + +def _reduce(input_, process_group): + # skip if only one rank involved + if dist.get_world_size(process_group) == 1: + return input_ + else: + dist.all_reduce(input_, group=process_group) + return input_ + + +def _split(input_, dim=-1, process_group=None): + # skip if only one rank involved + world_size = dist.get_world_size(process_group) + if world_size == 1: + return input_ + + # Split along last dimension. + dim_size = input_.size(dim) + assert dim_size % world_size == 0, \ + f'The dimension to split ({dim_size}) is not a multiple of world size ({world_size}), ' \ + f'cannot split tensor evenly' + + tensor_list = torch.split(input_, dim_size // world_size, dim=dim) + rank = dist.get_rank(process_group) + output = tensor_list[rank].contiguous() + + return output + + +def _gather(input_, dim=-1, process_group=None): + # skip if only one rank involved + world_size = dist.get_world_size(process_group) + if world_size == 1: + return input_ + + # all gather + rank = dist.get_rank(process_group) + tensor_list = [torch.empty_like(input_) for _ in range(world_size)] + tensor_list[rank] = input_ + torch.distributed.all_gather(tensor_list, input_, group=process_group) + + # concat + output = torch.cat(tensor_list, dim=dim).contiguous() + + return output + + +class _GatherForwardSplitBackward(torch.autograd.Function): + """Gather the input from model parallel region and concatenate. + + Args: + input_: input matrix. + parallel_mode: parallel mode. + dim: dimension + """ + + @staticmethod + def forward(ctx, input_, dim, process_group): + ctx.process_group = process_group + ctx.dim = dim + return _gather(input_, dim, process_group) + + @staticmethod + def backward(ctx, grad_output): + return _split(grad_output, ctx.dim, ctx.process_group), None, None + + +def linear_with_async_comm(input_, weight, bias, process_group, async_grad_allreduce): + return LinearWithAsyncCommunication.apply(input_, weight, bias, process_group, async_grad_allreduce) + + +def gather_forward_split_backward(input_, dim, process_group): + return _GatherForwardSplitBackward.apply(input_, dim, process_group) + + +def split_forward_gather_backward(input_, dim, process_group): + return _SplitForwardGatherBackward.apply(input_, dim, process_group) + + +def reduce_input(input_, process_group): + return _ReduceInput.apply(input_, process_group) diff --git a/colossalai/shardformer/layer/dropout.py b/colossalai/shardformer/layer/dropout.py index 0f653a9be780..5d295be6bd83 100644 --- a/colossalai/shardformer/layer/dropout.py +++ b/colossalai/shardformer/layer/dropout.py @@ -1,58 +1,20 @@ -import os -from contextlib import contextmanager - import torch +import torch.distributed as dist import torch.nn as nn - -class SeedManager: - """ - This class is a random state manager to change random state for different random seed. - - """ - - def __init__(self): - original_state = torch.cuda.get_rng_state() - # TODO: unify this seed manager with the colossalai.context.random - seed = os.getpid() - torch.cuda.manual_seed(int(seed)) - self.dropout_state = torch.cuda.get_rng_state() - torch.cuda.set_rng_state(original_state) - - def set_mode(self, rng_state): - torch.cuda.set_rng_state(rng_state) - - def get_current_mode(self): - current_state = torch.cuda.get_rng_state() - return current_state - - @contextmanager - def dropout_mode(self): - """ - This is a context manager to change the dropout state and recover the original state. - - Usage: - :: - >>> with _seed_manager.dropout_mode(): - >>> input = super().forward(input) - """ - try: - current_mode = self.get_current_mode() - yield self.set_mode(self.dropout_state) - finally: - self.dropout_state = self.get_current_mode() - self.set_mode(current_mode) - - -_seed_manager = SeedManager() +from .utils import create_randomizer_with_offset class Dropout1D(nn.Dropout): - def __init__(self, p=0.5, inplace=False): + def __init__(self, p=0.5, inplace=False, process_group=None): super().__init__(p, inplace) + # offset the seed with randomizer index and rank + seed = torch.random.initial_seed() + self.randomizer = create_randomizer_with_offset(seed, process_group=process_group) + def forward(self, input): - with _seed_manager.dropout_mode(): + with self.randomizer.fork_rng(): input = super().forward(input) return input diff --git a/colossalai/shardformer/layer/layers.py b/colossalai/shardformer/layer/layers.py index a9f3cf5ad14c..2ad6523c9a86 100644 --- a/colossalai/shardformer/layer/layers.py +++ b/colossalai/shardformer/layer/layers.py @@ -2,12 +2,16 @@ # -*- encoding: utf-8 -*- import math +from abc import ABC, abstractmethod from collections import OrderedDict -from typing import Callable, Tuple +from typing import Callable, List, Tuple, Union import torch +import torch.distributed as dist +import torch.nn as nn import torch.nn.functional as F from torch import Tensor +from torch.distributed import ProcessGroup from torch.nn.parameter import Parameter from colossalai.communication import broadcast @@ -22,13 +26,11 @@ gather_forward_split_backward, get_parallel_input, reduce_grad, - reduce_input, set_parallel_input, - split_forward_gather_backward, ) from colossalai.nn.layer.utils import divide, set_tensor_parallel_attribute_by_partition from colossalai.nn.layer.vanilla import VanillaLayerNorm, VanillaPatchEmbedding -from colossalai.registry import LAYERS +from colossalai.tensor.d_tensor.api import shard_colwise, shard_rowwise from colossalai.utils.checkpointing import ( broadcast_state_dict, gather_tensor_parallel_state_dict, @@ -36,7 +38,13 @@ ) from colossalai.utils.cuda import get_current_device -from ._operation import linear_with_async_comm +from ._operation import ( + gather_forward_split_backward, + linear_with_async_comm, + reduce_input, + split_forward_gather_backward, +) +from .utils import create_randomizer_with_offset Fast_LN = None try: @@ -46,21 +54,44 @@ pass -# @LAYERS.register_module -class Linear1D(ColossalaiModule): - r"""Linear layer for 1D parallelism. +class ParallelModule(nn.Module, ABC): + + @abstractmethod + def from_native_module(module: nn.Module, + process_group: Union[ProcessGroup, List[ProcessGroup]] = None) -> "ParallelModule": + """ + Convert a native PyTorch module to a parallelized module. + + Args: + module (nn.Module): the module to be converted. + process_group (ProcessGroup or list[ProcessGroup]): the process group(s) to be used for communication. + If this is a list, the process group at the ith index of the list will correspond to the process group + in the ith axis of the device mesh. Defaults to None, which means the global process group. + """ + pass + + +class Linear1D_Col(ParallelModule): + r"""Linear layer with column parallelism. + + The linear layer is defined as :math:`Y = XA + b`. A is parallelized along + its second dimension as :math:`A = [A_1, ..., A_p]`. Args: in_features (int): size of each input sample. out_features (int): size of each output sample. bias (bool, optional): If set to ``False``, the layer will not learn an additive bias, defaults to ``True``. - dtype (:class:`torch.dtype`, optional): The dtype of parameters, defaults to None. - gather_output (bool, optional): Whether to call all-gather on output, defaults to False. - skip_bias_add (bool, optional): If set to ``True``, it will skip bias add for linear layer, + dtype (`torch.dtype`): The dtype of parameters, defaults to None. + device (`torch.device`): The device of parameters, defaults to None. + process_group (`torch.distributed.ProcessGroup`): The process group to be used for weight sharding and communication, defaults to None. + gather_output (bool, optional): If true, call all-gather on output and make Y available + to all GPUs, otherwise, every GPU will have its output + which is :math:`Y_i = XA_i`, defaults to False + skip_bias_add (bool): If set to ``True``, it will skip bias add for linear layer, which is preserved for kernel fusion, defaults to False - weight_initializer (:class:`typing.Callable`, optional): + weight_initializer (`typing.Callable`): The initializer of weight, defaults to kaiming uniform initializer. - bias_initializer (:class:`typing.Callable`, optional): + bias_initializer (`typing.Callable`): The initializer of bias, defaults to xavier uniform initializer. More details about ``initializer`` please refer to @@ -72,32 +103,281 @@ def __init__(self, out_features: int, bias: bool = True, dtype: torch.dtype = None, + device: torch.device = None, + process_group: ProcessGroup = None, gather_output: bool = False, skip_bias_add: bool = False, weight_initializer: Callable = init.kaiming_uniform_(a=math.sqrt(5)), bias_initializer: Callable = init.xavier_uniform_(a=1, scale=1)): - parallel_input = get_parallel_input() - if not parallel_input and not gather_output: - layer = Linear1D_Col(in_features, - out_features, + super().__init__() + + # Keep input parameters + self.in_features = in_features + self.out_features = out_features + self.gather_output = gather_output + self.skip_bias_add = skip_bias_add + self.device = device + self.process_group = process_group + self.num_partitions = dist.get_world_size(self.process_group) + + if skip_bias_add and not bias: + raise ValueError('cannot skip bias addition if bias is None') + + self.out_features_per_partition = divide(out_features, self.num_partitions) + + # Parameters. + # Initialize weight. + if device is None: + device = get_current_device() + factory_kwargs = {'device': device, 'dtype': dtype} + self.weight = Parameter(torch.empty(self.out_features_per_partition, self.in_features, **factory_kwargs)) + + if bias: + self.bias = Parameter(torch.empty(self.out_features_per_partition, **factory_kwargs)) + else: + self.bias = None + + # offset the seed with randomizer index and rank + seed = torch.random.initial_seed() + self.randomizer = create_randomizer_with_offset(seed, process_group=self.process_group) + + with self.randomizer.fork_rng(enable_cpu=True): + self.reset_parameters(weight_initializer, bias_initializer) + + def from_native_module(module: nn.Linear, process_group: Union[ProcessGroup, List[ProcessGroup]], *args, + **kwargs) -> ParallelModule: + r""" + Convert a native PyTorch linear layer to a parallelized linear layer. + """ + # get the attributes + in_features = module.in_features + out_features = module.out_features + bias = module.bias is not None + device = module.weight.device + + # ensure only one process group is passed + if isinstance(process_group, (list, tuple)): + assert len(process_group) == 1, \ + f'Expected only one process group, got {len(process_group)}.' + process_group = process_group[0] + + linear_1d = Linear1D_Col(in_features=in_features, + out_features=out_features, bias=bias, - dtype=dtype, - skip_bias_add=skip_bias_add, - weight_initializer=weight_initializer, - bias_initializer=bias_initializer) + device=device, + process_group=process_group, + *args, + **kwargs) + + # TODO: copy the sharded weights + with torch.no_grad(): + # the weigh to the linear layer is a transpose + # thus shard on row is equal to shard on column + sharded_weight = shard_rowwise(module.weight.data, process_group) + linear_1d.weight.data.copy_(sharded_weight) + if bias: + sharded_bias = shard_colwise(module.bias.data, process_group) + linear_1d.bias.copy_(sharded_bias) + + return linear_1d + + def reset_parameters(self, weight_initializer, bias_initializer) -> None: + fan_in, fan_out = self.in_features, self.out_features + weight_initializer(self.weight, fan_in=fan_in, fan_out=fan_out) + if self.bias is not None: + bias_initializer(self.bias, fan_in=fan_in) + + def forward(self, input_: Tensor) -> Tuple[Tensor, Tensor]: + assert input_.shape[-1] == self.weight.shape[-1], \ + 'Invalid shapes in Linear1D_Col forward: input={}, weight={}. Expected last dim of input {}.'.format( + input_.shape, self.weight.shape, self.weight.shape[-1]) + # Set up backprop all-reduce. + # input_parallel = reduce_grad(input_, ParallelMode.PARALLEL_1D) + input_parallel = input_ + # Matrix multiply. + bias = self.bias if not self.skip_bias_add else None + output_parallel = linear_with_async_comm(input_parallel, self.weight, bias, self.process_group, True) + + if self.gather_output: + # All-gather across the partitions. + output = gather_forward_split_backward(output_parallel, dim=-1, process_group=self.process_group) + else: + output = output_parallel + + if self.skip_bias_add: + return output, self.bias + else: + return output + + +class Linear1D_Row(ParallelModule): + r""" Linear layer with row parallelism + + Args: + in_features (int): size of each input sample. + out_features (int): size of each output sample. + bias (bool, optional): If set to ``False``, the layer will not learn an additive bias, defaults to ``True``. + dtype (`torch.dtype`): The dtype of parameters, defaults to None. + parallel_input (bool): If set to ``True``, it's assumed that the input is split, defaults to False. + skip_bias_add (bool): If set to ``True``, it will skip bias add for linear layer, + which is preserved for kernel fusion, defaults to False + weight_initializer (:class:`typing.Callable`, optional): + The initializer of weight, defaults to kaiming uniform initializer. + bias_initializer (:class:`typing.Callable`, optional): + The initializer of bias, defaults to xavier uniform initializer. + + More details about ``initializer`` please refer to + `init `_. + """ + + def __init__(self, + in_features: int, + out_features: int, + bias: bool = True, + dtype: torch.dtype = None, + device: torch.device = None, + process_group: ProcessGroup = None, + parallel_input: bool = True, + skip_bias_add: bool = False, + weight_initializer: Callable = init.kaiming_uniform_(a=math.sqrt(5)), + bias_initializer: Callable = init.xavier_uniform_(a=1, scale=1), + stream_chunk_num: int = 1): + super().__init__() + + self.stream_chunk_num = stream_chunk_num + + # Keep input parameters + self.in_features = in_features + self.out_features = out_features + self.parallel_input = parallel_input + self.skip_bias_add = skip_bias_add + self.process_group = process_group + self.num_partitions = dist.get_world_size(self.process_group) + + if skip_bias_add and not bias: + raise ValueError('cannot skip bias addition if bias is None') + + # Divide the weight matrix along the last dimension. + self.input_size_per_partition = divide(in_features, self.num_partitions) + + # Parameters. + # Initialize weight. + if device is None: + device = get_current_device() + + factory_kwargs = {'device': device, 'dtype': dtype} + self.weight = Parameter(torch.empty(self.out_features, self.input_size_per_partition, **factory_kwargs)) + + if self.stream_chunk_num > 1: + # TODO() work for inference only + self.chunk_weight() + if bias: + self.bias = Parameter(torch.empty(self.out_features, **factory_kwargs)) else: - layer = Linear1D_Row(in_features, - out_features, + self.bias = None + + # offset the seed with randomizer index and rank + seed = torch.random.initial_seed() + self.randomizer = create_randomizer_with_offset(seed, process_group=self.process_group) + + with self.randomizer.fork_rng(enable_cpu=True): + self.reset_parameters(weight_initializer, bias_initializer) + + @staticmethod + def from_native_module(module: nn.Linear, process_group: Union[ProcessGroup, List[ProcessGroup]], *args, + **kwargs) -> ParallelModule: + r""" + Convert a native PyTorch linear layer to a parallelized linear layer. + """ + # get the attributes + in_features = module.in_features + out_features = module.out_features + bias = module.bias is not None + device = module.weight.device + + # ensure only one process group is passed + if isinstance(process_group, (list, tuple)): + assert len(process_group) == 1, \ + f'Expected only one process group, got {len(process_group)}.' + process_group = process_group[0] + + linear_1d = Linear1D_Row(in_features=in_features, + out_features=out_features, bias=bias, - dtype=dtype, - parallel_input=parallel_input, - skip_bias_add=skip_bias_add, - weight_initializer=weight_initializer, - bias_initializer=bias_initializer) - super().__init__(layer) + device=device, + process_group=process_group, + *args, + **kwargs) + + # TODO: copy the sharded weights + with torch.no_grad(): + # the weigh to the linear layer is a transpose + # thus shard on col is equal to shard on row + sharded_weight = shard_colwise(module.weight.data, process_group) + linear_1d.weight.data.copy_(sharded_weight) + + if bias: + linear_1d.bias.copy_(module.bias.data) + + return linear_1d + + def chunk_weight(self): + self.weight_list = torch.chunk(self.weight, self.stream_chunk_num, dim=0) + + def reset_parameters(self, weight_initializer, bias_initializer) -> None: + fan_in, fan_out = self.in_features, self.out_features + weight_initializer(self.weight, fan_in=fan_in, fan_out=fan_out) + + if self.bias is not None: + bias_initializer(self.bias, fan_in=fan_in) + if self.process_group is None: + src_rank = 0 + else: + src_rank = dist.distributed_c10d._get_global_rank(self.process_group, 0) + dist.broadcast(self.bias, src=src_rank, group=self.process_group) + + def forward(self, input_: Tensor) -> Tensor: + # Set up backprop all-reduce. + if self.parallel_input: + assert input_.shape[-1] == self.weight.shape[-1], \ + 'Invalid shapes in Linear1D_Row forward: input={}, weight={}. Expected last dim of input {}.'.format( + input_.shape, self.weight.shape, self.weight.shape[-1]) + input_ = input_ + else: + assert divide(input_.shape[-1], self.num_partitions) == self.weight.shape[-1], \ + 'Invalid shapes in Linear1D_Row forward: input={}, weight={}. Expected last dim of input {}.'.format( + input_.shape, self.weight.shape, self.weight.shape[-1] * self.num_partitions) + input_ = split_forward_gather_backward(input_, dim=-1, process_group=self.process_group) + + if self.stream_chunk_num > 1: + if self.training: + raise RuntimeError("use stream_chunk_num=1 in Linear1D_Row for training!") + with torch.no_grad(): + output_parallel_list = [None for i in range(self.stream_chunk_num)] + handle_list = [] + for i in range(self.stream_chunk_num): + output_parallel_list[i] = F.linear(input_, self.weight_list[i]) + handle = torch.distributed.all_reduce(output_parallel_list[i], + group=self.process_group, + async_op=True) + handle_list.append(handle) + # output_parallel_list[i] = reduce_input(output_parallel_list[i], ParallelMode.PARALLEL_1D) + for handle in handle_list: + handle.wait() + output = torch.cat(output_parallel_list, dim=-1) + else: + output_parallel = F.linear(input_, self.weight) + # output_parallel = linear_with_async_comm(input_, self.weight, None, ParallelMode.PARALLEL_1D, False) + output = reduce_input(output_parallel, self.process_group) + + if not self.skip_bias_add: + if self.bias is not None: + output = output + self.bias + return output + else: + return output, self.bias -# @LAYERS.register_module class LayerNorm1D(ColossalaiModule): r""" Layer Normalization for colossalai @@ -152,7 +432,6 @@ def _save_to_state_dict(self, destination, prefix, keep_vars): super()._save_to_state_dict(destination, prefix, keep_vars) -# @LAYERS.register_module class Classifier1D(ParallelLayer): r"""RowLinear with given weight. Classifier of 1D parallelism. @@ -288,7 +567,6 @@ def forward(self, input_: Tensor) -> Tensor: return output -# @LAYERS.register_module class VocabParallelClassifier1D(ParallelLayer): r"""ColLinear with given weight. Classifier of 1D parallelism. @@ -424,317 +702,8 @@ def forward(self, input_: Tensor) -> Tensor: # @LAYERS.register_module -class Linear1D_Col(ParallelLayer): - r"""Linear layer with column parallelism. - - The linear layer is defined as :math:`Y = XA + b`. A is parallelized along - its second dimension as :math:`A = [A_1, ..., A_p]`. - - Args: - in_features (int): size of each input sample. - out_features (int): size of each output sample. - bias (bool, optional): If set to ``False``, the layer will not learn an additive bias, defaults to ``True``. - dtype (:class:`torch.dtype`, optional): The dtype of parameters, defaults to None. - gather_output (bool, optional): If true, call all-gather on output and make Y available - to all GPUs, otherwise, every GPU will have its output - which is :math:`Y_i = XA_i`, defaults to False - skip_bias_add (bool, optional): If set to ``True``, it will skip bias add for linear layer, - which is preserved for kernel fusion, defaults to False - weight_initializer (:class:`typing.Callable`, optional): - The initializer of weight, defaults to kaiming uniform initializer. - bias_initializer (:class:`typing.Callable`, optional): - The initializer of bias, defaults to xavier uniform initializer. - More details about ``initializer`` please refer to - `init `_. - """ - def __init__(self, - in_features: int, - out_features: int, - bias: bool = True, - dtype: torch.dtype = None, - gather_output: bool = False, - skip_bias_add: bool = False, - weight_initializer: Callable = init.kaiming_uniform_(a=math.sqrt(5)), - bias_initializer: Callable = init.xavier_uniform_(a=1, scale=1)): - super().__init__() - - # Keep input parameters - self.in_features = in_features - self.out_features = out_features - self.gather_output = gather_output - self.skip_bias_add = skip_bias_add - - if skip_bias_add and not bias: - raise ValueError('cannot skip bias addition if bias is None') - - # self.out_features_per_partition = divide(out_features*2, gpc.tensor_parallel_size) - self.out_features_per_partition = out_features - - # Parameters. - # Initialize weight. - factory_kwargs = {'device': get_current_device(), 'dtype': dtype} - self.weight = Parameter(torch.empty(self.out_features_per_partition, self.in_features, **factory_kwargs)) - - if bias: - self.bias = Parameter(torch.empty(self.out_features_per_partition, **factory_kwargs)) - else: - self.bias = None - with seed(ParallelMode.TENSOR): - self.reset_parameters(weight_initializer, bias_initializer) - self._set_tensor_parallel_attributes() - is_parallel_output = not self.gather_output - set_parallel_input(is_parallel_output) - - def reset_parameters(self, weight_initializer, bias_initializer) -> None: - fan_in, fan_out = self.in_features, self.out_features - weight_initializer(self.weight, fan_in=fan_in, fan_out=fan_out) - if self.bias is not None: - bias_initializer(self.bias, fan_in=fan_in) - - def _set_tensor_parallel_attributes(self): - num_partition = gpc.get_world_size(ParallelMode.TENSOR) - set_tensor_parallel_attribute_by_partition(self.weight, num_partition) - if self.bias is not None: - set_tensor_parallel_attribute_by_partition(self.bias, num_partition) - - def _load_from_global_state_dict(self, state_dict, prefix, *args): - local_state = OrderedDict() - weight_key = prefix + 'weight' - bias_key = prefix + 'bias' - if gpc.get_local_rank(ParallelMode.TENSOR) == 0: - # weight - weight = state_dict.pop(weight_key, None) - if weight is not None: - local_state[weight_key] = weight - # bias - if self.bias is not None: - bias = state_dict.pop(bias_key, None) - if bias is not None: - local_state[bias_key] = bias - - local_state = partition_tensor_parallel_state_dict(local_state, - ParallelMode.PARALLEL_1D, - dims={ - weight_key: 0, - bias_key: 0 - }, - partition_states={ - weight_key: True, - bias_key: True - }) - super()._load_from_global_state_dict(local_state, prefix, *args) - - def _save_to_global_state_dict(self, destination, prefix, keep_vars): - weight_key = prefix + 'weight' - bias_key = prefix + 'bias' - local_state = OrderedDict({weight_key: self.weight}) - if self.bias is not None: - local_state[bias_key] = self.bias - local_state = gather_tensor_parallel_state_dict(local_state, - ParallelMode.PARALLEL_1D, - dims={ - weight_key: 0, - bias_key: 0 - }, - partition_states={ - weight_key: True, - bias_key: True - }, - keep_vars=keep_vars) - destination.update(local_state) - - def forward(self, input_: Tensor) -> Tuple[Tensor, Tensor]: - assert input_.shape[-1] == self.weight.shape[-1], \ - 'Invalid shapes in Linear1D_Col forward: input={}, weight={}. Expected last dim of input {}.'.format( - input_.shape, self.weight.shape, self.weight.shape[-1]) - # Set up backprop all-reduce. - # input_parallel = reduce_grad(input_, ParallelMode.PARALLEL_1D) - input_parallel = input_ - # Matrix multiply. - bias = self.bias if not self.skip_bias_add else None - # output_parallel = F.linear(input_parallel, self.weight, bias) - output_parallel = linear_with_async_comm(input_parallel, self.weight, bias, ParallelMode.PARALLEL_1D, True) - if self.gather_output: - # All-gather across the partitions. - output = gather_forward_split_backward(output_parallel, ParallelMode.PARALLEL_1D, dim=-1) - else: - output = output_parallel - - if self.skip_bias_add: - return output, self.bias - else: - return output - - -# @LAYERS.register_module -class Linear1D_Row(ParallelLayer): - r""" Linear layer with row parallelism - - Args: - in_features (int): size of each input sample. - out_features (int): size of each output sample. - bias (bool, optional): If set to ``False``, the layer will not learn an additive bias, defaults to ``True``. - dtype (:class:`torch.dtype`, optional): The dtype of parameters, defaults to None. - parallel_input (bool, optional): If set to ``True``, it's assumed that the input is split, defaults to False. - skip_bias_add (bool, optional): If set to ``True``, it will skip bias add for linear layer, - which is preserved for kernel fusion, defaults to False - weight_initializer (:class:`typing.Callable`, optional): - The initializer of weight, defaults to kaiming uniform initializer. - bias_initializer (:class:`typing.Callable`, optional): - The initializer of bias, defaults to xavier uniform initializer. - - More details about ``initializer`` please refer to - `init `_. - """ - - def __init__(self, - in_features: int, - out_features: int, - bias: bool = True, - dtype: torch.dtype = None, - parallel_input: bool = True, - skip_bias_add: bool = False, - weight_initializer: Callable = init.kaiming_uniform_(a=math.sqrt(5)), - bias_initializer: Callable = init.xavier_uniform_(a=1, scale=1), - stream_chunk_num: int = 1): - super().__init__() - - self.stream_chunk_num = stream_chunk_num - - # Keep input parameters - self.in_features = in_features - self.out_features = out_features - self.parallel_input = parallel_input - self.skip_bias_add = skip_bias_add - - if skip_bias_add and not bias: - raise ValueError('cannot skip bias addition if bias is None') - - # Divide the weight matrix along the last dimension. - # self.input_size_per_partition = divide(in_features*2, gpc.tensor_parallel_size) - self.input_size_per_partition = in_features - - # Parameters. - # Initialize weight. - factory_kwargs = {'device': get_current_device(), 'dtype': dtype} - self.weight = Parameter(torch.empty(self.out_features, self.input_size_per_partition, **factory_kwargs)) - - if self.stream_chunk_num > 1: - # TODO() work for inference only - self.chunk_weight() - if bias: - self.bias = Parameter(torch.empty(self.out_features, **factory_kwargs)) - else: - self.bias = None - with seed(ParallelMode.TENSOR): - self.reset_parameters(weight_initializer, bias_initializer) - self._set_tensor_parallel_attributes() - set_parallel_input(False) - - def chunk_weight(self): - self.weight_list = torch.chunk(self.weight, self.stream_chunk_num, dim=0) - - def reset_parameters(self, weight_initializer, bias_initializer) -> None: - fan_in, fan_out = self.in_features, self.out_features - weight_initializer(self.weight, fan_in=fan_in, fan_out=fan_out) - if self.bias is not None: - bias_initializer(self.bias, fan_in=fan_in) - broadcast(self.bias, gpc.get_ranks_in_group(ParallelMode.PARALLEL_1D)[0], ParallelMode.PARALLEL_1D) - - def _set_tensor_parallel_attributes(self): - num_partition = gpc.get_world_size(ParallelMode.TENSOR) - set_tensor_parallel_attribute_by_partition(self.weight, num_partition) - - def _load_from_global_state_dict(self, state_dict, prefix, *args): - local_state = OrderedDict() - weight_key = prefix + 'weight' - bias_key = prefix + 'bias' - if gpc.get_local_rank(ParallelMode.TENSOR) == 0: - # weight - weight = state_dict.pop(weight_key, None) - if weight is not None: - local_state[weight_key] = weight - # bias - if self.bias is not None: - bias = state_dict.pop(bias_key, None) - if bias is not None: - local_state[bias_key] = bias - - local_state = partition_tensor_parallel_state_dict(local_state, - ParallelMode.PARALLEL_1D, - dims={ - weight_key: -1, - bias_key: 0 - }, - partition_states={ - weight_key: True, - bias_key: False - }) - super()._load_from_global_state_dict(local_state, prefix, *args) - - def _save_to_global_state_dict(self, destination, prefix, keep_vars): - weight_key = prefix + 'weight' - bias_key = prefix + 'bias' - local_state = OrderedDict({weight_key: self.weight}) - if self.bias is not None: - local_state[bias_key] = self.bias - local_state = gather_tensor_parallel_state_dict(local_state, - ParallelMode.PARALLEL_1D, - dims={ - weight_key: -1, - bias_key: 0 - }, - partition_states={ - weight_key: True, - bias_key: False - }, - keep_vars=keep_vars) - destination.update(local_state) - - def forward(self, input_: Tensor) -> Tensor: - # Set up backprop all-reduce. - if self.parallel_input: - assert input_.shape[-1] == self.weight.shape[-1], \ - 'Invalid shapes in Linear1D_Row forward: input={}, weight={}. Expected last dim of input {}.'.format( - input_.shape, self.weight.shape, self.weight.shape[-1]) - input_ = input_ - else: - assert divide(input_.shape[-1], gpc.tensor_parallel_size) == self.weight.shape[-1], \ - 'Invalid shapes in Linear1D_Row forward: input={}, weight={}. Expected last dim of input {}.'.format( - input_.shape, self.weight.shape, self.weight.shape[-1] * gpc.tensor_parallel_size) - input_ = split_forward_gather_backward(input_, ParallelMode.PARALLEL_1D, dim=-1) - - if self.stream_chunk_num > 1: - if self.training: - raise RuntimeError("use stream_chunk_num=1 in Linear1D_Row for training!") - with torch.no_grad(): - output_parallel_list = [None for i in range(self.stream_chunk_num)] - handle_list = [] - for i in range(self.stream_chunk_num): - output_parallel_list[i] = F.linear(input_, self.weight_list[i]) - handle = torch.distributed.all_reduce(output_parallel_list[i], - group=gpc.get_group(ParallelMode.PARALLEL_1D), - async_op=True) - handle_list.append(handle) - # output_parallel_list[i] = reduce_input(output_parallel_list[i], ParallelMode.PARALLEL_1D) - for handle in handle_list: - handle.wait() - output = torch.cat(output_parallel_list, dim=-1) - else: - output_parallel = F.linear(input_, self.weight) - # output_parallel = linear_with_async_comm(input_, self.weight, None, ParallelMode.PARALLEL_1D, False) - output = reduce_input(output_parallel, ParallelMode.PARALLEL_1D) - if not self.skip_bias_add: - if self.bias is not None: - output = output + self.bias - return output - else: - return output, self.bias - - -# @LAYERS.register_module class Embedding1D(ParallelLayer): r"""Embedding for 1D parallelism. @@ -842,7 +811,6 @@ def forward(self, input_: Tensor) -> Tensor: return output -# @LAYERS.register_module class VocabParallelEmbedding1D(ParallelLayer): r"""Embedding parallelized in the vocabulary dimension. @@ -960,7 +928,6 @@ def forward(self, input_: Tensor) -> Tensor: return output -# @LAYERS.register_module class Dropout1D(ParallelLayer): """Dropout layer of 1D parallelism. diff --git a/colossalai/shardformer/layer/utils.py b/colossalai/shardformer/layer/utils.py new file mode 100644 index 000000000000..c3d6ab57e3e9 --- /dev/null +++ b/colossalai/shardformer/layer/utils.py @@ -0,0 +1,138 @@ +from contextlib import contextmanager + +import torch +import torch.distributed as dist +from torch.distributed import ProcessGroup + + +class Randomizer: + """ + Randomizer enables the program to be executed under a different seed within the context. + + Example: + + ```python + randomizer = Randomizer(seed=1024) + + with randomizer.fork(): + # do something here with seed 1024 + do_something() + ``` + + Args: + seed (int): The random seed to set. + enable_cpu (bool): fork the CPU RNG state as well. + with_index (bool): whether to use the index of the randomizer. + """ + + _INDEX = 0 + + def __init__(self, seed: int): + # TODO: remove colossalai.context.random + + self.seed = seed + + # Handle CUDA rng state + # 1. get the current rng state + # 2. set the seed and store the rng state + # 3. recover the original rng state + cuda_original_rng_state = torch.cuda.get_rng_state() + torch.cuda.manual_seed(seed) + self.cuda_rng_state = torch.cuda.get_rng_state() + torch.cuda.set_rng_state(cuda_original_rng_state) + + # to the same for cpu rng state + cpu_original_rng_state = torch.get_rng_state() + torch.manual_seed(seed) + self.cpu_rng_state = torch.get_rng_state() + torch.set_rng_state(cpu_original_rng_state) + + def _set_cuda_rng_state(self, rng_state): + torch.cuda.set_rng_state(rng_state) + + def _get_cuda_rng_state(self): + current_state = torch.cuda.get_rng_state() + return current_state + + def _set_cpu_rng_state(self, rng_state): + torch.set_rng_state(rng_state) + + def _get_cpu_rng_state(self): + current_state = torch.get_rng_state() + return current_state + + @contextmanager + def fork_rng(self, enable_cpu: bool = False): + """ + This is a context manager to change the dropout state and recover the original state. + + Usage: + :: + >>> with _seed_manager.dropout_mode(): + >>> input = super().forward(input) + """ + try: + current_cuda_rng_state = self._get_cuda_rng_state() + self._set_cuda_rng_state(self.cuda_rng_state) + + if enable_cpu: + current_cpu_rng_state = self._get_cpu_rng_state() + self._set_cpu_rng_state(self.cpu_rng_state) + yield + finally: + self.cuda_rng_state = self._get_cuda_rng_state() + self._set_cuda_rng_state(current_cuda_rng_state) + + if enable_cpu: + self.cpu_rng_state = self._get_cpu_rng_state() + self._set_cpu_rng_state(current_cpu_rng_state) + + @staticmethod + def index(): + """ + Return the index of the randomizer. The index is useful when the user wants + to introduce some randomness in the program. + + Note: + The index will increment by one each time this method is called. + + Example: + + ```python + # assume we need a randomizer to init the weight of different layers + # we can use the index of the randomizer to do so that + # each layer has its own randomizer with a different seed + base_seed = torch.random.initial_seed() + seed = base_seed + Randomizer.index() + randomizer = Randomizer(seed) + + with randomizer.fork(): + init_weights() + ``` + + """ + idx = Randomizer._INDEX + Randomizer._INDEX += 1 + return idx + + +def create_randomizer_with_offset(seed: int, process_group: ProcessGroup = None): + """ + Create a randomizer with an offset. The offset is equal to the rank of the process and the index of the randomizer. + + Args: + seed (int): The base random seed to set. + enable_cpu (bool): fork the CPU RNG state as well. + process_group (ProcessGroup): the process group to get the rank from. + + Returns: + Randomizer: the randomizer with offset. + """ + offset = Randomizer.index() + + if dist.is_initialized(): + rank = dist.get_rank(process_group) + offset += rank + + seed += offset + return Randomizer(seed=seed) diff --git a/colossalai/tensor/d_tensor/api.py b/colossalai/tensor/d_tensor/api.py new file mode 100644 index 000000000000..afb1fc003e02 --- /dev/null +++ b/colossalai/tensor/d_tensor/api.py @@ -0,0 +1,44 @@ +from typing import Union + +import torch +import torch.distributed as dist +from torch.distributed import ProcessGroup + +from colossalai.device.device_mesh import DeviceMesh + +from .d_tensor import DTensor +from .sharding_spec import ShardingSpec + + +def shard_rowwise(tensor: torch.Tensor, group_or_device_mesh: Union[ProcessGroup, DeviceMesh] = None) -> DTensor: + """ + Shard the first dim of the given tensor + """ + # if the group_or_device_mesh is None, we shard the tensor with respect to the global process group + if group_or_device_mesh is None: + group_or_device_mesh = dist.GroupMember.WORLD + + if isinstance(group_or_device_mesh, ProcessGroup): + device_mesh = DeviceMesh.from_process_group(group_or_device_mesh) + else: + assert len(group_or_device_mesh.shape) == 1, 'Only 1D DeviceMesh is accepted for row-wise sharding.' + device_mesh = group_or_device_mesh + sharding_spec = ShardingSpec(dim_size=tensor.dim(), dim_partition_dict={0: [0]}) + return DTensor(tensor, device_mesh, sharding_spec) + + +def shard_colwise(tensor: torch.Tensor, group_or_device_mesh: Union[ProcessGroup, DeviceMesh] = None) -> DTensor: + """ + Shard the first dim of the given tensor + """ + # if the group_or_device_mesh is None, we shard the tensor with respect to the global process group + if group_or_device_mesh is None: + group_or_device_mesh = dist.GroupMember.WORLD + + if isinstance(group_or_device_mesh, ProcessGroup): + device_mesh = DeviceMesh.from_process_group(group_or_device_mesh) + else: + assert len(group_or_device_mesh.shape) == 1, 'Only 1D DeviceMesh is accepted for row-wise sharding.' + device_mesh = group_or_device_mesh + sharding_spec = ShardingSpec(dim_size=tensor.dim(), dim_partition_dict={-1: [0]}) + return DTensor(tensor, device_mesh, sharding_spec) diff --git a/colossalai/tensor/d_tensor/layout.py b/colossalai/tensor/d_tensor/layout.py index ee7ef74a99ae..f15956ea3d52 100644 --- a/colossalai/tensor/d_tensor/layout.py +++ b/colossalai/tensor/d_tensor/layout.py @@ -34,7 +34,7 @@ def __hash__(self) -> int: def get_sharded_shape_per_device(self): sharded_shape = list(self.entire_shape) for dim, shard_list in self.sharding_spec.dim_partition_dict.items(): - mesh_list = [self.device_mesh.mesh_shape[mesh_dim] for mesh_dim in shard_list] + mesh_list = [self.device_mesh.shape[mesh_dim] for mesh_dim in shard_list] shard_partitions = reduce(operator.mul, mesh_list, 1) assert sharded_shape[ dim] % shard_partitions == 0, f'Cannot shard dimension {dim} into {shard_partitions} partitions.' @@ -45,14 +45,15 @@ def _sanity_check(self): sharding_spec = self.sharding_spec # make sure all axes in logical device mesh only be used once - dim_check_list = list(range(self.device_mesh.logical_mesh_id.dim())) - for dim, shard_list in sharding_spec.dim_partition_dict.items(): - for element in shard_list: - if element in dim_check_list: - dim_check_list.remove(element) - else: - raise DuplicatedShardingDimensionError( - f"find an invalid sharding axis {element} in dim_partition_dict in tensor dimension {dim}.") + if self.device_mesh.logical_mesh_id is not None: + dim_check_list = list(range(self.device_mesh.logical_mesh_id.dim())) + for dim, shard_list in sharding_spec.dim_partition_dict.items(): + for element in shard_list: + if element in dim_check_list: + dim_check_list.remove(element) + else: + raise DuplicatedShardingDimensionError( + f"find an invalid sharding axis {element} in dim_partition_dict in tensor dimension {dim}.") # make sure that the sharding for a dimension is divisible by the number of devices for dim, shard_list in sharding_spec.dim_partition_dict.items(): @@ -60,7 +61,7 @@ def _sanity_check(self): num_devices = 1 for element in shard_list: - num_devices *= self.device_mesh.mesh_shape[element] + num_devices *= self.device_mesh.shape[element] if tensor_dim_size % num_devices != 0: raise ShardingNotDivisibleError( diff --git a/colossalai/tensor/d_tensor/layout_converter.py b/colossalai/tensor/d_tensor/layout_converter.py index cf02aac309f4..abc70e19a126 100644 --- a/colossalai/tensor/d_tensor/layout_converter.py +++ b/colossalai/tensor/d_tensor/layout_converter.py @@ -304,7 +304,7 @@ def shard_transform_layout(self, source_layout: Layout) -> Dict[Layout, CommSpec process_groups_dict = source_layout.device_mesh.process_groups_dict # legal sharding dims means the mesh_id is still available to use. - legal_sharding_dims = [i for i in range(len(source_layout.device_mesh.mesh_shape))] + legal_sharding_dims = [i for i in range(len(source_layout.device_mesh.shape))] for dim, shard_list in source_spec.dim_partition_dict.items(): for element in shard_list: legal_sharding_dims.remove(element) diff --git a/tests/test_shardformer/test_layer/test_linear_1d.py b/tests/test_shardformer/test_layer/test_linear_1d.py new file mode 100644 index 000000000000..449522c64129 --- /dev/null +++ b/tests/test_shardformer/test_layer/test_linear_1d.py @@ -0,0 +1,67 @@ +import torch +import torch.distributed as dist +import torch.nn as nn +from torch.testing import assert_close + +import colossalai +from colossalai.shardformer.layer.layers import Linear1D_Col, Linear1D_Row +from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn + + +def check_linear_1d_col(): + linear = nn.Linear(32, 128).cuda() + linear_col = Linear1D_Col.from_native_module(linear, process_group=None, gather_output=True) + + assert linear_col.weight.shape == torch.Size([64, 32]) + assert linear_col.bias.shape == torch.Size([64]) + + # check computation correctness + x = torch.rand(4, 32).cuda() + out = linear(x) + gather_out = linear_col(x) + assert_close(out, gather_out) + + # check backward correctness + out.sum().backward() + gather_out.sum().backward() + + rank = dist.get_rank() + target_grad = torch.chunk(linear.weight.grad, 2, dim=0)[rank] + assert_close(target_grad, linear_col.weight.grad) + + +def check_linear_1d_row(): + linear = nn.Linear(32, 128).cuda() + linear_row = Linear1D_Row.from_native_module(linear, process_group=None, parallel_input=False) + + assert linear_row.weight.shape == torch.Size([128, 16]) + assert linear_row.bias.shape == torch.Size([128]) + + # check computation correctness + x = torch.rand(4, 32).cuda() + out = linear(x) + gather_out = linear_row(x) + assert_close(out, gather_out) + + # check backward correctness + out.sum().backward() + gather_out.sum().backward() + + rank = dist.get_rank() + target_grad = torch.chunk(linear.weight.grad, 2, dim=1)[rank] + assert_close(target_grad, linear_row.weight.grad) + + +def run_dist(rank, world_size, port): + colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') + check_linear_1d_col() + check_linear_1d_row() + + +@rerun_if_address_is_in_use() +def test_linear(): + spawn(run_dist, nprocs=2) + + +if __name__ == '__main__': + test_linear() From bdc405e69eb45b73dadfb4b997aff76881dee6d6 Mon Sep 17 00:00:00 2001 From: FoolPlayer <45593998+FoolPlayer@users.noreply.github.com> Date: Fri, 16 Jun 2023 11:23:30 +0800 Subject: [PATCH 19/49] integrate with dist layer (#4011) --- colossalai/shardformer/policies/bert.py | 28 ++++++++++++++----- colossalai/shardformer/shard/sharder.py | 15 +++++----- .../test_model/test_shard_bert.py | 23 +++++++++------ 3 files changed, 42 insertions(+), 24 deletions(-) diff --git a/colossalai/shardformer/policies/bert.py b/colossalai/shardformer/policies/bert.py index f3431c386fe4..fc3e8447337d 100644 --- a/colossalai/shardformer/policies/bert.py +++ b/colossalai/shardformer/policies/bert.py @@ -8,12 +8,6 @@ from .basepolicy import ModulePolicyDescription, Policy, SubModuleReplacementDescription -class ParallelModule(): - - def __init__(self): - pass - - class BertPolicy(Policy): def preprocess(self, shard_config: ShardConfig = None): @@ -49,7 +43,27 @@ def module_policy(self, shard_config: ShardConfig = None): sub_module_replacement=[ SubModuleReplacementDescription( suffix="attention.self.query", - target_module=ParallelModule, + target_module=col_nn.Linear1D_Col, + ), + SubModuleReplacementDescription( + suffix="attention.self.key", + target_module=col_nn.Linear1D_Col, + ), + SubModuleReplacementDescription( + suffix="attention.self.value", + target_module=col_nn.Linear1D_Col, + ), + SubModuleReplacementDescription( + suffix="attention.output.dense", + target_module=col_nn.Linear1D_Row, + ), + SubModuleReplacementDescription( + suffix="intermediate.dense", + target_module=col_nn.Linear1D_Col, + ), + SubModuleReplacementDescription( + suffix="output.dense", + target_module=col_nn.Linear1D_Row, ), ]) } diff --git a/colossalai/shardformer/shard/sharder.py b/colossalai/shardformer/shard/sharder.py index 8eee3c6a3b7e..eb8300d5998e 100644 --- a/colossalai/shardformer/shard/sharder.py +++ b/colossalai/shardformer/shard/sharder.py @@ -7,8 +7,8 @@ from colossalai.cluster.process_group_manager import ProcessGroupManager from ..policies.autopolicy import get_autopolicy -from ..policies.basepolicy import Policy -from ..utils.utils import setattr_ +from ..policies.basepolicy import Policy, SubModuleReplacementDescription +from ..utils.utils import getattr_, setattr_ from .shard_config import ShardConfig __all__ = ['ModelSharder', 'shard_model'] @@ -90,9 +90,7 @@ def replace_module(self,) -> None: Args: model (:class:`torch.nn.Module`): The model to shard """ - print(self.policy) module_descriptions = self.policy.module_policy(self.shard_config) - print(f"*******{module_descriptions}") for module_description in module_descriptions.items(): origin_layer_cls = module_description[0] attr_replacement = module_description[1].attribute_replacement @@ -160,7 +158,7 @@ def _replace_param( def _replace_sub_module( self, org_layer: nn.Module, - sub_module_replacement: List[Callable], + sub_module_replacement: List[SubModuleReplacementDescription], ) -> None: r""" Shard one layer according to the policy, the layer should be the same class as the key in policy's argument_policy return dict @@ -177,7 +175,8 @@ def _replace_sub_module( assert target_module is not None, 'target_module should not be None' - # TODO: integrate with new layer - # replace_layer = target_module.from_native_layer(org_layer, self.pg_manager) - replace_layer = None + # TODO: support different parallel mode + native_sub_module = getattr_(org_layer, suffix) + replace_layer = target_module.from_native_module(native_sub_module, self.pg_manager.pg_store['tp1d']) + setattr_(org_layer, suffix, replace_layer) diff --git a/tests/test_shardformer/test_model/test_shard_bert.py b/tests/test_shardformer/test_model/test_shard_bert.py index 9b29111eadb2..05d03343632f 100644 --- a/tests/test_shardformer/test_model/test_shard_bert.py +++ b/tests/test_shardformer/test_model/test_shard_bert.py @@ -17,7 +17,7 @@ import colossalai from colossalai.logging import disable_existing_loggers -from colossalai.shardformer.shard import ShardConfig, shard_model +from colossalai.shardformer import ShardConfig, ShardFormer from colossalai.testing import rerun_if_address_is_in_use, spawn os.environ['TRANSFORMERS_NO_ADVISORY_WARNINGS'] = 'true' @@ -30,16 +30,21 @@ def build_model(rank, world_size, model): config.hidden_dropout_prob = 0 config.attention_probs_dropout_prob = 0 - org_model = model(config=config) + org_model = BertForMaskedLM.from_pretrained('bert-base-uncased', config=config) org_model_forshard = copy.deepcopy(org_model) - org_model = org_model.to('cuda') - shardconfig = ShardConfig( - rank=rank, - world_size=world_size, - gather_output=True, - ) - sharded_model = shard_model(org_model_forshard, shardconfig).to('cuda') + org_model.to('cuda') + # TODO: no need to transfer to cuda + org_model_forshard.to('cuda') + shard_config = ShardConfig(tensor_parallel_size=2, + data_parallel_size=1, + pipeline_parallel_size=1, + tensor_parallel_mode='1d', + inference_only=True, + gather_output=True) + shard_former = ShardFormer(shard_config=shard_config) + shard_former.init_distributed() + sharded_model = shard_former.shard_model(org_model_forshard).to('cuda') return org_model, sharded_model From 2c366e3f7206554fdf3b1424d7e5d1172ce1b860 Mon Sep 17 00:00:00 2001 From: Frank Lee Date: Fri, 16 Jun 2023 15:00:26 +0800 Subject: [PATCH 20/49] [shardformer] refactored embedding and dropout to parallel module (#4013) * [shardformer] refactored embedding and dropout to parallel module * polish code --- .../shardformer/layer/dist_crossentropy.py | 20 +- colossalai/shardformer/layer/dropout.py | 32 +- colossalai/shardformer/layer/layers.py | 467 +++--------------- .../test_layer/test_dropout.py | 53 ++ .../test_layer/test_embedding.py | 43 ++ .../test_layer/test_linear_1d.py | 2 +- 6 files changed, 196 insertions(+), 421 deletions(-) create mode 100644 tests/test_shardformer/test_layer/test_dropout.py create mode 100644 tests/test_shardformer/test_layer/test_embedding.py diff --git a/colossalai/shardformer/layer/dist_crossentropy.py b/colossalai/shardformer/layer/dist_crossentropy.py index ff05209fefe8..7840c2f2e5da 100644 --- a/colossalai/shardformer/layer/dist_crossentropy.py +++ b/colossalai/shardformer/layer/dist_crossentropy.py @@ -3,6 +3,7 @@ import torch.nn as nn import torch.nn.functional as F from torch.autograd import Function +from torch.distributed import ProcessGroup class DistCrossEntropy(Function): @@ -14,7 +15,7 @@ class DistCrossEntropy(Function): """ @staticmethod - def forward(ctx, vocab_logits: torch.Tensor, target: torch.Tensor, ignore_index: int): + def forward(ctx, vocab_logits: torch.Tensor, target: torch.Tensor, ignore_index: int, process_group: ProcessGroup): r""" Calculate the cross entropy loss before gather, the origin loss function is as follows: loss = -log(exp(x[class])/sum(exp(x[i])) @@ -34,15 +35,15 @@ def forward(ctx, vocab_logits: torch.Tensor, target: torch.Tensor, ignore_index: """ # get the max logits_max = torch.max(vocab_logits, dim=-1)[0] - dist.all_reduce(logits_max, op=dist.ReduceOp.MAX) + dist.all_reduce(logits_max, op=dist.ReduceOp.MAX, group=process_group) # minus the max to avoid the result of sum of exp is too large and the log is nan vocab_logits = vocab_logits - logits_max.unsqueeze(dim=-1) # mask the target in the local device partition_vocab_size = vocab_logits.size()[-1] - rank = dist.get_rank() - world_size = dist.get_world_size() + rank = dist.get_rank(group=process_group) + world_size = dist.get_world_size(group=process_group) global_vocab_size = partition_vocab_size * world_size # [down, up) => false, other device and -100 => true @@ -67,11 +68,11 @@ def forward(ctx, vocab_logits: torch.Tensor, target: torch.Tensor, ignore_index: pred_logits[mask] = 0.0 # allreduce the get all x(i,y) - dist.all_reduce(pred_logits, op=dist.ReduceOp.SUM) + dist.all_reduce(pred_logits, op=dist.ReduceOp.SUM, group=process_group) exp_logits = vocab_logits torch.exp(vocab_logits, out=exp_logits) sum_exp_logits = torch.sum(exp_logits, dim=-1) - dist.all_reduce(sum_exp_logits, op=dist.ReduceOp.SUM) + dist.all_reduce(sum_exp_logits, op=dist.ReduceOp.SUM, group=process_group) # calculate the loss # loss = log(sum(exp(x[i]))) - x[class] @@ -101,5 +102,8 @@ def backward(ctx, grad_output): return grad_logits, None, None -def applyDistCrossEntropy(vocab_logits: torch.Tensor, labels: torch.Tensor, ignore_index: int = -100) -> torch.Tensor: - return DistCrossEntropy.apply(vocab_logits, labels, ignore_index) +def cross_entropy_1d(vocab_logits: torch.Tensor, + labels: torch.Tensor, + ignore_index: int = -100, + process_group: ProcessGroup = None) -> torch.Tensor: + return DistCrossEntropy.apply(vocab_logits, labels, ignore_index, process_group) diff --git a/colossalai/shardformer/layer/dropout.py b/colossalai/shardformer/layer/dropout.py index 5d295be6bd83..ec08d072f338 100644 --- a/colossalai/shardformer/layer/dropout.py +++ b/colossalai/shardformer/layer/dropout.py @@ -1,19 +1,43 @@ +from typing import List, Union + import torch -import torch.distributed as dist import torch.nn as nn +from torch.distributed import ProcessGroup +from .layers import ParallelModule from .utils import create_randomizer_with_offset -class Dropout1D(nn.Dropout): +class Dropout1D(ParallelModule, nn.Dropout): + """ + The Dropout Layer will apply dropout mask to the input tensor. The dropout mask is generated with + randomness on different ranks of the given process group. This can avoid the same dropout mask is generated + and applied on the same position of different ranks, leading to poor convergence performance. + + Args: + p (float): probability of an element to be zeroed. Defaults to 0.5. + inplace (bool): If set to True, will do this operation in-place. Defaults to False. + process_group (ProcessGroup): the process group to be used for generating randomness. Defaults to None. + """ - def __init__(self, p=0.5, inplace=False, process_group=None): - super().__init__(p, inplace) + def __init__(self, p: float = 0.5, inplace: bool = False, process_group: ProcessGroup = None): + # init with nn.Dropout + super(nn.Dropout, self).__init__(p=p, inplace=inplace) # offset the seed with randomizer index and rank seed = torch.random.initial_seed() self.randomizer = create_randomizer_with_offset(seed, process_group=process_group) + @staticmethod + def from_native_module(module: nn.Dropout, + process_group: Union[ProcessGroup, List[ProcessGroup]] = None) -> "Dropout1D": + """ + Create a Dropout1D layer from a native dropout layer. + """ + p = module.p + inplace = module.inplace + return Dropout1D(p=p, inplace=inplace, process_group=process_group) + def forward(self, input): with self.randomizer.fork_rng(): input = super().forward(input) diff --git a/colossalai/shardformer/layer/layers.py b/colossalai/shardformer/layer/layers.py index 2ad6523c9a86..87d24f18e178 100644 --- a/colossalai/shardformer/layer/layers.py +++ b/colossalai/shardformer/layer/layers.py @@ -22,12 +22,7 @@ from colossalai.nn import init as init from colossalai.nn.layer.base_layer import ParallelLayer from colossalai.nn.layer.colossalai_layer._utils import ColossalaiModule -from colossalai.nn.layer.parallel_1d._utils import ( - gather_forward_split_backward, - get_parallel_input, - reduce_grad, - set_parallel_input, -) +from colossalai.nn.layer.parallel_1d._utils import get_parallel_input, reduce_grad, set_parallel_input from colossalai.nn.layer.utils import divide, set_tensor_parallel_attribute_by_partition from colossalai.nn.layer.vanilla import VanillaLayerNorm, VanillaPatchEmbedding from colossalai.tensor.d_tensor.api import shard_colwise, shard_rowwise @@ -432,279 +427,7 @@ def _save_to_state_dict(self, destination, prefix, keep_vars): super()._save_to_state_dict(destination, prefix, keep_vars) -class Classifier1D(ParallelLayer): - r"""RowLinear with given weight. Classifier of 1D parallelism. - - Args: - in_features (int): size of each input sample. - num_classes (int): number of classes. - weight (:class:`torch.nn.Parameter`, optional): weight of the classifier, defaults to None. - bias (bool, optional): If set to ``False``, the layer will not learn an additive bias, defaults to ``True``. - dtype (:class:`torch.dtype`, optional): The dtype of parameters, defaults to None. - weight_initializer (:class:`typing.Callable`, optional): - The initializer of weight, defaults to kaiming uniform initializer. - bias_initializer (:class:`typing.Callable`, optional): - The initializer of bias, defaults to xavier uniform initializer. - - More details about ``initializer`` please refer to - `init `_. - """ - - def __init__(self, - in_features: int, - num_classes: int, - weight: Parameter = None, - bias: bool = True, - dtype: torch.dtype = None, - weight_initializer: Callable = init.kaiming_uniform_(a=math.sqrt(5)), - bias_initializer: Callable = init.xavier_uniform_(a=1, scale=1)): - super().__init__() - self.in_features = in_features - self.num_classes = num_classes - self.parallel_input = get_parallel_input() - - # Divide the weight matrix along the last dimension. - self.input_size_per_partition = divide(in_features, gpc.tensor_parallel_size) - - # Parameters. - # Initialize weight. - factory_kwargs = {'device': get_current_device(), 'dtype': dtype} - if weight is not None: - self.weight = weight - self.has_weight = False - else: - self.weight = Parameter(torch.empty(self.num_classes, self.input_size_per_partition, **factory_kwargs)) - self.has_weight = True - if bias: - self.bias = Parameter(torch.empty(self.num_classes, **factory_kwargs)) - else: - self.bias = None - with seed(ParallelMode.TENSOR): - self.reset_parameters(weight_initializer, bias_initializer) - self._set_tensor_parallel_attributes() - set_parallel_input(False) - env.vocab_parallel = False - - def reset_parameters(self, weight_initializer, bias_initializer) -> None: - fan_in, fan_out = self.in_features, self.num_classes - if self.has_weight: - weight_initializer(self.weight, fan_in=fan_in, fan_out=fan_out) - if self.bias is not None: - bias_initializer(self.bias, fan_in=fan_in) - broadcast(self.bias, gpc.get_ranks_in_group(ParallelMode.PARALLEL_1D)[0], ParallelMode.PARALLEL_1D) - - def _set_tensor_parallel_attributes(self): - if self.has_weight: - num_partition = gpc.get_world_size(ParallelMode.TENSOR) - set_tensor_parallel_attribute_by_partition(self.weight, num_partition) - - def _load_from_global_state_dict(self, state_dict, prefix, *args): - local_state = OrderedDict() - weight_key = prefix + 'weight' - bias_key = prefix + 'bias' - if gpc.get_local_rank(ParallelMode.TENSOR) == 0: - # weight - if self.has_weight: - weight = state_dict.pop(weight_key, None) - if weight is not None: - local_state[weight_key] = weight - # bias - if self.bias is not None: - bias = state_dict.pop(bias_key, None) - if bias is not None: - local_state[bias_key] = bias - - local_state = partition_tensor_parallel_state_dict(local_state, - ParallelMode.PARALLEL_1D, - dims={ - weight_key: -1, - bias_key: 0 - }, - partition_states={ - weight_key: True, - bias_key: False - }) - super()._load_from_global_state_dict(local_state, prefix, *args) - - def _save_to_global_state_dict(self, destination, prefix, keep_vars): - weight_key = prefix + 'weight' - bias_key = prefix + 'bias' - local_state = OrderedDict() - if self.has_weight: - local_state[weight_key] = self.weight - if self.bias is not None: - local_state[bias_key] = self.bias - local_state = gather_tensor_parallel_state_dict(local_state, - ParallelMode.PARALLEL_1D, - dims={ - weight_key: -1, - bias_key: 0 - }, - partition_states={ - weight_key: True, - bias_key: False - }, - keep_vars=keep_vars) - destination.update(local_state) - - def forward(self, input_: Tensor) -> Tensor: - # Set up backprop all-reduce. - if self.parallel_input: - assert input_.shape[-1] == self.weight.shape[-1], \ - 'Invalid shapes in Classifier1D forward: input={}, weight={}. Expected last dim of input {}.'.format( - input_.shape, self.weight.shape, self.weight.shape[-1]) - input_ = input_ - else: - assert divide(input_.shape[-1], gpc.tensor_parallel_size) == self.weight.shape[-1], \ - 'Invalid shapes in Classifier1D forward: input={}, weight={}. Expected last dim of input {}.'.format( - input_.shape, self.weight.shape, self.weight.shape[-1] * gpc.tensor_parallel_size) - input_ = split_forward_gather_backward(input_, ParallelMode.PARALLEL_1D, dim=-1) - - output_parallel = F.linear(input_, self.weight) - output = reduce_input(output_parallel, ParallelMode.PARALLEL_1D) - if self.bias is not None: - output = output + self.bias - return output - - -class VocabParallelClassifier1D(ParallelLayer): - r"""ColLinear with given weight. Classifier of 1D parallelism. - - Args: - in_features (int): size of each input sample. - num_classes (int): number of classes. - weight (:class:`torch.nn.Parameter`, optional): weight of the classifier, defaults to None. - bias (bool, optional): If set to ``False``, the layer will not learn an additive bias, defaults to ``True``. - dtype (:class:`torch.dtype`, optional): The dtype of parameters, defaults to None. - weight_initializer (:class:`typing.Callable`, optional): - The initializer of weight, defaults to kaiming uniform initializer. - bias_initializer (:class:`typing.Callable`, optional): - The initializer of bias, defaults to xavier uniform initializer. - - More details about ``initializer`` please refer to - `init `_. - """ - - def __init__(self, - in_features: int, - num_classes: int, - weight: Parameter = None, - bias: bool = True, - dtype: torch.dtype = None, - gather_output: bool = False, - weight_initializer: Callable = init.kaiming_uniform_(a=math.sqrt(5)), - bias_initializer: Callable = init.xavier_uniform_(a=1, scale=1)): - super().__init__() - self.in_features = in_features - self.num_classes = num_classes - self.gather_output = gather_output - self.parallel_input = get_parallel_input() - - # Divide the weight matrix along the last dimension. - self.num_classes_per_partition = divide(num_classes, gpc.tensor_parallel_size) - - # Parameters. - # Initialize weight. - factory_kwargs = {'device': get_current_device(), 'dtype': dtype} - if weight is not None: - self.weight = weight - self.has_weight = False - else: - self.weight = Parameter(torch.empty(self.num_classes_per_partition, self.in_features, **factory_kwargs)) - self.has_weight = True - if bias: - self.bias = Parameter(torch.empty(self.num_classes_per_partition, **factory_kwargs)) - else: - self.bias = None - with seed(ParallelMode.TENSOR): - self.reset_parameters(weight_initializer, bias_initializer) - self._set_tensor_parallel_attributes() - set_parallel_input(False) - env.vocab_parallel = True - - def reset_parameters(self, weight_initializer, bias_initializer) -> None: - fan_in, fan_out = self.in_features, self.num_classes - if self.has_weight: - weight_initializer(self.weight, fan_in=fan_in, fan_out=fan_out) - if self.bias is not None: - bias_initializer(self.bias, fan_in=fan_in) - - def _set_tensor_parallel_attributes(self): - num_partition = gpc.get_world_size(ParallelMode.TENSOR) - if self.has_weight: - set_tensor_parallel_attribute_by_partition(self.weight, num_partition) - if self.bias is not None: - set_tensor_parallel_attribute_by_partition(self.bias, num_partition) - - def _load_from_global_state_dict(self, state_dict, prefix, *args): - local_state = OrderedDict() - weight_key = prefix + 'weight' - bias_key = prefix + 'bias' - if gpc.get_local_rank(ParallelMode.TENSOR) == 0: - # weight - if self.has_weight: - weight = state_dict.pop(weight_key, None) - if weight is not None: - local_state[weight_key] = weight - # bias - if self.bias is not None: - bias = state_dict.pop(bias_key, None) - if bias is not None: - local_state[bias_key] = bias - - local_state = partition_tensor_parallel_state_dict(local_state, - ParallelMode.PARALLEL_1D, - dims={ - weight_key: 0, - bias_key: 0 - }, - partition_states={ - weight_key: True, - bias_key: True - }) - super()._load_from_global_state_dict(local_state, prefix, *args) - - def _save_to_global_state_dict(self, destination, prefix, keep_vars): - weight_key = prefix + 'weight' - bias_key = prefix + 'bias' - local_state = OrderedDict() - if self.has_weight: - local_state[weight_key] = self.weight - if self.bias is not None: - local_state[bias_key] = self.bias - local_state = gather_tensor_parallel_state_dict(local_state, - ParallelMode.PARALLEL_1D, - dims={ - weight_key: 0, - bias_key: 0 - }, - partition_states={ - weight_key: True, - bias_key: True - }, - keep_vars=keep_vars) - destination.update(local_state) - - def forward(self, input_: Tensor) -> Tensor: - assert input_.shape[-1] == self.weight.shape[-1], \ - 'Invalid shapes in VocabParallelClassifier1D forward: input={}, weight={}. Expected last dim of input {}.'.format( - input_.shape, self.weight.shape, self.weight.shape[-1]) - # Set up backprop all-reduce. - input_parallel = reduce_grad(input_, ParallelMode.PARALLEL_1D) - # Matrix multiply. - output_parallel = F.linear(input_parallel, self.weight, self.bias) - if self.gather_output: - # All-gather across the partitions. - output = gather_forward_split_backward(output_parallel, ParallelMode.PARALLEL_1D, dim=-1) - else: - output = output_parallel - return output - - -# @LAYERS.register_module - - -class Embedding1D(ParallelLayer): +class Embedding1D(ParallelModule): r"""Embedding for 1D parallelism. Args: @@ -739,7 +462,8 @@ def __init__(self, embedding_dim: int, padding_idx: int = None, dtype: torch.dtype = None, - gather_output: bool = True, + device: torch.device = None, + process_group: ProcessGroup = None, weight_initializer: Callable = init.normal_(), *args, **kwargs): @@ -747,66 +471,79 @@ def __init__(self, self.num_embeddings = num_embeddings self.embed_dim = embedding_dim - embed_dim_per_partition = divide(embedding_dim, gpc.tensor_parallel_size) + self.process_group = process_group + self.num_partitions = dist.get_world_size(process_group) + self.embed_dim_per_partition = divide(embedding_dim, self.num_partitions) self.padding_idx = padding_idx self.embed_args = args self.embed_kwargs = kwargs self.gather_output = gather_output - self.weight = Parameter( - torch.empty((num_embeddings, embed_dim_per_partition), device=get_current_device(), dtype=dtype)) + if device is None: + device = get_current_device() - self.reset_parameters(weight_initializer) - self._set_tensor_parallel_attributes() - set_parallel_input(False) + self.weight = Parameter(torch.empty((num_embeddings, self.embed_dim_per_partition), device=device, dtype=dtype)) - def _set_tensor_parallel_attributes(self): - set_tensor_parallel_attribute_by_partition(self.weight, gpc.tensor_parallel_size) + # offset the seed with randomizer index and rank + seed = torch.random.initial_seed() + self.randomizer = create_randomizer_with_offset(seed, process_group=self.process_group) + + with self.randomizer.fork_rng(enable_cpu=True): + self.reset_parameters(weight_initializer) + + @staticmethod + def from_native_module(module: nn.Embedding, + process_group: Union[ProcessGroup, List[ProcessGroup]] = None) -> "Embedding1D": + r""" + Build a 1D parallelized Embedding from a native nn.Embedding module. + """ + # get the attributes + num_embedding = module.num_embeddings + embedding_dim = module.embedding_dim + padding_idx = module.padding_idx + max_norm = module.max_norm + norm_type = module.norm_type + scale_grad_by_freq = module.scale_grad_by_freq + sparse = module.sparse + dtype = module.weight.dtype + device = module.weight.device + + # sparse is not support yet + if sparse: + raise NotImplementedError("The Embedding1D module does not support sparse embedding yet.") + + embedding = Embedding1D(num_embeddings=num_embedding, + embedding_dim=embedding_dim, + padding_idx=padding_idx, + process_group=process_group, + dtype=dtype, + device=device, + max_norm=max_norm, + norm_type=norm_type, + scale_grad_by_freq=scale_grad_by_freq, + sparse=sparse) + + # copy the weight + with torch.no_grad(): + sharded_weight = shard_colwise(module.weight.data, process_group) + embedding.weight.copy_(sharded_weight) + + return embedding def reset_parameters(self, weight_initializer) -> None: - with seed(ParallelMode.TENSOR): - fan_in, fan_out = self.num_embeddings, self.embed_dim - weight_initializer(self.weight, fan_in=fan_in, fan_out=fan_out) - self._fill_padding_idx_with_zero() + fan_in, fan_out = self.num_embeddings, self.embed_dim + weight_initializer(self.weight, fan_in=fan_in, fan_out=fan_out) + self._fill_padding_idx_with_zero() def _fill_padding_idx_with_zero(self) -> None: if self.padding_idx is not None: with torch.no_grad(): self.weight[self.padding_idx].fill_(0) - def _load_from_global_state_dict(self, state_dict, prefix, *args): - local_state = OrderedDict() - weight_key = prefix + 'weight' - if gpc.get_local_rank(ParallelMode.TENSOR) == 0: - # weight - weight = state_dict.pop(weight_key, None) - if weight is not None: - local_state[weight_key] = weight - - local_state = partition_tensor_parallel_state_dict(local_state, - ParallelMode.PARALLEL_1D, - dims={weight_key: -1}, - partition_states={weight_key: True}) - super()._load_from_global_state_dict(local_state, prefix, *args) - - def _save_to_global_state_dict(self, destination, prefix, keep_vars): - weight_key = prefix + 'weight' - local_state = OrderedDict({weight_key: self.weight}) - local_state = gather_tensor_parallel_state_dict(local_state, - ParallelMode.PARALLEL_1D, - dims={weight_key: -1}, - partition_states={weight_key: True}, - keep_vars=keep_vars) - destination.update(local_state) - def forward(self, input_: Tensor) -> Tensor: - output_parallel = F.embedding(input_, self.weight, self.padding_idx, *self.embed_args, **self.embed_kwargs) - if self.gather_output: - output = gather_forward_split_backward(output_parallel, ParallelMode.PARALLEL_1D, dim=-1) - else: - output = output_parallel + output = gather_forward_split_backward(output_parallel, dim=-1, process_group=self.process_group) return output @@ -926,89 +663,3 @@ def forward(self, input_: Tensor) -> Tensor: # Reduce across all the model parallel GPUs. output = reduce_input(output_parallel, ParallelMode.PARALLEL_1D) return output - - -class Dropout1D(ParallelLayer): - """Dropout layer of 1D parallelism. - - Args: - p (float, optional): probability of an element to be zeroed, defaults 0.5. - inplace (bool, optional): whether to do dropout in-place, default to be False. - """ - - def __init__(self, p: float = 0.5, inplace: bool = False): - super().__init__() - self.parallel_input = get_parallel_input() - self.p = p - self.inplace = inplace - - def forward(self, input_: Tensor) -> Tensor: - if self.parallel_input: - with seed(ParallelMode.TENSOR): - output = F.dropout(input_, self.p, self.training, self.inplace) - else: - output = F.dropout(input_, self.p, self.training, self.inplace) - return output - - -# @LAYERS.register_module -class PatchEmbedding1D(ColossalaiModule): - """ - 2D Image to Patch Embedding - - :param img_size: image size - :type img_size: int - :param patch_size: patch size - :type patch_size: int - :param in_chans: number of channels of input image - :type in_chans: int - :param embed_size: size of embedding - :type embed_size: int - :param dtype: The dtype of parameters, defaults to None - :type dtype: torch.dtype, optional - :param flatten: whether to flatten output tensor, defaults to True - :type flatten: bool, optional - :param weight_initializer: The initializer of weight, defaults to kaiming uniform initializer - :type weight_initializer: typing.Callable, optional - :param bias_initializer: The initializer of bias, defaults to xavier uniform initializer - :type bias_initializer: typing.Callable, optional - :param position_embed_initializer: The initializer of position embedding, defaults to zero - :type position_embed_initializer: typing.Callable, optional - """ - - def __init__(self, - img_size: int, - patch_size: int, - in_chans: int, - embed_size: int, - dtype: torch.dtype = None, - flatten: bool = True, - weight_initializer: Callable = init.kaiming_uniform_(a=math.sqrt(5)), - bias_initializer: Callable = init.xavier_uniform_(a=1, scale=1), - position_embed_initializer: Callable = init.zeros_()): - embed = VanillaPatchEmbedding(img_size, - patch_size, - in_chans, - embed_size, - dtype=dtype, - flatten=flatten, - weight_initializer=weight_initializer, - bias_initializer=bias_initializer, - position_embed_initializer=position_embed_initializer) - super().__init__(embed) - - def _load_from_state_dict(self, state_dict, prefix, *args): - local_state = OrderedDict() - param_keys = [prefix + 'weight', prefix + 'bias', prefix + 'cls_token', prefix + 'pos_embed'] - if gpc.get_local_rank(ParallelMode.TENSOR) == 0: - for key in param_keys: - param = state_dict.pop(key, None) - if param is not None: - local_state[key] = param - - local_state = broadcast_state_dict(local_state, ParallelMode.PARALLEL_1D) - super()._load_from_state_dict(local_state, prefix, *args) - - def _save_to_state_dict(self, destination, prefix, keep_vars): - if gpc.get_local_rank(ParallelMode.TENSOR) == 0: - super()._save_to_state_dict(destination, prefix, keep_vars) diff --git a/tests/test_shardformer/test_layer/test_dropout.py b/tests/test_shardformer/test_layer/test_dropout.py new file mode 100644 index 000000000000..c48c11b36d91 --- /dev/null +++ b/tests/test_shardformer/test_layer/test_dropout.py @@ -0,0 +1,53 @@ +import torch +import torch.distributed as dist +import torch.nn as nn + +import colossalai +from colossalai.shardformer.layer.dropout import Dropout1D +from colossalai.testing import assert_equal, assert_not_equal, rerun_if_address_is_in_use, spawn + + +def check_dropout(): + dropout = nn.Dropout().cuda() + dropout_1d = Dropout1D.from_native_module(dropout, process_group=None) + + # check computation correctness + x = torch.rand(4, 128).cuda() + + # we set seed so that dropout will generate the same mask + torch.cuda.manual_seed(1024) + out = dropout(x) + + # we set seed to simulate the same scenario + # but expect the dropout mask to be different + # due to the internal randomness control + torch.cuda.manual_seed(1024) + out_1d = dropout_1d(x) + + # ensure out is the same across all ranks + world_size = dist.get_world_size() + out_all = [torch.empty_like(out) for _ in range(world_size)] + dist.all_gather(out_all, out) + + for i in range(world_size): + assert_equal(out_all[i], out_all[0]) + + # ensure out_1d is different across ranks + out_1d_all = [torch.zeros_like(out_1d) for _ in range(world_size)] + dist.all_gather(out_1d_all, out_1d) + for i in range(1, world_size): + assert_not_equal(out_1d_all[i], out_1d_all[0]) + + +def run_dist(rank, world_size, port): + colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') + check_dropout() + + +@rerun_if_address_is_in_use() +def test_dropout(): + spawn(run_dist, nprocs=2) + + +if __name__ == '__main__': + test_dropout() diff --git a/tests/test_shardformer/test_layer/test_embedding.py b/tests/test_shardformer/test_layer/test_embedding.py new file mode 100644 index 000000000000..462349ecb93b --- /dev/null +++ b/tests/test_shardformer/test_layer/test_embedding.py @@ -0,0 +1,43 @@ +import torch +import torch.distributed as dist +import torch.nn as nn +from torch.testing import assert_close + +import colossalai +from colossalai.shardformer.layer.layers import Embedding1D +from colossalai.testing import rerun_if_address_is_in_use, spawn + + +def check_embedding_1d(): + embedding = nn.Embedding(32, 128).cuda() + embedding_1d = Embedding1D.from_native_module(embedding, process_group=None) + + assert embedding_1d.weight.shape == torch.Size([32, 64]) + + # check computation correctness + x = torch.randint(low=0, high=32, size=(4, 32)).cuda() + out = embedding(x) + gather_out = embedding_1d(x) + assert_close(out, gather_out) + + # check backward correctness + out.sum().backward() + gather_out.sum().backward() + + rank = dist.get_rank() + target_grad = torch.chunk(embedding.weight.grad, 2, dim=1)[rank] + assert_close(target_grad, embedding_1d.weight.grad) + + +def run_dist(rank, world_size, port): + colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') + check_embedding_1d() + + +@rerun_if_address_is_in_use() +def test_embedding_1d(): + spawn(run_dist, nprocs=2) + + +if __name__ == '__main__': + test_embedding_1d() diff --git a/tests/test_shardformer/test_layer/test_linear_1d.py b/tests/test_shardformer/test_layer/test_linear_1d.py index 449522c64129..2a3ce99384cb 100644 --- a/tests/test_shardformer/test_layer/test_linear_1d.py +++ b/tests/test_shardformer/test_layer/test_linear_1d.py @@ -5,7 +5,7 @@ import colossalai from colossalai.shardformer.layer.layers import Linear1D_Col, Linear1D_Row -from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn +from colossalai.testing import rerun_if_address_is_in_use, spawn def check_linear_1d_col(): From eaa46d7eda5e8aabd30a67ed5e17ff08d1e0d6a6 Mon Sep 17 00:00:00 2001 From: Frank Lee Date: Fri, 16 Jun 2023 15:58:27 +0800 Subject: [PATCH 21/49] [shardformer] removed inplace tensor sharding (#4018) --- colossalai/shardformer/layer/layers.py | 4 +++ colossalai/tensor/d_tensor/api.py | 40 +++++++++++++++++++++++--- 2 files changed, 40 insertions(+), 4 deletions(-) diff --git a/colossalai/shardformer/layer/layers.py b/colossalai/shardformer/layer/layers.py index 87d24f18e178..586aec124b86 100644 --- a/colossalai/shardformer/layer/layers.py +++ b/colossalai/shardformer/layer/layers.py @@ -329,7 +329,11 @@ def reset_parameters(self, weight_initializer, bias_initializer) -> None: src_rank = 0 else: src_rank = dist.distributed_c10d._get_global_rank(self.process_group, 0) + + origin_device = self.bias.device + self.bias = self.bias.cuda() dist.broadcast(self.bias, src=src_rank, group=self.process_group) + self.bias = self.bias.to(origin_device) def forward(self, input_: Tensor) -> Tensor: # Set up backprop all-reduce. diff --git a/colossalai/tensor/d_tensor/api.py b/colossalai/tensor/d_tensor/api.py index afb1fc003e02..b58edadfef20 100644 --- a/colossalai/tensor/d_tensor/api.py +++ b/colossalai/tensor/d_tensor/api.py @@ -10,9 +10,21 @@ from .sharding_spec import ShardingSpec -def shard_rowwise(tensor: torch.Tensor, group_or_device_mesh: Union[ProcessGroup, DeviceMesh] = None) -> DTensor: +def shard_rowwise(tensor: torch.Tensor, + group_or_device_mesh: Union[ProcessGroup, DeviceMesh] = None, + inplace: bool = False) -> DTensor: """ - Shard the first dim of the given tensor + Shard the first dim of the given tensor. + + Args: + tensor (torch.Tensor): The tensor to be sharded. + group_or_device_mesh (Union[ProcessGroup, DeviceMesh], optional): The group or device mesh to shard the tensor. + If None, the tensor will be sharded with respect to the global process group. + Defaults to None. + inplace (bool, optional): Whether to shard the tensor in-place. Defaults to False. + + Returns: + DTensor: The sharded tensor. """ # if the group_or_device_mesh is None, we shard the tensor with respect to the global process group if group_or_device_mesh is None: @@ -24,12 +36,28 @@ def shard_rowwise(tensor: torch.Tensor, group_or_device_mesh: Union[ProcessGroup assert len(group_or_device_mesh.shape) == 1, 'Only 1D DeviceMesh is accepted for row-wise sharding.' device_mesh = group_or_device_mesh sharding_spec = ShardingSpec(dim_size=tensor.dim(), dim_partition_dict={0: [0]}) + + if not inplace: + tensor = tensor.detach().clone() + return DTensor(tensor, device_mesh, sharding_spec) -def shard_colwise(tensor: torch.Tensor, group_or_device_mesh: Union[ProcessGroup, DeviceMesh] = None) -> DTensor: +def shard_colwise(tensor: torch.Tensor, + group_or_device_mesh: Union[ProcessGroup, DeviceMesh] = None, + inplace: bool = False) -> DTensor: """ - Shard the first dim of the given tensor + Shard the first dim of the given tensor. + + Args: + tensor (torch.Tensor): The tensor to be sharded. + group_or_device_mesh (Union[ProcessGroup, DeviceMesh], optional): The group or device mesh to shard the tensor. + If None, the tensor will be sharded with respect to the global process group. + Defaults to None. + inplace (bool, optional): Whether to shard the tensor in-place. Defaults to False. + + Returns: + DTensor: The sharded tensor. """ # if the group_or_device_mesh is None, we shard the tensor with respect to the global process group if group_or_device_mesh is None: @@ -41,4 +69,8 @@ def shard_colwise(tensor: torch.Tensor, group_or_device_mesh: Union[ProcessGroup assert len(group_or_device_mesh.shape) == 1, 'Only 1D DeviceMesh is accepted for row-wise sharding.' device_mesh = group_or_device_mesh sharding_spec = ShardingSpec(dim_size=tensor.dim(), dim_partition_dict={-1: [0]}) + + if not inplace: + tensor = tensor.detach().clone() + return DTensor(tensor, device_mesh, sharding_spec) From 60eb38062495af2397c11c472f4e74fcfb7d86bc Mon Sep 17 00:00:00 2001 From: FoolPlayer <498107402@qq.com> Date: Fri, 16 Jun 2023 15:04:07 +0800 Subject: [PATCH 22/49] add vocabembedding layer --- colossalai/shardformer/layer/layers.py | 65 ++++++++++++++++--- .../test_vocab_parallel_embedding_1d.py | 45 +++++++++++++ 2 files changed, 100 insertions(+), 10 deletions(-) create mode 100644 tests/test_shardformer/test_layer/test_vocab_parallel_embedding_1d.py diff --git a/colossalai/shardformer/layer/layers.py b/colossalai/shardformer/layer/layers.py index 586aec124b86..ad6e1896aa5e 100644 --- a/colossalai/shardformer/layer/layers.py +++ b/colossalai/shardformer/layer/layers.py @@ -139,6 +139,7 @@ def __init__(self, with self.randomizer.fork_rng(enable_cpu=True): self.reset_parameters(weight_initializer, bias_initializer) + @staticmethod def from_native_module(module: nn.Linear, process_group: Union[ProcessGroup, List[ProcessGroup]], *args, **kwargs) -> ParallelModule: r""" @@ -587,6 +588,8 @@ def __init__(self, embedding_dim: int, padding_idx: int = None, dtype: torch.dtype = None, + device: torch.device = None, + process_group: ProcessGroup = None, weight_initializer: Callable = init.normal_(), *args, **kwargs): @@ -596,21 +599,63 @@ def __init__(self, self.padding_idx = padding_idx self.embed_args = args self.embed_kwargs = kwargs + self.process_group = process_group - tensor_parallel_size = gpc.get_world_size(ParallelMode.PARALLEL_1D) - tensor_parallel_rank = gpc.get_local_rank(ParallelMode.PARALLEL_1D) - # self.num_embeddings_per_partition = divide(num_embeddings, tensor_parallel_size) - self.num_embeddings_per_partition = num_embeddings + tensor_parallel_size = dist.get_world_size(group=process_group) + tensor_parallel_rank = dist.get_rank(group=process_group) + + self.num_embeddings_per_partition = divide(num_embeddings, tensor_parallel_size) + self.num_embeddings = self.num_embeddings_per_partition self.vocab_start_index = tensor_parallel_rank * self.num_embeddings_per_partition self.vocab_end_index = self.vocab_start_index + self.num_embeddings_per_partition self.weight = Parameter( - torch.empty((self.num_embeddings_per_partition, self.embed_dim), device=get_current_device(), dtype=dtype)) + torch.empty((self.num_embeddings_per_partition, self.embed_dim), device=device, dtype=dtype)) + + # offset the seed with randomizer index and rank + seed = torch.random.initial_seed() + self.randomizer = create_randomizer_with_offset(seed, process_group=self.process_group) + + with self.randomizer.fork_rng(enable_cpu=True): + self.reset_parameters(weight_initializer) + # self.reset_parameters(weight_initializer) + # self._set_tensor_parallel_attributes() + # set_parallel_input(False) + # env.vocab_parallel = True + + @staticmethod + def from_native_module(module: nn.Embedding, process_group: Union[ProcessGroup, List[ProcessGroup]], *args, + **kwargs) -> ParallelModule: + r""" + Convert a native pytorch embedding module to a parallel module. + """ + # get the origin attributes + num_embeddings = module.num_embeddings + embedding_dim = module.embedding_dim + padding_idx = module.padding_idx + device = module.weight.device + + # ensure only one process group is used + if isinstance(process_group, (list, tuple)): + assert len(process_group) == 1, \ + f'Expected only one process group, got {len(process_group)}.' + process_group = process_group[0] + + # create the parallel module + vocab_embedding_1d = VocabParallelEmbedding1D(num_embeddings=num_embeddings, + embedding_dim=embedding_dim, + padding_idx=padding_idx, + device=device, + process_group=process_group, + *args, + **kwargs) + with torch.no_grad(): + # shard and slice the weight along the vocabulary(num_embeddings) dimension + # the shape of the weight is (num_embeddings, embedding_dim) + shard_weight = shard_rowwise(module.weight.data, process_group) + vocab_embedding_1d.weight.data.copy_(shard_weight) - self.reset_parameters(weight_initializer) - self._set_tensor_parallel_attributes() - set_parallel_input(False) - env.vocab_parallel = True + return vocab_embedding_1d def _set_tensor_parallel_attributes(self): set_tensor_parallel_attribute_by_partition(self.weight, gpc.tensor_parallel_size) @@ -665,5 +710,5 @@ def forward(self, input_: Tensor) -> Tensor: # Mask the output embedding. output_parallel[input_mask, :] = 0. # Reduce across all the model parallel GPUs. - output = reduce_input(output_parallel, ParallelMode.PARALLEL_1D) + output = reduce_input(output_parallel, self.process_group) return output diff --git a/tests/test_shardformer/test_layer/test_vocab_parallel_embedding_1d.py b/tests/test_shardformer/test_layer/test_vocab_parallel_embedding_1d.py new file mode 100644 index 000000000000..3df53e8a8458 --- /dev/null +++ b/tests/test_shardformer/test_layer/test_vocab_parallel_embedding_1d.py @@ -0,0 +1,45 @@ +import torch +import torch.distributed as dist +import torch.nn as nn +from torch.testing import assert_close + +import colossalai +from colossalai.shardformer.layer.layers import VocabParallelEmbedding1D +from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn + + +def check_vocab_embedding_1d(): + embedding = nn.Embedding(128, 32).to('cuda') + dist_embedding_1d = VocabParallelEmbedding1D.from_native_module(embedding, process_group=None) + + assert dist_embedding_1d.weight.shape == torch.Size([64, 32]) + assert dist_embedding_1d.num_embeddings == 64 + assert dist_embedding_1d.embed_dim == 32 + + # check embedding correctness + x = torch.randint(0, 128, (4, 32)).to('cuda') + org_out = embedding(x) + dist_out = dist_embedding_1d(x) + assert_close(org_out, dist_out) + + # check backward correctness + org_out.sum().backward() + dist_out.sum().backward() + + rank = dist.get_rank() + target_grad = torch.chunk(embedding.weight.grad, 2, dim=0)[rank] + assert_close(target_grad, dist_embedding_1d.weight.grad) + + +def run_dist(rank, world_size, port): + colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') + check_vocab_embedding_1d() + + +@rerun_if_address_is_in_use() +def test_vocab_embedding(): + spawn(run_dist, nprocs=2) + + +if __name__ == '__main__': + test_vocab_embedding() From 90e1a0ac22aeb23c8146aa4c628ecf8a56417f6b Mon Sep 17 00:00:00 2001 From: FoolPlayer <498107402@qq.com> Date: Fri, 16 Jun 2023 16:12:27 +0800 Subject: [PATCH 23/49] support bert with new api --- colossalai/shardformer/policies/bert.py | 35 ++++++++++++++++++++++++- colossalai/shardformer/shard/sharder.py | 5 ++-- 2 files changed, 37 insertions(+), 3 deletions(-) diff --git a/colossalai/shardformer/policies/bert.py b/colossalai/shardformer/policies/bert.py index fc3e8447337d..fe74f83ca745 100644 --- a/colossalai/shardformer/policies/bert.py +++ b/colossalai/shardformer/policies/bert.py @@ -2,6 +2,7 @@ from transformers.models.bert.modeling_bert import BertEmbeddings, BertLayer, BertLMPredictionHead import colossalai.shardformer.layer.layers as col_nn +from colossalai.shardformer.layer.dropout import Dropout1D from ..shard.shard_config import ShardConfig from ..utils import getattr_, setattr_ @@ -65,7 +66,24 @@ def module_policy(self, shard_config: ShardConfig = None): suffix="output.dense", target_module=col_nn.Linear1D_Row, ), - ]) + SubModuleReplacementDescription( + suffix="attention.self.dropout", + target_module=Dropout1D, + ), + SubModuleReplacementDescription( + suffix="attention.output.dropout", + target_module=Dropout1D, + ) + ]), + BertEmbeddings: + ModulePolicyDescription(attribute_replacement={}, + param_replacement=[], + sub_module_replacement=[ + SubModuleReplacementDescription( + suffix="word_embeddings", + target_module=col_nn.VocabParallelEmbedding1D, + ) + ]) } def new_model_class(self): @@ -87,6 +105,21 @@ class BertForMaskedLMPolicy(BertPolicy): def __init__(self) -> None: super().__init__() + def module_policy(self, shard_config: ShardConfig = None): + module_policy = super().module_policy(shard_config) + addon_module = { + BertLMPredictionHead: + ModulePolicyDescription(attribute_replacement={}, + param_replacement=[], + sub_module_replacement=[ + SubModuleReplacementDescription(suffix="decoder", + target_module=col_nn.Linear1D_Col, + kwargs={"gather_output": True}) + ]) + } + module_policy.update(addon_module) + return module_policy + # BertLMHeadModel class BertLMHeadModelPolicy(BertPolicy): diff --git a/colossalai/shardformer/shard/sharder.py b/colossalai/shardformer/shard/sharder.py index eb8300d5998e..5c8584595c0c 100644 --- a/colossalai/shardformer/shard/sharder.py +++ b/colossalai/shardformer/shard/sharder.py @@ -171,12 +171,13 @@ def _replace_sub_module( for description in sub_module_replacement: suffix = description.suffix target_module = description.target_module - kwargs = description.kwargs + kwargs = {} if description.kwargs is None else description.kwargs assert target_module is not None, 'target_module should not be None' # TODO: support different parallel mode native_sub_module = getattr_(org_layer, suffix) - replace_layer = target_module.from_native_module(native_sub_module, self.pg_manager.pg_store['tp1d']) + replace_layer = target_module.from_native_module(native_sub_module, self.pg_manager.pg_store['tp1d'], + **kwargs) setattr_(org_layer, suffix, replace_layer) From 38ceded2cfb16cbda074eb85982448bdee19e84d Mon Sep 17 00:00:00 2001 From: Frank Lee Date: Fri, 16 Jun 2023 16:15:10 +0800 Subject: [PATCH 24/49] [shardformer] updated doc (#4016) --- colossalai/shardformer/README.md | 504 ++++++++++++++++--------------- 1 file changed, 257 insertions(+), 247 deletions(-) diff --git a/colossalai/shardformer/README.md b/colossalai/shardformer/README.md index b8357c203939..dc2946ec937f 100644 --- a/colossalai/shardformer/README.md +++ b/colossalai/shardformer/README.md @@ -6,9 +6,15 @@ - [📚 Table of Contents](#-table-of-contents) - [🔗 Introduction](#-introduction) - [🔨 Usage](#-usage) - - [🔮 Simple example](#-simple-example) - - [💡 Policy](#-policy) - - [😊 Module](#-module) + - [Quick Start](#quick-start) + - [Write your own policy](#write-your-own-policy) + - [🗺 Roadmap](#-roadmap) + - [💡 API Design](#-api-design) + - [Distributed Modules](#distributed-modules) + - [Shard Config](#shard-config) + - [Policy](#policy) + - [Model Sharder](#model-sharder) + - [User-facing API](#user-facing-api) ## 🔗 Introduction @@ -17,299 +23,303 @@ ## 🔨 Usage +### Quick Start + The sample API usage is given below: ``` python -from colossalai.shardformer import ShardConfig, shard_model +from colossalai.shardformer import ShardConfig, Shard from transformers import BertForMaskedLM -# create huggingface model as normal -model = BertForMaskedLM.from_pretrained("bert-base-uncased") - -# make the huggingface model paralleled to ShardModel -# auto policy: -shardconfig = ShardConfig( - rank=rank, - world_size=world_size, - gather_output=True, -) -sharded_model = shard_model(model, config=shardconfig) - -# custom policy: -from xxx import -sharded_model = shard_model(model, ) - -# do angthing as normal -... -``` +# launch colossalai +colossalai.launch_from_torch() -## 🔮 Simple example +# create model +config = BertConfig.from_pretrained('bert-base-uncased') +model = BertForMaskedLM.from_pretrained('bert-base-uncased', config=config) -``` shell -# inference -colossalai run --nproc_per_node 2 --master_port 29500 test.py --config config.py --mode inference -# train -colossalai run --nproc_per_node 2 --master_port 29500 test.py --config config.py --mode train +# create huggingface model as normal +shard_config = ShardConfig(tensor_parallel_size=2, + data_parallel_size=1, + gather_output=True) +shard_former = ShardFormer(shard_config=shard_config) +shard_former.init_distributed() +sharded_model = shard_former.shard_model(model).to('cuda') + +# do everything like normal +... ``` +### Write your own policy -## 💡 Policy - -If you wanna parallel the model in a custom way, just overwrite the policy class for the Hugging Face model. Please refer to any policy that we have pre-established, like [bert policy](./policies/bert.py) or [gpt2 policy](./policies/gpt2.py). - -You should do: - -1. Inherit Policy class -2. Overwrite `argument_policy` method - - In this method, you need to list which layers class you wanna modify and the attributes and parameters in those layers. Shardformer will replace all the layer belonging to the class you specified. - - `attr_dict` is dict contains all the attributes need to be modified in this layer. - - `param_funcs` is a list contains some functions which will return the path of the weight and bias from the layer. -3. Overwrite `inject_policy` method (Optional) - - Shardformer will inject the model according to this method. If you need to modify the forward or backward progress (like distributed corssentropy loss in Bert) you need to overwrite this method. -4. Overwrite or add the param functions - - These functions use a suffix to record the path of weight or bias for the layer. - - The return is a list contains some `Col_Layer`, `Row_Layer` or `Dropout_Layer` objects, which means slice along col and row respectively or as dropout layer, refer to CLASS `Layer` for more details. -5. Overwrite `binding_policy` (Optional) - - Overwrite to specify Shardformer will bind some weight between layers, like embedding and unembedding layers. - - This function will return a dict, the key and value are the suffix of weight need to be binded. - -More details can be found in shardformer/policies/basepolicy.py -``` python -from colossalai.shardformer.policies.basepolicy import Policy, Layer, Col_Layer, Row_Layer, Argument - -class CustomPolicy(Policy): -@staticmethod - def argument_policy(model_config, shard_config: int) -> Dict[nn.Module, Argument]: - r""" - Return the dict for the modify policy, the key is the original layer class and the value is the - argument for the modify layer - - Args: - model_config (:class:`tansformer.Config`): The config of transformer model - shard_config (:class:`ShardConfig`): The config for sharding model - - Return: - Dict for the modify policy, - :: - { - origin layer class1 (nn.Module): Argument( - attr_dict = { - argument1: value1, - argument2: value2, - ... - }, - param_funcs = [ - staticmethod1, - staticmethod2, - ... - ] - ), - origin layer class2 (nn.Module): Argument( - attr_dict = { - argument1: value1, - argument2: value2, - ... - }, - param_funcs = [ - staticmethod1, - staticmethod2, - ... - ] - ), - ... - } - - """ - raise NotImplementedError - - @staticmethod - def inject_policy() -> Union[Tuple[nn.Module, nn.Module], None]: - r""" - Return the dict for the inject model - - Return: - The injected model, key is the original model and value is the new shardmodel - :: - (OrignModel, CustomModel) - in `CustomModel`, we can overwrite the forward and backward process - """ - return None - - @staticmethod - def binding_policy() -> Union[Dict[str, str], None]: - r""" - Return the dict for the binding model, None means no need to bind - - Return: - This method should return the binding relationship for some layers share the weight or bias, - the key and value is the suffix of the weight or bias of the model - :: - return { - "bert.embeddings.word_embeddings.weight": "cls.predictions.decoder.weight", - } - """ - return None - - @staticmethod - def attn_in() -> Union[List, None]: - r""" - Attention qkv layer - In this kind of method, we should return the list of ``Layer`` object, each ``Layer`` object should be - ``Layer`` for no slicing, ``Col_Layer`` for col slicing, ``Row_Layer`` for row slicing. And the parameters - in ``Layer`` object can refer to the ``Layer`` class. - - Returns: - List[Layer]: List of layer object, each layer is the new - """ - return None - - @staticmethod - def attn_out() -> Union[List, None]: - r""" - Attention output projection layer - - Returns: - List[Layer]: List of layer object - """ - return None +If you have a custom model, you can also use Shardformer to parallelize it by writing your own sharding policy. More information about the sharding policy can be found in [API Design](#-api-design). - @staticmethod - def mlp_in() -> Union[List, None]: - r""" - h -> 4h mlp layer +```python +from colossalai.shardformer import Policy - Returns: - List[Layer]: List of layer object - """ - return None +class MyPolicy(Policy): + # implement your own policy + ... - @staticmethod - def mlp_out() -> Union[List, None]: - r""" - 4h -> h mlp layer +# init model and shard former +... - Returns: - List[Layer]: List of layer object - """ - return None +# use customized policy to shard model +my_policy = MyPolicy() +shard_former.shard_model(model, my_policy) - @staticmethod - def embedding() -> Union[List, None]: - r""" - Partially slice the embedding layer +``` - Return: - List[Layer]: List of layer object +## 🗺 Roadmap + +We will follow this roadmap to develop Shardformer: + +- [x] API Design +- [x] API Implementation +- [x] Unit Testing +- [ ] Policy Implementation + - [ ] Hugging Face + - [ ] NLP + - [x] BERT + - [ ] T5 + - [ ] LlaMa + - [ ] GPT2 + - [ ] BLOOM + - [ ] RoBERTa + - [ ] ALBERT + - [ ] ERNIE + - [ ] GPT Neo + - [ ] GPT-J + - [ ] CV + - [ ] CV + - [ ] ViT + - [ ] BEiT + - [ ] SwinTransformer + - [ ] SwinTransformer V2 + - [ ] Audio + - [ ] To be added + - [ ] Multi-modal + - [ ] To be added + +## 💡 API Design + +We will discuss the major components of `ShardFormer` below to help you better understand how things work. +This section serves as the design doc for Shardformer and the function signature might differ from the actual implementation. +Please refer to the code for more details. + +

+ +
+ This diagram is deprecated, need to update it +

+ + + +### Distributed Modules + +`ShardFormer` replaces the original PyTorch module with a distributed module. +The distributed module keeps the same attributes as the original module but replaces the original parameters with distributed parameters and defines a new `forward` function to execute distributed computation. +Each distributed module implements its `from_native_module` static method to convert the PyTorch module to its corresponding distributed module. + +```python +class ParallelModule(torch.nn.Module): + + @abstractmethod + def from_native_module(module: torch.nn.Module, process_group: Union[ProcessGroup, Tuple[ProcessGroup]]) -> ParallelModule """ - return None + Convert a native module to a parallelized - @staticmethod - def unembedding() -> Union[List, None]: - r""" - Partially slice the embedding layer, None means there is no unembedding layer + Examples: - Return: - List[Layer]: List of layer object + ```python + # replace module + my_linear = Linear1D_Col.from_native_module(my_linear, process_group) + ``` """ - return None - ``` +### Shard Config -## 😊 Module - - 1. Flowchart - -

- -

- - 2. Important Modules - - - CLASS `shard_model`: - - This is the user api to use shardformer, just create a model from transformers and define a custom policy or use shardformer autopolicy to make a shard model. - - - CLASS `Layer`: - - Parameters: - - suffix: (str): the suffix of the layer to indicate the attribute of the layer. - - replace_layer (:class:`colosalai.nn`): The layer to replace the original layer - - ignore (bool): Whether to ignore this layer if it is not in the model - - reversed (bool): Whether the weight in layer is reversed, commonly the weight in `torch.nn.Linear` is [out, in], but in GPT2 `Conv1D` layer is [in, out] which is reversed. - - n_cast (int): The number of weight will cast to, like q, k, v in attention layer, n_cast should be 3. commonly in TP, we just chunk the weight with the number of devices, but in multi-head attention, we need to chunk the weight with the number of $ devices * n\_head $, and each device should have a part of Q, K and V weight. - - This class is a base class used to specify the replacement policy and the suffix the layer for a particular layer. +`ShardConfig` is a simple data class to tell `ShardFormer` how sharding will be performed. - CLASS `Col_Layer(Layer)`: - - weight (str): The weight suffix of the layer - - bias (str): The bias suffix of the layer - - gather_output (bool): Whether the output of this layer can be gathered, like the last layer can be gathered, but most of the time, the intermediate layers of the model do not need to be gathered. +```python +@dataclass +class ShardConfig: + data_parallel_size: int + tensor_parallel_size: int + ... - This class inherited from `Layer`, representing the layer will be sliced along colum and indicate the attributes of weight and bias. Setting `bias` to `None` means ignoring bias, regardless of whether or not it originally exists. - - CLASS `Row_Layer(Layer)`: - - weight (str): The weight suffix of the layer - - bias (str): The bias suffix of the layer - - This class inherited from `Layer`, representing the layer will be sliced along row. Just like `Col_Layer` but in tensor parrallel, there is no need to gather the output of layer sliced by row. - - - CLASS `Policy`: - - In Shardformer, this class holds significant importance as it defines the model partitioning methods, required parameter modifications, and model injection techniques all within a single Policy class. - - `Policy.attn_in()/attn_out()/mlp_in()/mlp_out()/embedding()/unembedding()`...... - - These functions define the partitioning methods of the parameters at different locations in the model. Each function returns a list of objects of Layer class that specify the replacement approach for these parameters. Shardformer also supports user-defined functions for modifying their models, in addition to the listed functions. + # Some possible future config fields + pipeline_parallel_size: int # Support pipeline parallelism + tensor_parallel_mode: Choice['1d', '2d', '2.5d', '3d'] # support different tensor parallel mode + inference_only: bool # only inject inference-suitable sharding policy + gather_output: bool # gather the model output + use_flash_attention: bool # whether to use flash attention to speed up attention +``` - - `Policy.argument_policy()` +### Policy - In this function, the user should use multiple dict to define which class of layers will require replacement. This includes the attributes and parameters that need to be modified or replaced. Attributes are stored in the form of a "suffix-string: value" dict, while parameters are stored via multiple static methods that return the replacement approach. +The `Policy` class describes how to handle the model sharding. +It is merely a description, the actual sharding will be performed by `ModelSharder`. +We abstract the policy into four stages: - - `Policy.inject_policy()` +1. Preprocessing: call `Policy.preprocess` to do some prior work before sharding, for example, resizing the embedding +2. Providing a new class: call `Policy.new_model_class` to get a new class for the model, this class replaces attributes and the forward function +3. Providing `ModulePolicyDescription`: call `Policy.module_policy` to get a bunch of `ModulePolicyDescription` to tell `ModelSharder` how the submodules's attributes, child parameters, and deeper submodules will be substituted. +4. Postprocessing: call `Policy.postprocess` to perform some postprocessing work, for example, binding the embedding and classifier head weights of the BERT model. - This function will return the injected model to replace the original model. The new model should be a nn.Module class which includes modified forward or backward functions or anything else. +``` python +@dataclass +class ModulePolicyDescription: + """ + Describe how the attributes and parameters will be transformed in a policy + + Args: + attribute_replacement (Dict[str, Any]): key is the attribute name, value is the attribute value after sharding + param_replacement (List[Callable]): a list of functions to perform in-place param replacement. The function must receive two arguments: module, process_group. One example is + def example_replace_weight(module: torch.nn.Module, process_group): + weight = module.weight + new_weight = shard_rowwise(weight, process_group) + module.weight = torch.nn.Parameter(new_weight) + sub_module_replacement: each element in the list is a ParamReplacementDescription object which specifies the module to be replaced and the target module used to replacement + """ + attribute_replacement: Dict[str, Any] + param_replacement: List[Callable] + sub_module_replacement: List[SubModuleReplacementDescription] + +@dataclass +class SubModuleReplacementDescription: + """ + Describe how a submodule will be replaced + + Args: + suffix (str): used to get the submodule object + target_module (ParallelModule): specifies the module class used to replace to submodule + kwargs (Dict[str, Any]): the dictionary used to pass extra arguments to the `ParallelModule.from_native_module` method. + """ + suffix: str + target_module: ParallelModule + kwargs: Dict[str, Any] = None + + +class Policy(ABC): + + def __init__(self) + self.model = None + + def set_model(self, model: nn.Module) -> None: + """ + Set model as an attribute of the Policy object so that we can access the model's attributes. + """ + self.model = model - - `Policy.binding_policy()` + @abstractmethod + def preprocess(self) -> nn.Module: + """ + Perform some preprocessing on the model, such as resizing the embedding size + """ + ... - This function will return the weight sharing information in the model in some dict. The key and value are both the suffixes of the shared parameters. + @abstractmethod + def module_policy(self) -> Dict[Union[str, nn.Module], ModulePolicyDescription]: + """ + Return the dict for the modify policy, the key is the original layer class and the value is the + argument for the modify layer + """ + ... + @abstractmethod + def new_model_class(self) -> Union[Type[nn.Module], None]: + """ + replace the class of the model to substitute the forward and attributes + """ + ... - - CLASS `ModelSharder(model, policy)`: + @abstractmethods + def postprocess(self) -> nn.Module: + """ + Perform some postprocessing on the model, such as binding the embedding with the weight of the classifier head + """ + ... +``` - This class helps shard the model, the parameter is the created transformers model and the custom policy. If custom policy is None, shardformer will automatically get already defined policy for the model. - - `ModelShard.inject_model()` +### Model Sharder - This function is used to inject the model to modify the forward and backward progress. +`ModelSharder` is the class in charge of sharding the model based on the given policy. - - `ModelShard.replace_layer()` +```python +class ModelSharder: - This function is used to replace the original layers with colossalai layer to make them paralleled and can do distributed communication. + def __init__(self, model: torch.nn.Module, shard_config: ShardConfig, Policy: ShardPolicy = None) + #TODO: input is a cls or a obj - - `ModelShard.bind_layer()` + def shard(self) -> None: + """ + Shard model with parallelelism with the help of pre-processing, replace_model_class, replace_module, and post-processing. + """ + ... - This function is used to help different layers share weight or bias. + def replace_model_class(self) -> None: + """ + Replace the model's methods and attributes with our own defined class. + E.g. we can replace the forward function of the original BertForMaskedLM object + with the forward function we define in BertForMaskedLM_ class. + """ + ... - - CLASS `Slicer`: + def replace_module(self) -> None: + """ + Replace the layer according to the policy. Call Policy.module_policy() to get the module. Call _replace_module recursively. + """ + ... +``` - This class is used to slice tensor according to policy. +### User-facing API +We only expose a limited number of APIs to the user to keep their user experience simple and clean. - 3. DistCrossEntropy Loss - - Overview +```python +class ShardFormer: + """ + Parallelize model based on the given config and policy - In order to reduce the communication size, caculate the crossentropy before all gather, refer to [Megatron-LM](https://github.com/NVIDIA/Megatron-LM), reduce the communication size from [batch_size * seq_length * vocab_size] to [batch_size * seq_length]. The origin loss function is: - $$ loss = -\log(\frac{\exp(x[class])}{\sum_i\exp(x[i])})$$ + Example: - alse can be represented as: + shard_former = ShardFormer(shard_config=shard_config) + shard_former.init_distributed() + model = shard_former.shard_model(model, policy=policy) + dataloader = shard_former.shard_dataset(dataset) - $$ loss = \log(\sum_i\exp(x[i])) - x[class]$$ + """ - - Step + def __init__(self, shard_config: ShardConfig): + """ + Do two things: + 1. Create a colossalai.cluster.process_group_manager to manage process groups for dp, tp and pp + 2. serve as a store for shard config + """ + self.shard_config = shard_config + self.pg_manager = None - - First get the maximum logits across all the devices, make all the logist minus the maximun value to scale the value less than zero to avoid the value of exp being too large + def init_distributed(self) -> colossalai.cluster.ProcessGroupManager: + """ + Initialize the distributed process group according to the + """ + pg_manager = ... + self.pg_manager = pg_manager + return pg_manager - - Get a mask to mask the logits not in the local device + def shard_model(self, model: torch.nn.Module,policy: Policy) -> torch.nn.Module: + """ + Shard model for TP and PP + """ + ... - - Caculate the loss according to the second formula + def shard_dataset(self, dataset: Dataset) -> Dataloader: + """ + Shard dataset for DP + """ + ... +``` From c9827698c60dc513a23497b39ff43449cf7b7d37 Mon Sep 17 00:00:00 2001 From: FoolPlayer <45593998+FoolPlayer@users.noreply.github.com> Date: Mon, 19 Jun 2023 10:47:16 +0800 Subject: [PATCH 25/49] [shardformer] fix bert and gpt downstream with new api (#4024) * fix bert downstream with new api * remove comment line --- colossalai/shardformer/policies/basepolicy.py | 14 ++- colossalai/shardformer/policies/bert.py | 92 +++++++++++++++---- colossalai/shardformer/shard/shard_config.py | 6 +- colossalai/shardformer/shard/sharder.py | 9 +- colossalai/shardformer/shard/shardformer.py | 4 - .../test_model/test_shard_bert.py | 11 +-- 6 files changed, 97 insertions(+), 39 deletions(-) diff --git a/colossalai/shardformer/policies/basepolicy.py b/colossalai/shardformer/policies/basepolicy.py index 80ea7a252131..baae95980c14 100644 --- a/colossalai/shardformer/policies/basepolicy.py +++ b/colossalai/shardformer/policies/basepolicy.py @@ -76,6 +76,7 @@ class for the example. def __init__(self) -> None: self.model = None + self.shard_config = None def set_model(self, model: nn.Module) -> None: r""" @@ -86,14 +87,23 @@ def set_model(self, model: nn.Module) -> None: """ self.model = model + def set_shard_config(self, shard_config: ShardConfig) -> None: + r""" + Set shard config as an attribute of the Policy object. + + Args: + shard_config (:class:`ShardConfig`): The shard config to be perform + """ + self.shard_config = shard_config + @abstractmethod - def preprocess(self, shard_config: ShardConfig = None) -> nn.Module: + def preprocess(self) -> nn.Module: r""" Perform some preprocessing of the model, like reshaping the embedding layer """ @abstractmethod - def module_policy(self, shard_config: ShardConfig = None) -> Dict[Union[str, nn.Module], ModulePolicyDescription]: + def module_policy(self) -> Dict[Union[str, nn.Module], ModulePolicyDescription]: r""" Return the dict for the modify policy, the key is the original layer class and the value is the argument for the modify layer diff --git a/colossalai/shardformer/policies/bert.py b/colossalai/shardformer/policies/bert.py index fe74f83ca745..06ee9b435e7e 100644 --- a/colossalai/shardformer/policies/bert.py +++ b/colossalai/shardformer/policies/bert.py @@ -4,41 +4,40 @@ import colossalai.shardformer.layer.layers as col_nn from colossalai.shardformer.layer.dropout import Dropout1D -from ..shard.shard_config import ShardConfig from ..utils import getattr_, setattr_ from .basepolicy import ModulePolicyDescription, Policy, SubModuleReplacementDescription class BertPolicy(Policy): - def preprocess(self, shard_config: ShardConfig = None): + def preprocess(self): # reshape the embedding layer r""" Reshape the Embedding layer to make the embedding dimension divisible by world_size """ # TODO: vocab_size = self.model.config.vocab_size - world_size = shard_config.tensor_parallel_size + world_size = self.shard_config.tensor_parallel_size if vocab_size % world_size != 0: new_vocab_size = vocab_size + world_size - vocab_size % world_size self.model.resize_token_embeddings(new_vocab_size) return self.model - def module_policy(self, shard_config: ShardConfig = None): + def module_policy(self): return { BertLayer: ModulePolicyDescription( attribute_replacement={ # 1. shard hidden size "attention.self.all_head_size": - self.model.config.hidden_size // shard_config.tensor_parallel_size, + self.model.config.hidden_size // self.shard_config.tensor_parallel_size, "crossattention.self.all_head_size": - self.model.config.hidden_size // shard_config.tensor_parallel_size, + self.model.config.hidden_size // self.shard_config.tensor_parallel_size, # 2. shard number of heads "attention.self.num_attention_heads": - self.model.config.num_attention_heads // shard_config.tensor_parallel_size, + self.model.config.num_attention_heads // self.shard_config.tensor_parallel_size, "crossattention.self.num_attention_heads": - self.model.config.num_attention_heads // shard_config.tensor_parallel_size, + self.model.config.num_attention_heads // self.shard_config.tensor_parallel_size, }, param_replacement=[], sub_module_replacement=[ @@ -100,13 +99,43 @@ def postprocess(self): return self.model +# BertModel +class BertModelPolicy(BertPolicy): + + def __init__(self) -> None: + super().__init__() + + +# BertForPreTraining +class BertForPretrainingPolicy(BertPolicy): + + def __init__(self) -> None: + super().__init__() + + def module_policy(self): + module_policy = super().module_policy() + addon_module = { + BertLMPredictionHead: + ModulePolicyDescription(attribute_replacement={}, + param_replacement=[], + sub_module_replacement=[ + SubModuleReplacementDescription(suffix="decoder", + target_module=col_nn.Linear1D_Col, + kwargs={"gather_output": True}) + ]) + } + module_policy.update(addon_module) + return module_policy + + +# BertForMaskedLM class BertForMaskedLMPolicy(BertPolicy): def __init__(self) -> None: super().__init__() - def module_policy(self, shard_config: ShardConfig = None): - module_policy = super().module_policy(shard_config) + def module_policy(self): + module_policy = super().module_policy() addon_module = { BertLMPredictionHead: ModulePolicyDescription(attribute_replacement={}, @@ -124,16 +153,41 @@ def module_policy(self, shard_config: ShardConfig = None): # BertLMHeadModel class BertLMHeadModelPolicy(BertPolicy): - @staticmethod - def argument_policy(config, world_size): - base_argument = BertPolicy.argument_policy(config, world_size) - argument = { - BertLMPredictionHead: Argument(attr_dict={}, param_funcs=[ - BertPolicy.unembedding, - ]), + def __init__(self) -> None: + super().__init__() + + def module_policy(self): + module_policy = super().module_policy() + addon_module = { + BertLMPredictionHead: + ModulePolicyDescription(attribute_replacement={}, + param_replacement=[], + sub_module_replacement=[ + SubModuleReplacementDescription(suffix="decoder", + target_module=col_nn.Linear1D_Col, + kwargs={"gather_output": True}) + ]) } - argument.update(base_argument) - return argument + module_policy.update(addon_module) + return module_policy + + +# BertForNextSentencePrediction +class BertForNextSentencePredictionPolicy(BertPolicy): + + def __init__(self) -> None: + super().__init__() + + +# BertForSequenceClassification +class BertForSequenceClassificationPolicy(BertPolicy): + + def __init__(self) -> None: + super().__init__() + + +# BertForMultipleChoice +class BertForMultipleChoicePolicy(BertPolicy): def __init__(self) -> None: super().__init__() diff --git a/colossalai/shardformer/shard/shard_config.py b/colossalai/shardformer/shard/shard_config.py index 53999529d277..670a5775d8a9 100644 --- a/colossalai/shardformer/shard/shard_config.py +++ b/colossalai/shardformer/shard/shard_config.py @@ -18,10 +18,10 @@ class ShardConfig: will not calculate the loss and just return the output. gather_output (bool): Whether to gather the output of the model of the last layer """ - data_parallel_size: int tensor_parallel_size: int - - pipeline_parallel_size: int + # TODO: add support for tensor parallel + # pipeline_parallel_size: int + # data_parallel_size: int tensor_parallel_mode: Literal['1d', '2d', '2.5d', '3d'] inference_only: bool = True gather_output: bool = True diff --git a/colossalai/shardformer/shard/sharder.py b/colossalai/shardformer/shard/sharder.py index 5c8584595c0c..b90e79059943 100644 --- a/colossalai/shardformer/shard/sharder.py +++ b/colossalai/shardformer/shard/sharder.py @@ -40,6 +40,7 @@ def shard(self) -> None: Shard the model according to the policy """ self.policy.set_model(self.model) + self.policy.set_shard_config(self.shard_config) self.preprocess() self.replace_model_class() self.replace_module() @@ -57,12 +58,12 @@ def reshape_embedding(self) -> None: self.model_config = self.model.config def preprocess(self) -> None: - self.model = self.policy.preprocess(self.shard_config) + self.model = self.policy.preprocess() def postprocess(self) -> None: self.model = self.policy.postprocess() - def replace_model_class(self,) -> None: + def replace_model_class(self) -> None: r""" Replace the model to policy defined model Mainly modify the forward and backward to fit distributed model @@ -83,14 +84,14 @@ def replace_model_class(self,) -> None: getattr(new_model_class, key), ) - def replace_module(self,) -> None: + def replace_module(self) -> None: r""" Replace the module according to the policy, and replace the module one by one Args: model (:class:`torch.nn.Module`): The model to shard """ - module_descriptions = self.policy.module_policy(self.shard_config) + module_descriptions = self.policy.module_policy() for module_description in module_descriptions.items(): origin_layer_cls = module_description[0] attr_replacement = module_description[1].attribute_replacement diff --git a/colossalai/shardformer/shard/shardformer.py b/colossalai/shardformer/shard/shardformer.py index 5313dfecb37e..954bdaa82454 100644 --- a/colossalai/shardformer/shard/shardformer.py +++ b/colossalai/shardformer/shard/shardformer.py @@ -25,11 +25,7 @@ class ShardFormer: org_model = BertForMaskedLM.from_pretrained('bert-base-uncased') shard_config = ShardConfig( tensor_parallel_size=2, - data_parallel_size=1, - pipeline_parallel_size=1, tensor_parallel_mode='1d', - inference_only=True, - gather_output=True ) shard_former = ShardFormer(shard_config=shard_config) shard_former.init_distributed() diff --git a/tests/test_shardformer/test_model/test_shard_bert.py b/tests/test_shardformer/test_model/test_shard_bert.py index 05d03343632f..0dd0fdeee8f8 100644 --- a/tests/test_shardformer/test_model/test_shard_bert.py +++ b/tests/test_shardformer/test_model/test_shard_bert.py @@ -7,7 +7,6 @@ AutoTokenizer, BertConfig, BertForMaskedLM, - BertForMultipleChoice, BertForNextSentencePrediction, BertForPreTraining, BertForSequenceClassification, @@ -36,12 +35,10 @@ def build_model(rank, world_size, model): org_model.to('cuda') # TODO: no need to transfer to cuda org_model_forshard.to('cuda') - shard_config = ShardConfig(tensor_parallel_size=2, - data_parallel_size=1, - pipeline_parallel_size=1, - tensor_parallel_mode='1d', - inference_only=True, - gather_output=True) + shard_config = ShardConfig( + tensor_parallel_size=2, + tensor_parallel_mode='1d', + ) shard_former = ShardFormer(shard_config=shard_config) shard_former.init_distributed() sharded_model = shard_former.shard_model(org_model_forshard).to('cuda') From b2c5dd0b9be0138487d776ba9d844f67f76bf29b Mon Sep 17 00:00:00 2001 From: Frank Lee Date: Mon, 19 Jun 2023 13:53:17 +0800 Subject: [PATCH 26/49] [shardformer] adapted llama to the new API (#4036) --- colossalai/shardformer/policies/autopolicy.py | 134 ++++++------ colossalai/shardformer/policies/basepolicy.py | 5 + colossalai/shardformer/policies/llama.py | 195 +++++++++--------- colossalai/shardformer/shard/shard_config.py | 19 +- colossalai/shardformer/shard/sharder.py | 18 +- colossalai/shardformer/shard/shardformer.py | 6 +- .../test_model/test_shard_bert.py | 28 +-- .../test_model/test_shard_llama.py | 45 ++-- .../test_model/test_shard_t5.py | 3 +- 9 files changed, 245 insertions(+), 208 deletions(-) diff --git a/colossalai/shardformer/policies/autopolicy.py b/colossalai/shardformer/policies/autopolicy.py index 6239397b7cbe..e1b3a6a815a2 100644 --- a/colossalai/shardformer/policies/autopolicy.py +++ b/colossalai/shardformer/policies/autopolicy.py @@ -1,64 +1,76 @@ +import importlib +from dataclasses import dataclass + import torch.nn as nn from .basepolicy import Policy -def build_policies(): - r""" - Build the policies for the model - - Return: - The dict for the policies +@dataclass +class PolicyLocation: """ - auto_policy_dict = {} - - from transformers import BertModel - - from .bert import BertModelPolicy - auto_policy_dict[BertModel] = BertModelPolicy - - from transformers import BertForPreTraining - - from .bert import BertForPretrainingPolicy - auto_policy_dict[BertForPreTraining] = BertForPretrainingPolicy - - from transformers import BertLMHeadModel - - from .bert import BertLMHeadModelPolicy - auto_policy_dict[BertLMHeadModel] = BertLMHeadModelPolicy - - from transformers import BertForMaskedLM - - from .bert import BertForMaskedLMPolicy - auto_policy_dict[BertForMaskedLM] = BertForMaskedLMPolicy - - from transformers import BertForNextSentencePrediction + PolicyLocation describes the location of a policy class. - from .bert import BertForNextSentencePredictionPolicy - auto_policy_dict[BertForNextSentencePrediction] = BertForNextSentencePredictionPolicy - - from transformers import BertForSequenceClassification - - from .bert import BertForSequenceClassificationPolicy - auto_policy_dict[BertForSequenceClassification] = BertForSequenceClassificationPolicy - from transformers.models.llama.modeling_llama import LlamaModel + Args: + file_name (str): The file name of the policy under colossalai.shardformer.policies + class_name (str): The class name of the policy class + """ + file_name: str + class_name: str + + +# we don't want to import all policies here +# as each policy file imports its own model zoo library +# we will allow the user to only import the policy file needed +_POLICY_LIST = { + # BERT + "transformers.models.bert.modeling_bert.BertModel": + PolicyLocation(file_name="bert", class_name="BertPolicy"), + "transformers.models.bert.modeling_bert.BertForPreTraining": + PolicyLocation(file_name="bert", class_name="BertForPretrainingPolicy"), + "transformers.models.bert.modeling_bert.BertForMaskedLM": + PolicyLocation(file_name="bert", class_name="BertForMaskedLMPolicy"), + "transformers.models.bert.modeling_bert.BertLMHeadModel": + PolicyLocation(file_name="bert", class_name="BertLMHeadModelPolicy"), + "transformers.models.bert.modeling_bert.BertForNextSentencePrediction": + PolicyLocation(file_name="bert", class_name="BertForNextSentencePredictionPolicy"), + "transformers.models.bert.modeling_bert.BertForSequenceClassification": + PolicyLocation(file_name="bert", class_name="BertForSequenceClassificationPolicy"), + "transformers.models.bert.modeling_bert.BertForMultipleChoice": + PolicyLocation(file_name="bert", class_name="BertForMultipleChoicePolicy"), + + # LLaMA + "transformers.models.llama.modeling_llama.LlamaModel": + PolicyLocation(file_name="llama", class_name="LlamaPolicy"), + "transformers.models.llama.modeling_llama.LlamaForCausalLM": + PolicyLocation(file_name="llama", class_name="LlamaForCausalLMPolicy"), + "transformers.models.llama.modeling_llama.LlamaForSequenceClassification": + PolicyLocation(file_name="llama", class_name="LlamaForSequenceClassificationPolicy"), + + # T5 + + # GPT2 +} + + +def import_policy(policy_location: PolicyLocation) -> Policy: + """ + Dynamically import a Policy class based on the policy location. + """ + module_name = f"colossalai.shardformer.policies.{policy_location.file_name}" + module = importlib.import_module(module_name) + return getattr(module, policy_location.class_name) - # from .llama import LlamaPolicy - # auto_policy_dict[LlamaModel] = LlamaPolicy - # from transformers import LlamaForSequenceClassification - # from .llama import LlamaForSequenceClassificationPolicy - # auto_policy_dict[LlamaForSequenceClassification] = LlamaForSequenceClassificationPolicy - # from transformers import LlamaForCausalLM - # from .llama import LlamaForCausalLMPolicy - # auto_policy_dict[LlamaForCausalLM] = LlamaForCausalLMPolicy - # from transformers import GPT2Model - # from .gpt2 import GPT2Policy - # auto_policy_dict[GPT2Model] = GPT2Policy - # from transformers import GPT2LMHeadModel - # from .gpt2 import GPT2LMHeadModelPolicy - # auto_policy_dict[GPT2LMHeadModel] = GPT2LMHeadModelPolicy - return auto_policy_dict +def _fullname(obj): + """ + Return the full name of an object, including the module name. + """ + klass = obj.__class__ + module = klass.__module__ + if module == 'builtins': + return klass.__qualname__ # avoid outputs like 'builtins.str' + return module + '.' + klass.__qualname__ def get_autopolicy(model: nn.Module) -> Policy: @@ -71,16 +83,14 @@ def get_autopolicy(model: nn.Module) -> Policy: Return: :class:`Policy`: The auto policy for the model """ - auto_policy_dict = build_policies() - policy = auto_policy_dict.get(model.__class__, None) - if policy is None: + full_name = _fullname(model) + policy_location = _POLICY_LIST.get(full_name, None) + + if policy_location is None: raise NotImplementedError( - f"Auto policy for {model.__class__.__qualname__} is not implemented\n Supported models are {[i.__qualname__ for i in auto_policy_dict.keys()]}" + f"Auto policy for {model.__class__.__qualname__} is not implemented\n. Supported models are {list(_POLICY_LIST.keys())}" ) + else: + policy = import_policy(policy_location) + return policy() return policy() - - -# from transformers.models.bert.modeling_bert import BertForMaskedLM, BertForPreTraining -# model = BertForPreTraining -# policy = get_autopolicy(model) -# print(policy) diff --git a/colossalai/shardformer/policies/basepolicy.py b/colossalai/shardformer/policies/basepolicy.py index baae95980c14..e4f2e9432e10 100644 --- a/colossalai/shardformer/policies/basepolicy.py +++ b/colossalai/shardformer/policies/basepolicy.py @@ -75,6 +75,7 @@ class for the example. """ def __init__(self) -> None: + self.shard_config = None self.model = None self.shard_config = None @@ -101,6 +102,7 @@ def preprocess(self) -> nn.Module: r""" Perform some preprocessing of the model, like reshaping the embedding layer """ + pass @abstractmethod def module_policy(self) -> Dict[Union[str, nn.Module], ModulePolicyDescription]: @@ -135,6 +137,7 @@ def module_policy(self) -> Dict[Union[str, nn.Module], ModulePolicyDescription]: ... } """ + pass @abstractmethod def new_model_class(self) -> Union[Type[nn.Module], None]: @@ -149,6 +152,7 @@ def new_model_class(self) -> Union[Type[nn.Module], None]: return BertModel_ ``` """ + pass @abstractmethod def postprocess(self) -> nn.Module: @@ -156,3 +160,4 @@ def postprocess(self) -> nn.Module: Perform some postprocessing of the model, like binding the weight of embedding layer with the classifier layer """ + pass diff --git a/colossalai/shardformer/policies/llama.py b/colossalai/shardformer/policies/llama.py index fac6765cdcb5..ae1b794fca12 100644 --- a/colossalai/shardformer/policies/llama.py +++ b/colossalai/shardformer/policies/llama.py @@ -1,122 +1,121 @@ -from dataclasses import dataclass -from typing import Any, Callable, Dict, List, Tuple, Type +from typing import Dict, Union import torch.nn as nn +from transformers import LlamaForCausalLM, LlamaForSequenceClassification from transformers.models.llama.modeling_llama import LlamaDecoderLayer, LlamaModel -import colossalai.shardformer.layer.layers as col_nn +from colossalai.shardformer.layer.layers import Linear1D_Col, Linear1D_Row, VocabParallelEmbedding1D -from .basepolicy import Argument, Col_Layer, Policy, Row_Layer +from .basepolicy import ModulePolicyDescription, Policy, SubModuleReplacementDescription class LlamaPolicy(Policy): - @staticmethod - def argument_policy(config, world_size: int) -> Dict[nn.Module, Argument]: + def preprocess(self): + # Resize embedding + vocab_size = self.model.config.vocab_size + world_size = self.shard_config.tensor_parallel_size + + if vocab_size % world_size != 0: + new_vocab_size = vocab_size + world_size - vocab_size % world_size + self.model.resize_token_embeddings(new_vocab_size) + + return self.model + + def module_policy(self) -> Dict[Union[str, nn.Module], ModulePolicyDescription]: return { LlamaDecoderLayer: - Argument(attr_dict={ - "self_attn.hidden_size": config.hidden_size // world_size, - "self_attn.num_heads": config.num_attention_heads // world_size, - }, - param_funcs=[LlamaPolicy.attn_layer, LlamaPolicy.mlp_layer]), + ModulePolicyDescription( + attribute_replacement={ + "self_attn.hidden_size": + self.model.config.hidden_size // self.shard_config.tensor_parallel_size, + "self_attn.num_heads": + self.model.config.num_attention_heads // self.shard_config.tensor_parallel_size, + }, + param_replacement=[], + sub_module_replacement=[ + SubModuleReplacementDescription( + suffix="self_attn.q_proj", + target_module=Linear1D_Col, + ), + SubModuleReplacementDescription( + suffix="self_attn.k_proj", + target_module=Linear1D_Col, + ), + SubModuleReplacementDescription( + suffix="self_attn.v_proj", + target_module=Linear1D_Col, + ), + SubModuleReplacementDescription( + suffix="self_attn.o_proj", + target_module=Linear1D_Row, + ), + SubModuleReplacementDescription( + suffix="mlp.gate_proj", + target_module=Linear1D_Col, + ), + SubModuleReplacementDescription( + suffix="mlp.up_proj", + target_module=Linear1D_Col, + ), + SubModuleReplacementDescription( + suffix="mlp.down_proj", + target_module=Linear1D_Row, + ) + ], + ), LlamaModel: - Argument(attr_dict={}, param_funcs=[LlamaPolicy.embeddings]) + ModulePolicyDescription(attribute_replacement={}, + param_replacement=[], + sub_module_replacement=[ + SubModuleReplacementDescription( + suffix="embed_tokens", + target_module=VocabParallelEmbedding1D, + ) + ]) } - @staticmethod - def attn_layer() -> List: - return [ - Col_Layer( - suffix="self_attn.q_proj", - weight="weight", - bias="bias", - replace_layer=col_nn.Linear1D_Col, - ), - Col_Layer( - suffix="self_attn.k_proj", - weight="weight", - bias="bias", - replace_layer=col_nn.Linear1D_Col, - ), - Col_Layer( - suffix="self_attn.v_proj", - weight="weight", - bias="bias", - replace_layer=col_nn.Linear1D_Col, - ), - Row_Layer( - suffix="self_attn.o_proj", - weight="weight", - bias="bias", - replace_layer=col_nn.Linear1D_Row, - ) - ] - - @staticmethod - def mlp_layer() -> List: - return [ - Col_Layer( - suffix="mlp.gate_proj", - weight="weight", - bias="bias", - replace_layer=col_nn.Linear1D_Col, - gather_output=True, - ), - Col_Layer( - suffix="mlp.up_proj", - weight="weight", - bias="bias", - replace_layer=col_nn.Linear1D_Row, - gather_output=True, - ), - Col_Layer( - suffix="mlp.down_proj", - weight="weight", - bias="bias", - replace_layer=col_nn.Linear1D_Col, - gather_output=True, - ), - ] - - @staticmethod - def embeddings() -> List: - return [Col_Layer( - suffix="embed_tokens", - weight="weight", - replace_layer=col_nn.VocabParallelEmbedding1D, - )] - -from transformers import LlamaForCausalLM - - -class LlamaForCausalLMPolicy(LlamaPolicy): + def new_model_class(self): + return None - @staticmethod - def argument(config, world_size): - llamapolicy = LlamaPolicy.argument_policy(config, world_size) - argument = {LlamaForCausalLM: Argument(attr_dict={}, param_funcs=[LlamaForCausalLMPolicy.lm_head])} - argument.update(llamapolicy) + def postprocess(self): + return self.model - @staticmethod - def lm_head() -> List: - return [Col_Layer(suffix="lm_head", weight="weight", replace_layer=col_nn.Linear1D_Col, gather_output=True)] +class LlamaForCausalLMPolicy(LlamaPolicy): -from transformers import LlamaForSequenceClassification + def module_policy(self): + policy = super().module_policy() + # add a new item for casual lm + new_item = { + LlamaForCausalLM: + ModulePolicyDescription(attribute_replacement={}, + param_replacement=[], + sub_module_replacement=[ + SubModuleReplacementDescription(suffix="lm_head", + target_module=Linear1D_Col, + kwargs=dict(gather_output=True)) + ]) + } + policy.update(new_item) + return policy class LlamaForSequenceClassificationPolicy(LlamaPolicy): - @staticmethod - def argument(config, world_size): - llamapolicy = LlamaPolicy.argument_policy(config, world_size) - argument = { + def module_policy(self): + policy = super().module_policy() + + # add a new item for sequence classification + new_item = { LlamaForSequenceClassification: - Argument(attr_dict={}, param_funcs=[LlamaForSequenceClassificationPolicy.score]) + ModulePolicyDescription(attribute_replacement={}, + param_replacement=[], + sub_module_replacement=[ + SubModuleReplacementDescription(suffix="score", + target_module=Linear1D_Col, + kwargs=dict(gather_output=True)) + ]) } - argument.update(llamapolicy) - - @staticmethod - def score() -> List: - return [Col_Layer(suffix="score", weight="weight", replace_layer=col_nn.Linear1D_Col, gather_output=True)] + policy.update(new_item) + return policy diff --git a/colossalai/shardformer/shard/shard_config.py b/colossalai/shardformer/shard/shard_config.py index 670a5775d8a9..7379a8208745 100644 --- a/colossalai/shardformer/shard/shard_config.py +++ b/colossalai/shardformer/shard/shard_config.py @@ -1,5 +1,6 @@ from dataclasses import dataclass -from typing import List, Literal + +from colossalai.cluster.dist_coordinator import DistCoordinator __all__ = ['ShardConfig'] @@ -19,9 +20,19 @@ class ShardConfig: gather_output (bool): Whether to gather the output of the model of the last layer """ tensor_parallel_size: int + # TODO: add support for tensor parallel # pipeline_parallel_size: int # data_parallel_size: int - tensor_parallel_mode: Literal['1d', '2d', '2.5d', '3d'] - inference_only: bool = True - gather_output: bool = True + # tensor_parallel_mode: Literal['1d', '2d', '2.5d', '3d'] + # inference_only: bool = True + # gather_output: bool = True + + def __post_init__(self): + coordinator = DistCoordinator() + + # ensure the parallel size can match the world size + world_size = coordinator.world_size + self.data_parallel_size = world_size // self.tensor_parallel_size + assert world_size == self.data_parallel_size * self.tensor_parallel_size, \ + f"The world size ({world_size}) should be divisible by the data parallel size {self.data_parallel_size} and tensor parallel size {self.tensor_parallel_size}" diff --git a/colossalai/shardformer/shard/sharder.py b/colossalai/shardformer/shard/sharder.py index b90e79059943..c948a7939d15 100644 --- a/colossalai/shardformer/shard/sharder.py +++ b/colossalai/shardformer/shard/sharder.py @@ -1,8 +1,6 @@ from typing import Any, Callable, Dict, List -import torch import torch.nn as nn -from transformers.pytorch_utils import Conv1D from colossalai.cluster.process_group_manager import ProcessGroupManager @@ -41,10 +39,10 @@ def shard(self) -> None: """ self.policy.set_model(self.model) self.policy.set_shard_config(self.shard_config) - self.preprocess() - self.replace_model_class() - self.replace_module() - self.postprocess() + self._preprocess() + self._replace_model_class() + self._replace_module() + self._postprocess() def reshape_embedding(self) -> None: r""" @@ -57,13 +55,13 @@ def reshape_embedding(self) -> None: self.model.resize_token_embeddings(new_vocab_size) self.model_config = self.model.config - def preprocess(self) -> None: + def _preprocess(self) -> None: self.model = self.policy.preprocess() - def postprocess(self) -> None: + def _postprocess(self) -> None: self.model = self.policy.postprocess() - def replace_model_class(self) -> None: + def _replace_model_class(self,) -> None: r""" Replace the model to policy defined model Mainly modify the forward and backward to fit distributed model @@ -84,7 +82,7 @@ def replace_model_class(self) -> None: getattr(new_model_class, key), ) - def replace_module(self) -> None: + def _replace_module(self,) -> None: r""" Replace the module according to the policy, and replace the module one by one diff --git a/colossalai/shardformer/shard/shardformer.py b/colossalai/shardformer/shard/shardformer.py index 954bdaa82454..1208a9d090fb 100644 --- a/colossalai/shardformer/shard/shardformer.py +++ b/colossalai/shardformer/shard/shardformer.py @@ -47,10 +47,12 @@ def init_distributed(self) -> ProcessGroupManager: """ Initialize the distributed process group according to the """ + # create process group manager and 1d process group + # TODO: may need to support other parallel mode when the config has such as field pg_manager = ProcessGroupManager() - if (self.shard_config.tensor_parallel_mode == '1d'): - pg_manager.create_process_group(name='tp1d', ranks=range(self.coordinator.world_size)) + pg_manager.create_process_group(name='tp1d', ranks=range(self.coordinator.world_size)) self.pg_manager = pg_manager + return pg_manager def shard_model(self, model: nn.Module, policy: Policy = None): diff --git a/tests/test_shardformer/test_model/test_shard_bert.py b/tests/test_shardformer/test_model/test_shard_bert.py index 0dd0fdeee8f8..54fea0335e54 100644 --- a/tests/test_shardformer/test_model/test_shard_bert.py +++ b/tests/test_shardformer/test_model/test_shard_bert.py @@ -24,21 +24,18 @@ tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased") -def build_model(rank, world_size, model): - config = BertConfig.from_pretrained('bert-base-uncased') +def build_model(world_size, model_fn): + config = BertConfig() config.hidden_dropout_prob = 0 config.attention_probs_dropout_prob = 0 - org_model = BertForMaskedLM.from_pretrained('bert-base-uncased', config=config) + org_model = model_fn(config=config) org_model_forshard = copy.deepcopy(org_model) org_model.to('cuda') # TODO: no need to transfer to cuda org_model_forshard.to('cuda') - shard_config = ShardConfig( - tensor_parallel_size=2, - tensor_parallel_mode='1d', - ) + shard_config = ShardConfig(tensor_parallel_size=world_size,) shard_former = ShardFormer(shard_config=shard_config) shard_former.init_distributed() sharded_model = shard_former.shard_model(org_model_forshard).to('cuda') @@ -99,15 +96,22 @@ def check_bert(rank, world_size, port): disable_existing_loggers() colossalai.launch(config=CONFIG, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') forward_list = [ - BertModel, BertForPreTraining, BertForMaskedLM, BertLMHeadModel, BertForNextSentencePrediction, - BertForSequenceClassification + BertForMaskedLM, + BertForPreTraining, + BertLMHeadModel, + + # TODO: do not work yet + # BertModel, + # BertForSequenceClassification + # BertForNextSentencePrediction, ] backward_lsit = [BertForMaskedLM, BertLMHeadModel] - for model in forward_list: - org_model, sharded_model = build_model(rank, world_size, model) + for model_fn in forward_list: + org_model, sharded_model = build_model(model_fn) check_forward(org_model, sharded_model) - if model in backward_lsit: + + if model_fn in backward_lsit: check_backward(org_model, sharded_model) torch.cuda.empty_cache() diff --git a/tests/test_shardformer/test_model/test_shard_llama.py b/tests/test_shardformer/test_model/test_shard_llama.py index 689898bbbad2..a3c7647fafc6 100644 --- a/tests/test_shardformer/test_model/test_shard_llama.py +++ b/tests/test_shardformer/test_model/test_shard_llama.py @@ -4,31 +4,28 @@ import pytest import torch -from transformers import AutoTokenizer, LlamaConfig, LlamaForCausalLM, LlamaModel, LlamaTokenizerFast +from transformers import LlamaConfig, LlamaForCausalLM, LlamaForSequenceClassification, LlamaModel, LlamaTokenizerFast import colossalai from colossalai.logging import disable_existing_loggers -from colossalai.shardformer.shard import ShardConfig, shard_model +from colossalai.shardformer import ShardConfig, ShardFormer from colossalai.testing import rerun_if_address_is_in_use, spawn os.environ['TRANSFORMERS_NO_ADVISORY_WARNINGS'] = 'true' -CONFIG = dict(parallel=dict(data=1, pipeline=1, tensor=dict(size=4, mode='1d')),) tokenizer = LlamaTokenizerFast.from_pretrained("hf-internal-testing/llama-tokenizer") -def build_model(rank, world_size): - cfg = LlamaConfig(num_hidden_layers=16) - org_model = LlamaForCausalLM(cfg) +def build_model(world_size, model_fn): + # create new model + config = LlamaConfig(num_hidden_layers=8) + org_model = model_fn(config).cuda() - shardconfig = ShardConfig( - rank=rank, - world_size=world_size, - gather_output=True, - ) - org_model = org_model.to('cuda') - - org_model_forshard = copy.deepcopy(org_model) - sharded_model = shard_model(org_model_forshard, shardconfig).to('cuda') + # shard model + shard_config = ShardConfig(tensor_parallel_size=world_size) + model_copy = copy.deepcopy(org_model) + shard_former = ShardFormer(shard_config=shard_config) + shard_former.init_distributed() + sharded_model = shard_former.shard_model(model_copy) return org_model, sharded_model @@ -38,6 +35,7 @@ def check_forward(org_model, sharded_model): inputs = tokenizer(input, return_tensors='pt').to('cuda') del inputs["token_type_ids"] del inputs["attention_mask"] + #orgin model org_model.eval() org_out = org_model(**inputs) @@ -87,11 +85,20 @@ def check_backward(org_model, sharded_model): def check_llama(rank, world_size, port): disable_existing_loggers() - colossalai.launch(config=CONFIG, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') + colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') + + model_list = [ + LlamaForCausalLM, + + # TODO: do not work yet + # LlamaModel, + # LlamaForSequenceClassification + ] - org_model, sharded_model = build_model(rank, world_size) - check_forward(org_model, sharded_model) - check_backward(org_model, sharded_model) + for model_fn in model_list: + org_model, sharded_model = build_model(world_size, model_fn) + check_forward(org_model, sharded_model) + check_backward(org_model, sharded_model) torch.cuda.empty_cache() diff --git a/tests/test_shardformer/test_model/test_shard_t5.py b/tests/test_shardformer/test_model/test_shard_t5.py index ca44f0b00a74..9b1c2678f39b 100644 --- a/tests/test_shardformer/test_model/test_shard_t5.py +++ b/tests/test_shardformer/test_model/test_shard_t5.py @@ -8,7 +8,7 @@ import colossalai from colossalai.logging import disable_existing_loggers -from colossalai.shardformer.shard import ShardConfig, shard_model +from colossalai.shardformer.shard import ShardConfig, ShardFormer from colossalai.testing import rerun_if_address_is_in_use, spawn os.environ['TRANSFORMERS_NO_ADVISORY_WARNINGS'] = 'true' @@ -90,6 +90,7 @@ def check_t5(rank, world_size, port): @pytest.mark.dist +@pytest.mark.skip @rerun_if_address_is_in_use() def test_t5(): spawn(check_t5, 2) From 8219d960922fb2a204f41f65071664b2e8f1b52c Mon Sep 17 00:00:00 2001 From: Frank Lee Date: Mon, 19 Jun 2023 17:57:37 +0800 Subject: [PATCH 27/49] [shardformer] supported T5 and its variants (#4045) --- colossalai/shardformer/README.md | 5 +- colossalai/shardformer/layer/layers.py | 26 +- colossalai/shardformer/policies/autopolicy.py | 6 + colossalai/shardformer/policies/basepolicy.py | 1 + colossalai/shardformer/policies/t5.py | 258 +++++++++--------- colossalai/shardformer/shard/sharder.py | 11 +- colossalai/testing/__init__.py | 3 +- colossalai/testing/comparison.py | 51 +++- .../test_model/test_shard_llama.py | 82 +++--- .../test_model/test_shard_t5.py | 94 ++++--- 10 files changed, 316 insertions(+), 221 deletions(-) diff --git a/colossalai/shardformer/README.md b/colossalai/shardformer/README.md index dc2946ec937f..fee4cce7a28a 100644 --- a/colossalai/shardformer/README.md +++ b/colossalai/shardformer/README.md @@ -81,8 +81,8 @@ We will follow this roadmap to develop Shardformer: - [ ] Hugging Face - [ ] NLP - [x] BERT - - [ ] T5 - - [ ] LlaMa + - [x] T5 + - [x] LlaMa - [ ] GPT2 - [ ] BLOOM - [ ] RoBERTa @@ -90,7 +90,6 @@ We will follow this roadmap to develop Shardformer: - [ ] ERNIE - [ ] GPT Neo - [ ] GPT-J - - [ ] CV - [ ] CV - [ ] ViT - [ ] BEiT diff --git a/colossalai/shardformer/layer/layers.py b/colossalai/shardformer/layer/layers.py index ad6e1896aa5e..5dbe28956d27 100644 --- a/colossalai/shardformer/layer/layers.py +++ b/colossalai/shardformer/layer/layers.py @@ -469,13 +469,14 @@ def __init__(self, dtype: torch.dtype = None, device: torch.device = None, process_group: ProcessGroup = None, + gather_output: bool = True, weight_initializer: Callable = init.normal_(), *args, **kwargs): super().__init__() self.num_embeddings = num_embeddings - self.embed_dim = embedding_dim + self.embedding_dim = embedding_dim self.process_group = process_group self.num_partitions = dist.get_world_size(process_group) self.embed_dim_per_partition = divide(embedding_dim, self.num_partitions) @@ -499,7 +500,9 @@ def __init__(self, @staticmethod def from_native_module(module: nn.Embedding, - process_group: Union[ProcessGroup, List[ProcessGroup]] = None) -> "Embedding1D": + process_group: Union[ProcessGroup, List[ProcessGroup]] = None, + *args, + **kwargs) -> "Embedding1D": r""" Build a 1D parallelized Embedding from a native nn.Embedding module. """ @@ -527,7 +530,9 @@ def from_native_module(module: nn.Embedding, max_norm=max_norm, norm_type=norm_type, scale_grad_by_freq=scale_grad_by_freq, - sparse=sparse) + sparse=sparse, + *args, + **kwargs) # copy the weight with torch.no_grad(): @@ -537,7 +542,7 @@ def from_native_module(module: nn.Embedding, return embedding def reset_parameters(self, weight_initializer) -> None: - fan_in, fan_out = self.num_embeddings, self.embed_dim + fan_in, fan_out = self.num_embeddings, self.embedding_dim weight_initializer(self.weight, fan_in=fan_in, fan_out=fan_out) self._fill_padding_idx_with_zero() @@ -548,9 +553,12 @@ def _fill_padding_idx_with_zero(self) -> None: def forward(self, input_: Tensor) -> Tensor: output_parallel = F.embedding(input_, self.weight, self.padding_idx, *self.embed_args, **self.embed_kwargs) - output = gather_forward_split_backward(output_parallel, dim=-1, process_group=self.process_group) - return output + if self.gather_output: + output = gather_forward_split_backward(output_parallel, dim=-1, process_group=self.process_group) + return output + else: + return output_parallel class VocabParallelEmbedding1D(ParallelLayer): @@ -595,7 +603,7 @@ def __init__(self, **kwargs): super().__init__() self.num_embeddings = num_embeddings - self.embed_dim = embedding_dim + self.embedding_dim = embedding_dim self.padding_idx = padding_idx self.embed_args = args self.embed_kwargs = kwargs @@ -610,7 +618,7 @@ def __init__(self, self.vocab_end_index = self.vocab_start_index + self.num_embeddings_per_partition self.weight = Parameter( - torch.empty((self.num_embeddings_per_partition, self.embed_dim), device=device, dtype=dtype)) + torch.empty((self.num_embeddings_per_partition, self.embedding_dim), device=device, dtype=dtype)) # offset the seed with randomizer index and rank seed = torch.random.initial_seed() @@ -662,7 +670,7 @@ def _set_tensor_parallel_attributes(self): def reset_parameters(self, weight_initializer) -> None: with seed(ParallelMode.TENSOR): - fan_in, fan_out = self.num_embeddings, self.embed_dim + fan_in, fan_out = self.num_embeddings, self.embedding_dim weight_initializer(self.weight, fan_in=fan_in, fan_out=fan_out) self._fill_padding_idx_with_zero() diff --git a/colossalai/shardformer/policies/autopolicy.py b/colossalai/shardformer/policies/autopolicy.py index e1b3a6a815a2..6ce0b8fb3a3d 100644 --- a/colossalai/shardformer/policies/autopolicy.py +++ b/colossalai/shardformer/policies/autopolicy.py @@ -48,6 +48,12 @@ class PolicyLocation: PolicyLocation(file_name="llama", class_name="LlamaForSequenceClassificationPolicy"), # T5 + "transformers.models.t5.modeling_t5.T5Model": + PolicyLocation(file_name="t5", class_name="T5ModelPolicy"), + "transformers.models.t5.modeling_t5.T5ForConditionalGeneration": + PolicyLocation(file_name="t5", class_name="T5ForConditionalGenerationPolicy"), + "transformers.models.t5.modeling_t5.T5EncoderModel": + PolicyLocation(file_name="t5", class_name="T5EncoderPolicy"), # GPT2 } diff --git a/colossalai/shardformer/policies/basepolicy.py b/colossalai/shardformer/policies/basepolicy.py index e4f2e9432e10..175a914a84f9 100644 --- a/colossalai/shardformer/policies/basepolicy.py +++ b/colossalai/shardformer/policies/basepolicy.py @@ -27,6 +27,7 @@ class SubModuleReplacementDescription: suffix: str target_module: ParallelModule kwargs: Dict[str, Any] = None + ignore_if_not_exist: bool = False @dataclass diff --git a/colossalai/shardformer/policies/t5.py b/colossalai/shardformer/policies/t5.py index 7b013a37845a..9c8ee59b4178 100644 --- a/colossalai/shardformer/policies/t5.py +++ b/colossalai/shardformer/policies/t5.py @@ -1,159 +1,173 @@ -from typing import Dict - +import torch import torch.nn as nn -from torch.nn import Embedding +from transformers import T5ForConditionalGeneration from transformers.models.t5.modeling_t5 import ( T5Attention, - T5Block, T5DenseActDense, T5DenseGatedActDense, T5LayerCrossAttention, T5LayerFF, T5LayerSelfAttention, - T5Model, T5Stack, ) -import colossalai.shardformer.layer.layers as col_nn +from colossalai.shardformer.layer.dropout import Dropout1D +from colossalai.shardformer.layer.layers import Embedding1D, Linear1D_Col, Linear1D_Row -from .basepolicy import Argument, Col_Layer, Dropout_Layer, Embedding_Layer, Policy, Row_Layer +from .basepolicy import ModulePolicyDescription, Policy, SubModuleReplacementDescription + +__all__ = ["T5ModelPolicy", "T5ForConditionalGenerationPolicy", "T5EncoderPolicy"] class T5ModelPolicy(Policy): - @staticmethod - def argument_policy(config, world_size: int) -> Dict[nn.Module, Argument]: - print('config heads', config.num_heads) + def preprocess(self): + # reshape the embedding layer + r""" + Reshape the Embedding layer to make the embedding dimension divisible by world_size + """ + vocab_size = self.model.config.vocab_size + world_size = self.shard_config.tensor_parallel_size + if vocab_size % world_size != 0: + new_vocab_size = vocab_size + world_size - vocab_size % world_size + self.model.resize_token_embeddings(new_vocab_size) + return self.model + + def module_policy(self): return { T5Stack: - Argument(attr_dict={}, param_funcs=[T5ModelPolicy.dropout, T5ModelPolicy.embedding]), - T5Block: - Argument(attr_dict={}, param_funcs=[]), + ModulePolicyDescription(attribute_replacement={}, + param_replacement=[], + sub_module_replacement=[ + SubModuleReplacementDescription( + suffix="dropout", + target_module=Dropout1D, + ) + ]), T5LayerSelfAttention: - Argument(attr_dict={}, param_funcs=[T5ModelPolicy.dropout]), + ModulePolicyDescription(attribute_replacement={}, + param_replacement=[], + sub_module_replacement=[ + SubModuleReplacementDescription( + suffix="dropout", + target_module=Dropout1D, + ), + ]), T5LayerCrossAttention: - Argument(attr_dict={}, param_funcs=[T5ModelPolicy.dropout]), + ModulePolicyDescription(attribute_replacement={}, + param_replacement=[], + sub_module_replacement=[ + SubModuleReplacementDescription( + suffix="dropout", + target_module=Dropout1D, + ) + ]), T5Attention: - Argument(attr_dict={ - "d_model": config.d_model // world_size, - "n_heads": config.num_heads // world_size, - "inner_dim": config.num_heads * config.d_kv // world_size, + ModulePolicyDescription(attribute_replacement={ + "d_model": + self.model.config.d_model // self.shard_config.tensor_parallel_size, + "n_heads": + self.model.config.num_heads // self.shard_config.tensor_parallel_size, + "inner_dim": + self.model.config.num_heads * self.model.config.d_kv // self.shard_config.tensor_parallel_size }, - param_funcs=[T5ModelPolicy.attn_layer]), + param_replacement=[], + sub_module_replacement=[ + SubModuleReplacementDescription( + suffix="q", + target_module=Linear1D_Col, + ), + SubModuleReplacementDescription( + suffix="k", + target_module=Linear1D_Col, + ), + SubModuleReplacementDescription( + suffix="v", + target_module=Linear1D_Col, + ), + SubModuleReplacementDescription( + suffix="o", + target_module=Linear1D_Row, + ), + SubModuleReplacementDescription(suffix="relative_attention_bias", + target_module=Embedding1D, + kwargs=dict(gather_output=False), + ignore_if_not_exist=True) + ]), T5LayerFF: - Argument(attr_dict={}, param_funcs=[T5ModelPolicy.dropout]), + ModulePolicyDescription(attribute_replacement={}, + param_replacement=[], + sub_module_replacement=[ + SubModuleReplacementDescription( + suffix="dropout", + target_module=Dropout1D, + ), + ]), T5DenseGatedActDense: - Argument(attr_dict={}, param_funcs=[T5ModelPolicy.dropout, T5ModelPolicy.dense_gated_layer]), + ModulePolicyDescription(attribute_replacement={}, + param_replacement=[], + sub_module_replacement=[ + SubModuleReplacementDescription( + suffix="wi_0", + target_module=Linear1D_Col, + ), + SubModuleReplacementDescription( + suffix="wi_1", + target_module=Linear1D_Row, + ), + SubModuleReplacementDescription(suffix="wo", + target_module=Linear1D_Col, + kwargs=dict(gather_output=True)), + SubModuleReplacementDescription( + suffix="dropout", + target_module=Dropout1D, + ) + ]), T5DenseActDense: - Argument(attr_dict={}, param_funcs=[T5ModelPolicy.dropout, T5ModelPolicy.dense_act_layer]), + ModulePolicyDescription(attribute_replacement={}, + param_replacement=[], + sub_module_replacement=[ + SubModuleReplacementDescription( + suffix="wi", + target_module=Linear1D_Col, + ), + SubModuleReplacementDescription( + suffix="wo", + target_module=Linear1D_Row, + ), + SubModuleReplacementDescription( + suffix="dropout", + target_module=Dropout1D, + ) + ]) } - @staticmethod - def dense_gated_layer(): - return [ - Col_Layer( - suffix="wi_0", - weight="weight", - replace_layer=col_nn.Linear1D_Col, - ), - Row_Layer( - suffix="wi_1", - weight="weight", - replace_layer=col_nn.Linear1D_Row, - ), - Col_Layer(suffix="wo", weight="weight", replace_layer=col_nn.Linear1D_Col, gather_output=True) - ] - - @staticmethod - def dense_act_layer(): - return [ - Col_Layer( - suffix="wi", - weight="weight", - replace_layer=col_nn.Linear1D_Col, - ), - Row_Layer( - suffix="wo", - weight="weight", - replace_layer=col_nn.Linear1D_Row, - ) - ] - - @staticmethod - def attn_layer(): - return [ - Col_Layer( - suffix="q", - weight="weight", - bias="bias", - replace_layer=col_nn.Linear1D_Col, - ), - Col_Layer( - suffix="k", - weight="weight", - bias="bias", - replace_layer=col_nn.Linear1D_Col, - ), - Col_Layer( - suffix="v", - weight="weight", - bias="bias", - replace_layer=col_nn.Linear1D_Col, - ), - Row_Layer( - suffix="o", - weight="weight", - bias="bias", - replace_layer=col_nn.Linear1D_Row, - ), - ] - - @staticmethod - def dropout(): - return [Dropout_Layer( - suffix="dropout", - p="p", - replace_layer=col_nn.Dropout1D, - )] - - @staticmethod - def embedding(): - return [ - Embedding_Layer( - suffix="block[0].layer[0].SelfAttention.relative_attention_bias", - weight="weight", - replace_layer=col_nn.Embedding1D, - gather_output=False, - ) - ] + def new_model_class(self): + return None - -from transformers import T5ForConditionalGeneration + def postprocess(self): + return self.model class T5ForConditionalGenerationPolicy(T5ModelPolicy): - @staticmethod - def argument_policy(config, world_size): - base_argument = T5ModelPolicy.argument_policy(config, world_size) - argument = { - T5ForConditionalGeneration: Argument(attr_dict={}, param_funcs=[T5ForConditionalGenerationPolicy.lm_head]) + def module_policy(self): + policy = super().module_policy() + + new_item = { + T5ForConditionalGeneration: + ModulePolicyDescription(attribute_replacement={}, + param_replacement=[], + sub_module_replacement=[ + SubModuleReplacementDescription(suffix="lm_head", + target_module=Linear1D_Col, + kwargs=dict(gather_output=True)) + ]) } - argument.update(base_argument) - return argument - - @staticmethod - def lm_head(): - return [Col_Layer( - suffix="lm_head", - weight="weight", - replace_layer=col_nn.Linear1D_Col, - gather_output=True, - )] - -from transformers import T5EncoderModel + policy.update(new_item) + return policy -class T5EncoderModelPolicy(T5ModelPolicy): +class T5EncoderPolicy(T5ModelPolicy): pass diff --git a/colossalai/shardformer/shard/sharder.py b/colossalai/shardformer/shard/sharder.py index c948a7939d15..f6ade26b758a 100644 --- a/colossalai/shardformer/shard/sharder.py +++ b/colossalai/shardformer/shard/sharder.py @@ -175,7 +175,16 @@ def _replace_sub_module( assert target_module is not None, 'target_module should not be None' # TODO: support different parallel mode - native_sub_module = getattr_(org_layer, suffix) + native_sub_module = getattr_(org_layer, suffix, ignore=True) + + assert not isinstance(native_sub_module, target_module), \ + f"The module with suffix {suffix} has been replaced, please check the policy" + + # if it is None and we are allowed to ignore this module + # just skip + if description.ignore_if_not_exist and native_sub_module is None: + continue + replace_layer = target_module.from_native_module(native_sub_module, self.pg_manager.pg_store['tp1d'], **kwargs) diff --git a/colossalai/testing/__init__.py b/colossalai/testing/__init__.py index 9d0475ed064c..0db33361c6a0 100644 --- a/colossalai/testing/__init__.py +++ b/colossalai/testing/__init__.py @@ -3,6 +3,7 @@ assert_close_loose, assert_equal, assert_equal_in_group, + assert_hf_output_close, assert_not_equal, check_state_dict_equal, ) @@ -20,5 +21,5 @@ __all__ = [ 'assert_equal', 'assert_not_equal', 'assert_close', 'assert_close_loose', 'assert_equal_in_group', 'parameterize', 'rerun_on_exception', 'rerun_if_address_is_in_use', 'skip_if_not_enough_gpus', 'free_port', 'spawn', - 'clear_cache_before_run', 'run_on_environment_flag', 'check_state_dict_equal' + 'clear_cache_before_run', 'run_on_environment_flag', 'check_state_dict_equal', 'assert_hf_output_close' ] diff --git a/colossalai/testing/comparison.py b/colossalai/testing/comparison.py index faf61638d8bb..aeecee7f11f5 100644 --- a/colossalai/testing/comparison.py +++ b/colossalai/testing/comparison.py @@ -1,4 +1,4 @@ -from typing import OrderedDict +from typing import Any, List, OrderedDict import torch import torch.distributed as dist @@ -52,3 +52,52 @@ def check_state_dict_equal(d1: OrderedDict, d2: OrderedDict, ignore_device: bool assert torch.equal(v, d2[k]) else: assert v == d2[k] + + +def assert_hf_output_close(out1: Any, + out2: Any, + ignore_keys: List[str] = None, + track_name: str = "", + atol=1e-5, + rtol=1e-5): + """ + Check if two outputs from huggingface are equal. + + Args: + out1 (Any): the first output + out2 (Any): the second output + ignore_keys (List[str]): the keys to ignore when comparing two dicts + track_name (str): the name of the value compared, used to track the path + """ + if isinstance(out1, dict) and isinstance(out2, dict): + # if two values are dict + # we recursively check the keys + assert set(out1.keys()) == set(out2.keys()) + for k in out1.keys(): + if ignore_keys is not None and k in ignore_keys: + continue + assert_hf_output_close(out1[k], + out2[k], + track_name=f"{track_name}.{k}", + ignore_keys=ignore_keys, + atol=atol, + rtol=rtol) + elif isinstance(out1, (list, tuple)) and isinstance(out2, (list, tuple)): + # if two values are list + # we recursively check the elements + assert len(out1) == len(out2) + for i in range(len(out1)): + assert_hf_output_close(out1[i], + out2[i], + track_name=f"{track_name}.{i}", + ignore_keys=ignore_keys, + atol=atol, + rtol=rtol) + elif isinstance(out1, Tensor) and isinstance(out2, Tensor): + if out1.shape != out2.shape: + raise AssertionError(f"{track_name}: shape mismatch: {out1.shape} vs {out2.shape}") + assert torch.allclose( + out1, out2, atol=atol, rtol=rtol + ), f"{track_name}: tensor value mismatch\nvalue 1: {out1}\nvalue 2: {out2}, mean error: {torch.abs(out1 - out2).mean()}" + else: + assert out1 == out2, f"{track_name}: value mismatch.\nout1: {out1}\nout2: {out2}" diff --git a/tests/test_shardformer/test_model/test_shard_llama.py b/tests/test_shardformer/test_model/test_shard_llama.py index a3c7647fafc6..b15f81aba52e 100644 --- a/tests/test_shardformer/test_model/test_shard_llama.py +++ b/tests/test_shardformer/test_model/test_shard_llama.py @@ -9,7 +9,7 @@ import colossalai from colossalai.logging import disable_existing_loggers from colossalai.shardformer import ShardConfig, ShardFormer -from colossalai.testing import rerun_if_address_is_in_use, spawn +from colossalai.testing import assert_hf_output_close, clear_cache_before_run, rerun_if_address_is_in_use, spawn os.environ['TRANSFORMERS_NO_ADVISORY_WARNINGS'] = 'true' tokenizer = LlamaTokenizerFast.from_pretrained("hf-internal-testing/llama-tokenizer") @@ -17,7 +17,11 @@ def build_model(world_size, model_fn): # create new model - config = LlamaConfig(num_hidden_layers=8) + config = LlamaConfig(num_hidden_layers=4, + hidden_size=128, + intermediate_size=256, + num_attention_heads=4, + max_position_embeddings=128) org_model = model_fn(config).cuda() # shard model @@ -30,49 +34,47 @@ def build_model(world_size, model_fn): return org_model, sharded_model -def check_forward(org_model, sharded_model): - input = 'Hello, my dog is cute' - inputs = tokenizer(input, return_tensors='pt').to('cuda') - del inputs["token_type_ids"] - del inputs["attention_mask"] - - #orgin model - org_model.eval() - org_out = org_model(**inputs) - - #shard model - sharded_model.eval() - shard_out = sharded_model(**inputs) - - assert torch.allclose( - org_out[0], shard_out[0], - atol=1e-4), f"shard model output is not equal to orgin model output\n{org_out[0]}\n{shard_out[0]}" - - -def check_backward(org_model, sharded_model): +def check_forward_backward(org_model, sharded_model): # prepare input input = 'Hello, my dog is cute' tokenized_input = tokenizer(input, return_tensors='pt').to('cuda') del tokenized_input["token_type_ids"] del tokenized_input["attention_mask"] - labels = tokenized_input['input_ids'].clone() - labels[labels == tokenizer.pad_token_id] = -100 - tokenized_input['labels'] = labels - #orgin model + # switch to train mode org_model.train() - org_out = org_model(**tokenized_input) - org_loss = org_out.loss - org_loss.backward() - org_grad = org_model.model.layers[0].self_attn.q_proj.weight.grad - - torch.cuda.empty_cache() - #shard model sharded_model.train() - shard_out = sharded_model(**tokenized_input) - shard_loss = shard_out.loss + + if isinstance(org_model, (LlamaModel, LlamaForSequenceClassification)): + org_output = org_model(**tokenized_input) + org_loss = org_output.last_hidden_state.mean() + shard_output = sharded_model(**tokenized_input) + shard_loss = shard_output.last_hidden_state.mean() + elif isinstance(org_model, LlamaForCausalLM): + labels = tokenized_input['input_ids'].clone() + labels[labels == tokenizer.pad_token_id] = -100 + tokenized_input['labels'] = labels + org_output = org_model(**tokenized_input) + org_loss = org_output.loss + shard_output = sharded_model(**tokenized_input) + shard_loss = shard_output.loss + + assert_hf_output_close(org_output, shard_output, ignore_keys=['past_key_values'], rtol=1e-4) + + # run backward + org_loss.backward() shard_loss.backward() - shard_grad = sharded_model.model.layers[0].self_attn.q_proj.weight.grad + + # check grad + if isinstance(org_model, LlamaModel): + llama_model = org_model + shard_llama_model = sharded_model + else: + llama_model = org_model.model + shard_llama_model = sharded_model.model + + org_grad = llama_model.layers[0].self_attn.q_proj.weight.grad + shard_grad = shard_llama_model.layers[0].self_attn.q_proj.weight.grad shard_grad_list = [torch.zeros([*shard_grad.shape]).to('cuda') for _ in range(4)] shard_grad = torch.distributed.all_gather(shard_grad_list, shard_grad) all_shard_grad = torch.cat(shard_grad_list, dim=0) @@ -88,23 +90,23 @@ def check_llama(rank, world_size, port): colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') model_list = [ - LlamaForCausalLM, + LlamaModel, + # LlamaForCausalLM, # TODO: do not work yet - # LlamaModel, # LlamaForSequenceClassification ] for model_fn in model_list: org_model, sharded_model = build_model(world_size, model_fn) - check_forward(org_model, sharded_model) - check_backward(org_model, sharded_model) + check_forward_backward(org_model, sharded_model) torch.cuda.empty_cache() @pytest.mark.dist @rerun_if_address_is_in_use() +@clear_cache_before_run() def test_llama(): spawn(check_llama, 4) diff --git a/tests/test_shardformer/test_model/test_shard_t5.py b/tests/test_shardformer/test_model/test_shard_t5.py index 9b1c2678f39b..254649409c59 100644 --- a/tests/test_shardformer/test_model/test_shard_t5.py +++ b/tests/test_shardformer/test_model/test_shard_t5.py @@ -1,71 +1,72 @@ import copy import os -import random import pytest import torch -from transformers import AutoTokenizer, BertConfig, BertForMaskedLM, T5Config, T5ForConditionalGeneration, T5Tokenizer +from transformers import T5Config, T5EncoderModel, T5ForConditionalGeneration, T5Model, T5Tokenizer, T5TokenizerFast import colossalai from colossalai.logging import disable_existing_loggers from colossalai.shardformer.shard import ShardConfig, ShardFormer -from colossalai.testing import rerun_if_address_is_in_use, spawn +from colossalai.testing import assert_hf_output_close, clear_cache_before_run, rerun_if_address_is_in_use, spawn os.environ['TRANSFORMERS_NO_ADVISORY_WARNINGS'] = 'true' CONFIG = dict(parallel=dict(data=1, pipeline=1, tensor=dict(size=2, mode='1d')),) tokenizer = T5Tokenizer.from_pretrained("t5-small") -def build_model(rank, world_size): - config = T5Config.from_pretrained("t5-small") +def build_model(world_size, model_fn): + config = T5Config(decoder_start_token_id=0) config.dropout_rate = 0 - org_model = T5ForConditionalGeneration.from_pretrained("t5-small", config=config).to('cuda') + org_model = model_fn(config=config).to('cuda') + shard_config = ShardConfig(tensor_parallel_size=world_size) - shardconfig = ShardConfig( - rank=rank, - world_size=world_size, - gather_output=True, - ) - - org_model_for_shard = copy.deepcopy(org_model) - - sharded_model = shard_model(org_model_for_shard, shardconfig).to('cuda') + # shard model + shard_config = ShardConfig(tensor_parallel_size=world_size) + model_copy = copy.deepcopy(org_model) + shard_former = ShardFormer(shard_config=shard_config) + shard_former.init_distributed() + sharded_model = shard_former.shard_model(model_copy) return org_model, sharded_model -def check_forward(org_model, sharded_model): - - input_ids = tokenizer("translate English to German: The house is wonderful.", - return_tensors="pt").input_ids.to('cuda') - #orgin model - org_model.eval() - org_output = org_model.generate(input_ids) - - #shard model - sharded_model.eval() - shard_output = sharded_model.generate(input_ids) - assert torch.allclose( - org_output[0], shard_output[0], - atol=1e-5), f"shard model output is not equal to orgin model output\n{org_out[0]}\n{shard_out[0]}" - - -def check_backward(org_model, sharded_model): +def check_forward_backward(org_model, sharded_model): # prepare input input_ids = tokenizer("translate English to German: The house is wonderful.", return_tensors="pt").input_ids.to('cuda') labels = tokenizer("Das Haus ist wunderbar.", return_tensors="pt").input_ids.to('cuda') - #orgin model + # switch to train mode org_model.train() - org_loss = org_model(input_ids=input_ids, labels=labels).loss - org_loss.backward() - org_grad = org_model.encoder.block[0].layer[0].SelfAttention.q.weight.grad - - #shard model sharded_model.train() - shard_loss = sharded_model(input_ids=input_ids, labels=labels).loss + + if isinstance(org_model, T5ForConditionalGeneration): + org_output = org_model(input_ids=input_ids, labels=labels) + org_loss = org_output.loss + shard_output = sharded_model(input_ids=input_ids, labels=labels) + shard_loss = shard_output.loss + elif isinstance(org_model, T5Model): + decoder_input_ids = org_model._shift_right(input_ids) + org_output = org_model(input_ids=input_ids, decoder_input_ids=decoder_input_ids) + org_loss = org_output.last_hidden_state.mean() + shard_output = sharded_model(input_ids=input_ids, decoder_input_ids=decoder_input_ids) + shard_loss = shard_output.last_hidden_state.mean() + elif isinstance(org_model, T5EncoderModel): + org_output = org_model(input_ids=input_ids) + org_loss = org_output.last_hidden_state.mean() + shard_output = sharded_model(input_ids=input_ids) + shard_loss = shard_output.last_hidden_state.mean() + + # key is sharded, so we ignore + assert_hf_output_close(org_output, shard_output, ignore_keys=['past_key_values']) + + # do backward + org_loss.backward() shard_loss.backward() + + # check grad equality + org_grad = org_model.encoder.block[0].layer[0].SelfAttention.q.weight.grad shard_grad = sharded_model.encoder.block[0].layer[0].SelfAttention.q.weight.grad shard_grad_list = [torch.zeros([*shard_grad.shape]).to('cuda') for _ in range(2)] @@ -82,16 +83,21 @@ def check_t5(rank, world_size, port): disable_existing_loggers() colossalai.launch(config=CONFIG, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') - org_model, sharded_model = build_model(rank, world_size) - check_forward(org_model, sharded_model) - check_backward(org_model, sharded_model) + model_fn_list = [ + T5Model, + T5ForConditionalGeneration, + T5EncoderModel, + ] - torch.cuda.empty_cache() + for model_fn in model_fn_list: + org_model, sharded_model = build_model(world_size, model_fn) + check_forward_backward(org_model, sharded_model) + torch.cuda.empty_cache() @pytest.mark.dist -@pytest.mark.skip @rerun_if_address_is_in_use() +@clear_cache_before_run() def test_t5(): spawn(check_t5, 2) From 0113097cf66fccbc1014f6ac0565436fed3fe6e4 Mon Sep 17 00:00:00 2001 From: FoolPlayer <45593998+FoolPlayer@users.noreply.github.com> Date: Tue, 20 Jun 2023 11:45:16 +0800 Subject: [PATCH 28/49] [shardformer] add gpt2 test and layer class refactor (#4041) * add gpt2 test and layer class refactor * add dropout in gpt2 policy --- colossalai/shardformer/layer/__init__.py | 17 + colossalai/shardformer/layer/dropout.py | 2 +- colossalai/shardformer/layer/embedding1d.py | 149 ++++ colossalai/shardformer/layer/layernorm1d.py | 73 ++ colossalai/shardformer/layer/layers.py | 722 ------------------ colossalai/shardformer/layer/linear1d.py | 346 +++++++++ colossalai/shardformer/layer/linearconv1d.py | 377 +++++++++ .../shardformer/layer/parallelmodule.py | 35 + .../layer/vocabparallelembedding1d.py | 170 +++++ colossalai/shardformer/policies/autopolicy.py | 3 +- colossalai/shardformer/policies/bert.py | 37 +- colossalai/shardformer/policies/gpt2.py | 189 ++--- .../test_model/test_shard_bert.py | 2 +- .../test_model/test_shard_gpt2.py | 118 +++ 14 files changed, 1400 insertions(+), 840 deletions(-) create mode 100644 colossalai/shardformer/layer/embedding1d.py create mode 100644 colossalai/shardformer/layer/layernorm1d.py delete mode 100644 colossalai/shardformer/layer/layers.py create mode 100644 colossalai/shardformer/layer/linear1d.py create mode 100644 colossalai/shardformer/layer/linearconv1d.py create mode 100644 colossalai/shardformer/layer/parallelmodule.py create mode 100644 colossalai/shardformer/layer/vocabparallelembedding1d.py create mode 100644 tests/test_shardformer/test_model/test_shard_gpt2.py diff --git a/colossalai/shardformer/layer/__init__.py b/colossalai/shardformer/layer/__init__.py index e69de29bb2d1..66d86913bb2b 100644 --- a/colossalai/shardformer/layer/__init__.py +++ b/colossalai/shardformer/layer/__init__.py @@ -0,0 +1,17 @@ +from .dropout import Dropout1D +from .embedding1d import Embedding1D +from .layernorm1d import LayerNorm1D +from .linear1d import Linear1D_Col, Linear1D_Row +from .linearconv1d import LinearConv1D_Col, LinearConv1D_Row +from .vocabparallelembedding1d import VocabParallelEmbedding1D + +__all__ = [ + "Embedding1D", + "VocabParallelEmbedding1D", + "Linear1D_Col", + "Linear1D_Row", + "LinearConv1D_Col", + "LinearConv1D_Row", + "LayerNorm1D", + "Dropout1D", +] diff --git a/colossalai/shardformer/layer/dropout.py b/colossalai/shardformer/layer/dropout.py index ec08d072f338..08dfb8afd7fb 100644 --- a/colossalai/shardformer/layer/dropout.py +++ b/colossalai/shardformer/layer/dropout.py @@ -4,7 +4,7 @@ import torch.nn as nn from torch.distributed import ProcessGroup -from .layers import ParallelModule +from .parallelmodule import ParallelModule from .utils import create_randomizer_with_offset diff --git a/colossalai/shardformer/layer/embedding1d.py b/colossalai/shardformer/layer/embedding1d.py new file mode 100644 index 000000000000..1108d5d6a936 --- /dev/null +++ b/colossalai/shardformer/layer/embedding1d.py @@ -0,0 +1,149 @@ +#!/usr/bin/env python +# -*- encoding: utf-8 -*- + +from typing import Callable, List, Union + +import torch +import torch.distributed as dist +import torch.nn as nn +import torch.nn.functional as F +from torch import Tensor +from torch.distributed import ProcessGroup +from torch.nn.parameter import Parameter + +from colossalai.nn import init as init +from colossalai.nn.layer.utils import divide +from colossalai.tensor.d_tensor.api import shard_colwise +from colossalai.utils.cuda import get_current_device + +from ._operation import gather_forward_split_backward +from .parallelmodule import ParallelModule +from .utils import create_randomizer_with_offset + +Fast_LN = None +try: + from apex.contrib.layer_norm.layer_norm import FastLayerNorm + Fast_LN = FastLayerNorm +except ImportError: + pass + + +class Embedding1D(ParallelModule): + r"""Embedding for 1D parallelism. + + Args: + num_embeddings (int): number of embeddings. + embedding_dim (int): dimension of embedding. + padding_idx (int, optional): If specified, the entries at padding_idx do not contribute to the gradient; + therefore, the embedding vector at padding_idx is not updated during training, + i.e. it remains as a fixed “pad”, defaults to None. + dtype (:class:`torch.dtype`, optional): The dtype of parameters, defaults to None. + weight_initializer (:class:`typing.Callable`, optional): + he initializer of weight, defaults to normal initializer. + + The ``args`` and ``kwargs`` used in :class:`torch.nn.functional.embedding` should contain: + :: + + max_norm (float, optional): If given, each embedding vector with norm larger than max_norm is + renormalized to have norm max_norm. Note: this will modify weight in-place. + norm_type (float, optional): The p of the p-norm to compute for the max_norm option. Default 2. + scale_grad_by_freq (bool, optional): If given, this will scale gradients by the inverse + of frequency of the words in the mini-batch. Default False. + sparse (bool, optional): If True, gradient w.r.t. weight will be a sparse tensor. Default False. + + More details about ``args`` and ``kwargs`` could be found in + `Embedding `_. + + More details about ``initializer`` please refer to + `init `_ + """ + + def __init__(self, + num_embeddings: int, + embedding_dim: int, + padding_idx: int = None, + dtype: torch.dtype = None, + device: torch.device = None, + process_group: ProcessGroup = None, + weight_initializer: Callable = init.normal_(), + *args, + **kwargs): + super().__init__() + + self.num_embeddings = num_embeddings + self.embed_dim = embedding_dim + self.process_group = process_group + self.num_partitions = dist.get_world_size(process_group) + self.embed_dim_per_partition = divide(embedding_dim, self.num_partitions) + + self.padding_idx = padding_idx + self.embed_args = args + self.embed_kwargs = kwargs + # self.gather_output = gather_output + + if device is None: + device = get_current_device() + + self.weight = Parameter(torch.empty((num_embeddings, self.embed_dim_per_partition), device=device, dtype=dtype)) + + # offset the seed with randomizer index and rank + seed = torch.random.initial_seed() + self.randomizer = create_randomizer_with_offset(seed, process_group=self.process_group) + + with self.randomizer.fork_rng(enable_cpu=True): + self.reset_parameters(weight_initializer) + + @staticmethod + def from_native_module(module: nn.Embedding, + process_group: Union[ProcessGroup, List[ProcessGroup]] = None) -> "Embedding1D": + r""" + Build a 1D parallelized Embedding from a native nn.Embedding module. + """ + # get the attributes + num_embedding = module.num_embeddings + embedding_dim = module.embedding_dim + padding_idx = module.padding_idx + max_norm = module.max_norm + norm_type = module.norm_type + scale_grad_by_freq = module.scale_grad_by_freq + sparse = module.sparse + dtype = module.weight.dtype + device = module.weight.device + + # sparse is not support yet + if sparse: + raise NotImplementedError("The Embedding1D module does not support sparse embedding yet.") + + embedding = Embedding1D(num_embeddings=num_embedding, + embedding_dim=embedding_dim, + padding_idx=padding_idx, + process_group=process_group, + dtype=dtype, + device=device, + max_norm=max_norm, + norm_type=norm_type, + scale_grad_by_freq=scale_grad_by_freq, + sparse=sparse) + + # copy the weight + with torch.no_grad(): + sharded_weight = shard_colwise(module.weight.data, process_group) + embedding.weight.copy_(sharded_weight) + + return embedding + + def reset_parameters(self, weight_initializer) -> None: + fan_in, fan_out = self.num_embeddings, self.embed_dim + weight_initializer(self.weight, fan_in=fan_in, fan_out=fan_out) + self._fill_padding_idx_with_zero() + + def _fill_padding_idx_with_zero(self) -> None: + if self.padding_idx is not None: + with torch.no_grad(): + self.weight[self.padding_idx].fill_(0) + + def forward(self, input_: Tensor) -> Tensor: + output_parallel = F.embedding(input_, self.weight, self.padding_idx, *self.embed_args, **self.embed_kwargs) + output = gather_forward_split_backward(output_parallel, dim=-1, process_group=self.process_group) + + return output diff --git a/colossalai/shardformer/layer/layernorm1d.py b/colossalai/shardformer/layer/layernorm1d.py new file mode 100644 index 000000000000..78bd64cfb504 --- /dev/null +++ b/colossalai/shardformer/layer/layernorm1d.py @@ -0,0 +1,73 @@ +#!/usr/bin/env python +# -*- encoding: utf-8 -*- + +from collections import OrderedDict + +from colossalai.context import ParallelMode, seed +from colossalai.core import global_context as gpc +from colossalai.global_variables import tensor_parallel_env as env +from colossalai.kernel import LayerNorm +from colossalai.nn import init as init +from colossalai.nn.layer.colossalai_layer._utils import ColossalaiModule +from colossalai.utils.checkpointing import broadcast_state_dict + +Fast_LN = None +try: + from apex.contrib.layer_norm.layer_norm import FastLayerNorm + Fast_LN = FastLayerNorm +except ImportError: + pass + + +class LayerNorm1D(ColossalaiModule): + r""" + Layer Normalization for colossalai + + Args: + normalized_shape (int): input shape from an expected input of size. + :math:`[* \times \text{normalized_shape}[0] \times \text{normalized_shape}[1] + \times \ldots \times \text{normalized_shape}[-1]]` + If a single integer is used, it is treated as a singleton list, and this module will + normalize over the last dimension which is expected to be of that specific size. + eps (float): a value added to the denominator for numerical stability, defaults to 1e-05. + bias (bool, optional): Whether to add a bias, defaults to ``True``. + dtype (:class:`torch.dtype`, optional): The dtype of parameters, defaults to None. + """ + + _fast_ln_supported_sizes = [ + 1024, 1536, 2048, 2304, 3072, 3840, 4096, 5120, 6144, 8192, 10240, 12288, 12800, 15360, 16384, 18432, 20480, + 24576, 25600, 30720, 32768, 40960, 49152, 65536 + ] + + def __init__(self, normalized_shape: int, eps=1e-05, bias=True, dtype=None): + if Fast_LN is not None and normalized_shape in self._fast_ln_supported_sizes: + norm = Fast_LN(normalized_shape, eps=eps).to(dtype) + else: + norm = None + try: + from apex.normalization import FusedLayerNorm + norm = FusedLayerNorm(normalized_shape, eps=eps).to(dtype) + except ImportError: + norm = LayerNorm(normalized_shape, eps=eps).to(dtype) + super().__init__(norm) + + def _load_from_state_dict(self, state_dict, prefix, *args): + local_state = OrderedDict() + weight_key = prefix + 'weight' + bias_key = prefix + 'bias' + if gpc.get_local_rank(ParallelMode.TENSOR) == 0: + # weight + weight = state_dict.pop(weight_key, None) + if weight is not None: + local_state[weight_key] = weight + # bias + bias = state_dict.pop(bias_key, None) + if bias is not None: + local_state[bias_key] = bias + + local_state = broadcast_state_dict(local_state, ParallelMode.PARALLEL_1D) + super()._load_from_state_dict(local_state, prefix, *args) + + def _save_to_state_dict(self, destination, prefix, keep_vars): + if gpc.get_local_rank(ParallelMode.TENSOR) == 0: + super()._save_to_state_dict(destination, prefix, keep_vars) diff --git a/colossalai/shardformer/layer/layers.py b/colossalai/shardformer/layer/layers.py deleted file mode 100644 index 5dbe28956d27..000000000000 --- a/colossalai/shardformer/layer/layers.py +++ /dev/null @@ -1,722 +0,0 @@ -#!/usr/bin/env python -# -*- encoding: utf-8 -*- - -import math -from abc import ABC, abstractmethod -from collections import OrderedDict -from typing import Callable, List, Tuple, Union - -import torch -import torch.distributed as dist -import torch.nn as nn -import torch.nn.functional as F -from torch import Tensor -from torch.distributed import ProcessGroup -from torch.nn.parameter import Parameter - -from colossalai.communication import broadcast -from colossalai.context import ParallelMode, seed -from colossalai.core import global_context as gpc -from colossalai.global_variables import tensor_parallel_env as env -from colossalai.kernel import LayerNorm -from colossalai.nn import init as init -from colossalai.nn.layer.base_layer import ParallelLayer -from colossalai.nn.layer.colossalai_layer._utils import ColossalaiModule -from colossalai.nn.layer.parallel_1d._utils import get_parallel_input, reduce_grad, set_parallel_input -from colossalai.nn.layer.utils import divide, set_tensor_parallel_attribute_by_partition -from colossalai.nn.layer.vanilla import VanillaLayerNorm, VanillaPatchEmbedding -from colossalai.tensor.d_tensor.api import shard_colwise, shard_rowwise -from colossalai.utils.checkpointing import ( - broadcast_state_dict, - gather_tensor_parallel_state_dict, - partition_tensor_parallel_state_dict, -) -from colossalai.utils.cuda import get_current_device - -from ._operation import ( - gather_forward_split_backward, - linear_with_async_comm, - reduce_input, - split_forward_gather_backward, -) -from .utils import create_randomizer_with_offset - -Fast_LN = None -try: - from apex.contrib.layer_norm.layer_norm import FastLayerNorm - Fast_LN = FastLayerNorm -except ImportError: - pass - - -class ParallelModule(nn.Module, ABC): - - @abstractmethod - def from_native_module(module: nn.Module, - process_group: Union[ProcessGroup, List[ProcessGroup]] = None) -> "ParallelModule": - """ - Convert a native PyTorch module to a parallelized module. - - Args: - module (nn.Module): the module to be converted. - process_group (ProcessGroup or list[ProcessGroup]): the process group(s) to be used for communication. - If this is a list, the process group at the ith index of the list will correspond to the process group - in the ith axis of the device mesh. Defaults to None, which means the global process group. - """ - pass - - -class Linear1D_Col(ParallelModule): - r"""Linear layer with column parallelism. - - The linear layer is defined as :math:`Y = XA + b`. A is parallelized along - its second dimension as :math:`A = [A_1, ..., A_p]`. - - Args: - in_features (int): size of each input sample. - out_features (int): size of each output sample. - bias (bool, optional): If set to ``False``, the layer will not learn an additive bias, defaults to ``True``. - dtype (`torch.dtype`): The dtype of parameters, defaults to None. - device (`torch.device`): The device of parameters, defaults to None. - process_group (`torch.distributed.ProcessGroup`): The process group to be used for weight sharding and communication, defaults to None. - gather_output (bool, optional): If true, call all-gather on output and make Y available - to all GPUs, otherwise, every GPU will have its output - which is :math:`Y_i = XA_i`, defaults to False - skip_bias_add (bool): If set to ``True``, it will skip bias add for linear layer, - which is preserved for kernel fusion, defaults to False - weight_initializer (`typing.Callable`): - The initializer of weight, defaults to kaiming uniform initializer. - bias_initializer (`typing.Callable`): - The initializer of bias, defaults to xavier uniform initializer. - - More details about ``initializer`` please refer to - `init `_. - """ - - def __init__(self, - in_features: int, - out_features: int, - bias: bool = True, - dtype: torch.dtype = None, - device: torch.device = None, - process_group: ProcessGroup = None, - gather_output: bool = False, - skip_bias_add: bool = False, - weight_initializer: Callable = init.kaiming_uniform_(a=math.sqrt(5)), - bias_initializer: Callable = init.xavier_uniform_(a=1, scale=1)): - super().__init__() - - # Keep input parameters - self.in_features = in_features - self.out_features = out_features - self.gather_output = gather_output - self.skip_bias_add = skip_bias_add - self.device = device - self.process_group = process_group - self.num_partitions = dist.get_world_size(self.process_group) - - if skip_bias_add and not bias: - raise ValueError('cannot skip bias addition if bias is None') - - self.out_features_per_partition = divide(out_features, self.num_partitions) - - # Parameters. - # Initialize weight. - if device is None: - device = get_current_device() - factory_kwargs = {'device': device, 'dtype': dtype} - self.weight = Parameter(torch.empty(self.out_features_per_partition, self.in_features, **factory_kwargs)) - - if bias: - self.bias = Parameter(torch.empty(self.out_features_per_partition, **factory_kwargs)) - else: - self.bias = None - - # offset the seed with randomizer index and rank - seed = torch.random.initial_seed() - self.randomizer = create_randomizer_with_offset(seed, process_group=self.process_group) - - with self.randomizer.fork_rng(enable_cpu=True): - self.reset_parameters(weight_initializer, bias_initializer) - - @staticmethod - def from_native_module(module: nn.Linear, process_group: Union[ProcessGroup, List[ProcessGroup]], *args, - **kwargs) -> ParallelModule: - r""" - Convert a native PyTorch linear layer to a parallelized linear layer. - """ - # get the attributes - in_features = module.in_features - out_features = module.out_features - bias = module.bias is not None - device = module.weight.device - - # ensure only one process group is passed - if isinstance(process_group, (list, tuple)): - assert len(process_group) == 1, \ - f'Expected only one process group, got {len(process_group)}.' - process_group = process_group[0] - - linear_1d = Linear1D_Col(in_features=in_features, - out_features=out_features, - bias=bias, - device=device, - process_group=process_group, - *args, - **kwargs) - - # TODO: copy the sharded weights - with torch.no_grad(): - # the weigh to the linear layer is a transpose - # thus shard on row is equal to shard on column - sharded_weight = shard_rowwise(module.weight.data, process_group) - linear_1d.weight.data.copy_(sharded_weight) - if bias: - sharded_bias = shard_colwise(module.bias.data, process_group) - linear_1d.bias.copy_(sharded_bias) - - return linear_1d - - def reset_parameters(self, weight_initializer, bias_initializer) -> None: - fan_in, fan_out = self.in_features, self.out_features - weight_initializer(self.weight, fan_in=fan_in, fan_out=fan_out) - if self.bias is not None: - bias_initializer(self.bias, fan_in=fan_in) - - def forward(self, input_: Tensor) -> Tuple[Tensor, Tensor]: - assert input_.shape[-1] == self.weight.shape[-1], \ - 'Invalid shapes in Linear1D_Col forward: input={}, weight={}. Expected last dim of input {}.'.format( - input_.shape, self.weight.shape, self.weight.shape[-1]) - # Set up backprop all-reduce. - # input_parallel = reduce_grad(input_, ParallelMode.PARALLEL_1D) - input_parallel = input_ - # Matrix multiply. - bias = self.bias if not self.skip_bias_add else None - output_parallel = linear_with_async_comm(input_parallel, self.weight, bias, self.process_group, True) - - if self.gather_output: - # All-gather across the partitions. - output = gather_forward_split_backward(output_parallel, dim=-1, process_group=self.process_group) - else: - output = output_parallel - - if self.skip_bias_add: - return output, self.bias - else: - return output - - -class Linear1D_Row(ParallelModule): - r""" Linear layer with row parallelism - - Args: - in_features (int): size of each input sample. - out_features (int): size of each output sample. - bias (bool, optional): If set to ``False``, the layer will not learn an additive bias, defaults to ``True``. - dtype (`torch.dtype`): The dtype of parameters, defaults to None. - parallel_input (bool): If set to ``True``, it's assumed that the input is split, defaults to False. - skip_bias_add (bool): If set to ``True``, it will skip bias add for linear layer, - which is preserved for kernel fusion, defaults to False - weight_initializer (:class:`typing.Callable`, optional): - The initializer of weight, defaults to kaiming uniform initializer. - bias_initializer (:class:`typing.Callable`, optional): - The initializer of bias, defaults to xavier uniform initializer. - - More details about ``initializer`` please refer to - `init `_. - """ - - def __init__(self, - in_features: int, - out_features: int, - bias: bool = True, - dtype: torch.dtype = None, - device: torch.device = None, - process_group: ProcessGroup = None, - parallel_input: bool = True, - skip_bias_add: bool = False, - weight_initializer: Callable = init.kaiming_uniform_(a=math.sqrt(5)), - bias_initializer: Callable = init.xavier_uniform_(a=1, scale=1), - stream_chunk_num: int = 1): - super().__init__() - - self.stream_chunk_num = stream_chunk_num - - # Keep input parameters - self.in_features = in_features - self.out_features = out_features - self.parallel_input = parallel_input - self.skip_bias_add = skip_bias_add - self.process_group = process_group - self.num_partitions = dist.get_world_size(self.process_group) - - if skip_bias_add and not bias: - raise ValueError('cannot skip bias addition if bias is None') - - # Divide the weight matrix along the last dimension. - self.input_size_per_partition = divide(in_features, self.num_partitions) - - # Parameters. - # Initialize weight. - if device is None: - device = get_current_device() - - factory_kwargs = {'device': device, 'dtype': dtype} - self.weight = Parameter(torch.empty(self.out_features, self.input_size_per_partition, **factory_kwargs)) - - if self.stream_chunk_num > 1: - # TODO() work for inference only - self.chunk_weight() - if bias: - self.bias = Parameter(torch.empty(self.out_features, **factory_kwargs)) - else: - self.bias = None - - # offset the seed with randomizer index and rank - seed = torch.random.initial_seed() - self.randomizer = create_randomizer_with_offset(seed, process_group=self.process_group) - - with self.randomizer.fork_rng(enable_cpu=True): - self.reset_parameters(weight_initializer, bias_initializer) - - @staticmethod - def from_native_module(module: nn.Linear, process_group: Union[ProcessGroup, List[ProcessGroup]], *args, - **kwargs) -> ParallelModule: - r""" - Convert a native PyTorch linear layer to a parallelized linear layer. - """ - # get the attributes - in_features = module.in_features - out_features = module.out_features - bias = module.bias is not None - device = module.weight.device - - # ensure only one process group is passed - if isinstance(process_group, (list, tuple)): - assert len(process_group) == 1, \ - f'Expected only one process group, got {len(process_group)}.' - process_group = process_group[0] - - linear_1d = Linear1D_Row(in_features=in_features, - out_features=out_features, - bias=bias, - device=device, - process_group=process_group, - *args, - **kwargs) - - # TODO: copy the sharded weights - with torch.no_grad(): - # the weigh to the linear layer is a transpose - # thus shard on col is equal to shard on row - sharded_weight = shard_colwise(module.weight.data, process_group) - linear_1d.weight.data.copy_(sharded_weight) - - if bias: - linear_1d.bias.copy_(module.bias.data) - - return linear_1d - - def chunk_weight(self): - self.weight_list = torch.chunk(self.weight, self.stream_chunk_num, dim=0) - - def reset_parameters(self, weight_initializer, bias_initializer) -> None: - fan_in, fan_out = self.in_features, self.out_features - weight_initializer(self.weight, fan_in=fan_in, fan_out=fan_out) - - if self.bias is not None: - bias_initializer(self.bias, fan_in=fan_in) - if self.process_group is None: - src_rank = 0 - else: - src_rank = dist.distributed_c10d._get_global_rank(self.process_group, 0) - - origin_device = self.bias.device - self.bias = self.bias.cuda() - dist.broadcast(self.bias, src=src_rank, group=self.process_group) - self.bias = self.bias.to(origin_device) - - def forward(self, input_: Tensor) -> Tensor: - # Set up backprop all-reduce. - if self.parallel_input: - assert input_.shape[-1] == self.weight.shape[-1], \ - 'Invalid shapes in Linear1D_Row forward: input={}, weight={}. Expected last dim of input {}.'.format( - input_.shape, self.weight.shape, self.weight.shape[-1]) - input_ = input_ - else: - assert divide(input_.shape[-1], self.num_partitions) == self.weight.shape[-1], \ - 'Invalid shapes in Linear1D_Row forward: input={}, weight={}. Expected last dim of input {}.'.format( - input_.shape, self.weight.shape, self.weight.shape[-1] * self.num_partitions) - input_ = split_forward_gather_backward(input_, dim=-1, process_group=self.process_group) - - if self.stream_chunk_num > 1: - if self.training: - raise RuntimeError("use stream_chunk_num=1 in Linear1D_Row for training!") - with torch.no_grad(): - output_parallel_list = [None for i in range(self.stream_chunk_num)] - handle_list = [] - for i in range(self.stream_chunk_num): - output_parallel_list[i] = F.linear(input_, self.weight_list[i]) - handle = torch.distributed.all_reduce(output_parallel_list[i], - group=self.process_group, - async_op=True) - handle_list.append(handle) - # output_parallel_list[i] = reduce_input(output_parallel_list[i], ParallelMode.PARALLEL_1D) - for handle in handle_list: - handle.wait() - output = torch.cat(output_parallel_list, dim=-1) - else: - output_parallel = F.linear(input_, self.weight) - # output_parallel = linear_with_async_comm(input_, self.weight, None, ParallelMode.PARALLEL_1D, False) - output = reduce_input(output_parallel, self.process_group) - - if not self.skip_bias_add: - if self.bias is not None: - output = output + self.bias - return output - else: - return output, self.bias - - -class LayerNorm1D(ColossalaiModule): - r""" - Layer Normalization for colossalai - - Args: - normalized_shape (int): input shape from an expected input of size. - :math:`[* \times \text{normalized_shape}[0] \times \text{normalized_shape}[1] - \times \ldots \times \text{normalized_shape}[-1]]` - If a single integer is used, it is treated as a singleton list, and this module will - normalize over the last dimension which is expected to be of that specific size. - eps (float): a value added to the denominator for numerical stability, defaults to 1e-05. - bias (bool, optional): Whether to add a bias, defaults to ``True``. - dtype (:class:`torch.dtype`, optional): The dtype of parameters, defaults to None. - """ - - _fast_ln_supported_sizes = [ - 1024, 1536, 2048, 2304, 3072, 3840, 4096, 5120, 6144, 8192, 10240, 12288, 12800, 15360, 16384, 18432, 20480, - 24576, 25600, 30720, 32768, 40960, 49152, 65536 - ] - - def __init__(self, normalized_shape: int, eps=1e-05, bias=True, dtype=None): - if Fast_LN is not None and normalized_shape in self._fast_ln_supported_sizes: - norm = Fast_LN(normalized_shape, eps=eps).to(dtype) - else: - norm = None - try: - from apex.normalization import FusedLayerNorm - norm = FusedLayerNorm(normalized_shape, eps=eps).to(dtype) - except ImportError: - norm = LayerNorm(normalized_shape, eps=eps).to(dtype) - super().__init__(norm) - - def _load_from_state_dict(self, state_dict, prefix, *args): - local_state = OrderedDict() - weight_key = prefix + 'weight' - bias_key = prefix + 'bias' - if gpc.get_local_rank(ParallelMode.TENSOR) == 0: - # weight - weight = state_dict.pop(weight_key, None) - if weight is not None: - local_state[weight_key] = weight - # bias - bias = state_dict.pop(bias_key, None) - if bias is not None: - local_state[bias_key] = bias - - local_state = broadcast_state_dict(local_state, ParallelMode.PARALLEL_1D) - super()._load_from_state_dict(local_state, prefix, *args) - - def _save_to_state_dict(self, destination, prefix, keep_vars): - if gpc.get_local_rank(ParallelMode.TENSOR) == 0: - super()._save_to_state_dict(destination, prefix, keep_vars) - - -class Embedding1D(ParallelModule): - r"""Embedding for 1D parallelism. - - Args: - num_embeddings (int): number of embeddings. - embedding_dim (int): dimension of embedding. - padding_idx (int, optional): If specified, the entries at padding_idx do not contribute to the gradient; - therefore, the embedding vector at padding_idx is not updated during training, - i.e. it remains as a fixed “pad”, defaults to None. - dtype (:class:`torch.dtype`, optional): The dtype of parameters, defaults to None. - weight_initializer (:class:`typing.Callable`, optional): - he initializer of weight, defaults to normal initializer. - - The ``args`` and ``kwargs`` used in :class:`torch.nn.functional.embedding` should contain: - :: - - max_norm (float, optional): If given, each embedding vector with norm larger than max_norm is - renormalized to have norm max_norm. Note: this will modify weight in-place. - norm_type (float, optional): The p of the p-norm to compute for the max_norm option. Default 2. - scale_grad_by_freq (bool, optional): If given, this will scale gradients by the inverse - of frequency of the words in the mini-batch. Default False. - sparse (bool, optional): If True, gradient w.r.t. weight will be a sparse tensor. Default False. - - More details about ``args`` and ``kwargs`` could be found in - `Embedding `_. - - More details about ``initializer`` please refer to - `init `_ - """ - - def __init__(self, - num_embeddings: int, - embedding_dim: int, - padding_idx: int = None, - dtype: torch.dtype = None, - device: torch.device = None, - process_group: ProcessGroup = None, - gather_output: bool = True, - weight_initializer: Callable = init.normal_(), - *args, - **kwargs): - super().__init__() - - self.num_embeddings = num_embeddings - self.embedding_dim = embedding_dim - self.process_group = process_group - self.num_partitions = dist.get_world_size(process_group) - self.embed_dim_per_partition = divide(embedding_dim, self.num_partitions) - - self.padding_idx = padding_idx - self.embed_args = args - self.embed_kwargs = kwargs - self.gather_output = gather_output - - if device is None: - device = get_current_device() - - self.weight = Parameter(torch.empty((num_embeddings, self.embed_dim_per_partition), device=device, dtype=dtype)) - - # offset the seed with randomizer index and rank - seed = torch.random.initial_seed() - self.randomizer = create_randomizer_with_offset(seed, process_group=self.process_group) - - with self.randomizer.fork_rng(enable_cpu=True): - self.reset_parameters(weight_initializer) - - @staticmethod - def from_native_module(module: nn.Embedding, - process_group: Union[ProcessGroup, List[ProcessGroup]] = None, - *args, - **kwargs) -> "Embedding1D": - r""" - Build a 1D parallelized Embedding from a native nn.Embedding module. - """ - # get the attributes - num_embedding = module.num_embeddings - embedding_dim = module.embedding_dim - padding_idx = module.padding_idx - max_norm = module.max_norm - norm_type = module.norm_type - scale_grad_by_freq = module.scale_grad_by_freq - sparse = module.sparse - dtype = module.weight.dtype - device = module.weight.device - - # sparse is not support yet - if sparse: - raise NotImplementedError("The Embedding1D module does not support sparse embedding yet.") - - embedding = Embedding1D(num_embeddings=num_embedding, - embedding_dim=embedding_dim, - padding_idx=padding_idx, - process_group=process_group, - dtype=dtype, - device=device, - max_norm=max_norm, - norm_type=norm_type, - scale_grad_by_freq=scale_grad_by_freq, - sparse=sparse, - *args, - **kwargs) - - # copy the weight - with torch.no_grad(): - sharded_weight = shard_colwise(module.weight.data, process_group) - embedding.weight.copy_(sharded_weight) - - return embedding - - def reset_parameters(self, weight_initializer) -> None: - fan_in, fan_out = self.num_embeddings, self.embedding_dim - weight_initializer(self.weight, fan_in=fan_in, fan_out=fan_out) - self._fill_padding_idx_with_zero() - - def _fill_padding_idx_with_zero(self) -> None: - if self.padding_idx is not None: - with torch.no_grad(): - self.weight[self.padding_idx].fill_(0) - - def forward(self, input_: Tensor) -> Tensor: - output_parallel = F.embedding(input_, self.weight, self.padding_idx, *self.embed_args, **self.embed_kwargs) - - if self.gather_output: - output = gather_forward_split_backward(output_parallel, dim=-1, process_group=self.process_group) - return output - else: - return output_parallel - - -class VocabParallelEmbedding1D(ParallelLayer): - r"""Embedding parallelized in the vocabulary dimension. - - Args: - num_embeddings (int): number of embeddings. - embedding_dim (int): dimension of embedding. - padding_idx (int, optional): If specified, the entries at padding_idx do not contribute to the gradient; - therefore, the embedding vector at padding_idx is not updated during training, - i.e. it remains as a fixed “pad”, defaults to None. - dtype (:class:`torch.dtype`, optional): The dtype of parameters, defaults to None. - weight_initializer (:class:`typing.Callable`, optional): - he initializer of weight, defaults to normal initializer. - - The ``args`` and ``kwargs`` used in :class:``torch.nn.functional.embedding`` should contain: - :: - - max_norm (float, optional): If given, each embedding vector with norm larger than max_norm is - renormalized to have norm max_norm. Note: this will modify weight in-place. - norm_type (float, optional): The p of the p-norm to compute for the max_norm option. Default 2. - scale_grad_by_freq (bool, optional): If given, this will scale gradients by the inverse - of frequency of the words in the mini-batch. Default False. - sparse (bool, optional): If True, gradient w.r.t. weight will be a sparse tensor. Default False. - - More details about ``args`` and ``kwargs`` could be found in - `Embedding `_. - - More details about initializer please refer to - `init `_. - """ - - def __init__(self, - num_embeddings: int, - embedding_dim: int, - padding_idx: int = None, - dtype: torch.dtype = None, - device: torch.device = None, - process_group: ProcessGroup = None, - weight_initializer: Callable = init.normal_(), - *args, - **kwargs): - super().__init__() - self.num_embeddings = num_embeddings - self.embedding_dim = embedding_dim - self.padding_idx = padding_idx - self.embed_args = args - self.embed_kwargs = kwargs - self.process_group = process_group - - tensor_parallel_size = dist.get_world_size(group=process_group) - tensor_parallel_rank = dist.get_rank(group=process_group) - - self.num_embeddings_per_partition = divide(num_embeddings, tensor_parallel_size) - self.num_embeddings = self.num_embeddings_per_partition - self.vocab_start_index = tensor_parallel_rank * self.num_embeddings_per_partition - self.vocab_end_index = self.vocab_start_index + self.num_embeddings_per_partition - - self.weight = Parameter( - torch.empty((self.num_embeddings_per_partition, self.embedding_dim), device=device, dtype=dtype)) - - # offset the seed with randomizer index and rank - seed = torch.random.initial_seed() - self.randomizer = create_randomizer_with_offset(seed, process_group=self.process_group) - - with self.randomizer.fork_rng(enable_cpu=True): - self.reset_parameters(weight_initializer) - # self.reset_parameters(weight_initializer) - # self._set_tensor_parallel_attributes() - # set_parallel_input(False) - # env.vocab_parallel = True - - @staticmethod - def from_native_module(module: nn.Embedding, process_group: Union[ProcessGroup, List[ProcessGroup]], *args, - **kwargs) -> ParallelModule: - r""" - Convert a native pytorch embedding module to a parallel module. - """ - # get the origin attributes - num_embeddings = module.num_embeddings - embedding_dim = module.embedding_dim - padding_idx = module.padding_idx - device = module.weight.device - - # ensure only one process group is used - if isinstance(process_group, (list, tuple)): - assert len(process_group) == 1, \ - f'Expected only one process group, got {len(process_group)}.' - process_group = process_group[0] - - # create the parallel module - vocab_embedding_1d = VocabParallelEmbedding1D(num_embeddings=num_embeddings, - embedding_dim=embedding_dim, - padding_idx=padding_idx, - device=device, - process_group=process_group, - *args, - **kwargs) - with torch.no_grad(): - # shard and slice the weight along the vocabulary(num_embeddings) dimension - # the shape of the weight is (num_embeddings, embedding_dim) - shard_weight = shard_rowwise(module.weight.data, process_group) - vocab_embedding_1d.weight.data.copy_(shard_weight) - - return vocab_embedding_1d - - def _set_tensor_parallel_attributes(self): - set_tensor_parallel_attribute_by_partition(self.weight, gpc.tensor_parallel_size) - - def reset_parameters(self, weight_initializer) -> None: - with seed(ParallelMode.TENSOR): - fan_in, fan_out = self.num_embeddings, self.embedding_dim - weight_initializer(self.weight, fan_in=fan_in, fan_out=fan_out) - self._fill_padding_idx_with_zero() - - def _fill_padding_idx_with_zero(self) -> None: - if self.padding_idx is not None and \ - self.padding_idx >= self.vocab_start_index and self.padding_idx < self.vocab_end_index: - with torch.no_grad(): - self.weight[self.padding_idx - self.vocab_start_index].fill_(0) - - def _load_from_global_state_dict(self, state_dict, prefix, *args): - local_state = OrderedDict() - weight_key = prefix + 'weight' - if gpc.get_local_rank(ParallelMode.TENSOR) == 0: - # weight - weight = state_dict.pop(weight_key, None) - if weight is not None: - local_state[weight_key] = weight - - local_state = partition_tensor_parallel_state_dict(local_state, - ParallelMode.PARALLEL_1D, - dims={weight_key: 0}, - partition_states={weight_key: True}) - super()._load_from_global_state_dict(local_state, prefix, *args) - - def _save_to_global_state_dict(self, destination, prefix, keep_vars): - weight_key = prefix + 'weight' - local_state = OrderedDict({weight_key: self.weight}) - local_state = gather_tensor_parallel_state_dict(local_state, - ParallelMode.PARALLEL_1D, - dims={weight_key: 0}, - partition_states={weight_key: True}, - keep_vars=keep_vars) - destination.update(local_state) - - def forward(self, input_: Tensor) -> Tensor: - # Build the mask. - input_mask = (input_ < self.vocab_start_index) | (input_ >= self.vocab_end_index) - # Mask the input. - masked_input = input_.clone() - self.vocab_start_index - masked_input[input_mask] = 0 - - output_parallel = F.embedding(masked_input, self.weight, self.padding_idx, *self.embed_args, - **self.embed_kwargs) - - # Mask the output embedding. - output_parallel[input_mask, :] = 0. - # Reduce across all the model parallel GPUs. - output = reduce_input(output_parallel, self.process_group) - return output diff --git a/colossalai/shardformer/layer/linear1d.py b/colossalai/shardformer/layer/linear1d.py new file mode 100644 index 000000000000..d59d32df824e --- /dev/null +++ b/colossalai/shardformer/layer/linear1d.py @@ -0,0 +1,346 @@ +#!/usr/bin/env python +# -*- encoding: utf-8 -*- + +import math +from typing import Callable, List, Tuple, Union + +import torch +import torch.distributed as dist +import torch.nn as nn +import torch.nn.functional as F +from torch import Tensor +from torch.distributed import ProcessGroup +from torch.nn.parameter import Parameter + +from colossalai.nn import init as init +from colossalai.nn.layer.utils import divide +from colossalai.tensor.d_tensor.api import shard_colwise, shard_rowwise +from colossalai.utils.cuda import get_current_device + +from ._operation import ( + gather_forward_split_backward, + linear_with_async_comm, + reduce_input, + split_forward_gather_backward, +) +from .parallelmodule import ParallelModule +from .utils import create_randomizer_with_offset + +Fast_LN = None +try: + from apex.contrib.layer_norm.layer_norm import FastLayerNorm + Fast_LN = FastLayerNorm +except ImportError: + pass + + +class Linear1D_Col(ParallelModule): + r"""Linear layer with column parallelism. + + The linear layer is defined as :math:`Y = XA + b`. A is parallelized along + its second dimension as :math:`A = [A_1, ..., A_p]`. + + Args: + in_features (int): size of each input sample. + out_features (int): size of each output sample. + bias (bool, optional): If set to ``False``, the layer will not learn an additive bias, defaults to ``True``. + dtype (`torch.dtype`): The dtype of parameters, defaults to None. + device (`torch.device`): The device of parameters, defaults to None. + process_group (`torch.distributed.ProcessGroup`): The process group to be used for weight sharding and communication, defaults to None. + gather_output (bool, optional): If true, call all-gather on output and make Y available + to all GPUs, otherwise, every GPU will have its output + which is :math:`Y_i = XA_i`, defaults to False + skip_bias_add (bool): If set to ``True``, it will skip bias add for linear layer, + which is preserved for kernel fusion, defaults to False + weight_initializer (`typing.Callable`): + The initializer of weight, defaults to kaiming uniform initializer. + bias_initializer (`typing.Callable`): + The initializer of bias, defaults to xavier uniform initializer. + + More details about ``initializer`` please refer to + `init `_. + """ + + def __init__(self, + in_features: int, + out_features: int, + bias: bool = True, + dtype: torch.dtype = None, + device: torch.device = None, + process_group: ProcessGroup = None, + gather_output: bool = False, + skip_bias_add: bool = False, + weight_initializer: Callable = init.kaiming_uniform_(a=math.sqrt(5)), + bias_initializer: Callable = init.xavier_uniform_(a=1, scale=1)): + super().__init__() + + # Keep input parameters + self.in_features = in_features + self.out_features = out_features + self.gather_output = gather_output + self.skip_bias_add = skip_bias_add + self.device = device + self.process_group = process_group + self.num_partitions = dist.get_world_size(self.process_group) + + if skip_bias_add and not bias: + raise ValueError('cannot skip bias addition if bias is None') + + self.out_features_per_partition = divide(out_features, self.num_partitions) + + # Parameters. + # Initialize weight. + if device is None: + device = get_current_device() + factory_kwargs = {'device': device, 'dtype': dtype} + self.weight = Parameter(torch.empty(self.out_features_per_partition, self.in_features, **factory_kwargs)) + + if bias: + self.bias = Parameter(torch.empty(self.out_features_per_partition, **factory_kwargs)) + else: + self.bias = None + + # offset the seed with randomizer index and rank + seed = torch.random.initial_seed() + self.randomizer = create_randomizer_with_offset(seed, process_group=self.process_group) + + with self.randomizer.fork_rng(enable_cpu=True): + self.reset_parameters(weight_initializer, bias_initializer) + + @staticmethod + def from_native_module(module: nn.Linear, process_group: Union[ProcessGroup, List[ProcessGroup]], *args, + **kwargs) -> ParallelModule: + r""" + Convert a native PyTorch linear layer to a parallelized linear layer. + """ + # get the attributes + in_features = module.in_features + out_features = module.out_features + bias = module.bias is not None + device = module.weight.device + + # ensure only one process group is passed + if isinstance(process_group, (list, tuple)): + assert len(process_group) == 1, \ + f'Expected only one process group, got {len(process_group)}.' + process_group = process_group[0] + + linear_1d = Linear1D_Col(in_features=in_features, + out_features=out_features, + bias=bias, + device=device, + process_group=process_group, + *args, + **kwargs) + + # TODO: copy the sharded weights + with torch.no_grad(): + # the weigh to the linear layer is a transpose + # thus shard on row is equal to shard on column + sharded_weight = shard_rowwise(module.weight.data, process_group) + linear_1d.weight.data.copy_(sharded_weight) + if bias: + sharded_bias = shard_colwise(module.bias.data, process_group) + linear_1d.bias.copy_(sharded_bias) + + return linear_1d + + def reset_parameters(self, weight_initializer, bias_initializer) -> None: + fan_in, fan_out = self.in_features, self.out_features + weight_initializer(self.weight, fan_in=fan_in, fan_out=fan_out) + if self.bias is not None: + bias_initializer(self.bias, fan_in=fan_in) + + def forward(self, input_: Tensor) -> Tuple[Tensor, Tensor]: + assert input_.shape[-1] == self.weight.shape[-1], \ + 'Invalid shapes in Linear1D_Col forward: input={}, weight={}. Expected last dim of input {}.'.format( + input_.shape, self.weight.shape, self.weight.shape[-1]) + # Set up backprop all-reduce. + # input_parallel = reduce_grad(input_, ParallelMode.PARALLEL_1D) + input_parallel = input_ + # Matrix multiply. + bias = self.bias if not self.skip_bias_add else None + output_parallel = linear_with_async_comm(input_parallel, self.weight, bias, self.process_group, True) + + if self.gather_output: + # All-gather across the partitions. + output = gather_forward_split_backward(output_parallel, dim=-1, process_group=self.process_group) + else: + output = output_parallel + + if self.skip_bias_add: + return output, self.bias + else: + return output + + +class Linear1D_Row(ParallelModule): + r""" Linear layer with row parallelism + + Args: + in_features (int): size of each input sample. + out_features (int): size of each output sample. + bias (bool, optional): If set to ``False``, the layer will not learn an additive bias, defaults to ``True``. + dtype (`torch.dtype`): The dtype of parameters, defaults to None. + parallel_input (bool): If set to ``True``, it's assumed that the input is split, defaults to False. + skip_bias_add (bool): If set to ``True``, it will skip bias add for linear layer, + which is preserved for kernel fusion, defaults to False + weight_initializer (:class:`typing.Callable`, optional): + The initializer of weight, defaults to kaiming uniform initializer. + bias_initializer (:class:`typing.Callable`, optional): + The initializer of bias, defaults to xavier uniform initializer. + + More details about ``initializer`` please refer to + `init `_. + """ + + def __init__(self, + in_features: int, + out_features: int, + bias: bool = True, + dtype: torch.dtype = None, + device: torch.device = None, + process_group: ProcessGroup = None, + parallel_input: bool = True, + skip_bias_add: bool = False, + weight_initializer: Callable = init.kaiming_uniform_(a=math.sqrt(5)), + bias_initializer: Callable = init.xavier_uniform_(a=1, scale=1), + stream_chunk_num: int = 1): + super().__init__() + + self.stream_chunk_num = stream_chunk_num + + # Keep input parameters + self.in_features = in_features + self.out_features = out_features + self.parallel_input = parallel_input + self.skip_bias_add = skip_bias_add + self.process_group = process_group + self.num_partitions = dist.get_world_size(self.process_group) + + if skip_bias_add and not bias: + raise ValueError('cannot skip bias addition if bias is None') + + # Divide the weight matrix along the last dimension. + self.input_size_per_partition = divide(in_features, self.num_partitions) + + # Parameters. + # Initialize weight. + if device is None: + device = get_current_device() + + factory_kwargs = {'device': device, 'dtype': dtype} + self.weight = Parameter(torch.empty(self.out_features, self.input_size_per_partition, **factory_kwargs)) + + if self.stream_chunk_num > 1: + # TODO() work for inference only + self.chunk_weight() + if bias: + self.bias = Parameter(torch.empty(self.out_features, **factory_kwargs)) + else: + self.bias = None + + # offset the seed with randomizer index and rank + seed = torch.random.initial_seed() + self.randomizer = create_randomizer_with_offset(seed, process_group=self.process_group) + + with self.randomizer.fork_rng(enable_cpu=True): + self.reset_parameters(weight_initializer, bias_initializer) + + @staticmethod + def from_native_module(module: nn.Linear, process_group: Union[ProcessGroup, List[ProcessGroup]], *args, + **kwargs) -> ParallelModule: + r""" + Convert a native PyTorch linear layer to a parallelized linear layer. + """ + # get the attributes + in_features = module.in_features + out_features = module.out_features + bias = module.bias is not None + device = module.weight.device + + # ensure only one process group is passed + if isinstance(process_group, (list, tuple)): + assert len(process_group) == 1, \ + f'Expected only one process group, got {len(process_group)}.' + process_group = process_group[0] + + linear_1d = Linear1D_Row(in_features=in_features, + out_features=out_features, + bias=bias, + device=device, + process_group=process_group, + *args, + **kwargs) + + # TODO: copy the sharded weights + with torch.no_grad(): + # the weigh to the linear layer is a transpose + # thus shard on col is equal to shard on row + sharded_weight = shard_colwise(module.weight.data, process_group) + linear_1d.weight.data.copy_(sharded_weight) + + if bias: + linear_1d.bias.copy_(module.bias.data) + + return linear_1d + + def chunk_weight(self): + self.weight_list = torch.chunk(self.weight, self.stream_chunk_num, dim=0) + + def reset_parameters(self, weight_initializer, bias_initializer) -> None: + fan_in, fan_out = self.in_features, self.out_features + weight_initializer(self.weight, fan_in=fan_in, fan_out=fan_out) + + if self.bias is not None: + bias_initializer(self.bias, fan_in=fan_in) + if self.process_group is None: + src_rank = 0 + else: + src_rank = dist.distributed_c10d._get_global_rank(self.process_group, 0) + + origin_device = self.bias.device + self.bias = self.bias.cuda() + dist.broadcast(self.bias, src=src_rank, group=self.process_group) + self.bias = self.bias.to(origin_device) + + def forward(self, input_: Tensor) -> Tensor: + # Set up backprop all-reduce. + if self.parallel_input: + assert input_.shape[-1] == self.weight.shape[-1], \ + 'Invalid shapes in Linear1D_Row forward: input={}, weight={}. Expected last dim of input {}.'.format( + input_.shape, self.weight.shape, self.weight.shape[-1]) + input_ = input_ + else: + assert divide(input_.shape[-1], self.num_partitions) == self.weight.shape[-1], \ + 'Invalid shapes in Linear1D_Row forward: input={}, weight={}. Expected last dim of input {}.'.format( + input_.shape, self.weight.shape, self.weight.shape[-1] * self.num_partitions) + input_ = split_forward_gather_backward(input_, dim=-1, process_group=self.process_group) + + if self.stream_chunk_num > 1: + if self.training: + raise RuntimeError("use stream_chunk_num=1 in Linear1D_Row for training!") + with torch.no_grad(): + output_parallel_list = [None for i in range(self.stream_chunk_num)] + handle_list = [] + for i in range(self.stream_chunk_num): + output_parallel_list[i] = F.linear(input_, self.weight_list[i]) + handle = torch.distributed.all_reduce(output_parallel_list[i], + group=self.process_group, + async_op=True) + handle_list.append(handle) + # output_parallel_list[i] = reduce_input(output_parallel_list[i], ParallelMode.PARALLEL_1D) + for handle in handle_list: + handle.wait() + output = torch.cat(output_parallel_list, dim=-1) + else: + output_parallel = F.linear(input_, self.weight) + # output_parallel = linear_with_async_comm(input_, self.weight, None, ParallelMode.PARALLEL_1D, False) + output = reduce_input(output_parallel, self.process_group) + + if not self.skip_bias_add: + if self.bias is not None: + output = output + self.bias + return output + else: + return output, self.bias diff --git a/colossalai/shardformer/layer/linearconv1d.py b/colossalai/shardformer/layer/linearconv1d.py new file mode 100644 index 000000000000..4a5cb0707900 --- /dev/null +++ b/colossalai/shardformer/layer/linearconv1d.py @@ -0,0 +1,377 @@ +#!/usr/bin/env python +# -*- encoding: utf-8 -*- + +import math +from typing import Callable, List, Tuple, Union + +import torch +import torch.distributed as dist +import torch.nn as nn +import torch.nn.functional as F +from torch import Tensor +from torch.distributed import ProcessGroup +from torch.nn.parameter import Parameter + +from colossalai.nn import init as init +from colossalai.nn.layer.utils import divide +from colossalai.tensor.d_tensor.api import shard_colwise, shard_rowwise +from colossalai.utils.cuda import get_current_device + +from ._operation import ( + gather_forward_split_backward, + linear_with_async_comm, + reduce_input, + split_forward_gather_backward, +) +from .parallelmodule import ParallelModule +from .utils import create_randomizer_with_offset + +Fast_LN = None +try: + from apex.contrib.layer_norm.layer_norm import FastLayerNorm + Fast_LN = FastLayerNorm +except ImportError: + pass + + +class LinearConv1D_Col(ParallelModule): + r"""Linear layer with column parallelism. + + The linear layer is defined as :math:`Y = XA + b`. A is parallelized along + its second dimension as :math:`A = [A_1, ..., A_p]`. This layer is used to fit `Conv1D` layer in gpt2 of huggingface. + + Args: + in_features (int): size of each input sample. + out_features (int): size of each output sample. + bias (bool, optional): If set to ``False``, the layer will not learn an additive bias, defaults to ``True``. + dtype (`torch.dtype`): The dtype of parameters, defaults to None. + device (`torch.device`): The device of parameters, defaults to None. + process_group (`torch.distributed.ProcessGroup`): The process group to be used for weight sharding and communication, defaults to None. + gather_output (bool, optional): If true, call all-gather on output and make Y available + to all GPUs, otherwise, every GPU will have its output + which is :math:`Y_i = XA_i`, defaults to False + skip_bias_add (bool): If set to ``True``, it will skip bias add for linear layer, + which is preserved for kernel fusion, defaults to False + weight_initializer (`typing.Callable`): + The initializer of weight, defaults to kaiming uniform initializer. + bias_initializer (`typing.Callable`): + The initializer of bias, defaults to xavier uniform initializer. + + More details about ``initializer`` please refer to + `init `_. + """ + + def __init__(self, + in_features: int, + out_features: int, + bias: bool = True, + dtype: torch.dtype = None, + device: torch.device = None, + process_group: ProcessGroup = None, + gather_output: bool = False, + skip_bias_add: bool = False, + weight_initializer: Callable = init.kaiming_uniform_(a=math.sqrt(5)), + bias_initializer: Callable = init.xavier_uniform_(a=1, scale=1)): + super().__init__() + + # Keep input parameters + self.in_features = in_features + self.out_features = out_features + self.gather_output = gather_output + self.skip_bias_add = skip_bias_add + self.device = device + self.process_group = process_group + self.num_partitions = dist.get_world_size(self.process_group) + + if skip_bias_add and not bias: + raise ValueError('cannot skip bias addition if bias is None') + + self.out_features_per_partition = divide(out_features, self.num_partitions) + + # Parameters. + # Initialize weight. + if device is None: + device = get_current_device() + factory_kwargs = {'device': device, 'dtype': dtype} + self.weight = Parameter(torch.empty(self.out_features_per_partition, self.in_features, **factory_kwargs)) + + if bias: + self.bias = Parameter(torch.empty(self.out_features_per_partition, **factory_kwargs)) + else: + self.bias = None + + # offset the seed with randomizer index and rank + seed = torch.random.initial_seed() + self.randomizer = create_randomizer_with_offset(seed, process_group=self.process_group) + + with self.randomizer.fork_rng(enable_cpu=True): + self.reset_parameters(weight_initializer, bias_initializer) + + @staticmethod + def from_native_module(module: nn.Linear, process_group: Union[ProcessGroup, List[ProcessGroup]], n_cast: int, + *args, **kwargs) -> ParallelModule: + r""" + Convert a huggingface layer `Conv1D` in gpt2 to a parallelized linear layer. + """ + # get the attributes + in_features = module.weight.shape[0] + out_features = module.weight.shape[1] + bias = module.bias is not None + device = module.weight.device + + # ensure only one process group is passed + if isinstance(process_group, (list, tuple)): + assert len(process_group) == 1, \ + f'Expected only one process group, got {len(process_group)}.' + process_group = process_group[0] + + linear_1d = LinearConv1D_Col(in_features=in_features, + out_features=out_features, + bias=bias, + device=device, + process_group=process_group, + *args, + **kwargs) + + # TODO: copy the sharded weights + with torch.no_grad(): + # the weigh to the linear layer is a transpose + # thus shard on row is equal to shard on column + + # first rearange the order of weight and bias + world_size = dist.get_world_size(group=process_group) + order = torch.arange(world_size * n_cast) + new_order = [] + for i in range(world_size): + new_order.append(order[i::world_size]) + new_order = torch.cat(new_order) + + weight_chunks = torch.chunk(module.weight.data, world_size * n_cast, dim=1) + rearanged_weight_chunks = [weight_chunks[i] for i in new_order] + rearanged_weight = torch.cat(rearanged_weight_chunks, dim=1) + sharded_weight = shard_colwise(rearanged_weight, process_group) + linear_1d.weight.data.copy_(sharded_weight.T.contiguous()) + + if bias: + bias_chunks = torch.chunk(module.bias.data, world_size * n_cast, dim=0) + rearanged_bias_chunks = [bias_chunks[i] for i in new_order] + rearanged_bias = torch.cat(rearanged_bias_chunks, dim=0) + sharded_bias = shard_colwise(rearanged_bias, process_group) + linear_1d.bias.copy_(sharded_bias.contiguous()) + + return linear_1d + + def reset_parameters(self, weight_initializer, bias_initializer) -> None: + fan_in, fan_out = self.in_features, self.out_features + weight_initializer(self.weight, fan_in=fan_in, fan_out=fan_out) + if self.bias is not None: + bias_initializer(self.bias, fan_in=fan_in) + + def forward(self, input_: Tensor) -> Tuple[Tensor, Tensor]: + assert input_.shape[-1] == self.weight.shape[-1], \ + 'Invalid shapes in Linear1D_Col forward: input={}, weight={}. Expected last dim of input {}.'.format( + input_.shape, self.weight.shape, self.weight.shape[-1]) + # Set up backprop all-reduce. + # input_parallel = reduce_grad(input_, ParallelMode.PARALLEL_1D) + input_parallel = input_ + # Matrix multiply. + bias = self.bias if not self.skip_bias_add else None + output_parallel = linear_with_async_comm(input_parallel, self.weight, bias, self.process_group, True) + + if self.gather_output: + # All-gather across the partitions. + output = gather_forward_split_backward(output_parallel, dim=-1, process_group=self.process_group) + else: + output = output_parallel + + if self.skip_bias_add: + return output, self.bias + else: + return output + + +class LinearConv1D_Row(ParallelModule): + r""" Linear layer with row parallelism + + Args: + in_features (int): size of each input sample. + out_features (int): size of each output sample. + bias (bool, optional): If set to ``False``, the layer will not learn an additive bias, defaults to ``True``. + dtype (`torch.dtype`): The dtype of parameters, defaults to None. + parallel_input (bool): If set to ``True``, it's assumed that the input is split, defaults to False. + skip_bias_add (bool): If set to ``True``, it will skip bias add for linear layer, + which is preserved for kernel fusion, defaults to False + weight_initializer (:class:`typing.Callable`, optional): + The initializer of weight, defaults to kaiming uniform initializer. + bias_initializer (:class:`typing.Callable`, optional): + The initializer of bias, defaults to xavier uniform initializer. + + More details about ``initializer`` please refer to + `init `_. + """ + + def __init__(self, + in_features: int, + out_features: int, + bias: bool = True, + dtype: torch.dtype = None, + device: torch.device = None, + process_group: ProcessGroup = None, + parallel_input: bool = True, + skip_bias_add: bool = False, + weight_initializer: Callable = init.kaiming_uniform_(a=math.sqrt(5)), + bias_initializer: Callable = init.xavier_uniform_(a=1, scale=1), + stream_chunk_num: int = 1): + super().__init__() + + self.stream_chunk_num = stream_chunk_num + + # Keep input parameters + self.in_features = in_features + self.out_features = out_features + self.parallel_input = parallel_input + self.skip_bias_add = skip_bias_add + self.process_group = process_group + self.num_partitions = dist.get_world_size(self.process_group) + + if skip_bias_add and not bias: + raise ValueError('cannot skip bias addition if bias is None') + + # Divide the weight matrix along the last dimension. + self.input_size_per_partition = divide(in_features, self.num_partitions) + + # Parameters. + # Initialize weight. + if device is None: + device = get_current_device() + + factory_kwargs = {'device': device, 'dtype': dtype} + self.weight = Parameter(torch.empty(self.out_features, self.input_size_per_partition, **factory_kwargs)) + + if self.stream_chunk_num > 1: + # TODO() work for inference only + self.chunk_weight() + if bias: + self.bias = Parameter(torch.empty(self.out_features, **factory_kwargs)) + else: + self.bias = None + + # offset the seed with randomizer index and rank + seed = torch.random.initial_seed() + self.randomizer = create_randomizer_with_offset(seed, process_group=self.process_group) + + with self.randomizer.fork_rng(enable_cpu=True): + self.reset_parameters(weight_initializer, bias_initializer) + + @staticmethod + def from_native_module(module: nn.Linear, process_group: Union[ProcessGroup, List[ProcessGroup]], n_cast: int, + *args, **kwargs) -> ParallelModule: + r""" + Convert a native PyTorch linear layer to a parallelized linear layer. + """ + # get the attributes + in_features = module.weight.shape[0] + out_features = module.weight.shape[1] + bias = module.bias is not None + device = module.weight.device + + # ensure only one process group is passed + if isinstance(process_group, (list, tuple)): + assert len(process_group) == 1, \ + f'Expected only one process group, got {len(process_group)}.' + process_group = process_group[0] + + linear_1d = LinearConv1D_Row(in_features=in_features, + out_features=out_features, + bias=bias, + device=device, + process_group=process_group, + *args, + **kwargs) + + # TODO: copy the sharded weights + with torch.no_grad(): + # the weigh to the linear layer is a transpose + # thus shard on col is equal to shard on row + + # first rearange the order of weight and bias + world_size = dist.get_world_size(group=process_group) + order = torch.arange(world_size * n_cast) + new_order = [] + for i in range(world_size): + new_order.append(order[i::world_size]) + new_order = torch.cat(new_order) + + weight_chunks = torch.chunk(module.weight.data, world_size * n_cast, dim=0) + rearanged_weight_chunks = [weight_chunks[i] for i in new_order] + rearanged_weight = torch.cat(rearanged_weight_chunks, dim=0) + sharded_weight = shard_rowwise(rearanged_weight, process_group) + linear_1d.weight.data.copy_(sharded_weight.T.contiguous()) + + if bias: + bias_chunks = torch.chunk(module.bias.data, world_size * n_cast, dim=0) + rearanged_bias_chunks = [bias_chunks[i] for i in new_order] + rearanged_bias = torch.cat(rearanged_bias_chunks, dim=0) + linear_1d.bias.copy_(rearanged_bias.contiguous()) + + return linear_1d + + def chunk_weight(self): + self.weight_list = torch.chunk(self.weight, self.stream_chunk_num, dim=0) + + def reset_parameters(self, weight_initializer, bias_initializer) -> None: + fan_in, fan_out = self.in_features, self.out_features + weight_initializer(self.weight, fan_in=fan_in, fan_out=fan_out) + + if self.bias is not None: + bias_initializer(self.bias, fan_in=fan_in) + if self.process_group is None: + src_rank = 0 + else: + src_rank = dist.distributed_c10d._get_global_rank(self.process_group, 0) + + origin_device = self.bias.device + self.bias = self.bias.cuda() + dist.broadcast(self.bias, src=src_rank, group=self.process_group) + self.bias = self.bias.to(origin_device) + + def forward(self, input_: Tensor) -> Tensor: + # Set up backprop all-reduce. + if self.parallel_input: + assert input_.shape[-1] == self.weight.shape[-1], \ + 'Invalid shapes in Linear1D_Row forward: input={}, weight={}. Expected last dim of input {}.'.format( + input_.shape, self.weight.shape, self.weight.shape[-1]) + input_ = input_ + else: + assert divide(input_.shape[-1], self.num_partitions) == self.weight.shape[-1], \ + 'Invalid shapes in Linear1D_Row forward: input={}, weight={}. Expected last dim of input {}.'.format( + input_.shape, self.weight.shape, self.weight.shape[-1] * self.num_partitions) + input_ = split_forward_gather_backward(input_, dim=-1, process_group=self.process_group) + + if self.stream_chunk_num > 1: + if self.training: + raise RuntimeError("use stream_chunk_num=1 in Linear1D_Row for training!") + with torch.no_grad(): + output_parallel_list = [None for i in range(self.stream_chunk_num)] + handle_list = [] + for i in range(self.stream_chunk_num): + output_parallel_list[i] = F.linear(input_, self.weight_list[i]) + handle = torch.distributed.all_reduce(output_parallel_list[i], + group=self.process_group, + async_op=True) + handle_list.append(handle) + # output_parallel_list[i] = reduce_input(output_parallel_list[i], ParallelMode.PARALLEL_1D) + for handle in handle_list: + handle.wait() + output = torch.cat(output_parallel_list, dim=-1) + else: + output_parallel = F.linear(input_, self.weight) + # output_parallel = linear_with_async_comm(input_, self.weight, None, ParallelMode.PARALLEL_1D, False) + output = reduce_input(output_parallel, self.process_group) + + if not self.skip_bias_add: + if self.bias is not None: + output = output + self.bias + return output + else: + return output, self.bias diff --git a/colossalai/shardformer/layer/parallelmodule.py b/colossalai/shardformer/layer/parallelmodule.py new file mode 100644 index 000000000000..3d19bbea7e47 --- /dev/null +++ b/colossalai/shardformer/layer/parallelmodule.py @@ -0,0 +1,35 @@ +#!/usr/bin/env python +# -*- encoding: utf-8 -*- + +from abc import ABC, abstractmethod +from typing import List, Union + +import torch.nn as nn +from torch.distributed import ProcessGroup + +from colossalai.global_variables import tensor_parallel_env as env +from colossalai.nn import init as init + +Fast_LN = None +try: + from apex.contrib.layer_norm.layer_norm import FastLayerNorm + Fast_LN = FastLayerNorm +except ImportError: + pass + + +class ParallelModule(nn.Module, ABC): + + @abstractmethod + def from_native_module(module: nn.Module, + process_group: Union[ProcessGroup, List[ProcessGroup]] = None) -> "ParallelModule": + """ + Convert a native PyTorch module to a parallelized module. + + Args: + module (nn.Module): the module to be converted. + process_group (ProcessGroup or list[ProcessGroup]): the process group(s) to be used for communication. + If this is a list, the process group at the ith index of the list will correspond to the process group + in the ith axis of the device mesh. Defaults to None, which means the global process group. + """ + pass diff --git a/colossalai/shardformer/layer/vocabparallelembedding1d.py b/colossalai/shardformer/layer/vocabparallelembedding1d.py new file mode 100644 index 000000000000..4c325c68421b --- /dev/null +++ b/colossalai/shardformer/layer/vocabparallelembedding1d.py @@ -0,0 +1,170 @@ +#!/usr/bin/env python +# -*- encoding: utf-8 -*- + +from collections import OrderedDict +from typing import Callable, List, Union + +import torch +import torch.distributed as dist +import torch.nn as nn +import torch.nn.functional as F +from torch import Tensor +from torch.distributed import ProcessGroup +from torch.nn.parameter import Parameter + +from colossalai.context import ParallelMode, seed +from colossalai.nn import init as init +from colossalai.nn.layer.base_layer import ParallelLayer +from colossalai.nn.layer.utils import divide +from colossalai.tensor.d_tensor.api import shard_rowwise +from colossalai.utils.checkpointing import gather_tensor_parallel_state_dict + +from ._operation import reduce_input +from .parallelmodule import ParallelModule +from .utils import create_randomizer_with_offset + +Fast_LN = None +try: + from apex.contrib.layer_norm.layer_norm import FastLayerNorm + Fast_LN = FastLayerNorm +except ImportError: + pass + + +class VocabParallelEmbedding1D(ParallelLayer): + r"""Embedding parallelized in the vocabulary dimension. + + Args: + num_embeddings (int): number of embeddings. + embedding_dim (int): dimension of embedding. + padding_idx (int, optional): If specified, the entries at padding_idx do not contribute to the gradient; + therefore, the embedding vector at padding_idx is not updated during training, + i.e. it remains as a fixed “pad”, defaults to None. + dtype (:class:`torch.dtype`, optional): The dtype of parameters, defaults to None. + weight_initializer (:class:`typing.Callable`, optional): + he initializer of weight, defaults to normal initializer. + + The ``args`` and ``kwargs`` used in :class:``torch.nn.functional.embedding`` should contain: + :: + + max_norm (float, optional): If given, each embedding vector with norm larger than max_norm is + renormalized to have norm max_norm. Note: this will modify weight in-place. + norm_type (float, optional): The p of the p-norm to compute for the max_norm option. Default 2. + scale_grad_by_freq (bool, optional): If given, this will scale gradients by the inverse + of frequency of the words in the mini-batch. Default False. + sparse (bool, optional): If True, gradient w.r.t. weight will be a sparse tensor. Default False. + + More details about ``args`` and ``kwargs`` could be found in + `Embedding `_. + + More details about initializer please refer to + `init `_. + """ + + def __init__(self, + num_embeddings: int, + embedding_dim: int, + padding_idx: int = None, + dtype: torch.dtype = None, + device: torch.device = None, + process_group: ProcessGroup = None, + weight_initializer: Callable = init.normal_(), + *args, + **kwargs): + super().__init__() + self.num_embeddings = num_embeddings + self.embed_dim = embedding_dim + self.padding_idx = padding_idx + self.embed_args = args + self.embed_kwargs = kwargs + self.process_group = process_group + + tensor_parallel_size = dist.get_world_size(group=process_group) + tensor_parallel_rank = dist.get_rank(group=process_group) + + self.num_embeddings_per_partition = divide(num_embeddings, tensor_parallel_size) + self.num_embeddings = self.num_embeddings_per_partition + self.vocab_start_index = tensor_parallel_rank * self.num_embeddings_per_partition + self.vocab_end_index = self.vocab_start_index + self.num_embeddings_per_partition + + self.weight = Parameter( + torch.empty((self.num_embeddings_per_partition, self.embed_dim), device=device, dtype=dtype)) + + # offset the seed with randomizer index and rank + seed = torch.random.initial_seed() + self.randomizer = create_randomizer_with_offset(seed, process_group=self.process_group) + + with self.randomizer.fork_rng(enable_cpu=True): + self.reset_parameters(weight_initializer) + + @staticmethod + def from_native_module(module: nn.Embedding, process_group: Union[ProcessGroup, List[ProcessGroup]], *args, + **kwargs) -> ParallelModule: + r""" + Convert a native pytorch embedding module to a parallel module. + """ + # get the origin attributes + num_embeddings = module.num_embeddings + embedding_dim = module.embedding_dim + padding_idx = module.padding_idx + device = module.weight.device + + # ensure only one process group is used + if isinstance(process_group, (list, tuple)): + assert len(process_group) == 1, \ + f'Expected only one process group, got {len(process_group)}.' + process_group = process_group[0] + + # create the parallel module + vocab_embedding_1d = VocabParallelEmbedding1D(num_embeddings=num_embeddings, + embedding_dim=embedding_dim, + padding_idx=padding_idx, + device=device, + process_group=process_group, + *args, + **kwargs) + with torch.no_grad(): + # shard and slice the weight along the vocabulary(num_embeddings) dimension + # the shape of the weight is (num_embeddings, embedding_dim) + shard_weight = shard_rowwise(module.weight.data, process_group) + vocab_embedding_1d.weight.data.copy_(shard_weight) + + return vocab_embedding_1d + + def reset_parameters(self, weight_initializer) -> None: + with seed(ParallelMode.TENSOR): + fan_in, fan_out = self.num_embeddings, self.embed_dim + weight_initializer(self.weight, fan_in=fan_in, fan_out=fan_out) + self._fill_padding_idx_with_zero() + + def _fill_padding_idx_with_zero(self) -> None: + if self.padding_idx is not None and \ + self.padding_idx >= self.vocab_start_index and self.padding_idx < self.vocab_end_index: + with torch.no_grad(): + self.weight[self.padding_idx - self.vocab_start_index].fill_(0) + + def _save_to_global_state_dict(self, destination, prefix, keep_vars): + weight_key = prefix + 'weight' + local_state = OrderedDict({weight_key: self.weight}) + local_state = gather_tensor_parallel_state_dict(local_state, + ParallelMode.PARALLEL_1D, + dims={weight_key: 0}, + partition_states={weight_key: True}, + keep_vars=keep_vars) + destination.update(local_state) + + def forward(self, input_: Tensor) -> Tensor: + # Build the mask. + input_mask = (input_ < self.vocab_start_index) | (input_ >= self.vocab_end_index) + # Mask the input. + masked_input = input_.clone() - self.vocab_start_index + masked_input[input_mask] = 0 + + output_parallel = F.embedding(masked_input, self.weight, self.padding_idx, *self.embed_args, + **self.embed_kwargs) + + # Mask the output embedding. + output_parallel[input_mask, :] = 0. + # Reduce across all the model parallel GPUs. + output = reduce_input(output_parallel, self.process_group) + return output diff --git a/colossalai/shardformer/policies/autopolicy.py b/colossalai/shardformer/policies/autopolicy.py index 6ce0b8fb3a3d..5e7a285e3285 100644 --- a/colossalai/shardformer/policies/autopolicy.py +++ b/colossalai/shardformer/policies/autopolicy.py @@ -56,6 +56,8 @@ class PolicyLocation: PolicyLocation(file_name="t5", class_name="T5EncoderPolicy"), # GPT2 + "transformers.models.gpt2.modeling_gpt2.GPT2Model": + PolicyLocation(file_name="gpt2", class_name="GPT2ModelPolicy"), } @@ -99,4 +101,3 @@ def get_autopolicy(model: nn.Module) -> Policy: else: policy = import_policy(policy_location) return policy() - return policy() diff --git a/colossalai/shardformer/policies/bert.py b/colossalai/shardformer/policies/bert.py index 06ee9b435e7e..2a204f0defe4 100644 --- a/colossalai/shardformer/policies/bert.py +++ b/colossalai/shardformer/policies/bert.py @@ -1,7 +1,7 @@ import torch.nn as nn from transformers.models.bert.modeling_bert import BertEmbeddings, BertLayer, BertLMPredictionHead -import colossalai.shardformer.layer.layers as col_nn +import colossalai.shardformer.layer as col_nn from colossalai.shardformer.layer.dropout import Dropout1D from ..utils import getattr_, setattr_ @@ -87,15 +87,9 @@ def module_policy(self): def new_model_class(self): # do nothing - return None + return self.model def postprocess(self): - binding_map = {"bert.embeddings.word_embeddings.weight": "cls.predictions.decoder.weight"} - for k, v in binding_map.items(): - param = getattr_(self.model, k) - param = nn.Parameter(param) - setattr_(self.model, k, param) - setattr_(self.model, v, param) return self.model @@ -127,6 +121,15 @@ def module_policy(self): module_policy.update(addon_module) return module_policy + def postprocess(self): + binding_map = {"bert.embeddings.word_embeddings.weight": "cls.predictions.decoder.weight"} + for k, v in binding_map.items(): + param = getattr_(self.model, k) + param = nn.Parameter(param) + setattr_(self.model, k, param) + setattr_(self.model, v, param) + return self.model + # BertForMaskedLM class BertForMaskedLMPolicy(BertPolicy): @@ -149,6 +152,15 @@ def module_policy(self): module_policy.update(addon_module) return module_policy + def postprocess(self): + binding_map = {"bert.embeddings.word_embeddings.weight": "cls.predictions.decoder.weight"} + for k, v in binding_map.items(): + param = getattr_(self.model, k) + param = nn.Parameter(param) + setattr_(self.model, k, param) + setattr_(self.model, v, param) + return self.model + # BertLMHeadModel class BertLMHeadModelPolicy(BertPolicy): @@ -171,6 +183,15 @@ def module_policy(self): module_policy.update(addon_module) return module_policy + def postprocess(self): + binding_map = {"bert.embeddings.word_embeddings.weight": "cls.predictions.decoder.weight"} + for k, v in binding_map.items(): + param = getattr_(self.model, k) + param = nn.Parameter(param) + setattr_(self.model, k, param) + setattr_(self.model, v, param) + return self.model + # BertForNextSentencePrediction class BertForNextSentencePredictionPolicy(BertPolicy): diff --git a/colossalai/shardformer/policies/gpt2.py b/colossalai/shardformer/policies/gpt2.py index 0d4342e75783..d255325b2084 100644 --- a/colossalai/shardformer/policies/gpt2.py +++ b/colossalai/shardformer/policies/gpt2.py @@ -1,126 +1,101 @@ -from typing import Any, Callable, Dict, List, Tuple, Type +from typing import Type, Union import torch.nn as nn from transformers.models.gpt2.modeling_gpt2 import GPT2Block, GPT2Model -import colossalai.shardformer.layer.layers as col_nn +import colossalai.shardformer.layer as col_nn +from colossalai.shardformer.layer.dropout import Dropout1D -from .basepolicy import Argument, Col_Layer, Layer, Policy, Row_Layer +from ..utils import getattr_, setattr_ +from .basepolicy import ModulePolicyDescription, Policy, SubModuleReplacementDescription class GPT2Policy(Policy): - @staticmethod - def argument_policy(config, world_size): + def preprocess(self): + # reshape the embedding layer + r""" + Reshape the Embedding layer to make the embedding dimension divisible by world_size + """ + vocab_size = self.model.config.vocab_size + world_size = self.shard_config.tensor_parallel_size + if vocab_size % world_size != 0: + new_vocab_size = vocab_size + world_size - vocab_size % world_size + self.model.resize_token_embeddings(new_vocab_size) + return self.model + + def module_policy(self): return { GPT2Model: - Argument(attr_dict={}, param_funcs=[ - GPT2Policy.embedding, - ]), + ModulePolicyDescription(attribute_replacement={}, + param_replacement=[], + sub_module_replacement=[ + SubModuleReplacementDescription( + suffix="wte", + target_module=col_nn.VocabParallelEmbedding1D, + ), + ]), GPT2Block: - Argument( - attr_dict={ - # 1. reduce hidden size - "attn.embed_dim": config.hidden_size // world_size, - "attn.split_size": config.hidden_size // world_size, - "crossattention.embed_dim": config.hidden_size // world_size, - "crossattention.split_size": config.hidden_size // world_size, - # 2. reduce number of heads - "attn.num_heads": config.num_attention_heads // world_size, - "crossattention.num_heads": config.num_attention_heads // world_size, - }, - param_funcs=[ - GPT2Policy.attn_in, - GPT2Policy.attn_out, - GPT2Policy.mlp_in, - GPT2Policy.mlp_out, - ]), + ModulePolicyDescription(attribute_replacement={ + "attn.embed_dim": self.model.config.hidden_size // self.shard_config.tensor_parallel_size, + "attn.split_size": self.model.config.hidden_size // self.shard_config.tensor_parallel_size, + "attn.num_heads": self.model.config.num_attention_heads // self.shard_config.tensor_parallel_size, + }, + param_replacement=[], + sub_module_replacement=[ + SubModuleReplacementDescription( + suffix="attn.c_attn", + target_module=col_nn.LinearConv1D_Col, + kwargs={ + "n_cast": 3, + }, + ), + SubModuleReplacementDescription( + suffix="attn.c_proj", + target_module=col_nn.LinearConv1D_Row, + kwargs={ + "n_cast": 1, + }, + ), + SubModuleReplacementDescription( + suffix="mlp.c_fc", + target_module=col_nn.LinearConv1D_Col, + kwargs={ + "n_cast": 1, + }, + ), + SubModuleReplacementDescription( + suffix="mlp.c_proj", + target_module=col_nn.LinearConv1D_Row, + kwargs={ + "n_cast": 1, + }, + ), + SubModuleReplacementDescription( + suffix="attn.attn_dropout", + target_module=col_nn.Dropout1D, + ), + SubModuleReplacementDescription( + suffix="attn.resid_dropout", + target_module=col_nn.Dropout1D, + ), + SubModuleReplacementDescription( + suffix="mlp.dropout", + target_module=col_nn.Dropout1D, + ), + ]) } - @staticmethod - def attn_in() -> List: - return [ - Col_Layer(suffix="attn.c_attn", - weight="weight", - bias="bias", - n_cast=3, - reversed=True, - replace_layer=col_nn.Linear1D_Col), - Col_Layer(suffix="crossattention.c_attn", - weight="weight", - bias="bias", - n_cast=2, - reversed=True, - ignore=True, - replace_layer=col_nn.Linear1D_Col), - Col_Layer(suffix="crossattention.q_attn", - weight="weight", - bias="bias", - reversed=True, - ignore=True, - replace_layer=col_nn.Linear1D_Col) - ] + def new_model_class(self): - @staticmethod - def attn_out() -> List: - return [ - Row_Layer(suffix="attn.c_proj", - weight="weight", - bias="bias", - reversed=True, - replace_layer=col_nn.Linear1D_Row), - Row_Layer(suffix="crossattention.c_proj", - weight="weight", - bias="bias", - reversed=True, - ignore=True, - replace_layer=col_nn.Linear1D_Row) - ] + return self.model - @staticmethod - def mlp_in() -> List: - return [ - Col_Layer(suffix="mlp.c_fc", weight="weight", bias="bias", reversed=True, - replace_layer=col_nn.Linear1D_Col), - ] + def postprocess(self): + return self.model - @staticmethod - def mlp_out() -> List: - return [ - Row_Layer(suffix="mlp.c_proj", - weight="weight", - bias="bias", - reversed=True, - replace_layer=col_nn.Linear1D_Row) - ] - @staticmethod - def embedding() -> List: - return [Col_Layer(suffix="wte", weight="weight", replace_layer=col_nn.VocabParallelEmbedding1D)] +# GPT2Model +class GPT2ModelPolicy(GPT2Policy): - -from transformers import GPT2LMHeadModel - - -class GPT2LMHeadModelPolicy(GPT2Policy): - - @staticmethod - def argument_policy(config, world_size): - base_argument = GPT2Policy.argument_policy(config, world_size) - argument = { - GPT2LMHeadModel: Argument(attr_dict={}, param_funcs=[ - GPT2LMHeadModelPolicy.unembedding, - ]), - } - argument.update(base_argument) - return argument - - @staticmethod - def unembedding() -> List: - return [ - Col_Layer(suffix="lm_head", - weight="weight", - bias="bias", - replace_layer=col_nn.Linear1D_Col, - gather_output=True) - ] + def __init__(self) -> None: + super().__init__() diff --git a/tests/test_shardformer/test_model/test_shard_bert.py b/tests/test_shardformer/test_model/test_shard_bert.py index 54fea0335e54..043ed1a74a27 100644 --- a/tests/test_shardformer/test_model/test_shard_bert.py +++ b/tests/test_shardformer/test_model/test_shard_bert.py @@ -108,7 +108,7 @@ def check_bert(rank, world_size, port): backward_lsit = [BertForMaskedLM, BertLMHeadModel] for model_fn in forward_list: - org_model, sharded_model = build_model(model_fn) + org_model, sharded_model = build_model(world_size, model_fn) check_forward(org_model, sharded_model) if model_fn in backward_lsit: diff --git a/tests/test_shardformer/test_model/test_shard_gpt2.py b/tests/test_shardformer/test_model/test_shard_gpt2.py new file mode 100644 index 000000000000..2f679b83f99b --- /dev/null +++ b/tests/test_shardformer/test_model/test_shard_gpt2.py @@ -0,0 +1,118 @@ +import copy +import os + +import pytest +import torch +from transformers import AutoTokenizer, GPT2Config, GPT2Model + +import colossalai +from colossalai.logging import disable_existing_loggers +from colossalai.shardformer import ShardConfig, ShardFormer +from colossalai.testing import rerun_if_address_is_in_use, spawn + +os.environ['TRANSFORMERS_NO_ADVISORY_WARNINGS'] = 'true' +CONFIG = dict(parallel=dict(data=1, pipeline=1, tensor=dict(size=2, mode='1d')),) +tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased") + + +def build_model(world_size, model_fn): + config = GPT2Config() + config.attn_pdrop = 0 + config.embd_pdrop = 0 + config.resid_pdrop = 0 + config.summary_first_dropout + + org_model = model_fn(config=config) + org_model_forshard = copy.deepcopy(org_model) + + org_model.to('cuda') + # TODO: no need to transfer to cuda + org_model_forshard.to('cuda') + shard_config = ShardConfig(tensor_parallel_size=world_size,) + shard_former = ShardFormer(shard_config=shard_config) + shard_former.init_distributed() + sharded_model = shard_former.shard_model(org_model_forshard).to('cuda') + + return org_model, sharded_model + + +def check_forward(org_model, sharded_model): + input = 'Hello, my dog is cute' + tokenized_input = tokenizer(input, return_tensors='pt').to('cuda') + + #orgin model + org_model.eval() + org_out = org_model(**tokenized_input) + + #shard model + sharded_model.eval() + shard_out = sharded_model(**tokenized_input) + + assert torch.allclose( + org_out[0], shard_out[0], + atol=1e-5), f"shard model output is not equal to orgin model output\n{org_out[0]}\n{shard_out[0]}" + + +def check_backward(org_model, sharded_model): + # prepare input + input = 'Hello, my dog is cute' + tokenized_input = tokenizer(input, return_tensors='pt').to('cuda') + labels = tokenized_input['input_ids'].clone() + labels[labels == tokenizer.pad_token_id] = -100 + # tokenized_input['labels'] = labels + + #orgin model + org_model.train() + org_out = org_model(**tokenized_input) + org_loss = org_out.loss + org_loss.backward() + org_grad = org_model.h[0].attn.c_attn.weight.grad + + #shard model + sharded_model.train() + shard_out = sharded_model(**tokenized_input) + shard_loss = shard_out.loss + shard_loss.backward() + shard_grad = sharded_model.h[0].attn.c_attn.weight.grad + + shard_grad_list = [torch.zeros([*shard_grad.shape]).to('cuda') for _ in range(2)] + shard_grad = torch.distributed.all_gather(shard_grad_list, shard_grad) + all_shard_grad = torch.cat(shard_grad_list, dim=0) + + assert torch.allclose(org_loss, shard_loss, + atol=1e-5), f"shard model loss is not equal to orgin model loss\n{org_loss}\n{shard_loss}" + assert torch.allclose(org_grad, all_shard_grad, + atol=1e-5), f"shard model grad is not equal to orgin model grad\n{org_grad}\n{shard_grad}" + + +def check_bert(rank, world_size, port): + disable_existing_loggers() + colossalai.launch(config=CONFIG, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') + forward_list = [ + GPT2Model, + + # TODO: do not work yet + # BertModel, + # BertForSequenceClassification + # BertForNextSentencePrediction, + ] + backward_lsit = [] + + for model_fn in forward_list: + org_model, sharded_model = build_model(world_size, model_fn) + check_forward(org_model, sharded_model) + + if model_fn in backward_lsit: + check_backward(org_model, sharded_model) + + torch.cuda.empty_cache() + + +@pytest.mark.dist +@rerun_if_address_is_in_use() +def test_gpt2(): + spawn(check_bert, 2) + + +if __name__ == "__main__": + test_gpt2() From ac3aef3a85d78b085df6b65f6482060df7b85b12 Mon Sep 17 00:00:00 2001 From: Frank Lee Date: Wed, 21 Jun 2023 09:32:46 +0800 Subject: [PATCH 29/49] [shardformer] adapted T5 and LLaMa test to use kit (#4049) * [shardformer] adapted T5 and LLaMa test to use kit * polish code --- colossalai/shardformer/layer/embedding1d.py | 22 ++++-- colossalai/shardformer/policies/llama.py | 2 +- colossalai/shardformer/policies/t5.py | 3 +- colossalai/shardformer/shard/sharder.py | 11 ++- colossalai/testing/comparison.py | 2 +- tests/kit/model_zoo/registry.py | 30 ++++--- tests/kit/model_zoo/transformers/__init__.py | 1 + tests/kit/model_zoo/transformers/llama.py | 76 ++++++++++++++++++ tests/kit/model_zoo/transformers/t5.py | 53 ++++++++++--- .../test_mixed_precision/test_fp16_torch.py | 2 +- .../test_plugin/test_gemini_plugin.py | 2 +- .../test_plugin/test_low_level_zero_plugin.py | 2 +- .../test_plugin/test_torch_ddp_plugin.py | 2 +- .../test_plugin/test_torch_fsdp_plugin.py | 2 +- .../test_hf_model/test_hf_diffuser.py | 4 +- .../test_timm_model/test_timm_model.py | 2 +- .../test_torchaudio_model.py | 2 +- tests/test_lazy/lazy_init_utils.py | 2 +- tests/test_lazy/test_distribute.py | 2 +- tests/test_shardformer/__init__.py | 0 tests/test_shardformer/test_model/__init__.py | 0 tests/test_shardformer/test_model/_utils.py | 38 +++++++++ .../test_model/test_shard_llama.py | 78 ++++--------------- .../test_model/test_shard_t5.py | 75 ++++-------------- 24 files changed, 242 insertions(+), 171 deletions(-) create mode 100644 tests/kit/model_zoo/transformers/llama.py create mode 100644 tests/test_shardformer/__init__.py create mode 100644 tests/test_shardformer/test_model/__init__.py create mode 100644 tests/test_shardformer/test_model/_utils.py diff --git a/colossalai/shardformer/layer/embedding1d.py b/colossalai/shardformer/layer/embedding1d.py index 1108d5d6a936..ace7deb3ad0c 100644 --- a/colossalai/shardformer/layer/embedding1d.py +++ b/colossalai/shardformer/layer/embedding1d.py @@ -65,13 +65,14 @@ def __init__(self, dtype: torch.dtype = None, device: torch.device = None, process_group: ProcessGroup = None, + gather_output: bool = True, weight_initializer: Callable = init.normal_(), *args, **kwargs): super().__init__() self.num_embeddings = num_embeddings - self.embed_dim = embedding_dim + self.embedding_dim = embedding_dim self.process_group = process_group self.num_partitions = dist.get_world_size(process_group) self.embed_dim_per_partition = divide(embedding_dim, self.num_partitions) @@ -79,7 +80,7 @@ def __init__(self, self.padding_idx = padding_idx self.embed_args = args self.embed_kwargs = kwargs - # self.gather_output = gather_output + self.gather_output = gather_output if device is None: device = get_current_device() @@ -95,7 +96,9 @@ def __init__(self, @staticmethod def from_native_module(module: nn.Embedding, - process_group: Union[ProcessGroup, List[ProcessGroup]] = None) -> "Embedding1D": + process_group: Union[ProcessGroup, List[ProcessGroup]] = None, + *args, + **kwargs) -> "Embedding1D": r""" Build a 1D parallelized Embedding from a native nn.Embedding module. """ @@ -123,7 +126,9 @@ def from_native_module(module: nn.Embedding, max_norm=max_norm, norm_type=norm_type, scale_grad_by_freq=scale_grad_by_freq, - sparse=sparse) + sparse=sparse, + *args, + **kwargs) # copy the weight with torch.no_grad(): @@ -133,7 +138,7 @@ def from_native_module(module: nn.Embedding, return embedding def reset_parameters(self, weight_initializer) -> None: - fan_in, fan_out = self.num_embeddings, self.embed_dim + fan_in, fan_out = self.num_embeddings, self.embedding_dim weight_initializer(self.weight, fan_in=fan_in, fan_out=fan_out) self._fill_padding_idx_with_zero() @@ -144,6 +149,9 @@ def _fill_padding_idx_with_zero(self) -> None: def forward(self, input_: Tensor) -> Tensor: output_parallel = F.embedding(input_, self.weight, self.padding_idx, *self.embed_args, **self.embed_kwargs) - output = gather_forward_split_backward(output_parallel, dim=-1, process_group=self.process_group) - return output + if self.gather_output: + output = gather_forward_split_backward(output_parallel, dim=-1, process_group=self.process_group) + return output + else: + return output_parallel diff --git a/colossalai/shardformer/policies/llama.py b/colossalai/shardformer/policies/llama.py index ae1b794fca12..a13f5f087da4 100644 --- a/colossalai/shardformer/policies/llama.py +++ b/colossalai/shardformer/policies/llama.py @@ -4,7 +4,7 @@ from transformers import LlamaForCausalLM, LlamaForSequenceClassification from transformers.models.llama.modeling_llama import LlamaDecoderLayer, LlamaModel -from colossalai.shardformer.layer.layers import Linear1D_Col, Linear1D_Row, VocabParallelEmbedding1D +from colossalai.shardformer.layer import Linear1D_Col, Linear1D_Row, VocabParallelEmbedding1D from .basepolicy import ModulePolicyDescription, Policy, SubModuleReplacementDescription diff --git a/colossalai/shardformer/policies/t5.py b/colossalai/shardformer/policies/t5.py index 9c8ee59b4178..9e0c8604969c 100644 --- a/colossalai/shardformer/policies/t5.py +++ b/colossalai/shardformer/policies/t5.py @@ -11,8 +11,7 @@ T5Stack, ) -from colossalai.shardformer.layer.dropout import Dropout1D -from colossalai.shardformer.layer.layers import Embedding1D, Linear1D_Col, Linear1D_Row +from colossalai.shardformer.layer import Dropout1D, Embedding1D, Linear1D_Col, Linear1D_Row from .basepolicy import ModulePolicyDescription, Policy, SubModuleReplacementDescription diff --git a/colossalai/shardformer/shard/sharder.py b/colossalai/shardformer/shard/sharder.py index f6ade26b758a..66934b09b3ac 100644 --- a/colossalai/shardformer/shard/sharder.py +++ b/colossalai/shardformer/shard/sharder.py @@ -185,7 +185,14 @@ def _replace_sub_module( if description.ignore_if_not_exist and native_sub_module is None: continue - replace_layer = target_module.from_native_module(native_sub_module, self.pg_manager.pg_store['tp1d'], - **kwargs) + try: + replace_layer = target_module.from_native_module(native_sub_module, self.pg_manager.pg_store['tp1d'], + **kwargs) + except Exception as e: + raise RuntimeError( + f"Failed to replace {suffix} of type {native_sub_module.__class__.__qualname__}" + f" with {target_module.__qualname__} with the exception: {e}. " + "Please check your model configuration or sharding policy, you can set up an issue for us to help you as well." + ) setattr_(org_layer, suffix, replace_layer) diff --git a/colossalai/testing/comparison.py b/colossalai/testing/comparison.py index aeecee7f11f5..5cbfb936b144 100644 --- a/colossalai/testing/comparison.py +++ b/colossalai/testing/comparison.py @@ -98,6 +98,6 @@ def assert_hf_output_close(out1: Any, raise AssertionError(f"{track_name}: shape mismatch: {out1.shape} vs {out2.shape}") assert torch.allclose( out1, out2, atol=atol, rtol=rtol - ), f"{track_name}: tensor value mismatch\nvalue 1: {out1}\nvalue 2: {out2}, mean error: {torch.abs(out1 - out2).mean()}" + ), f"{track_name}: tensor value mismatch\nvalue 1: {out1}\nvalue 2: {out2}, \nmean error: {torch.abs(out1 - out2).mean()}" else: assert out1 == out2, f"{track_name}: value mismatch.\nout1: {out1}\nout2: {out2}" diff --git a/tests/kit/model_zoo/registry.py b/tests/kit/model_zoo/registry.py index 6cc4c8ef370d..efbf3a4d37b1 100644 --- a/tests/kit/model_zoo/registry.py +++ b/tests/kit/model_zoo/registry.py @@ -28,27 +28,35 @@ def register(self, model_fn: Callable, data_gen_fn: Callable, output_transform_fn: Callable, + loss_fn: Callable = None, model_attribute: ModelAttribute = None): """ Register a model and data generation function. Examples: - >>> # Register - >>> model_zoo = ModelZooRegistry() - >>> model_zoo.register('resnet18', resnet18, resnet18_data_gen) - >>> # Run the model - >>> data = resnet18_data_gen() # do not input any argument - >>> model = resnet18() # do not input any argument - >>> out = model(**data) + + ```python + # normal forward workflow + model = resnet18() + data = resnet18_data_gen() + output = model(**data) + transformed_output = output_transform_fn(output) + loss = loss_fn(transformed_output) + + # Register + model_zoo = ModelZooRegistry() + model_zoo.register('resnet18', resnet18, resnet18_data_gen, output_transform_fn, loss_fn) + ``` Args: name (str): Name of the model. - model_fn (callable): A function that returns a model. **It must not contain any arguments.** - output_transform_fn (callable): A function that transforms the output of the model into Dict. - data_gen_fn (callable): A function that returns a data sample in the form of Dict. **It must not contain any arguments.** + model_fn (Callable): A function that returns a model. **It must not contain any arguments.** + data_gen_fn (Callable): A function that returns a data sample in the form of Dict. **It must not contain any arguments.** + output_transform_fn (Callable): A function that transforms the output of the model into Dict. + loss_fn (Callable): a function to compute the loss from the given output. Defaults to None model_attribute (ModelAttribute): Attributes of the model. Defaults to None. """ - self[name] = (model_fn, data_gen_fn, output_transform_fn, model_attribute) + self[name] = (model_fn, data_gen_fn, output_transform_fn, loss_fn, model_attribute) def get_sub_registry(self, keyword: str): """ diff --git a/tests/kit/model_zoo/transformers/__init__.py b/tests/kit/model_zoo/transformers/__init__.py index f56ff7ad84eb..ffaf4c566df9 100644 --- a/tests/kit/model_zoo/transformers/__init__.py +++ b/tests/kit/model_zoo/transformers/__init__.py @@ -1,5 +1,6 @@ from .albert import * from .bert import * from .gpt import * +from .llama import * from .opt import * from .t5 import * diff --git a/tests/kit/model_zoo/transformers/llama.py b/tests/kit/model_zoo/transformers/llama.py new file mode 100644 index 000000000000..705bbc7364ba --- /dev/null +++ b/tests/kit/model_zoo/transformers/llama.py @@ -0,0 +1,76 @@ +import torch +import transformers + +from ..registry import ModelAttribute, model_zoo + +try: + from transformers import LlamaConfig, LlamaForCausalLM, LlamaForSequenceClassification, LlamaModel + HAS_LLAMA = True +except ImportError: + HAS_LLAMA = False + +if HAS_LLAMA: + # =============================== + # Register LLaMA + # =============================== + + def data_gen(): + # the input ids are corresponding to the sentence + # 'Hello, my dog is cute' + # + # the code is give below: + # ----------------------------------- + # from transformers import LlamaTokenizerFast + # tokenizer = LlamaTokenizerFast.from_pretrained("hf-internal-testing/llama-tokenizer") + # input = 'Hello, my dog is cute' + # tokenized_input = tokenizer(input, return_tensors='pt').to('cuda') + # ----------------------------------- + + input_ids = torch.Tensor([[1, 15043, 29892, 590, 11203, 338, 274, 1082]]).long() + attention_mask = torch.Tensor([[1, 1, 1, 1, 1, 1, 1, 1]]).long() + return dict(input_ids=input_ids, attention_mask=attention_mask) + + # label is needed for casual lm + def data_gen_for_casual_lm(): + data = data_gen() + labels = data['input_ids'].clone() + data['labels'] = labels + return data + + # transform the output to a dict + output_transform_fn = lambda x: x + + # function to get the loss + loss_fn = lambda output: output.last_hidden_state.mean() + loss_fn_for_casual_lm = lambda output: output.loss + loss_fn_for_seq_classification = lambda output: output.logits.mean() + + config = LlamaConfig(num_hidden_layers=4, + hidden_size=128, + intermediate_size=256, + num_attention_heads=4, + max_position_embeddings=128, + num_labels=16) + + # register the following models + # transformers.LlamaModel, + # transformers.LlamaForCausalLM, + # transformers.LlamaForSequenceClassification, + model_zoo.register(name='transformers_llama', + model_fn=lambda: transformers.LlamaModel(config), + data_gen_fn=data_gen, + output_transform_fn=output_transform_fn, + loss_fn=loss_fn, + model_attribute=ModelAttribute(has_control_flow=True)) + model_zoo.register(name='transformers_llama_for_casual_lm', + model_fn=lambda: transformers.LlamaForCausalLM(config), + data_gen_fn=data_gen_for_casual_lm, + output_transform_fn=output_transform_fn, + loss_fn=loss_fn_for_casual_lm, + model_attribute=ModelAttribute(has_control_flow=True)) + model_zoo.register(name='transformers_llama_for_sequence_classification', + model_fn=lambda: transformers.LlamaForSequenceClassification(config), + data_gen_fn=data_gen, + output_transform_fn=output_transform_fn, + loss_fn=loss_fn_for_seq_classification, + model_attribute=ModelAttribute(has_control_flow=True)) diff --git a/tests/kit/model_zoo/transformers/t5.py b/tests/kit/model_zoo/transformers/t5.py index b81bcad90db8..689db2c40abb 100644 --- a/tests/kit/model_zoo/transformers/t5.py +++ b/tests/kit/model_zoo/transformers/t5.py @@ -6,24 +6,50 @@ # =============================== # Register single-sentence T5 # =============================== -BATCH_SIZE = 2 -SEQ_LENGTH = 16 - - -def data_gen(): - input_ids = torch.zeros((BATCH_SIZE, SEQ_LENGTH), dtype=torch.int64) - decoder_input_ids = torch.zeros((BATCH_SIZE, SEQ_LENGTH), dtype=torch.int64) - return dict(input_ids=input_ids, decoder_input_ids=decoder_input_ids) +# define data gen function def data_gen_for_encoder_only(): - input_ids = torch.zeros((BATCH_SIZE, SEQ_LENGTH), dtype=torch.int64) + # Generated from following code snippet + # + # from transformers import T5Config, T5Tokenizer + # config = T5Config(decoder_start_token_id=0) + # tokenizer = T5Tokenizer.from_pretrained("t5-small") + # input_ids = tokenizer("translate English to German: The house is wonderful.", return_tensors="pt").input_ids + input_ids = torch.Tensor([[13959, 1566, 12, 2968, 10, 37, 629, 19, 1627, 5, 1]]).long() return dict(input_ids=input_ids) +def data_gen_for_conditional_generation(): + # labels is generated with the following code + # + # labels = tokenizer("Das Haus ist wunderbar.", return_tensors="pt").input_ids + data = data_gen_for_encoder_only() + labels = torch.Tensor([[644, 4598, 229, 19250, 5, 1]]).long() + data['labels'] = labels + return data + + +def data_gen_for_t5_model(): + # decoder_inputs_ids is obtained with the following code + # + # decoder_input_ids = model._shift_right(input_ids) + data = data_gen_for_encoder_only() + decoder_input_ids = torch.Tensor([[0, 13959, 1566, 12, 2968, 10, 37, 629, 19, 1627, 5]]).long() + data['decoder_input_ids'] = decoder_input_ids + return data + + +# output transform function output_transform_fn = lambda x: x -config = transformers.T5Config(d_model=128, num_layers=2) +# define loss funciton +loss_fn_for_t5_model = lambda x: x.last_hidden_state.mean() +loss_fn_for_encoder_only = lambda x: x.last_hidden_state.mean() +loss_fn_for_conditional_generation = lambda x: x.loss + +# define model config +config = transformers.T5Config(d_model=128, num_layers=2, dropout_rate=0, decoder_start_token_id=0) # register the following models # transformers.T5Model, @@ -31,16 +57,19 @@ def data_gen_for_encoder_only(): # transformers.T5EncoderModel, model_zoo.register(name='transformers_t5', model_fn=lambda: transformers.T5Model(config), - data_gen_fn=data_gen, + data_gen_fn=data_gen_for_t5_model, output_transform_fn=output_transform_fn, + loss_fn=loss_fn_for_t5_model, model_attribute=ModelAttribute(has_control_flow=True)) model_zoo.register(name='transformers_t5_for_conditional_generation', model_fn=lambda: transformers.T5ForConditionalGeneration(config), - data_gen_fn=data_gen, + data_gen_fn=data_gen_for_conditional_generation, output_transform_fn=output_transform_fn, + loss_fn=loss_fn_for_conditional_generation, model_attribute=ModelAttribute(has_control_flow=True)) model_zoo.register(name='transformers_t5_encoder_model', model_fn=lambda: transformers.T5EncoderModel(config), data_gen_fn=data_gen_for_encoder_only, output_transform_fn=output_transform_fn, + loss_fn=loss_fn_for_encoder_only, model_attribute=ModelAttribute(has_control_flow=True)) diff --git a/tests/test_booster/test_mixed_precision/test_fp16_torch.py b/tests/test_booster/test_mixed_precision/test_fp16_torch.py index 963387da262b..26ce00e94869 100644 --- a/tests/test_booster/test_mixed_precision/test_fp16_torch.py +++ b/tests/test_booster/test_mixed_precision/test_fp16_torch.py @@ -11,7 +11,7 @@ def run_torch_amp(rank, world_size, port): # init dist env colossalai.launch(config=dict(), rank=rank, world_size=world_size, port=port, host='localhost') sub_model_zoo = model_zoo.get_sub_registry('timm') - for name, (model_fn, data_gen_fn, output_transform_fn, _) in sub_model_zoo.items(): + for name, (model_fn, data_gen_fn, output_transform_fn, _, _) in sub_model_zoo.items(): # dlrm_interactionarch has not parameters, so skip if name == 'dlrm_interactionarch': continue diff --git a/tests/test_booster/test_plugin/test_gemini_plugin.py b/tests/test_booster/test_plugin/test_gemini_plugin.py index d606d6d89bd4..d29c92926066 100644 --- a/tests/test_booster/test_plugin/test_gemini_plugin.py +++ b/tests/test_booster/test_plugin/test_gemini_plugin.py @@ -71,7 +71,7 @@ def check_gemini_plugin(init_method: str = 'none', early_stop: bool = True): passed_models = [] failed_info = {} # (model_name, error) pair - for name, (model_fn, data_gen_fn, output_transform_fn, _) in model_zoo.items(): + for name, (model_fn, data_gen_fn, output_transform_fn, _, _) in model_zoo.items(): # These models lead to CUDA error if name in ('diffusers_auto_encoder_kl', 'diffusers_vq_model', 'diffusers_unet2d_model', 'timm_resmlp', 'timm_gmixer_12_224', 'timm_gmlp_b16_224', 'timm_mixer_b16_224', 'timm_convnext'): diff --git a/tests/test_booster/test_plugin/test_low_level_zero_plugin.py b/tests/test_booster/test_plugin/test_low_level_zero_plugin.py index f70f27be2aa7..eedd8c59a3a8 100644 --- a/tests/test_booster/test_plugin/test_low_level_zero_plugin.py +++ b/tests/test_booster/test_plugin/test_low_level_zero_plugin.py @@ -61,7 +61,7 @@ def check_low_level_zero_plugin(stage: int, early_stop: bool = True): ignore_models = _AMP_ERR_MODELS + _LOW_LEVEL_ZERO_ERR_MODELS + _STUCK_MODELS skipped_models = [] - for name, (model_fn, data_gen_fn, output_transform_fn, _) in model_zoo.items(): + for name, (model_fn, data_gen_fn, output_transform_fn, _, _) in model_zoo.items(): # FIXME(ver217): fix these models if name in ignore_models: skipped_models.append(name) diff --git a/tests/test_booster/test_plugin/test_torch_ddp_plugin.py b/tests/test_booster/test_plugin/test_torch_ddp_plugin.py index fbe44e5ce6fb..1484273973ae 100644 --- a/tests/test_booster/test_plugin/test_torch_ddp_plugin.py +++ b/tests/test_booster/test_plugin/test_torch_ddp_plugin.py @@ -40,7 +40,7 @@ def run_fn(model_fn, data_gen_fn, output_transform_fn): def check_torch_ddp_plugin(): - for name, (model_fn, data_gen_fn, output_transform_fn, _) in model_zoo.items(): + for name, (model_fn, data_gen_fn, output_transform_fn, _, _) in model_zoo.items(): if name == 'dlrm_interactionarch': continue run_fn(model_fn, data_gen_fn, output_transform_fn) diff --git a/tests/test_booster/test_plugin/test_torch_fsdp_plugin.py b/tests/test_booster/test_plugin/test_torch_fsdp_plugin.py index 44767f051fdd..cbd5d57800db 100644 --- a/tests/test_booster/test_plugin/test_torch_fsdp_plugin.py +++ b/tests/test_booster/test_plugin/test_torch_fsdp_plugin.py @@ -42,7 +42,7 @@ def run_fn(model_fn, data_gen_fn, output_transform_fn): def check_torch_fsdp_plugin(): - for name, (model_fn, data_gen_fn, output_transform_fn, _) in model_zoo.items(): + for name, (model_fn, data_gen_fn, output_transform_fn, _, _) in model_zoo.items(): if any(element in name for element in [ 'diffusers', 'deepfm_sparsearch', 'dlrm_interactionarch', 'torchvision_googlenet', 'torchvision_inception_v3' diff --git a/tests/test_fx/test_tracer/test_hf_model/test_hf_diffuser.py b/tests/test_fx/test_tracer/test_hf_model/test_hf_diffuser.py index 0cbea82e083a..ccbe2da58bf2 100644 --- a/tests/test_fx/test_tracer/test_hf_model/test_hf_diffuser.py +++ b/tests/test_fx/test_tracer/test_hf_model/test_hf_diffuser.py @@ -47,7 +47,7 @@ def test_diffusers(): sub_model_zoo = model_zoo.get_sub_registry('diffusers') - for name, (model_fn, data_gen_fn, output_transform_fn, attribute) in sub_model_zoo.items(): + for name, (model_fn, data_gen_fn, output_transform_fn, _, attribute) in sub_model_zoo.items(): data = data_gen_fn() trace_and_compare(model_fn, data, output_transform_fn) torch.cuda.synchronize() @@ -60,7 +60,7 @@ def test_torch_diffusers(): sub_model_zoo = model_zoo.get_sub_registry('diffusers') - for name, (model_fn, data_gen_fn, output_transform_fn, attribute) in sub_model_zoo.items(): + for name, (model_fn, data_gen_fn, output_transform_fn, _, attribute) in sub_model_zoo.items(): data = data_gen_fn() model = model_fn() output = model(**data) diff --git a/tests/test_fx/test_tracer/test_timm_model/test_timm_model.py b/tests/test_fx/test_tracer/test_timm_model/test_timm_model.py index 11302e8f36b0..117c70c84aa8 100644 --- a/tests/test_fx/test_tracer/test_timm_model/test_timm_model.py +++ b/tests/test_fx/test_tracer/test_timm_model/test_timm_model.py @@ -56,7 +56,7 @@ def test_timm_models(): sub_model_zoo = model_zoo.get_sub_registry('timm') - for name, (model_fn, data_gen_fn, output_transform_fn, attribute) in sub_model_zoo.items(): + for name, (model_fn, data_gen_fn, output_transform_fn, _, attribute) in sub_model_zoo.items(): data = data_gen_fn() if attribute is not None and attribute.has_control_flow: meta_args = {k: v.to('meta') for k, v in data.items()} diff --git a/tests/test_fx/test_tracer/test_torchaudio_model/test_torchaudio_model.py b/tests/test_fx/test_tracer/test_torchaudio_model/test_torchaudio_model.py index eafcaca10b1d..f73c5bb9a590 100644 --- a/tests/test_fx/test_tracer/test_torchaudio_model/test_torchaudio_model.py +++ b/tests/test_fx/test_tracer/test_torchaudio_model/test_torchaudio_model.py @@ -16,7 +16,7 @@ def test_torchaudio_models(): sub_model_zoo = model_zoo.get_sub_registry('torchaudio') - for name, (model_fn, data_gen_fn, output_transform_fn, attribute) in sub_model_zoo.items(): + for name, (model_fn, data_gen_fn, output_transform_fn, _, attribute) in sub_model_zoo.items(): model = model_fn() trace_and_compare(model, data_gen_fn, diff --git a/tests/test_lazy/lazy_init_utils.py b/tests/test_lazy/lazy_init_utils.py index 85bfd0e27801..2dd8d1ca3216 100644 --- a/tests/test_lazy/lazy_init_utils.py +++ b/tests/test_lazy/lazy_init_utils.py @@ -60,7 +60,7 @@ def assert_forward_equal(m1: torch.nn.Module, m2: torch.nn.Module, data_gen_fn: def check_lazy_init(entry: TestingEntry, seed: int = 42, verbose: bool = False, check_forward: bool = False) -> None: - model_fn, data_gen_fn, output_transform_fn, model_attr = entry + model_fn, data_gen_fn, output_transform_fn, _, model_attr = entry _MyTensor._pre_op_fn = lambda *args: set_seed(seed) LazyTensor._pre_op_fn = lambda *args: set_seed(seed) ctx = LazyInitContext(tensor_cls=_MyTensor) diff --git a/tests/test_lazy/test_distribute.py b/tests/test_lazy/test_distribute.py index d515b175a9ea..f33c037e3de6 100644 --- a/tests/test_lazy/test_distribute.py +++ b/tests/test_lazy/test_distribute.py @@ -78,7 +78,7 @@ def run_dist_lazy_init(subset, seed: int = 42): if name in ('torchaudio_wav2vec2_base', 'torchaudio_hubert_base'): continue print_rank_0(name) - model_fn, data_gen_fn, output_transform_fn, model_attr = entry + model_fn, data_gen_fn, output_transform_fn, _, model_attr = entry ctx = LazyInitContext(tensor_cls=_MyTensor) with ctx: model = model_fn() diff --git a/tests/test_shardformer/__init__.py b/tests/test_shardformer/__init__.py new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/tests/test_shardformer/test_model/__init__.py b/tests/test_shardformer/test_model/__init__.py new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/tests/test_shardformer/test_model/_utils.py b/tests/test_shardformer/test_model/_utils.py new file mode 100644 index 000000000000..52ca7fce895b --- /dev/null +++ b/tests/test_shardformer/test_model/_utils.py @@ -0,0 +1,38 @@ +import copy + +from colossalai.shardformer import ShardConfig, ShardFormer + + +def build_model(world_size, model_fn): + # create new model + org_model = model_fn().cuda() + + # shard model + shard_config = ShardConfig(tensor_parallel_size=world_size) + model_copy = copy.deepcopy(org_model) + shard_former = ShardFormer(shard_config=shard_config) + shard_former.init_distributed() + sharded_model = shard_former.shard_model(model_copy) + + return org_model, sharded_model + + +def run_forward(original_model, sharded_model, data_gen_fn, output_transform_fn, loss_fn): + # prepare input + data = data_gen_fn() + data = {k: v.cuda() for k, v in data.items()} + + # switch to train mode + original_model.train() + sharded_model.train() + + # run forward + org_output = original_model(**data) + org_output = output_transform_fn(org_output) + org_loss = loss_fn(org_output) + + shard_output = sharded_model(**data) + shard_output = output_transform_fn(shard_output) + shard_loss = loss_fn(shard_output) + + return org_output, org_loss, shard_output, shard_loss diff --git a/tests/test_shardformer/test_model/test_shard_llama.py b/tests/test_shardformer/test_model/test_shard_llama.py index b15f81aba52e..8b672af500bd 100644 --- a/tests/test_shardformer/test_model/test_shard_llama.py +++ b/tests/test_shardformer/test_model/test_shard_llama.py @@ -1,64 +1,22 @@ -import copy import os -import random import pytest import torch -from transformers import LlamaConfig, LlamaForCausalLM, LlamaForSequenceClassification, LlamaModel, LlamaTokenizerFast import colossalai from colossalai.logging import disable_existing_loggers -from colossalai.shardformer import ShardConfig, ShardFormer from colossalai.testing import assert_hf_output_close, clear_cache_before_run, rerun_if_address_is_in_use, spawn +from tests.kit.model_zoo import model_zoo +from tests.test_shardformer.test_model._utils import build_model, run_forward os.environ['TRANSFORMERS_NO_ADVISORY_WARNINGS'] = 'true' -tokenizer = LlamaTokenizerFast.from_pretrained("hf-internal-testing/llama-tokenizer") - - -def build_model(world_size, model_fn): - # create new model - config = LlamaConfig(num_hidden_layers=4, - hidden_size=128, - intermediate_size=256, - num_attention_heads=4, - max_position_embeddings=128) - org_model = model_fn(config).cuda() - - # shard model - shard_config = ShardConfig(tensor_parallel_size=world_size) - model_copy = copy.deepcopy(org_model) - shard_former = ShardFormer(shard_config=shard_config) - shard_former.init_distributed() - sharded_model = shard_former.shard_model(model_copy) - - return org_model, sharded_model - - -def check_forward_backward(org_model, sharded_model): - # prepare input - input = 'Hello, my dog is cute' - tokenized_input = tokenizer(input, return_tensors='pt').to('cuda') - del tokenized_input["token_type_ids"] - del tokenized_input["attention_mask"] - - # switch to train mode - org_model.train() - sharded_model.train() - - if isinstance(org_model, (LlamaModel, LlamaForSequenceClassification)): - org_output = org_model(**tokenized_input) - org_loss = org_output.last_hidden_state.mean() - shard_output = sharded_model(**tokenized_input) - shard_loss = shard_output.last_hidden_state.mean() - elif isinstance(org_model, LlamaForCausalLM): - labels = tokenized_input['input_ids'].clone() - labels[labels == tokenizer.pad_token_id] = -100 - tokenized_input['labels'] = labels - org_output = org_model(**tokenized_input) - org_loss = org_output.loss - shard_output = sharded_model(**tokenized_input) - shard_loss = shard_output.loss + +def check_forward_backward(org_model, sharded_model, data_gen_fn, output_transform_fn, loss_fn): + org_output, org_loss, shard_output, shard_loss = run_forward(org_model, sharded_model, data_gen_fn, + output_transform_fn, loss_fn) + + # forward check assert_hf_output_close(org_output, shard_output, ignore_keys=['past_key_values'], rtol=1e-4) # run backward @@ -66,12 +24,12 @@ def check_forward_backward(org_model, sharded_model): shard_loss.backward() # check grad - if isinstance(org_model, LlamaModel): - llama_model = org_model - shard_llama_model = sharded_model - else: + if hasattr(org_model, 'model'): llama_model = org_model.model shard_llama_model = sharded_model.model + else: + llama_model = org_model + shard_llama_model = sharded_model org_grad = llama_model.layers[0].self_attn.q_proj.weight.grad shard_grad = shard_llama_model.layers[0].self_attn.q_proj.weight.grad @@ -89,17 +47,11 @@ def check_llama(rank, world_size, port): disable_existing_loggers() colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') - model_list = [ - LlamaModel, - # LlamaForCausalLM, - - # TODO: do not work yet - # LlamaForSequenceClassification - ] + sub_model_zoo = model_zoo.get_sub_registry('transformers_llama') - for model_fn in model_list: + for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items(): org_model, sharded_model = build_model(world_size, model_fn) - check_forward_backward(org_model, sharded_model) + check_forward_backward(org_model, sharded_model, data_gen_fn, output_transform_fn, loss_fn) torch.cuda.empty_cache() diff --git a/tests/test_shardformer/test_model/test_shard_t5.py b/tests/test_shardformer/test_model/test_shard_t5.py index 254649409c59..2698d7675c8e 100644 --- a/tests/test_shardformer/test_model/test_shard_t5.py +++ b/tests/test_shardformer/test_model/test_shard_t5.py @@ -1,64 +1,20 @@ -import copy import os import pytest import torch -from transformers import T5Config, T5EncoderModel, T5ForConditionalGeneration, T5Model, T5Tokenizer, T5TokenizerFast import colossalai from colossalai.logging import disable_existing_loggers -from colossalai.shardformer.shard import ShardConfig, ShardFormer from colossalai.testing import assert_hf_output_close, clear_cache_before_run, rerun_if_address_is_in_use, spawn +from tests.kit.model_zoo import model_zoo +from tests.test_shardformer.test_model._utils import build_model, run_forward -os.environ['TRANSFORMERS_NO_ADVISORY_WARNINGS'] = 'true' -CONFIG = dict(parallel=dict(data=1, pipeline=1, tensor=dict(size=2, mode='1d')),) -tokenizer = T5Tokenizer.from_pretrained("t5-small") - - -def build_model(world_size, model_fn): - config = T5Config(decoder_start_token_id=0) - config.dropout_rate = 0 - org_model = model_fn(config=config).to('cuda') - shard_config = ShardConfig(tensor_parallel_size=world_size) - - # shard model - shard_config = ShardConfig(tensor_parallel_size=world_size) - model_copy = copy.deepcopy(org_model) - shard_former = ShardFormer(shard_config=shard_config) - shard_former.init_distributed() - sharded_model = shard_former.shard_model(model_copy) - - return org_model, sharded_model - - -def check_forward_backward(org_model, sharded_model): - # prepare input - input_ids = tokenizer("translate English to German: The house is wonderful.", - return_tensors="pt").input_ids.to('cuda') - labels = tokenizer("Das Haus ist wunderbar.", return_tensors="pt").input_ids.to('cuda') - - # switch to train mode - org_model.train() - sharded_model.train() - - if isinstance(org_model, T5ForConditionalGeneration): - org_output = org_model(input_ids=input_ids, labels=labels) - org_loss = org_output.loss - shard_output = sharded_model(input_ids=input_ids, labels=labels) - shard_loss = shard_output.loss - elif isinstance(org_model, T5Model): - decoder_input_ids = org_model._shift_right(input_ids) - org_output = org_model(input_ids=input_ids, decoder_input_ids=decoder_input_ids) - org_loss = org_output.last_hidden_state.mean() - shard_output = sharded_model(input_ids=input_ids, decoder_input_ids=decoder_input_ids) - shard_loss = shard_output.last_hidden_state.mean() - elif isinstance(org_model, T5EncoderModel): - org_output = org_model(input_ids=input_ids) - org_loss = org_output.last_hidden_state.mean() - shard_output = sharded_model(input_ids=input_ids) - shard_loss = shard_output.last_hidden_state.mean() - - # key is sharded, so we ignore + +def check_forward_backward(org_model, sharded_model, data_gen_fn, output_transform_fn, loss_fn): + # check forward + # the value "past_key_values" is sharded, so we ignore + org_output, org_loss, shard_output, shard_loss = run_forward(org_model, sharded_model, data_gen_fn, + output_transform_fn, loss_fn) assert_hf_output_close(org_output, shard_output, ignore_keys=['past_key_values']) # do backward @@ -81,18 +37,15 @@ def check_forward_backward(org_model, sharded_model): def check_t5(rank, world_size, port): disable_existing_loggers() - colossalai.launch(config=CONFIG, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') + colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') - model_fn_list = [ - T5Model, - T5ForConditionalGeneration, - T5EncoderModel, - ] + sub_model_zoo = model_zoo.get_sub_registry('transformers_t5') - for model_fn in model_fn_list: + for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items(): org_model, sharded_model = build_model(world_size, model_fn) - check_forward_backward(org_model, sharded_model) - torch.cuda.empty_cache() + check_forward_backward(org_model, sharded_model, data_gen_fn, output_transform_fn, loss_fn) + + torch.cuda.empty_cache() @pytest.mark.dist From e5d4a87be79c8c8d495d48fe8c90cd831256934f Mon Sep 17 00:00:00 2001 From: Frank Lee Date: Wed, 21 Jun 2023 14:30:06 +0800 Subject: [PATCH 30/49] [shardformer] refactored the shardformer layer structure (#4053) --- .../shardformer/{utils/utils.py => _utils.py} | 3 + colossalai/shardformer/layer/__init__.py | 19 +- colossalai/shardformer/layer/_operation.py | 2 - colossalai/shardformer/layer/dropout.py | 4 +- ...cabparallelembedding1d.py => embedding.py} | 165 +++++++++++++++--- colossalai/shardformer/layer/embedding1d.py | 157 ----------------- colossalai/shardformer/layer/layernorm1d.py | 73 -------- .../layer/{linear1d.py => linear.py} | 22 +-- .../layer/{linearconv1d.py => linear_conv.py} | 57 +++--- .../layer/{dist_crossentropy.py => loss.py} | 4 +- .../{parallelmodule.py => parallel_module.py} | 10 +- colossalai/shardformer/policies/basepolicy.py | 2 +- colossalai/shardformer/policies/bert.py | 2 +- colossalai/shardformer/policies/gpt2.py | 5 - colossalai/shardformer/policies/t5.py | 2 - colossalai/shardformer/shard/sharder.py | 2 +- colossalai/shardformer/utils/__init__.py | 1 - .../test_dist_crossentropy.py} | 4 +- .../test_layer/test_dropout.py | 2 +- .../test_layer/test_embedding.py | 2 +- .../test_layer/test_linear_1d.py | 2 +- .../test_vocab_parallel_embedding_1d.py | 2 +- .../test_module/test_dropout.py | 51 ------ .../test_module/test_slicer.py | 78 --------- 24 files changed, 198 insertions(+), 473 deletions(-) rename colossalai/shardformer/{utils/utils.py => _utils.py} (97%) rename colossalai/shardformer/layer/{vocabparallelembedding1d.py => embedding.py} (52%) delete mode 100644 colossalai/shardformer/layer/embedding1d.py delete mode 100644 colossalai/shardformer/layer/layernorm1d.py rename colossalai/shardformer/layer/{linear1d.py => linear.py} (96%) rename colossalai/shardformer/layer/{linearconv1d.py => linear_conv.py} (92%) rename colossalai/shardformer/layer/{dist_crossentropy.py => loss.py} (98%) rename colossalai/shardformer/layer/{parallelmodule.py => parallel_module.py} (78%) delete mode 100644 colossalai/shardformer/utils/__init__.py rename tests/test_shardformer/{test_module/test_distcrossentropy.py => test_layer/test_dist_crossentropy.py} (87%) delete mode 100644 tests/test_shardformer/test_module/test_dropout.py delete mode 100644 tests/test_shardformer/test_module/test_slicer.py diff --git a/colossalai/shardformer/utils/utils.py b/colossalai/shardformer/_utils.py similarity index 97% rename from colossalai/shardformer/utils/utils.py rename to colossalai/shardformer/_utils.py index 05a6a3ae6c30..a1c7203a929f 100644 --- a/colossalai/shardformer/utils/utils.py +++ b/colossalai/shardformer/_utils.py @@ -2,6 +2,9 @@ def get_obj_list_element(obj, a): + r""" + Get the element of the list in the object + """ re_pattern = r'\[\d+\]' prog = re.compile(re_pattern) result = prog.search(a) diff --git a/colossalai/shardformer/layer/__init__.py b/colossalai/shardformer/layer/__init__.py index 66d86913bb2b..808ebbc12aeb 100644 --- a/colossalai/shardformer/layer/__init__.py +++ b/colossalai/shardformer/layer/__init__.py @@ -1,17 +1,10 @@ from .dropout import Dropout1D -from .embedding1d import Embedding1D -from .layernorm1d import LayerNorm1D -from .linear1d import Linear1D_Col, Linear1D_Row -from .linearconv1d import LinearConv1D_Col, LinearConv1D_Row -from .vocabparallelembedding1d import VocabParallelEmbedding1D +from .embedding import Embedding1D, VocabParallelEmbedding1D +from .linear import Linear1D_Col, Linear1D_Row +from .linear_conv import LinearConv1D_Col, LinearConv1D_Row +from .loss import cross_entropy_1d __all__ = [ - "Embedding1D", - "VocabParallelEmbedding1D", - "Linear1D_Col", - "Linear1D_Row", - "LinearConv1D_Col", - "LinearConv1D_Row", - "LayerNorm1D", - "Dropout1D", + "Embedding1D", "VocabParallelEmbedding1D", "Linear1D_Col", "Linear1D_Row", "LinearConv1D_Col", "LinearConv1D_Row", + "Dropout1D", "cross_entropy_1d" ] diff --git a/colossalai/shardformer/layer/_operation.py b/colossalai/shardformer/layer/_operation.py index 208a391c33e2..280d5526342b 100644 --- a/colossalai/shardformer/layer/_operation.py +++ b/colossalai/shardformer/layer/_operation.py @@ -1,8 +1,6 @@ import torch import torch.distributed as dist -from colossalai.core import global_context as gpc - try: import fused_mix_prec_layer_norm_cuda except: diff --git a/colossalai/shardformer/layer/dropout.py b/colossalai/shardformer/layer/dropout.py index 08dfb8afd7fb..2c49b49faad6 100644 --- a/colossalai/shardformer/layer/dropout.py +++ b/colossalai/shardformer/layer/dropout.py @@ -4,9 +4,11 @@ import torch.nn as nn from torch.distributed import ProcessGroup -from .parallelmodule import ParallelModule +from .parallel_module import ParallelModule from .utils import create_randomizer_with_offset +__all__ = ['Dropout1D'] + class Dropout1D(ParallelModule, nn.Dropout): """ diff --git a/colossalai/shardformer/layer/vocabparallelembedding1d.py b/colossalai/shardformer/layer/embedding.py similarity index 52% rename from colossalai/shardformer/layer/vocabparallelembedding1d.py rename to colossalai/shardformer/layer/embedding.py index 4c325c68421b..8b9fb03ec798 100644 --- a/colossalai/shardformer/layer/vocabparallelembedding1d.py +++ b/colossalai/shardformer/layer/embedding.py @@ -1,7 +1,6 @@ #!/usr/bin/env python # -*- encoding: utf-8 -*- -from collections import OrderedDict from typing import Callable, List, Union import torch @@ -12,26 +11,148 @@ from torch.distributed import ProcessGroup from torch.nn.parameter import Parameter -from colossalai.context import ParallelMode, seed from colossalai.nn import init as init -from colossalai.nn.layer.base_layer import ParallelLayer from colossalai.nn.layer.utils import divide -from colossalai.tensor.d_tensor.api import shard_rowwise -from colossalai.utils.checkpointing import gather_tensor_parallel_state_dict +from colossalai.tensor.d_tensor.api import shard_colwise, shard_rowwise +from colossalai.utils.cuda import get_current_device -from ._operation import reduce_input -from .parallelmodule import ParallelModule +from ._operation import gather_forward_split_backward, reduce_input +from .parallel_module import ParallelModule from .utils import create_randomizer_with_offset -Fast_LN = None -try: - from apex.contrib.layer_norm.layer_norm import FastLayerNorm - Fast_LN = FastLayerNorm -except ImportError: - pass +__all__ = ['Embedding1D', 'VocabParallelEmbedding1D'] -class VocabParallelEmbedding1D(ParallelLayer): +class Embedding1D(ParallelModule): + r"""Embedding for 1D parallelism. + + Args: + num_embeddings (int): number of embeddings. + embedding_dim (int): dimension of embedding. + padding_idx (int, optional): If specified, the entries at padding_idx do not contribute to the gradient; + therefore, the embedding vector at padding_idx is not updated during training, + i.e. it remains as a fixed “pad”, defaults to None. + dtype (:class:`torch.dtype`, optional): The dtype of parameters, defaults to None. + weight_initializer (:class:`typing.Callable`, optional): + he initializer of weight, defaults to normal initializer. + + The ``args`` and ``kwargs`` used in :class:`torch.nn.functional.embedding` should contain: + :: + + max_norm (float, optional): If given, each embedding vector with norm larger than max_norm is + renormalized to have norm max_norm. Note: this will modify weight in-place. + norm_type (float, optional): The p of the p-norm to compute for the max_norm option. Default 2. + scale_grad_by_freq (bool, optional): If given, this will scale gradients by the inverse + of frequency of the words in the mini-batch. Default False. + sparse (bool, optional): If True, gradient w.r.t. weight will be a sparse tensor. Default False. + + More details about ``args`` and ``kwargs`` could be found in + `Embedding `_. + + More details about ``initializer`` please refer to + `init `_ + """ + + def __init__(self, + num_embeddings: int, + embedding_dim: int, + padding_idx: int = None, + dtype: torch.dtype = None, + device: torch.device = None, + process_group: ProcessGroup = None, + gather_output: bool = True, + weight_initializer: Callable = init.normal_(), + *args, + **kwargs): + super().__init__() + + self.num_embeddings = num_embeddings + self.embedding_dim = embedding_dim + self.process_group = process_group + self.num_partitions = dist.get_world_size(process_group) + self.embed_dim_per_partition = divide(embedding_dim, self.num_partitions) + + self.padding_idx = padding_idx + self.embed_args = args + self.embed_kwargs = kwargs + self.gather_output = gather_output + + if device is None: + device = get_current_device() + + self.weight = Parameter(torch.empty((num_embeddings, self.embed_dim_per_partition), device=device, dtype=dtype)) + + # offset the seed with randomizer index and rank + seed = torch.random.initial_seed() + self.randomizer = create_randomizer_with_offset(seed, process_group=self.process_group) + + with self.randomizer.fork_rng(enable_cpu=True): + self.reset_parameters(weight_initializer) + + @staticmethod + def from_native_module(module: nn.Embedding, + process_group: Union[ProcessGroup, List[ProcessGroup]] = None, + *args, + **kwargs) -> "Embedding1D": + r""" + Build a 1D parallelized Embedding from a native nn.Embedding module. + """ + # get the attributes + num_embedding = module.num_embeddings + embedding_dim = module.embedding_dim + padding_idx = module.padding_idx + max_norm = module.max_norm + norm_type = module.norm_type + scale_grad_by_freq = module.scale_grad_by_freq + sparse = module.sparse + dtype = module.weight.dtype + device = module.weight.device + + # sparse is not support yet + if sparse: + raise NotImplementedError("The Embedding1D module does not support sparse embedding yet.") + + embedding = Embedding1D(num_embeddings=num_embedding, + embedding_dim=embedding_dim, + padding_idx=padding_idx, + process_group=process_group, + dtype=dtype, + device=device, + max_norm=max_norm, + norm_type=norm_type, + scale_grad_by_freq=scale_grad_by_freq, + sparse=sparse, + *args, + **kwargs) + + # copy the weight + with torch.no_grad(): + sharded_weight = shard_colwise(module.weight.data, process_group) + embedding.weight.copy_(sharded_weight) + + return embedding + + def reset_parameters(self, weight_initializer) -> None: + fan_in, fan_out = self.num_embeddings, self.embedding_dim + weight_initializer(self.weight, fan_in=fan_in, fan_out=fan_out) + self._fill_padding_idx_with_zero() + + def _fill_padding_idx_with_zero(self) -> None: + if self.padding_idx is not None: + with torch.no_grad(): + self.weight[self.padding_idx].fill_(0) + + def forward(self, input_: Tensor) -> Tensor: + output_parallel = F.embedding(input_, self.weight, self.padding_idx, *self.embed_args, **self.embed_kwargs) + + if self.gather_output: + output = gather_forward_split_backward(output_parallel, dim=-1, process_group=self.process_group) + return output + else: + return output_parallel + + +class VocabParallelEmbedding1D(ParallelModule): r"""Embedding parallelized in the vocabulary dimension. Args: @@ -93,9 +214,7 @@ def __init__(self, # offset the seed with randomizer index and rank seed = torch.random.initial_seed() self.randomizer = create_randomizer_with_offset(seed, process_group=self.process_group) - - with self.randomizer.fork_rng(enable_cpu=True): - self.reset_parameters(weight_initializer) + self.reset_parameters(weight_initializer) @staticmethod def from_native_module(module: nn.Embedding, process_group: Union[ProcessGroup, List[ProcessGroup]], *args, @@ -132,7 +251,7 @@ def from_native_module(module: nn.Embedding, process_group: Union[ProcessGroup, return vocab_embedding_1d def reset_parameters(self, weight_initializer) -> None: - with seed(ParallelMode.TENSOR): + with self.randomizer.fork_rng(enable_cpu=True): fan_in, fan_out = self.num_embeddings, self.embed_dim weight_initializer(self.weight, fan_in=fan_in, fan_out=fan_out) self._fill_padding_idx_with_zero() @@ -143,16 +262,6 @@ def _fill_padding_idx_with_zero(self) -> None: with torch.no_grad(): self.weight[self.padding_idx - self.vocab_start_index].fill_(0) - def _save_to_global_state_dict(self, destination, prefix, keep_vars): - weight_key = prefix + 'weight' - local_state = OrderedDict({weight_key: self.weight}) - local_state = gather_tensor_parallel_state_dict(local_state, - ParallelMode.PARALLEL_1D, - dims={weight_key: 0}, - partition_states={weight_key: True}, - keep_vars=keep_vars) - destination.update(local_state) - def forward(self, input_: Tensor) -> Tensor: # Build the mask. input_mask = (input_ < self.vocab_start_index) | (input_ >= self.vocab_end_index) diff --git a/colossalai/shardformer/layer/embedding1d.py b/colossalai/shardformer/layer/embedding1d.py deleted file mode 100644 index ace7deb3ad0c..000000000000 --- a/colossalai/shardformer/layer/embedding1d.py +++ /dev/null @@ -1,157 +0,0 @@ -#!/usr/bin/env python -# -*- encoding: utf-8 -*- - -from typing import Callable, List, Union - -import torch -import torch.distributed as dist -import torch.nn as nn -import torch.nn.functional as F -from torch import Tensor -from torch.distributed import ProcessGroup -from torch.nn.parameter import Parameter - -from colossalai.nn import init as init -from colossalai.nn.layer.utils import divide -from colossalai.tensor.d_tensor.api import shard_colwise -from colossalai.utils.cuda import get_current_device - -from ._operation import gather_forward_split_backward -from .parallelmodule import ParallelModule -from .utils import create_randomizer_with_offset - -Fast_LN = None -try: - from apex.contrib.layer_norm.layer_norm import FastLayerNorm - Fast_LN = FastLayerNorm -except ImportError: - pass - - -class Embedding1D(ParallelModule): - r"""Embedding for 1D parallelism. - - Args: - num_embeddings (int): number of embeddings. - embedding_dim (int): dimension of embedding. - padding_idx (int, optional): If specified, the entries at padding_idx do not contribute to the gradient; - therefore, the embedding vector at padding_idx is not updated during training, - i.e. it remains as a fixed “pad”, defaults to None. - dtype (:class:`torch.dtype`, optional): The dtype of parameters, defaults to None. - weight_initializer (:class:`typing.Callable`, optional): - he initializer of weight, defaults to normal initializer. - - The ``args`` and ``kwargs`` used in :class:`torch.nn.functional.embedding` should contain: - :: - - max_norm (float, optional): If given, each embedding vector with norm larger than max_norm is - renormalized to have norm max_norm. Note: this will modify weight in-place. - norm_type (float, optional): The p of the p-norm to compute for the max_norm option. Default 2. - scale_grad_by_freq (bool, optional): If given, this will scale gradients by the inverse - of frequency of the words in the mini-batch. Default False. - sparse (bool, optional): If True, gradient w.r.t. weight will be a sparse tensor. Default False. - - More details about ``args`` and ``kwargs`` could be found in - `Embedding `_. - - More details about ``initializer`` please refer to - `init `_ - """ - - def __init__(self, - num_embeddings: int, - embedding_dim: int, - padding_idx: int = None, - dtype: torch.dtype = None, - device: torch.device = None, - process_group: ProcessGroup = None, - gather_output: bool = True, - weight_initializer: Callable = init.normal_(), - *args, - **kwargs): - super().__init__() - - self.num_embeddings = num_embeddings - self.embedding_dim = embedding_dim - self.process_group = process_group - self.num_partitions = dist.get_world_size(process_group) - self.embed_dim_per_partition = divide(embedding_dim, self.num_partitions) - - self.padding_idx = padding_idx - self.embed_args = args - self.embed_kwargs = kwargs - self.gather_output = gather_output - - if device is None: - device = get_current_device() - - self.weight = Parameter(torch.empty((num_embeddings, self.embed_dim_per_partition), device=device, dtype=dtype)) - - # offset the seed with randomizer index and rank - seed = torch.random.initial_seed() - self.randomizer = create_randomizer_with_offset(seed, process_group=self.process_group) - - with self.randomizer.fork_rng(enable_cpu=True): - self.reset_parameters(weight_initializer) - - @staticmethod - def from_native_module(module: nn.Embedding, - process_group: Union[ProcessGroup, List[ProcessGroup]] = None, - *args, - **kwargs) -> "Embedding1D": - r""" - Build a 1D parallelized Embedding from a native nn.Embedding module. - """ - # get the attributes - num_embedding = module.num_embeddings - embedding_dim = module.embedding_dim - padding_idx = module.padding_idx - max_norm = module.max_norm - norm_type = module.norm_type - scale_grad_by_freq = module.scale_grad_by_freq - sparse = module.sparse - dtype = module.weight.dtype - device = module.weight.device - - # sparse is not support yet - if sparse: - raise NotImplementedError("The Embedding1D module does not support sparse embedding yet.") - - embedding = Embedding1D(num_embeddings=num_embedding, - embedding_dim=embedding_dim, - padding_idx=padding_idx, - process_group=process_group, - dtype=dtype, - device=device, - max_norm=max_norm, - norm_type=norm_type, - scale_grad_by_freq=scale_grad_by_freq, - sparse=sparse, - *args, - **kwargs) - - # copy the weight - with torch.no_grad(): - sharded_weight = shard_colwise(module.weight.data, process_group) - embedding.weight.copy_(sharded_weight) - - return embedding - - def reset_parameters(self, weight_initializer) -> None: - fan_in, fan_out = self.num_embeddings, self.embedding_dim - weight_initializer(self.weight, fan_in=fan_in, fan_out=fan_out) - self._fill_padding_idx_with_zero() - - def _fill_padding_idx_with_zero(self) -> None: - if self.padding_idx is not None: - with torch.no_grad(): - self.weight[self.padding_idx].fill_(0) - - def forward(self, input_: Tensor) -> Tensor: - output_parallel = F.embedding(input_, self.weight, self.padding_idx, *self.embed_args, **self.embed_kwargs) - - if self.gather_output: - output = gather_forward_split_backward(output_parallel, dim=-1, process_group=self.process_group) - return output - else: - return output_parallel diff --git a/colossalai/shardformer/layer/layernorm1d.py b/colossalai/shardformer/layer/layernorm1d.py deleted file mode 100644 index 78bd64cfb504..000000000000 --- a/colossalai/shardformer/layer/layernorm1d.py +++ /dev/null @@ -1,73 +0,0 @@ -#!/usr/bin/env python -# -*- encoding: utf-8 -*- - -from collections import OrderedDict - -from colossalai.context import ParallelMode, seed -from colossalai.core import global_context as gpc -from colossalai.global_variables import tensor_parallel_env as env -from colossalai.kernel import LayerNorm -from colossalai.nn import init as init -from colossalai.nn.layer.colossalai_layer._utils import ColossalaiModule -from colossalai.utils.checkpointing import broadcast_state_dict - -Fast_LN = None -try: - from apex.contrib.layer_norm.layer_norm import FastLayerNorm - Fast_LN = FastLayerNorm -except ImportError: - pass - - -class LayerNorm1D(ColossalaiModule): - r""" - Layer Normalization for colossalai - - Args: - normalized_shape (int): input shape from an expected input of size. - :math:`[* \times \text{normalized_shape}[0] \times \text{normalized_shape}[1] - \times \ldots \times \text{normalized_shape}[-1]]` - If a single integer is used, it is treated as a singleton list, and this module will - normalize over the last dimension which is expected to be of that specific size. - eps (float): a value added to the denominator for numerical stability, defaults to 1e-05. - bias (bool, optional): Whether to add a bias, defaults to ``True``. - dtype (:class:`torch.dtype`, optional): The dtype of parameters, defaults to None. - """ - - _fast_ln_supported_sizes = [ - 1024, 1536, 2048, 2304, 3072, 3840, 4096, 5120, 6144, 8192, 10240, 12288, 12800, 15360, 16384, 18432, 20480, - 24576, 25600, 30720, 32768, 40960, 49152, 65536 - ] - - def __init__(self, normalized_shape: int, eps=1e-05, bias=True, dtype=None): - if Fast_LN is not None and normalized_shape in self._fast_ln_supported_sizes: - norm = Fast_LN(normalized_shape, eps=eps).to(dtype) - else: - norm = None - try: - from apex.normalization import FusedLayerNorm - norm = FusedLayerNorm(normalized_shape, eps=eps).to(dtype) - except ImportError: - norm = LayerNorm(normalized_shape, eps=eps).to(dtype) - super().__init__(norm) - - def _load_from_state_dict(self, state_dict, prefix, *args): - local_state = OrderedDict() - weight_key = prefix + 'weight' - bias_key = prefix + 'bias' - if gpc.get_local_rank(ParallelMode.TENSOR) == 0: - # weight - weight = state_dict.pop(weight_key, None) - if weight is not None: - local_state[weight_key] = weight - # bias - bias = state_dict.pop(bias_key, None) - if bias is not None: - local_state[bias_key] = bias - - local_state = broadcast_state_dict(local_state, ParallelMode.PARALLEL_1D) - super()._load_from_state_dict(local_state, prefix, *args) - - def _save_to_state_dict(self, destination, prefix, keep_vars): - if gpc.get_local_rank(ParallelMode.TENSOR) == 0: - super()._save_to_state_dict(destination, prefix, keep_vars) diff --git a/colossalai/shardformer/layer/linear1d.py b/colossalai/shardformer/layer/linear.py similarity index 96% rename from colossalai/shardformer/layer/linear1d.py rename to colossalai/shardformer/layer/linear.py index d59d32df824e..b87981c6db42 100644 --- a/colossalai/shardformer/layer/linear1d.py +++ b/colossalai/shardformer/layer/linear.py @@ -23,15 +23,10 @@ reduce_input, split_forward_gather_backward, ) -from .parallelmodule import ParallelModule +from .parallel_module import ParallelModule from .utils import create_randomizer_with_offset -Fast_LN = None -try: - from apex.contrib.layer_norm.layer_norm import FastLayerNorm - Fast_LN = FastLayerNorm -except ImportError: - pass +__all__ = ['Linear1D_Col', 'Linear1D_Row'] class Linear1D_Col(ParallelModule): @@ -104,8 +99,8 @@ def __init__(self, seed = torch.random.initial_seed() self.randomizer = create_randomizer_with_offset(seed, process_group=self.process_group) - with self.randomizer.fork_rng(enable_cpu=True): - self.reset_parameters(weight_initializer, bias_initializer) + # init weights + self.reset_parameters(weight_initializer, bias_initializer) @staticmethod def from_native_module(module: nn.Linear, process_group: Union[ProcessGroup, List[ProcessGroup]], *args, @@ -146,10 +141,11 @@ def from_native_module(module: nn.Linear, process_group: Union[ProcessGroup, Lis return linear_1d def reset_parameters(self, weight_initializer, bias_initializer) -> None: - fan_in, fan_out = self.in_features, self.out_features - weight_initializer(self.weight, fan_in=fan_in, fan_out=fan_out) - if self.bias is not None: - bias_initializer(self.bias, fan_in=fan_in) + with self.randomizer.fork_rng(enable_cpu=True): + fan_in, fan_out = self.in_features, self.out_features + weight_initializer(self.weight, fan_in=fan_in, fan_out=fan_out) + if self.bias is not None: + bias_initializer(self.bias, fan_in=fan_in) def forward(self, input_: Tensor) -> Tuple[Tensor, Tensor]: assert input_.shape[-1] == self.weight.shape[-1], \ diff --git a/colossalai/shardformer/layer/linearconv1d.py b/colossalai/shardformer/layer/linear_conv.py similarity index 92% rename from colossalai/shardformer/layer/linearconv1d.py rename to colossalai/shardformer/layer/linear_conv.py index 4a5cb0707900..b4599f48942d 100644 --- a/colossalai/shardformer/layer/linearconv1d.py +++ b/colossalai/shardformer/layer/linear_conv.py @@ -23,19 +23,15 @@ reduce_input, split_forward_gather_backward, ) -from .parallelmodule import ParallelModule +from .parallel_module import ParallelModule from .utils import create_randomizer_with_offset -Fast_LN = None -try: - from apex.contrib.layer_norm.layer_norm import FastLayerNorm - Fast_LN = FastLayerNorm -except ImportError: - pass +__all__ = ['LinearConv1D_Col', 'LinearConv1D_Row'] class LinearConv1D_Col(ParallelModule): r"""Linear layer with column parallelism. + Specially created for HuggingFace's GPT2 model. The linear layer is defined as :math:`Y = XA + b`. A is parallelized along its second dimension as :math:`A = [A_1, ..., A_p]`. This layer is used to fit `Conv1D` layer in gpt2 of huggingface. @@ -104,8 +100,8 @@ def __init__(self, seed = torch.random.initial_seed() self.randomizer = create_randomizer_with_offset(seed, process_group=self.process_group) - with self.randomizer.fork_rng(enable_cpu=True): - self.reset_parameters(weight_initializer, bias_initializer) + # init weights + self.reset_parameters(weight_initializer, bias_initializer) @staticmethod def from_native_module(module: nn.Linear, process_group: Union[ProcessGroup, List[ProcessGroup]], n_cast: int, @@ -162,10 +158,11 @@ def from_native_module(module: nn.Linear, process_group: Union[ProcessGroup, Lis return linear_1d def reset_parameters(self, weight_initializer, bias_initializer) -> None: - fan_in, fan_out = self.in_features, self.out_features - weight_initializer(self.weight, fan_in=fan_in, fan_out=fan_out) - if self.bias is not None: - bias_initializer(self.bias, fan_in=fan_in) + with self.randomizer.fork_rng(enable_cpu=True): + fan_in, fan_out = self.in_features, self.out_features + weight_initializer(self.weight, fan_in=fan_in, fan_out=fan_out) + if self.bias is not None: + bias_initializer(self.bias, fan_in=fan_in) def forward(self, input_: Tensor) -> Tuple[Tensor, Tensor]: assert input_.shape[-1] == self.weight.shape[-1], \ @@ -192,6 +189,7 @@ def forward(self, input_: Tensor) -> Tuple[Tensor, Tensor]: class LinearConv1D_Row(ParallelModule): r""" Linear layer with row parallelism + Specially created for HuggingFace's GPT2 model. Args: in_features (int): size of each input sample. @@ -260,8 +258,8 @@ def __init__(self, seed = torch.random.initial_seed() self.randomizer = create_randomizer_with_offset(seed, process_group=self.process_group) - with self.randomizer.fork_rng(enable_cpu=True): - self.reset_parameters(weight_initializer, bias_initializer) + # init weights + self.reset_parameters(weight_initializer, bias_initializer) @staticmethod def from_native_module(module: nn.Linear, process_group: Union[ProcessGroup, List[ProcessGroup]], n_cast: int, @@ -320,20 +318,21 @@ def chunk_weight(self): self.weight_list = torch.chunk(self.weight, self.stream_chunk_num, dim=0) def reset_parameters(self, weight_initializer, bias_initializer) -> None: - fan_in, fan_out = self.in_features, self.out_features - weight_initializer(self.weight, fan_in=fan_in, fan_out=fan_out) - - if self.bias is not None: - bias_initializer(self.bias, fan_in=fan_in) - if self.process_group is None: - src_rank = 0 - else: - src_rank = dist.distributed_c10d._get_global_rank(self.process_group, 0) - - origin_device = self.bias.device - self.bias = self.bias.cuda() - dist.broadcast(self.bias, src=src_rank, group=self.process_group) - self.bias = self.bias.to(origin_device) + with self.randomizer.fork_rng(enable_cpu=True): + fan_in, fan_out = self.in_features, self.out_features + weight_initializer(self.weight, fan_in=fan_in, fan_out=fan_out) + + if self.bias is not None: + bias_initializer(self.bias, fan_in=fan_in) + if self.process_group is None: + src_rank = 0 + else: + src_rank = dist.distributed_c10d._get_global_rank(self.process_group, 0) + + origin_device = self.bias.device + self.bias = self.bias.cuda() + dist.broadcast(self.bias, src=src_rank, group=self.process_group) + self.bias = self.bias.to(origin_device) def forward(self, input_: Tensor) -> Tensor: # Set up backprop all-reduce. diff --git a/colossalai/shardformer/layer/dist_crossentropy.py b/colossalai/shardformer/layer/loss.py similarity index 98% rename from colossalai/shardformer/layer/dist_crossentropy.py rename to colossalai/shardformer/layer/loss.py index 7840c2f2e5da..38a5395a0f57 100644 --- a/colossalai/shardformer/layer/dist_crossentropy.py +++ b/colossalai/shardformer/layer/loss.py @@ -1,10 +1,10 @@ import torch import torch.distributed as dist -import torch.nn as nn -import torch.nn.functional as F from torch.autograd import Function from torch.distributed import ProcessGroup +__all__ = ['DistCrossEntropy', 'cross_entropy_1d'] + class DistCrossEntropy(Function): r""" diff --git a/colossalai/shardformer/layer/parallelmodule.py b/colossalai/shardformer/layer/parallel_module.py similarity index 78% rename from colossalai/shardformer/layer/parallelmodule.py rename to colossalai/shardformer/layer/parallel_module.py index 3d19bbea7e47..c68cd57786ab 100644 --- a/colossalai/shardformer/layer/parallelmodule.py +++ b/colossalai/shardformer/layer/parallel_module.py @@ -7,15 +7,7 @@ import torch.nn as nn from torch.distributed import ProcessGroup -from colossalai.global_variables import tensor_parallel_env as env -from colossalai.nn import init as init - -Fast_LN = None -try: - from apex.contrib.layer_norm.layer_norm import FastLayerNorm - Fast_LN = FastLayerNorm -except ImportError: - pass +__all__ = ['ParallelModule'] class ParallelModule(nn.Module, ABC): diff --git a/colossalai/shardformer/policies/basepolicy.py b/colossalai/shardformer/policies/basepolicy.py index 175a914a84f9..b5d9cdbd7289 100644 --- a/colossalai/shardformer/policies/basepolicy.py +++ b/colossalai/shardformer/policies/basepolicy.py @@ -2,7 +2,7 @@ from abc import ABC, abstractmethod from dataclasses import dataclass -from typing import Any, Callable, Dict, List, Tuple, Type, Union +from typing import Any, Callable, Dict, List, Type, Union import torch.nn as nn diff --git a/colossalai/shardformer/policies/bert.py b/colossalai/shardformer/policies/bert.py index 2a204f0defe4..d5e8e01cf154 100644 --- a/colossalai/shardformer/policies/bert.py +++ b/colossalai/shardformer/policies/bert.py @@ -4,7 +4,7 @@ import colossalai.shardformer.layer as col_nn from colossalai.shardformer.layer.dropout import Dropout1D -from ..utils import getattr_, setattr_ +from .._utils import getattr_, setattr_ from .basepolicy import ModulePolicyDescription, Policy, SubModuleReplacementDescription diff --git a/colossalai/shardformer/policies/gpt2.py b/colossalai/shardformer/policies/gpt2.py index d255325b2084..da9e6b7bd32d 100644 --- a/colossalai/shardformer/policies/gpt2.py +++ b/colossalai/shardformer/policies/gpt2.py @@ -1,12 +1,7 @@ -from typing import Type, Union - -import torch.nn as nn from transformers.models.gpt2.modeling_gpt2 import GPT2Block, GPT2Model import colossalai.shardformer.layer as col_nn -from colossalai.shardformer.layer.dropout import Dropout1D -from ..utils import getattr_, setattr_ from .basepolicy import ModulePolicyDescription, Policy, SubModuleReplacementDescription diff --git a/colossalai/shardformer/policies/t5.py b/colossalai/shardformer/policies/t5.py index 9e0c8604969c..30433f751088 100644 --- a/colossalai/shardformer/policies/t5.py +++ b/colossalai/shardformer/policies/t5.py @@ -1,5 +1,3 @@ -import torch -import torch.nn as nn from transformers import T5ForConditionalGeneration from transformers.models.t5.modeling_t5 import ( T5Attention, diff --git a/colossalai/shardformer/shard/sharder.py b/colossalai/shardformer/shard/sharder.py index 66934b09b3ac..22f5f1c12d26 100644 --- a/colossalai/shardformer/shard/sharder.py +++ b/colossalai/shardformer/shard/sharder.py @@ -4,9 +4,9 @@ from colossalai.cluster.process_group_manager import ProcessGroupManager +from .._utils import getattr_, setattr_ from ..policies.autopolicy import get_autopolicy from ..policies.basepolicy import Policy, SubModuleReplacementDescription -from ..utils.utils import getattr_, setattr_ from .shard_config import ShardConfig __all__ = ['ModelSharder', 'shard_model'] diff --git a/colossalai/shardformer/utils/__init__.py b/colossalai/shardformer/utils/__init__.py deleted file mode 100644 index b50e7b2f6d80..000000000000 --- a/colossalai/shardformer/utils/__init__.py +++ /dev/null @@ -1 +0,0 @@ -from .utils import getattr_, hasattr_, setattr_ diff --git a/tests/test_shardformer/test_module/test_distcrossentropy.py b/tests/test_shardformer/test_layer/test_dist_crossentropy.py similarity index 87% rename from tests/test_shardformer/test_module/test_distcrossentropy.py rename to tests/test_shardformer/test_layer/test_dist_crossentropy.py index 9a19ec57821d..72e6e5cf26ed 100644 --- a/tests/test_shardformer/test_module/test_distcrossentropy.py +++ b/tests/test_shardformer/test_layer/test_dist_crossentropy.py @@ -4,7 +4,7 @@ import colossalai from colossalai.logging import disable_existing_loggers -from colossalai.shardformer.layer.dist_crossentropy import applyDistCrossEntropy +from colossalai.shardformer.layer import cross_entropy_1d from colossalai.testing import rerun_if_address_is_in_use, spawn CONFIG = dict(parallel=dict(data=1, pipeline=1, tensor=dict(size=2, mode='1d')),) @@ -25,7 +25,7 @@ def check_dist_crossentropy(rank, world_size, port, ignore_index): org_loss = F.cross_entropy(org_pred, org_labels) dist_pred = pred.chunk(world_size, -1)[rank] - dist_loss = applyDistCrossEntropy(dist_pred.to('cuda'), labels.to('cuda'), ignore_index=ignore_index) + dist_loss = cross_entropy_1d(dist_pred.to('cuda'), labels.to('cuda'), ignore_index=ignore_index) assert torch.allclose(org_loss, dist_loss, atol=1e-5), f"dist cross entropy loss is not equal to orgin loss\n{org_loss}\n{dist_loss}" diff --git a/tests/test_shardformer/test_layer/test_dropout.py b/tests/test_shardformer/test_layer/test_dropout.py index c48c11b36d91..c62d25d94aa4 100644 --- a/tests/test_shardformer/test_layer/test_dropout.py +++ b/tests/test_shardformer/test_layer/test_dropout.py @@ -3,7 +3,7 @@ import torch.nn as nn import colossalai -from colossalai.shardformer.layer.dropout import Dropout1D +from colossalai.shardformer.layer import Dropout1D from colossalai.testing import assert_equal, assert_not_equal, rerun_if_address_is_in_use, spawn diff --git a/tests/test_shardformer/test_layer/test_embedding.py b/tests/test_shardformer/test_layer/test_embedding.py index 462349ecb93b..70500008cfff 100644 --- a/tests/test_shardformer/test_layer/test_embedding.py +++ b/tests/test_shardformer/test_layer/test_embedding.py @@ -4,7 +4,7 @@ from torch.testing import assert_close import colossalai -from colossalai.shardformer.layer.layers import Embedding1D +from colossalai.shardformer.layer import Embedding1D from colossalai.testing import rerun_if_address_is_in_use, spawn diff --git a/tests/test_shardformer/test_layer/test_linear_1d.py b/tests/test_shardformer/test_layer/test_linear_1d.py index 2a3ce99384cb..00ecc37ce2fa 100644 --- a/tests/test_shardformer/test_layer/test_linear_1d.py +++ b/tests/test_shardformer/test_layer/test_linear_1d.py @@ -4,7 +4,7 @@ from torch.testing import assert_close import colossalai -from colossalai.shardformer.layer.layers import Linear1D_Col, Linear1D_Row +from colossalai.shardformer.layer import Linear1D_Col, Linear1D_Row from colossalai.testing import rerun_if_address_is_in_use, spawn diff --git a/tests/test_shardformer/test_layer/test_vocab_parallel_embedding_1d.py b/tests/test_shardformer/test_layer/test_vocab_parallel_embedding_1d.py index 3df53e8a8458..bee44a2fb109 100644 --- a/tests/test_shardformer/test_layer/test_vocab_parallel_embedding_1d.py +++ b/tests/test_shardformer/test_layer/test_vocab_parallel_embedding_1d.py @@ -4,7 +4,7 @@ from torch.testing import assert_close import colossalai -from colossalai.shardformer.layer.layers import VocabParallelEmbedding1D +from colossalai.shardformer.layer import VocabParallelEmbedding1D from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn diff --git a/tests/test_shardformer/test_module/test_dropout.py b/tests/test_shardformer/test_module/test_dropout.py deleted file mode 100644 index 4a13eb61c1fc..000000000000 --- a/tests/test_shardformer/test_module/test_dropout.py +++ /dev/null @@ -1,51 +0,0 @@ -import pytest -import torch -import torch.nn.functional as F - -import colossalai -from colossalai.logging import disable_existing_loggers -from colossalai.shardformer.layer.dropout import Dropout1D -from colossalai.testing import rerun_if_address_is_in_use, spawn - -CONFIG = dict(parallel=dict(data=1, pipeline=1, tensor=dict(size=2, mode='1d')),) - - -def check_dropout(rank, world_size, port): - disable_existing_loggers() - colossalai.launch(config=CONFIG, rank=rank, world_size=world_size, port=port, host='localhost', backend='nccl') - - # prepare data - input = torch.randn(5, 4).to('cuda') - dropout = Dropout1D(p=0.4).to('cuda') - output_list = [] - # compare the dropout pattern in each device - for i in range(2): - output = dropout(input) - output_list.append(output) - dist_output_list = [torch.zeros(*output.shape).to('cuda') for _ in range(world_size)] - torch.distributed.all_gather(dist_output_list, output) - for j in range(world_size): - for k in range(world_size): - if j != k: - mask = torch.eq(dist_output_list[j], 0.0) == torch.eq(dist_output_list[k], 0.0) - assert torch.all( - mask - ) == False, f"The dropout pattern in each device is not unique\n{dist_output_list[j]}\n{dist_output_list[k]}" - # compare the dropout pattern in loacl device - for i in range(len(output_list)): - for j in range(len(output_list)): - if i != j: - mask = torch.eq(output_list[i], 0.0) == torch.eq(output_list[j], 0.0) - assert torch.all( - mask - ) == False, f"The dropout pattern in one device is not unique\n{output_list[i]}\n{output_list[j]}" - - -@pytest.mark.dist -@rerun_if_address_is_in_use() -def test_dropout(): - spawn(check_dropout, 2) - - -if __name__ == '__main__': - test_dropout() diff --git a/tests/test_shardformer/test_module/test_slicer.py b/tests/test_shardformer/test_module/test_slicer.py deleted file mode 100644 index c72a0357573b..000000000000 --- a/tests/test_shardformer/test_module/test_slicer.py +++ /dev/null @@ -1,78 +0,0 @@ -import pytest -import torch -import torch.nn.functional as F - -import colossalai -from colossalai.logging import disable_existing_loggers -from colossalai.shardformer.policies.basepolicy import Col_Layer, Layer, Row_Layer -from colossalai.shardformer.shard.shard_config import ShardConfig -from colossalai.shardformer.shard.slicer import Slicer -from colossalai.testing import rerun_if_address_is_in_use, spawn - -CONFIG = dict(parallel=dict(data=1, pipeline=1, tensor=dict(size=2, mode='1d')),) - - -def check_slicer(rank, world_size, port, in_feature, out_feature): - disable_existing_loggers() - colossalai.launch(config=CONFIG, rank=rank, world_size=world_size, port=port, host='localhost', backend='nccl') - # initialize slicer - shardconfig = ShardConfig(rank=rank, world_size=world_size) - slicer = Slicer(shardconfig) - # initialize test data - weight = torch.randn(in_feature, out_feature) - bias = torch.randn(out_feature) - policy_layer_cls_list = [Layer, Col_Layer, Row_Layer] - n_cast_list = [None, 2, 3, 4] - # weight and bias - for n_cast in n_cast_list: - sliced_weight, sliced_bias = slicer.slice_weight_bias(weight, bias, policy_layer_cls=Layer, n_cast=n_cast) - expected_sliced_weight = weight - expected_sliced_bias = bias - assert torch.equal( - sliced_weight, expected_sliced_weight - ), f"In Layer case, weight: sliced_weight is not equal to expected_sliced_weight\norg:{weight}\nsliced:{sliced_weight}\nexpected:{expected_sliced_weight}" - assert torch.equal( - sliced_bias, expected_sliced_bias - ), f"In Layer case, bias: sliced_bias is not equal to expected_sliced_bias\norg:{bias}\nsliced:{sliced_weight}\nexpected:{expected_sliced_weight}" - - sliced_weight, sliced_bias = slicer.slice_weight_bias(weight, bias, policy_layer_cls=Col_Layer, n_cast=n_cast) - if (n_cast is None): - expected_sliced_weight = weight.chunk(world_size, dim=0)[rank] - expected_sliced_bias = bias.chunk(world_size)[rank] - else: - chunks = weight.chunk(world_size * n_cast, dim=0) - expected_sliced_weight = torch.cat([chunks[i] for i in range(rank, n_cast * world_size, world_size)], dim=0) - chunks = bias.chunk(world_size * n_cast, dim=0) - expected_sliced_bias = torch.cat([chunks[i] for i in range(rank, n_cast * world_size, world_size)]) - assert torch.equal( - sliced_weight, expected_sliced_weight - ), f"In Col_Layer {n_cast} cast case, weight: sliced_weight is not equal to expected_sliced_weight\norg:{weight}\nsliced:{sliced_weight}\nexpected:{expected_sliced_weight}" - assert torch.equal( - sliced_bias, expected_sliced_bias - ), f"In Col_Layer {n_cast} cast case, bias: sliced_bias is not equal to expected_sliced_bias\norg:{bias}\nsliced:{sliced_bias}\nexpected:{expected_sliced_bias}" - - sliced_weight, sliced_bias = slicer.slice_weight_bias(weight, bias, policy_layer_cls=Row_Layer, n_cast=n_cast) - if (n_cast is None): - expected_sliced_weight = weight.chunk(world_size, dim=1)[rank] - expected_sliced_bias = bias - else: - chunks = weight.chunk(world_size * n_cast, dim=1) - expected_sliced_weight = torch.cat([chunks[i] for i in range(rank, n_cast * world_size, world_size)], dim=1) - expected_sliced_bias = bias - assert torch.equal( - sliced_weight, expected_sliced_weight - ), f"In Row_Layer {n_cast} cast case, weight: sliced_weight is not equal to expected_sliced_weight\norg:{weight}\nsliced:{sliced_weight}\nexpected:{expected_sliced_weight}" - assert torch.equal( - sliced_bias, expected_sliced_bias - ), f"In Row_Layer {n_cast} cast case, bias: sliced_bias is not equal to expected_sliced_bias\norg:{bias}\nsliced:{sliced_weight}\nexpected:{expected_sliced_weight}" - - -@pytest.mark.dist -@rerun_if_address_is_in_use() -def test_slicer(): - args = dict(in_feature=24, out_feature=48) - spawn(check_slicer, nprocs=2, in_feature=args['in_feature'], out_feature=args['out_feature']) - - -if __name__ == '__main__': - test_slicer() From d5d9178a9de92445e79ff4804dcb360ab380699b Mon Sep 17 00:00:00 2001 From: FoolPlayer <45593998+FoolPlayer@users.noreply.github.com> Date: Thu, 22 Jun 2023 10:33:06 +0800 Subject: [PATCH 31/49] support kit use for bert/gpt test (#4055) * support kit use for bert test * support kit test for gpt2 --- colossalai/shardformer/policies/autopolicy.py | 20 ++- colossalai/shardformer/policies/bert.py | 23 ++- colossalai/shardformer/policies/gpt2.py | 81 +++++++++- tests/kit/model_zoo/transformers/bert.py | 140 +++++++++++++----- tests/kit/model_zoo/transformers/gpt.py | 69 +++++++-- .../test_model/test_shard_bert.py | 116 +++------------ .../test_model/test_shard_gpt2.py | 116 ++++----------- 7 files changed, 319 insertions(+), 246 deletions(-) diff --git a/colossalai/shardformer/policies/autopolicy.py b/colossalai/shardformer/policies/autopolicy.py index 5e7a285e3285..b1b8c6156f9f 100644 --- a/colossalai/shardformer/policies/autopolicy.py +++ b/colossalai/shardformer/policies/autopolicy.py @@ -25,17 +25,19 @@ class PolicyLocation: _POLICY_LIST = { # BERT "transformers.models.bert.modeling_bert.BertModel": - PolicyLocation(file_name="bert", class_name="BertPolicy"), + PolicyLocation(file_name="bert", class_name="BertModelPolicy"), "transformers.models.bert.modeling_bert.BertForPreTraining": PolicyLocation(file_name="bert", class_name="BertForPretrainingPolicy"), - "transformers.models.bert.modeling_bert.BertForMaskedLM": - PolicyLocation(file_name="bert", class_name="BertForMaskedLMPolicy"), "transformers.models.bert.modeling_bert.BertLMHeadModel": PolicyLocation(file_name="bert", class_name="BertLMHeadModelPolicy"), - "transformers.models.bert.modeling_bert.BertForNextSentencePrediction": - PolicyLocation(file_name="bert", class_name="BertForNextSentencePredictionPolicy"), + "transformers.models.bert.modeling_bert.BertForMaskedLM": + PolicyLocation(file_name="bert", class_name="BertForMaskedLMPolicy"), "transformers.models.bert.modeling_bert.BertForSequenceClassification": PolicyLocation(file_name="bert", class_name="BertForSequenceClassificationPolicy"), + "transformers.models.bert.modeling_bert.BertForTokenClassification": + PolicyLocation(file_name="bert", class_name="BertForTokenClassificationPolicy"), + "transformers.models.bert.modeling_bert.BertForNextSentencePrediction": + PolicyLocation(file_name="bert", class_name="BertForNextSentencePredictionPolicy"), "transformers.models.bert.modeling_bert.BertForMultipleChoice": PolicyLocation(file_name="bert", class_name="BertForMultipleChoicePolicy"), @@ -58,6 +60,14 @@ class PolicyLocation: # GPT2 "transformers.models.gpt2.modeling_gpt2.GPT2Model": PolicyLocation(file_name="gpt2", class_name="GPT2ModelPolicy"), + "transformers.models.gpt2.modeling_gpt2.GPT2LMHeadModel": + PolicyLocation(file_name="gpt2", class_name="GPT2LMHeadModelPolicy"), + "transformers.models.gpt2.modeling_gpt2.GPT2DoubleHeadsModel": + PolicyLocation(file_name="gpt2", class_name="GPT2DoubleHeadsModelPolicy"), + "transformers.models.gpt2.modeling_gpt2.GPT2ForTokenClassification": + PolicyLocation(file_name="gpt2", class_name="GPT2ForTokenClassificationPolicy"), + "transformers.models.gpt2.modeling_gpt2.GPT2ForSequenceClassification": + PolicyLocation(file_name="gpt2", class_name="GPT2ForSequenceClassificationPolicy"), } diff --git a/colossalai/shardformer/policies/bert.py b/colossalai/shardformer/policies/bert.py index d5e8e01cf154..8649c0dbeaa6 100644 --- a/colossalai/shardformer/policies/bert.py +++ b/colossalai/shardformer/policies/bert.py @@ -131,8 +131,8 @@ def postprocess(self): return self.model -# BertForMaskedLM -class BertForMaskedLMPolicy(BertPolicy): +# BertLMHeadModel +class BertLMHeadModelPolicy(BertPolicy): def __init__(self) -> None: super().__init__() @@ -162,8 +162,8 @@ def postprocess(self): return self.model -# BertLMHeadModel -class BertLMHeadModelPolicy(BertPolicy): +# BertForMaskedLM +class BertForMaskedLMPolicy(BertPolicy): def __init__(self) -> None: super().__init__() @@ -193,15 +193,22 @@ def postprocess(self): return self.model -# BertForNextSentencePrediction -class BertForNextSentencePredictionPolicy(BertPolicy): +# BertForSequenceClassification +class BertForSequenceClassificationPolicy(BertPolicy): def __init__(self) -> None: super().__init__() -# BertForSequenceClassification -class BertForSequenceClassificationPolicy(BertPolicy): +# BertForTokenClassification +class BertForTokenClassificationPolicy(BertPolicy): + + def __init__(self) -> None: + super().__init__() + + +# BertForNextSentencePrediction +class BertForNextSentencePredictionPolicy(BertPolicy): def __init__(self) -> None: super().__init__() diff --git a/colossalai/shardformer/policies/gpt2.py b/colossalai/shardformer/policies/gpt2.py index da9e6b7bd32d..54ea2f6e3279 100644 --- a/colossalai/shardformer/policies/gpt2.py +++ b/colossalai/shardformer/policies/gpt2.py @@ -1,7 +1,9 @@ -from transformers.models.gpt2.modeling_gpt2 import GPT2Block, GPT2Model +import torch.nn as nn +from transformers.models.gpt2.modeling_gpt2 import GPT2Block, GPT2DoubleHeadsModel, GPT2LMHeadModel, GPT2Model import colossalai.shardformer.layer as col_nn +from .._utils import getattr_, setattr_ from .basepolicy import ModulePolicyDescription, Policy, SubModuleReplacementDescription @@ -82,7 +84,6 @@ def module_policy(self): } def new_model_class(self): - return self.model def postprocess(self): @@ -94,3 +95,79 @@ class GPT2ModelPolicy(GPT2Policy): def __init__(self) -> None: super().__init__() + + +# GPT2LMHeadModel +class GPT2LMHeadModelPolicy(GPT2Policy): + + def __init__(self) -> None: + super().__init__() + + def module_policy(self): + module_policy = super().module_policy() + addon_module = { + GPT2LMHeadModel: + ModulePolicyDescription(attribute_replacement={}, + param_replacement=[], + sub_module_replacement=[ + SubModuleReplacementDescription(suffix="lm_head", + target_module=col_nn.Linear1D_Col, + kwargs={"gather_output": True}) + ]) + } + module_policy.update(addon_module) + return module_policy + + def postprocess(self): + binding_map = {"transformer.wte.weight": "lm_head.weight"} + for k, v in binding_map.items(): + param = getattr_(self.model, k) + param = nn.Parameter(param) + setattr_(self.model, k, param) + setattr_(self.model, v, param) + return self.model + + +# GPT22DoubleHeadsModel +class GPT2DoubleHeadsModelPolicy(GPT2Policy): + + def __init__(self) -> None: + super().__init__() + + def module_policy(self): + module_policy = super().module_policy() + addon_module = { + GPT2DoubleHeadsModel: + ModulePolicyDescription(attribute_replacement={}, + param_replacement=[], + sub_module_replacement=[ + SubModuleReplacementDescription(suffix="lm_head", + target_module=col_nn.Linear1D_Col, + kwargs={"gather_output": True}) + ]) + } + module_policy.update(addon_module) + return module_policy + + def postprocess(self): + binding_map = {"transformer.wte.weight": "lm_head.weight"} + for k, v in binding_map.items(): + param = getattr_(self.model, k) + param = nn.Parameter(param) + setattr_(self.model, k, param) + setattr_(self.model, v, param) + return self.model + + +# GPT2ForTokenClassification +class GPT2ForTokenClassificationPolicy(GPT2Policy): + + def __init__(self) -> None: + super().__init__() + + +# GPT2ForSequenceClassification +class GPT2ForSequenceClassificationPolicy(GPT2Policy): + + def __init__(self) -> None: + super().__init__() diff --git a/tests/kit/model_zoo/transformers/bert.py b/tests/kit/model_zoo/transformers/bert.py index 99135704da70..d2d3de7b7bee 100644 --- a/tests/kit/model_zoo/transformers/bert.py +++ b/tests/kit/model_zoo/transformers/bert.py @@ -6,83 +6,147 @@ # =============================== # Register single-sentence BERT # =============================== -BATCH_SIZE = 2 -SEQ_LENGTH = 16 -def data_gen_fn(): - input_ids = torch.zeros((BATCH_SIZE, SEQ_LENGTH), dtype=torch.int64) - token_type_ids = torch.zeros((BATCH_SIZE, SEQ_LENGTH), dtype=torch.int64) - attention_mask = torch.zeros((BATCH_SIZE, SEQ_LENGTH), dtype=torch.int64) +# define data gen function +def data_gen(): + # Generated from following code snippet + # + # from transformers import BertTokenizer + # input = 'Hello, my dog is cute' + # tokenized_input = tokenizer(input, return_tensors='pt') + # input_ids = tokenized_input['input_ids'] + # attention_mask = tokenized_input['attention_mask'] + # token_type_ids = tokenized_input['token_type_ids'] + input_ids = torch.tensor([[101, 7592, 1010, 2026, 3899, 2003, 10140, 102]], dtype=torch.int64) + token_type_ids = torch.tensor([[0, 0, 0, 0, 0, 0, 0, 0]], dtype=torch.int64) + attention_mask = torch.tensor([[1, 1, 1, 1, 1, 1, 1, 1]], dtype=torch.int64) return dict(input_ids=input_ids, token_type_ids=token_type_ids, attention_mask=attention_mask) +def data_gen_for_lm(): + # LM data gen + # the `labels` of LM is the token of the output, cause no padding, use `input_ids` as `labels` + data = data_gen() + data['labels'] = data['input_ids'].clone() + return data + + +def data_gen_for_pretraining(): + # pretraining data gen + # `next_sentence_label` is the label for next sentence prediction, 0 or 1 + data = data_gen_for_lm() + data['next_sentence_label'] = torch.tensor([1], dtype=torch.int64) + return data + + +def data_gen_for_sequence_classification(): + # sequence classification data gen + # `labels` is the label for sequence classification, 0 or 1 + data = data_gen() + data['labels'] = torch.tensor([1], dtype=torch.int64) + return data + + +def data_gen_for_token_classification(): + # token classification data gen + # `labels` is the type not the token id for token classification, 0 or 1 + data = data_gen() + data['labels'] = torch.tensor([[1, 0, 0, 0, 0, 0, 0, 0]], dtype=torch.int64) + return data + + +def data_gen_for_mcq(): + # multiple choice question data gen + # Generated from following code snippet + # + # tokenizer = transformers.BertTokenizer.from_pretrained("bert-base-uncased") + # prompt = "In Italy, pizza served in formal settings, such as at a restaurant, is presented unsliced." + # choice0 = "It is eaten with a fork and a knife." + # choice1 = "It is eaten while held in the hand." + # data = tokenizer([prompt, prompt], [choice0, choice1], return_tensors="pt", padding=True) + # data = {k: v.unsqueeze(0) for k, v in encoding.items()} + # data['labels'] = torch.tensor([0], dtype=torch.int64) + input_ids = torch.tensor([[[ + 101, 1999, 3304, 1010, 10733, 2366, 1999, 5337, 10906, 1010, 2107, 2004, 2012, 1037, 4825, 1010, 2003, 3591, + 4895, 14540, 6610, 2094, 1012, 102, 2009, 2003, 8828, 2007, 1037, 9292, 1998, 1037, 5442, 1012, 102 + ], + [ + 101, 1999, 3304, 1010, 10733, 2366, 1999, 5337, 10906, 1010, 2107, 2004, 2012, 1037, + 4825, 1010, 2003, 3591, 4895, 14540, 6610, 2094, 1012, 102, 2009, 2003, 8828, 2096, + 2218, 1999, 1996, 2192, 1012, 102, 0 + ]]]) + token_type_ids = torch.tensor( + [[[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0]]]) + attention_mask = torch.tensor( + [[[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1], + [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0]]]) + labels = torch.tensor([0], dtype=torch.int64) + + return dict(input_ids=input_ids, token_type_ids=token_type_ids, attention_mask=attention_mask, labels=labels) + + +# define output transform function output_transform_fn = lambda x: x -config = transformers.BertConfig(hidden_size=128, num_hidden_layers=2, num_attention_heads=4, intermediate_size=256) +# define loss funciton +loss_fn_for_bert_model = lambda x: x.pooler_output.mean() +loss_fn = lambda x: x.loss + +config = transformers.BertConfig(hidden_size=128, + num_hidden_layers=2, + num_attention_heads=4, + intermediate_size=256, + hidden_dropout_prob=0, + attention_probs_dropout_prob=0) # register the BERT variants model_zoo.register(name='transformers_bert', model_fn=lambda: transformers.BertModel(config), - data_gen_fn=data_gen_fn, + data_gen_fn=data_gen, output_transform_fn=output_transform_fn, + loss_fn=loss_fn_for_bert_model, model_attribute=ModelAttribute(has_control_flow=True)) model_zoo.register(name='transformers_bert_for_pretraining', model_fn=lambda: transformers.BertForPreTraining(config), - data_gen_fn=data_gen_fn, + data_gen_fn=data_gen_for_pretraining, output_transform_fn=output_transform_fn, + loss_fn=loss_fn, model_attribute=ModelAttribute(has_control_flow=True)) model_zoo.register(name='transformers_bert_lm_head_model', model_fn=lambda: transformers.BertLMHeadModel(config), - data_gen_fn=data_gen_fn, + data_gen_fn=data_gen_for_lm, output_transform_fn=output_transform_fn, + loss_fn=loss_fn, model_attribute=ModelAttribute(has_control_flow=True)) model_zoo.register(name='transformers_bert_for_masked_lm', model_fn=lambda: transformers.BertForMaskedLM(config), - data_gen_fn=data_gen_fn, + data_gen_fn=data_gen_for_lm, output_transform_fn=output_transform_fn, + loss_fn=loss_fn, model_attribute=ModelAttribute(has_control_flow=True)) model_zoo.register(name='transformers_bert_for_sequence_classification', model_fn=lambda: transformers.BertForSequenceClassification(config), - data_gen_fn=data_gen_fn, + data_gen_fn=data_gen_for_sequence_classification, output_transform_fn=output_transform_fn, + loss_fn=loss_fn, model_attribute=ModelAttribute(has_control_flow=True)) model_zoo.register(name='transformers_bert_for_token_classification', model_fn=lambda: transformers.BertForTokenClassification(config), - data_gen_fn=data_gen_fn, + data_gen_fn=data_gen_for_token_classification, output_transform_fn=output_transform_fn, + loss_fn=loss_fn, model_attribute=ModelAttribute(has_control_flow=True)) - - -# =============================== -# Register multi-sentence BERT -# =============================== -def data_gen_for_next_sentence(): - tokenizer = transformers.BertTokenizer.from_pretrained("bert-base-uncased") - prompt = "In Italy, pizza served in formal settings, such as at a restaurant, is presented unsliced." - next_sentence = "The sky is blue due to the shorter wavelength of blue light." - encoding = tokenizer(prompt, next_sentence, return_tensors="pt") - return encoding - - -def data_gen_for_mcq(): - tokenizer = transformers.BertTokenizer.from_pretrained("bert-base-uncased") - prompt = "In Italy, pizza served in formal settings, such as at a restaurant, is presented unsliced." - choice0 = "It is eaten with a fork and a knife." - choice1 = "It is eaten while held in the hand." - encoding = tokenizer([prompt, prompt], [choice0, choice1], return_tensors="pt", padding=True) - encoding = {k: v.unsqueeze(0) for k, v in encoding.items()} - return encoding - - -# register the following models model_zoo.register(name='transformers_bert_for_next_sentence', model_fn=lambda: transformers.BertForNextSentencePrediction(config), - data_gen_fn=data_gen_for_next_sentence, + data_gen_fn=data_gen_for_sequence_classification, output_transform_fn=output_transform_fn, + loss_fn=loss_fn, model_attribute=ModelAttribute(has_control_flow=True)) model_zoo.register(name='transformers_bert_for_mcq', model_fn=lambda: transformers.BertForMultipleChoice(config), data_gen_fn=data_gen_for_mcq, output_transform_fn=output_transform_fn, + loss_fn=loss_fn, model_attribute=ModelAttribute(has_control_flow=True)) diff --git a/tests/kit/model_zoo/transformers/gpt.py b/tests/kit/model_zoo/transformers/gpt.py index 5ed4fbe70dc9..c598fa8f48e0 100644 --- a/tests/kit/model_zoo/transformers/gpt.py +++ b/tests/kit/model_zoo/transformers/gpt.py @@ -11,47 +11,86 @@ def data_gen(): - input_ids = torch.zeros((BATCH_SIZE, SEQ_LENGTH), dtype=torch.int64) - token_type_ids = torch.zeros((BATCH_SIZE, SEQ_LENGTH), dtype=torch.int64) - attention_mask = torch.zeros((BATCH_SIZE, SEQ_LENGTH), dtype=torch.int64) - return dict(input_ids=input_ids, token_type_ids=token_type_ids, attention_mask=attention_mask) + # Generated from following code snippet + # + # from transformers import GPT2Tokenizer + # input = 'Hello, my dog is cute' + # tokenized_input = tokenizer(input, return_tensors='pt') + # input_ids = tokenized_input['input_ids'] + # attention_mask = tokenized_input['attention_mask'] + input_ids = torch.tensor([[15496, 11, 616, 3290, 318, 13779]], dtype=torch.int64) + attention_mask = torch.tensor([[1, 1, 1, 1, 1, 1]], dtype=torch.int64) + return dict(input_ids=input_ids, attention_mask=attention_mask) -def seq_classification_data_gen(): - # batch sizes should be 1 if no padding token is defined. - input_ids = torch.zeros((1, SEQ_LENGTH), dtype=torch.int64) - token_type_ids = torch.zeros((1, SEQ_LENGTH), dtype=torch.int64) - attention_mask = torch.zeros((1, SEQ_LENGTH), dtype=torch.int64) - return dict(input_ids=input_ids, token_type_ids=token_type_ids, attention_mask=attention_mask) +def data_gen_for_lm(): + # LM data gen + # the `labels` of LM is the token of the output, cause no padding, use `input_ids` as `labels` + data = data_gen() + data['labels'] = data['input_ids'].clone() + return data +def data_gen_for_token_classification(): + # token classification data gen + # `labels` is the type not the token id for token classification, 0 or 1 + data = data_gen() + data['labels'] = torch.tensor([[0, 0, 0, 0, 0, 0]], dtype=torch.int64) + return data + + +def data_gen_for_sequence_classification(): + # sequence classification data gen + data = data_gen() + data['labels'] = torch.tensor([0], dtype=torch.int64) + return data + + +# define output transform function output_transform_fn = lambda x: x -config = transformers.GPT2Config(n_position=64, n_layer=2, n_head=4) +# define loss function +loss_fn_for_gpt2_model = lambda x: x.last_hidden_state.mean() +loss_fn = lambda x: x.loss + +config = transformers.GPT2Config(n_layer=2, + n_head=4, + vocab_size=50258, + attn_pdrop=0, + embd_pdrop=0, + resid_pdrop=0, + summary_first_dropout=0, + hidden_dropout=0, + problem_type="single_label_classification") # register the following models model_zoo.register(name='transformers_gpt', model_fn=lambda: transformers.GPT2Model(config), data_gen_fn=data_gen, output_transform_fn=output_transform_fn, + loss_fn=loss_fn_for_gpt2_model, model_attribute=ModelAttribute(has_control_flow=True)) model_zoo.register(name='transformers_gpt_lm', model_fn=lambda: transformers.GPT2LMHeadModel(config), - data_gen_fn=data_gen, + data_gen_fn=data_gen_for_lm, output_transform_fn=output_transform_fn, + loss_fn=loss_fn, model_attribute=ModelAttribute(has_control_flow=True)) model_zoo.register(name='transformers_gpt_double_heads', model_fn=lambda: transformers.GPT2DoubleHeadsModel(config), - data_gen_fn=data_gen, + data_gen_fn=data_gen_for_lm, output_transform_fn=output_transform_fn, + loss_fn=loss_fn, model_attribute=ModelAttribute(has_control_flow=True)) model_zoo.register(name='transformers_gpt_for_token_classification', model_fn=lambda: transformers.GPT2ForTokenClassification(config), - data_gen_fn=data_gen, + data_gen_fn=data_gen_for_token_classification, output_transform_fn=output_transform_fn, + loss_fn=loss_fn, model_attribute=ModelAttribute(has_control_flow=True)) model_zoo.register(name='transformers_gpt_for_sequence_classification', model_fn=lambda: transformers.GPT2ForSequenceClassification(config), - data_gen_fn=seq_classification_data_gen, + data_gen_fn=data_gen_for_sequence_classification, output_transform_fn=output_transform_fn, + loss_fn=loss_fn, model_attribute=ModelAttribute(has_control_flow=True)) diff --git a/tests/test_shardformer/test_model/test_shard_bert.py b/tests/test_shardformer/test_model/test_shard_bert.py index 043ed1a74a27..ad98e3d073d4 100644 --- a/tests/test_shardformer/test_model/test_shard_bert.py +++ b/tests/test_shardformer/test_model/test_shard_bert.py @@ -1,86 +1,30 @@ -import copy -import os - import pytest import torch -from transformers import ( - AutoTokenizer, - BertConfig, - BertForMaskedLM, - BertForNextSentencePrediction, - BertForPreTraining, - BertForSequenceClassification, - BertLMHeadModel, - BertModel, -) import colossalai from colossalai.logging import disable_existing_loggers -from colossalai.shardformer import ShardConfig, ShardFormer -from colossalai.testing import rerun_if_address_is_in_use, spawn - -os.environ['TRANSFORMERS_NO_ADVISORY_WARNINGS'] = 'true' -CONFIG = dict(parallel=dict(data=1, pipeline=1, tensor=dict(size=2, mode='1d')),) -tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased") - - -def build_model(world_size, model_fn): - config = BertConfig() - config.hidden_dropout_prob = 0 - config.attention_probs_dropout_prob = 0 - - org_model = model_fn(config=config) - org_model_forshard = copy.deepcopy(org_model) - - org_model.to('cuda') - # TODO: no need to transfer to cuda - org_model_forshard.to('cuda') - shard_config = ShardConfig(tensor_parallel_size=world_size,) - shard_former = ShardFormer(shard_config=shard_config) - shard_former.init_distributed() - sharded_model = shard_former.shard_model(org_model_forshard).to('cuda') - - return org_model, sharded_model - +from colossalai.testing import assert_hf_output_close, clear_cache_before_run, rerun_if_address_is_in_use, spawn +from tests.kit.model_zoo import model_zoo +from tests.test_shardformer.test_model._utils import build_model, run_forward -def check_forward(org_model, sharded_model): - input = 'Hello, my dog is cute' - tokenized_input = tokenizer(input, return_tensors='pt').to('cuda') - #orgin model - org_model.eval() - org_out = org_model(**tokenized_input) +def check_forward_backward(org_model, sharded_model, data_gen_fn, output_transform_fn, loss_fn): + # check forward + org_output, org_loss, shard_output, shard_loss = run_forward(org_model, sharded_model, data_gen_fn, + output_transform_fn, loss_fn) + assert_hf_output_close(org_output, shard_output) - #shard model - sharded_model.eval() - shard_out = sharded_model(**tokenized_input) - - assert torch.allclose( - org_out[0], shard_out[0], - atol=1e-5), f"shard model output is not equal to orgin model output\n{org_out[0]}\n{shard_out[0]}" - - -def check_backward(org_model, sharded_model): - # prepare input - input = 'Hello, my dog is cute' - tokenized_input = tokenizer(input, return_tensors='pt').to('cuda') - labels = tokenized_input['input_ids'].clone() - labels[labels == tokenizer.pad_token_id] = -100 - tokenized_input['labels'] = labels - - #orgin model - org_model.train() - org_out = org_model(**tokenized_input) - org_loss = org_out.loss + # do backward org_loss.backward() - org_grad = org_model.bert.encoder.layer[0].attention.self.query.weight.grad - - #shard model - sharded_model.train() - shard_out = sharded_model(**tokenized_input) - shard_loss = shard_out.loss shard_loss.backward() - shard_grad = sharded_model.bert.encoder.layer[0].attention.self.query.weight.grad + + # check grad equality + if org_model.__class__.__name__ == 'BertModel': + org_grad = org_model.encoder.layer[0].attention.self.query.weight.grad + shard_grad = sharded_model.encoder.layer[0].attention.self.query.weight.grad + else: + org_grad = org_model.bert.encoder.layer[0].attention.self.query.weight.grad + shard_grad = sharded_model.bert.encoder.layer[0].attention.self.query.weight.grad shard_grad_list = [torch.zeros([*shard_grad.shape]).to('cuda') for _ in range(2)] shard_grad = torch.distributed.all_gather(shard_grad_list, shard_grad) @@ -89,36 +33,24 @@ def check_backward(org_model, sharded_model): assert torch.allclose(org_loss, shard_loss, atol=1e-5), f"shard model loss is not equal to orgin model loss\n{org_loss}\n{shard_loss}" assert torch.allclose(org_grad, all_shard_grad, - atol=1e-5), f"shard model grad is not equal to orgin model grad\n{org_grad}\n{shard_grad}" + atol=1e-5), f"shard model grad is not equal to orgin model grad\n{org_grad}\n{all_shard_grad}" def check_bert(rank, world_size, port): disable_existing_loggers() - colossalai.launch(config=CONFIG, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') - forward_list = [ - BertForMaskedLM, - BertForPreTraining, - BertLMHeadModel, + colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') - # TODO: do not work yet - # BertModel, - # BertForSequenceClassification - # BertForNextSentencePrediction, - ] - backward_lsit = [BertForMaskedLM, BertLMHeadModel] - - for model_fn in forward_list: + sub_model_zoo = model_zoo.get_sub_registry('transformers_bert') + for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items(): org_model, sharded_model = build_model(world_size, model_fn) - check_forward(org_model, sharded_model) - - if model_fn in backward_lsit: - check_backward(org_model, sharded_model) + check_forward_backward(org_model, sharded_model, data_gen_fn, output_transform_fn, loss_fn) - torch.cuda.empty_cache() + torch.cuda.empty_cache() @pytest.mark.dist @rerun_if_address_is_in_use() +@clear_cache_before_run() def test_bert(): spawn(check_bert, 2) diff --git a/tests/test_shardformer/test_model/test_shard_gpt2.py b/tests/test_shardformer/test_model/test_shard_gpt2.py index 2f679b83f99b..0c07f44401c7 100644 --- a/tests/test_shardformer/test_model/test_shard_gpt2.py +++ b/tests/test_shardformer/test_model/test_shard_gpt2.py @@ -1,117 +1,61 @@ -import copy -import os - import pytest import torch -from transformers import AutoTokenizer, GPT2Config, GPT2Model import colossalai from colossalai.logging import disable_existing_loggers -from colossalai.shardformer import ShardConfig, ShardFormer -from colossalai.testing import rerun_if_address_is_in_use, spawn - -os.environ['TRANSFORMERS_NO_ADVISORY_WARNINGS'] = 'true' -CONFIG = dict(parallel=dict(data=1, pipeline=1, tensor=dict(size=2, mode='1d')),) -tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased") - - -def build_model(world_size, model_fn): - config = GPT2Config() - config.attn_pdrop = 0 - config.embd_pdrop = 0 - config.resid_pdrop = 0 - config.summary_first_dropout - - org_model = model_fn(config=config) - org_model_forshard = copy.deepcopy(org_model) - - org_model.to('cuda') - # TODO: no need to transfer to cuda - org_model_forshard.to('cuda') - shard_config = ShardConfig(tensor_parallel_size=world_size,) - shard_former = ShardFormer(shard_config=shard_config) - shard_former.init_distributed() - sharded_model = shard_former.shard_model(org_model_forshard).to('cuda') - - return org_model, sharded_model - +from colossalai.testing import assert_hf_output_close, clear_cache_before_run, rerun_if_address_is_in_use, spawn +from tests.kit.model_zoo import model_zoo +from tests.test_shardformer.test_model._utils import build_model, run_forward -def check_forward(org_model, sharded_model): - input = 'Hello, my dog is cute' - tokenized_input = tokenizer(input, return_tensors='pt').to('cuda') - #orgin model - org_model.eval() - org_out = org_model(**tokenized_input) +def check_forward_backward(org_model, sharded_model, data_gen_fn, output_transform_fn, loss_fn): + # check forward + org_output, org_loss, shard_output, shard_loss = run_forward(org_model, sharded_model, data_gen_fn, + output_transform_fn, loss_fn) + assert_hf_output_close(org_output, shard_output, ignore_keys=['past_key_values']) - #shard model - sharded_model.eval() - shard_out = sharded_model(**tokenized_input) - - assert torch.allclose( - org_out[0], shard_out[0], - atol=1e-5), f"shard model output is not equal to orgin model output\n{org_out[0]}\n{shard_out[0]}" - - -def check_backward(org_model, sharded_model): - # prepare input - input = 'Hello, my dog is cute' - tokenized_input = tokenizer(input, return_tensors='pt').to('cuda') - labels = tokenized_input['input_ids'].clone() - labels[labels == tokenizer.pad_token_id] = -100 - # tokenized_input['labels'] = labels - - #orgin model - org_model.train() - org_out = org_model(**tokenized_input) - org_loss = org_out.loss + # do backward org_loss.backward() - org_grad = org_model.h[0].attn.c_attn.weight.grad - - #shard model - sharded_model.train() - shard_out = sharded_model(**tokenized_input) - shard_loss = shard_out.loss shard_loss.backward() - shard_grad = sharded_model.h[0].attn.c_attn.weight.grad + + # check grad equality + if org_model.__class__.__name__ == 'GPT2Model': + org_grad = org_model.h[0].attn.c_attn.weight.grad + shard_grad = sharded_model.h[0].attn.c_attn.weight.grad.transpose(0, 1).contiguous() + else: + org_grad = org_model.transformer.h[0].mlp.c_fc.weight.grad + shard_grad = sharded_model.transformer.h[0].mlp.c_fc.weight.grad.transpose(0, 1).contiguous() shard_grad_list = [torch.zeros([*shard_grad.shape]).to('cuda') for _ in range(2)] shard_grad = torch.distributed.all_gather(shard_grad_list, shard_grad) - all_shard_grad = torch.cat(shard_grad_list, dim=0) + all_shard_grad = torch.cat(shard_grad_list, dim=1) assert torch.allclose(org_loss, shard_loss, atol=1e-5), f"shard model loss is not equal to orgin model loss\n{org_loss}\n{shard_loss}" assert torch.allclose(org_grad, all_shard_grad, - atol=1e-5), f"shard model grad is not equal to orgin model grad\n{org_grad}\n{shard_grad}" + atol=1e-5), f"shard model grad is not equal to orgin model grad\n{org_grad}\n{all_shard_grad}" -def check_bert(rank, world_size, port): +def check_gpt2(rank, world_size, port): disable_existing_loggers() - colossalai.launch(config=CONFIG, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') - forward_list = [ - GPT2Model, + colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') - # TODO: do not work yet - # BertModel, - # BertForSequenceClassification - # BertForNextSentencePrediction, - ] - backward_lsit = [] - - for model_fn in forward_list: + sub_model_zoo = model_zoo.get_sub_registry('transformers_gpt') + for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items(): + print(name) + # if name == 'transformers_gpt': + # continue org_model, sharded_model = build_model(world_size, model_fn) - check_forward(org_model, sharded_model) - - if model_fn in backward_lsit: - check_backward(org_model, sharded_model) + check_forward_backward(org_model, sharded_model, data_gen_fn, output_transform_fn, loss_fn) - torch.cuda.empty_cache() + torch.cuda.empty_cache() @pytest.mark.dist @rerun_if_address_is_in_use() +@clear_cache_before_run() def test_gpt2(): - spawn(check_bert, 2) + spawn(check_gpt2, 2) if __name__ == "__main__": From 9436f739fedd2a26457aeca6301ae4e0097f1226 Mon Sep 17 00:00:00 2001 From: Frank Lee Date: Thu, 22 Jun 2023 11:42:11 +0800 Subject: [PATCH 32/49] [shardformer] support module saving and loading (#4062) * [shardformer] support module saving and loading * polish code --- colossalai/checkpoint_io/utils.py | 4 +- colossalai/lazy/lazy_init.py | 7 +- colossalai/shardformer/layer/embedding.py | 25 +- colossalai/shardformer/layer/linear.py | 30 +- colossalai/shardformer/layer/linear_conv.py | 2 - .../shardformer/layer/parallel_module.py | 142 ++++++++++ colossalai/tensor/d_tensor/__init__.py | 24 ++ colossalai/tensor/d_tensor/api.py | 263 ++++++++++++++++-- colossalai/tensor/d_tensor/layout.py | 3 +- .../tensor/d_tensor/layout_converter.py | 13 +- colossalai/tensor/d_tensor/utils.py | 2 +- test.py | 1 + tests/test_lazy/lazy_init_utils.py | 9 +- .../test_layer/test_embedding.py | 4 + .../test_layer/test_linear_1d.py | 12 +- .../test_vocab_parallel_embedding_1d.py | 6 +- .../test_dtensor/test_comm_spec.py | 3 - .../test_tensor/test_dtensor/test_dtensor.py | 43 ++- .../test_dtensor/test_layout_converter.py | 2 +- 19 files changed, 493 insertions(+), 102 deletions(-) create mode 100644 test.py diff --git a/colossalai/checkpoint_io/utils.py b/colossalai/checkpoint_io/utils.py index 3dada00cd9b5..68981dff0d0a 100644 --- a/colossalai/checkpoint_io/utils.py +++ b/colossalai/checkpoint_io/utils.py @@ -10,7 +10,7 @@ import torch.nn as nn from torch.optim import Optimizer -from colossalai.tensor.d_tensor.d_tensor import DTensor +from colossalai.tensor.d_tensor import is_distributed_tensor SAFE_WEIGHTS_NAME = "model.safetensors" WEIGHTS_NAME = "pytorch_model.bin" @@ -99,7 +99,7 @@ def shard_model_checkpoint(state_dict: torch.Tensor, max_shard_size: int = 1024) for key, weight in state_dict.items(): ret_block = None ret_block_size = 0 - if type(weight) != DTensor: + if is_distributed_tensor(weight): weight_size = calculate_tensor_size(weight) # If this weight is going to tip up over the maximal size, we split. diff --git a/colossalai/lazy/lazy_init.py b/colossalai/lazy/lazy_init.py index 76f550dc4392..1e45eced5f34 100644 --- a/colossalai/lazy/lazy_init.py +++ b/colossalai/lazy/lazy_init.py @@ -8,8 +8,9 @@ from torch.utils._pytree import tree_map from colossalai._analyzer._subclasses import MetaTensor -from colossalai.tensor.d_tensor.d_tensor import DTensor -from colossalai.tensor.d_tensor.layout import Layout +from colossalai.device.device_mesh import DeviceMesh +from colossalai.tensor.d_tensor import distribute_tensor +from colossalai.tensor.d_tensor.sharding_spec import ShardingSpec # reference: https://pytorch.org/cppdocs/notes/tensor_creation.html _NORMAL_FACTORY = [ @@ -183,7 +184,7 @@ def distribute(self, layout: Layout) -> torch.Tensor: """ target = self._materialize_data() self.clean() - local_tensor = DTensor(target, layout).local_tensor + local_tensor = distribute_tensor(target, device_mesh, sharding_spec) return _convert_cls(self, local_tensor) def clean(self) -> None: diff --git a/colossalai/shardformer/layer/embedding.py b/colossalai/shardformer/layer/embedding.py index 8b9fb03ec798..23601a04a27b 100644 --- a/colossalai/shardformer/layer/embedding.py +++ b/colossalai/shardformer/layer/embedding.py @@ -13,8 +13,7 @@ from colossalai.nn import init as init from colossalai.nn.layer.utils import divide -from colossalai.tensor.d_tensor.api import shard_colwise, shard_rowwise -from colossalai.utils.cuda import get_current_device +from colossalai.tensor.d_tensor.api import shard_colwise, shard_rowwise, sharded_tensor_to_param from ._operation import gather_forward_split_backward, reduce_input from .parallel_module import ParallelModule @@ -69,18 +68,17 @@ def __init__(self, self.num_embeddings = num_embeddings self.embedding_dim = embedding_dim self.process_group = process_group - self.num_partitions = dist.get_world_size(process_group) - self.embed_dim_per_partition = divide(embedding_dim, self.num_partitions) self.padding_idx = padding_idx self.embed_args = args self.embed_kwargs = kwargs self.gather_output = gather_output - if device is None: - device = get_current_device() - - self.weight = Parameter(torch.empty((num_embeddings, self.embed_dim_per_partition), device=device, dtype=dtype)) + # Parameters. + factory_kwargs = {'device': device, 'dtype': dtype} + weight = torch.empty((num_embeddings, self.embedding_dim), **factory_kwargs) + sharded_weight = shard_colwise(weight, process_group) + self.weight = sharded_tensor_to_param(sharded_weight) # offset the seed with randomizer index and rank seed = torch.random.initial_seed() @@ -194,7 +192,7 @@ def __init__(self, **kwargs): super().__init__() self.num_embeddings = num_embeddings - self.embed_dim = embedding_dim + self.embedding_dim = embedding_dim self.padding_idx = padding_idx self.embed_args = args self.embed_kwargs = kwargs @@ -208,8 +206,11 @@ def __init__(self, self.vocab_start_index = tensor_parallel_rank * self.num_embeddings_per_partition self.vocab_end_index = self.vocab_start_index + self.num_embeddings_per_partition - self.weight = Parameter( - torch.empty((self.num_embeddings_per_partition, self.embed_dim), device=device, dtype=dtype)) + # parameter + factory_kwargs = {'device': device, 'dtype': dtype} + weight = torch.empty((num_embeddings, self.embedding_dim), **factory_kwargs) + sharded_weight = shard_rowwise(weight, process_group) + self.weight = sharded_tensor_to_param(sharded_weight) # offset the seed with randomizer index and rank seed = torch.random.initial_seed() @@ -252,7 +253,7 @@ def from_native_module(module: nn.Embedding, process_group: Union[ProcessGroup, def reset_parameters(self, weight_initializer) -> None: with self.randomizer.fork_rng(enable_cpu=True): - fan_in, fan_out = self.num_embeddings, self.embed_dim + fan_in, fan_out = self.num_embeddings, self.embedding_dim weight_initializer(self.weight, fan_in=fan_in, fan_out=fan_out) self._fill_padding_idx_with_zero() diff --git a/colossalai/shardformer/layer/linear.py b/colossalai/shardformer/layer/linear.py index b87981c6db42..912be26b99ba 100644 --- a/colossalai/shardformer/layer/linear.py +++ b/colossalai/shardformer/layer/linear.py @@ -14,7 +14,7 @@ from colossalai.nn import init as init from colossalai.nn.layer.utils import divide -from colossalai.tensor.d_tensor.api import shard_colwise, shard_rowwise +from colossalai.tensor.d_tensor import shard_colwise, shard_rowwise, sharded_tensor_to_param from colossalai.utils.cuda import get_current_device from ._operation import ( @@ -76,22 +76,21 @@ def __init__(self, self.skip_bias_add = skip_bias_add self.device = device self.process_group = process_group - self.num_partitions = dist.get_world_size(self.process_group) if skip_bias_add and not bias: raise ValueError('cannot skip bias addition if bias is None') - self.out_features_per_partition = divide(out_features, self.num_partitions) - # Parameters. - # Initialize weight. - if device is None: - device = get_current_device() factory_kwargs = {'device': device, 'dtype': dtype} - self.weight = Parameter(torch.empty(self.out_features_per_partition, self.in_features, **factory_kwargs)) + + weight = torch.empty(self.out_features, self.in_features, **factory_kwargs) + sharded_weight = shard_rowwise(weight, self.process_group) + self.weight = sharded_tensor_to_param(sharded_weight) if bias: - self.bias = Parameter(torch.empty(self.out_features_per_partition, **factory_kwargs)) + bias = torch.empty(self.out_features, **factory_kwargs) + sharded_bias = shard_colwise(bias, self.process_group) + self.bias = sharded_tensor_to_param(sharded_bias) else: self.bias = None @@ -128,7 +127,6 @@ def from_native_module(module: nn.Linear, process_group: Union[ProcessGroup, Lis *args, **kwargs) - # TODO: copy the sharded weights with torch.no_grad(): # the weigh to the linear layer is a transpose # thus shard on row is equal to shard on column @@ -137,7 +135,6 @@ def from_native_module(module: nn.Linear, process_group: Union[ProcessGroup, Lis if bias: sharded_bias = shard_colwise(module.bias.data, process_group) linear_1d.bias.copy_(sharded_bias) - return linear_1d def reset_parameters(self, weight_initializer, bias_initializer) -> None: @@ -212,21 +209,20 @@ def __init__(self, self.parallel_input = parallel_input self.skip_bias_add = skip_bias_add self.process_group = process_group - self.num_partitions = dist.get_world_size(self.process_group) if skip_bias_add and not bias: raise ValueError('cannot skip bias addition if bias is None') - # Divide the weight matrix along the last dimension. - self.input_size_per_partition = divide(in_features, self.num_partitions) - # Parameters. # Initialize weight. if device is None: device = get_current_device() factory_kwargs = {'device': device, 'dtype': dtype} - self.weight = Parameter(torch.empty(self.out_features, self.input_size_per_partition, **factory_kwargs)) + + weight = torch.empty(self.out_features, self.in_features, **factory_kwargs) + sharded_weight = shard_colwise(weight, self.process_group) + self.weight = sharded_tensor_to_param(sharded_weight) if self.stream_chunk_num > 1: # TODO() work for inference only @@ -340,3 +336,5 @@ def forward(self, input_: Tensor) -> Tensor: return output else: return output, self.bias + return output, self.bias + return output, self.bias diff --git a/colossalai/shardformer/layer/linear_conv.py b/colossalai/shardformer/layer/linear_conv.py index b4599f48942d..2adfc182895e 100644 --- a/colossalai/shardformer/layer/linear_conv.py +++ b/colossalai/shardformer/layer/linear_conv.py @@ -31,7 +31,6 @@ class LinearConv1D_Col(ParallelModule): r"""Linear layer with column parallelism. - Specially created for HuggingFace's GPT2 model. The linear layer is defined as :math:`Y = XA + b`. A is parallelized along its second dimension as :math:`A = [A_1, ..., A_p]`. This layer is used to fit `Conv1D` layer in gpt2 of huggingface. @@ -189,7 +188,6 @@ def forward(self, input_: Tensor) -> Tuple[Tensor, Tensor]: class LinearConv1D_Row(ParallelModule): r""" Linear layer with row parallelism - Specially created for HuggingFace's GPT2 model. Args: in_features (int): size of each input sample. diff --git a/colossalai/shardformer/layer/parallel_module.py b/colossalai/shardformer/layer/parallel_module.py index c68cd57786ab..5edcb9dde748 100644 --- a/colossalai/shardformer/layer/parallel_module.py +++ b/colossalai/shardformer/layer/parallel_module.py @@ -1,11 +1,23 @@ #!/usr/bin/env python # -*- encoding: utf-8 -*- +import itertools from abc import ABC, abstractmethod from typing import List, Union +import torch import torch.nn as nn from torch.distributed import ProcessGroup +from torch.nn.modules.module import _EXTRA_STATE_KEY_SUFFIX, Module + +from colossalai.tensor.d_tensor import ( + distribute_tensor, + get_device_mesh, + get_sharding_spec, + is_distributed_tensor, + sharded_tensor_to_param, + to_global, +) __all__ = ['ParallelModule'] @@ -25,3 +37,133 @@ def from_native_module(module: nn.Module, in the ith axis of the device mesh. Defaults to None, which means the global process group. """ pass + + def _save_to_state_dict(self, destination, prefix, keep_vars): + r"""Saves module state to `destination` dictionary, containing a state + of the module, but not its descendants. This is called on every + submodule in :meth:`~torch.nn.Module.state_dict`. + + In rare cases, subclasses can achieve class-specific behavior by + overriding this method with custom logic. + + Args: + destination (dict): a dict where state will be stored + prefix (str): the prefix for parameters and buffers used in this + module + """ + for name, param in self._parameters.items(): + if param is not None: + param_ = param if keep_vars else param.detach() + + if is_distributed_tensor(param_): + destination[prefix + name] = to_global(param_) + else: + destination[prefix + name] = param_ + + for name, buf in self._buffers.items(): + if buf is not None and name not in self._non_persistent_buffers_set: + destination[prefix + name] = buf if keep_vars else buf.detach() + extra_state_key = prefix + _EXTRA_STATE_KEY_SUFFIX + if getattr(self.__class__, "get_extra_state", Module.get_extra_state) is not Module.get_extra_state: + destination[extra_state_key] = self.get_extra_state() + + def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, + error_msgs): + r"""Copies parameters and buffers from :attr:`state_dict` into only + this module, but not its descendants. This is called on every submodule + in :meth:`~torch.nn.Module.load_state_dict`. Metadata saved for this + module in input :attr:`state_dict` is provided as :attr:`local_metadata`. + For state dicts without metadata, :attr:`local_metadata` is empty. + Subclasses can achieve class-specific backward compatible loading using + the version number at `local_metadata.get("version", None)`. + + .. note:: + :attr:`state_dict` is not the same object as the input + :attr:`state_dict` to :meth:`~torch.nn.Module.load_state_dict`. So + it can be modified. + + Args: + state_dict (dict): a dict containing parameters and + persistent buffers. + prefix (str): the prefix for parameters and buffers used in this + module + local_metadata (dict): a dict containing the metadata for this module. + See + strict (bool): whether to strictly enforce that the keys in + :attr:`state_dict` with :attr:`prefix` match the names of + parameters and buffers in this module + missing_keys (list of str): if ``strict=True``, add missing keys to + this list + unexpected_keys (list of str): if ``strict=True``, add unexpected + keys to this list + error_msgs (list of str): error messages should be added to this + list, and will be reported together in + :meth:`~torch.nn.Module.load_state_dict` + """ + for hook in self._load_state_dict_pre_hooks.values(): + hook(state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs) + + persistent_buffers = {k: v for k, v in self._buffers.items() if k not in self._non_persistent_buffers_set} + local_name_params = itertools.chain(self._parameters.items(), persistent_buffers.items()) + local_state = {k: v for k, v in local_name_params if v is not None} + + for name, param in local_state.items(): + key = prefix + name + + if key in state_dict: + input_param = state_dict[key] + if not torch.overrides.is_tensor_like(input_param): + error_msgs.append('While copying the parameter named "{}", ' + 'expected torch.Tensor or Tensor-like object from checkpoint but ' + 'received {}'.format(key, type(input_param))) + continue + + if is_distributed_tensor(param): + # shard the input param + device_mesh = get_device_mesh(param) + sharding_spec = get_sharding_spec(param) + sharded_tensor = distribute_tensor(input_param, device_mesh, sharding_spec) + input_param = sharded_tensor_to_param(sharded_tensor) + + # This is used to avoid copying uninitialized parameters into + # non-lazy modules, since they dont have the hook to do the checks + # in such case, it will error when accessing the .shape attribute. + is_param_lazy = torch.nn.parameter.is_lazy(param) + # Backward compatibility: loading 1-dim tensor from 0.3.* to version 0.4+ + if not is_param_lazy and len(param.shape) == 0 and len(input_param.shape) == 1: + input_param = input_param[0] + + if not is_param_lazy and input_param.shape != param.shape: + # local shape should match the one in checkpoint + error_msgs.append('size mismatch for {}: copying a param with shape {} from checkpoint, ' + 'the shape in current model is {}.'.format(key, input_param.shape, param.shape)) + continue + + try: + with torch.no_grad(): + param.copy_(input_param) + except Exception as ex: + error_msgs.append('While copying the parameter named "{}", ' + 'whose dimensions in the model are {} and ' + 'whose dimensions in the checkpoint are {}, ' + 'an exception occurred : {}.'.format(key, param.size(), input_param.size(), + ex.args)) + elif strict: + missing_keys.append(key) + + extra_state_key = prefix + _EXTRA_STATE_KEY_SUFFIX + if getattr(self.__class__, "set_extra_state", Module.set_extra_state) is not Module.set_extra_state: + if extra_state_key in state_dict: + self.set_extra_state(state_dict[extra_state_key]) + elif strict: + missing_keys.append(extra_state_key) + elif strict and (extra_state_key in state_dict): + unexpected_keys.append(extra_state_key) + + if strict: + for key in state_dict.keys(): + if key.startswith(prefix) and key != extra_state_key: + input_name = key[len(prefix):] + input_name = input_name.split('.', 1)[0] # get the name of param/buffer/child + if input_name not in self._modules and input_name not in local_state: + unexpected_keys.append(key) diff --git a/colossalai/tensor/d_tensor/__init__.py b/colossalai/tensor/d_tensor/__init__.py index e69de29bb2d1..52eae0e14877 100644 --- a/colossalai/tensor/d_tensor/__init__.py +++ b/colossalai/tensor/d_tensor/__init__.py @@ -0,0 +1,24 @@ +from .api import ( + compute_global_numel, + distribute_tensor, + get_device_mesh, + get_global_shape, + get_layout, + get_sharding_spec, + is_distributed_tensor, + is_sharded, + redistribute, + shard_colwise, + shard_rowwise, + sharded_tensor_to_param, + to_global, +) +from .layout import Layout +from .sharding_spec import ShardingSpec + +__all__ = [ + 'is_distributed_tensor', 'distribute_tensor', 'to_global', 'is_sharded', 'shard_rowwise', 'shard_colwise', + 'sharded_tensor_to_param', 'compute_global_numel', 'get_sharding_spec', 'get_global_shape', 'get_device_mesh', + 'redistribute', 'get_layout' + 'Layout', 'ShardingSpec' +] diff --git a/colossalai/tensor/d_tensor/api.py b/colossalai/tensor/d_tensor/api.py index b58edadfef20..a38e5e6b7184 100644 --- a/colossalai/tensor/d_tensor/api.py +++ b/colossalai/tensor/d_tensor/api.py @@ -1,3 +1,6 @@ +import copy +import operator +from functools import reduce from typing import Union import torch @@ -6,13 +9,165 @@ from colossalai.device.device_mesh import DeviceMesh -from .d_tensor import DTensor +from .layout import Layout +from .layout_converter import LayoutConverter from .sharding_spec import ShardingSpec +layout_converter = LayoutConverter() -def shard_rowwise(tensor: torch.Tensor, - group_or_device_mesh: Union[ProcessGroup, DeviceMesh] = None, - inplace: bool = False) -> DTensor: + +def is_distributed_tensor(tensor: torch.Tensor) -> bool: + """ + Check whether the given tensor is a distributed tensor. + + Args: + tensor (torch.Tensor): The tensor to be checked. + + Returns: + bool: Whether the given tensor is a distributed tensor. + """ + return hasattr(tensor, "dist_layout") + + +def is_sharded(dtensor: torch.Tensor) -> bool: + """ + Check if a tensor is sharded. + + Args: + tensor (torch.Tensor): The tensor to be checked. + + Returns: + bool: True if the tensor is sharded, False otherwise. + """ + assert is_distributed_tensor(dtensor), 'The input tensor is not a distributed tensor.' + return list(dtensor.shape) == list(dtensor.dist_layout.global_shape) + + +def _hijack_detach_and_clone(dtensor: torch.Tensor) -> torch.Tensor: + """ + Hijack the detach and clone methods of the tensor to make sure the dist_layout is copied. + + Args: + tensor (torch.Tensor): The tensor to be hijacked. + + Returns: + torch.Tensor: The hijacked tensor. + """ + dtensor._old_detach = dtensor.detach + dtensor._old_clone = dtensor.clone + + def new_detach(self): + t_ = self._old_detach() + t_.dist_layout = copy.deepcopy(self.dist_layout) + return t_ + + def new_clone(self, *args, **kwargs): + t_ = self._old_clone(*args, **kwargs) + t_.dist_layout = copy.deepcopy(self.dist_layout) + return t_ + + # bind the new methods to the tensor + dtensor.detach = new_detach.__get__(dtensor) + dtensor.clone = new_clone.__get__(dtensor) + return dtensor + + +def _construct_default_sharding_spec(tensor: torch.Tensor,) -> ShardingSpec: + ''' + Construct the default sharding specification for the tensor. + + Args: + tensor (`torch.Tensor`): the tensor to be sharded. + + Returns: + A `ShardingSpec` object without any sharding specified. + ''' + return ShardingSpec(dim_size=tensor.dim(), dim_partition_dict={}) + + +def _apply_layout(tensor, layout): + ''' + Apply the layout to the local tensor during initializing process. + ''' + # layout converter requires a source and target laytout + # we construct the source layer for an unsharded tensor + # and use self.dist_layer as the targer layout for the sharded tensor + source_spec = _construct_default_sharding_spec(tensor) + source_layout = Layout(device_mesh=layout.device_mesh, sharding_spec=source_spec, global_shape=tensor.shape) + sharded_tensor = layout_converter.apply(tensor=tensor, source_layout=source_layout, target_layout=layout) + return sharded_tensor + + +def distribute_tensor(tensor: torch.Tensor, device_mesh: DeviceMesh, sharding_spec: ShardingSpec) -> torch.Tensor: + """ + Convert the given tensor to a distributed tensor. + + Args: + tensor (torch.Tensor): The tensor to be converted. + device_mesh (DeviceMesh): The device mesh for abstraction of the compute devices. + sharding_spec (ShardingSpec): The sharding specification which describes how the tensor will be sharded. + + Returns: + torch.Tensor: The distributed tensor. + """ + assert not is_distributed_tensor(tensor), 'The input tensor is already a distributed tensor.' + dist_layout = Layout(device_mesh=device_mesh, sharding_spec=sharding_spec, global_shape=tensor.shape) + + # shard tensor + sharded_tensor = _apply_layout(tensor, dist_layout) + + # hack some tensor methods + _hijack_detach_and_clone(sharded_tensor) + + return sharded_tensor + + +def redistribute(dtensor: torch.Tensor, device_mesh: DeviceMesh, sharding_spec: ShardingSpec) -> None: + ''' + Convert the layout of the tensor from source_spec to target_spec. + This will update the `local_tensor` and `dist_layout` in place. + + Args: + dtensor (torch.Tensor): the distributed tensor to be converted. + device_mesh (DeviceMesh): the device mesh for abstraction of the compute devices. + target_layout (Layout): the target layout specification. + ''' + assert is_distributed_tensor(dtensor), 'The input tensor is not a distributed tensor.' + global_shape = get_global_shape(dtensor) + target_layout = Layout(device_mesh=device_mesh, sharding_spec=sharding_spec, global_shape=global_shape) + resharded_tensor = layout_converter.apply(tensor=dtensor, + source_layout=dtensor.dist_layout, + target_layout=target_layout) + return resharded_tensor + + +def to_global(dtensor: torch.Tensor) -> torch.Tensor: + """ + Convert a distributed tensor to the global tensor with the given layout. + This function returns a native `torch.Tensor` object. + + Args: + dtensor (torch.Tensor): the distributed tensor to be converted. + + Returns: + torch.Tensor: the global tensor. + """ + assert is_distributed_tensor(dtensor), 'The input tensor is not a distributed tensor.' + layout_converter = LayoutConverter() + + global_sharding_spec = ShardingSpec(dtensor.dim(), {}) + device_mesh = get_device_mesh(dtensor) + global_shape = get_global_shape(dtensor) + global_layout = Layout(device_mesh=device_mesh, sharding_spec=global_sharding_spec, global_shape=global_shape) + + global_tensor = layout_converter.apply(dtensor, dtensor.dist_layout, global_layout) + return global_tensor + + +def shard_rowwise( + tensor: torch.Tensor, + group_or_device_mesh: Union[ProcessGroup, DeviceMesh] = None, +) -> torch.Tensor: """ Shard the first dim of the given tensor. @@ -24,7 +179,7 @@ def shard_rowwise(tensor: torch.Tensor, inplace (bool, optional): Whether to shard the tensor in-place. Defaults to False. Returns: - DTensor: The sharded tensor. + torch.Tensor: The sharded tensor. """ # if the group_or_device_mesh is None, we shard the tensor with respect to the global process group if group_or_device_mesh is None: @@ -35,17 +190,13 @@ def shard_rowwise(tensor: torch.Tensor, else: assert len(group_or_device_mesh.shape) == 1, 'Only 1D DeviceMesh is accepted for row-wise sharding.' device_mesh = group_or_device_mesh - sharding_spec = ShardingSpec(dim_size=tensor.dim(), dim_partition_dict={0: [0]}) - if not inplace: - tensor = tensor.detach().clone() + sharding_spec = ShardingSpec(dim_size=tensor.dim(), dim_partition_dict={0: [0]}) - return DTensor(tensor, device_mesh, sharding_spec) + return distribute_tensor(tensor, device_mesh, sharding_spec) -def shard_colwise(tensor: torch.Tensor, - group_or_device_mesh: Union[ProcessGroup, DeviceMesh] = None, - inplace: bool = False) -> DTensor: +def shard_colwise(tensor: torch.Tensor, group_or_device_mesh: Union[ProcessGroup, DeviceMesh] = None) -> torch.Tensor: """ Shard the first dim of the given tensor. @@ -57,7 +208,7 @@ def shard_colwise(tensor: torch.Tensor, inplace (bool, optional): Whether to shard the tensor in-place. Defaults to False. Returns: - DTensor: The sharded tensor. + torch.Tensor: The sharded tensor. """ # if the group_or_device_mesh is None, we shard the tensor with respect to the global process group if group_or_device_mesh is None: @@ -70,7 +221,87 @@ def shard_colwise(tensor: torch.Tensor, device_mesh = group_or_device_mesh sharding_spec = ShardingSpec(dim_size=tensor.dim(), dim_partition_dict={-1: [0]}) - if not inplace: - tensor = tensor.detach().clone() + return distribute_tensor(tensor, device_mesh, sharding_spec) + + +def sharded_tensor_to_param(dtensor: torch.Tensor, requires_grad: bool = True): + assert is_distributed_tensor(dtensor), 'The input tensor is not a distributed tensor.' + param = torch.nn.Parameter(dtensor, requires_grad=requires_grad) + + # make it distributed as well + param.dist_layout = dtensor.dist_layout + _hijack_detach_and_clone(param) + + return param + - return DTensor(tensor, device_mesh, sharding_spec) +def compute_global_numel(dtensor: torch.Tensor) -> int: + """ + Compute the global number of elements in the distributed tensor. + + Args: + dtensor (torch.Tensor): The distributed tensor. + + Returns: + int: The global number of elements in the distributed tensor. + """ + assert is_distributed_tensor(dtensor), 'The input tensor is not a distributed tensor.' + numel = reduce(operator.mul, dtensor.dist_layout.global_shape) + return numel + + +def get_layout(dtensor: torch.Tensor) -> Layout: + """ + Get the layout of the distributed tensor. + + Args: + dtensor (torch.Tensor): The distributed tensor. + + Returns: + Layout: The layout of the distributed tensor. + + """ + assert is_distributed_tensor(dtensor), 'The input tensor is not a distributed tensor.' + return dtensor.dist_layout + + +def get_global_shape(dtensor: torch.Tensor) -> torch.Size: + """ + Get the global shape of the distributed tensor. + + Args: + dtensor (torch.Tensor): The distributed tensor. + + Returns: + torch.Size: The global shape of the distributed tensor. + """ + assert is_distributed_tensor(dtensor), 'The input tensor is not a distributed tensor.' + return dtensor.dist_layout.global_shape + + +def get_device_mesh(dtensor: torch.Tensor) -> DeviceMesh: + """ + Get the device mesh of the distributed tensor. + + Args: + dtensor (torch.Tensor): The distributed tensor. + + Returns: + DeviceMesh: The device mesh of the distributed tensor. + """ + assert is_distributed_tensor(dtensor), 'The input tensor is not a distributed tensor.' + return dtensor.dist_layout.device_mesh + + +def get_sharding_spec(dtensor: torch.Tensor) -> ShardingSpec: + """ + Get the sharding spec of the distributed tensor. + + Args: + dtensor (torch.Tensor): The distributed tensor. + + Returns: + ShardingSpec: The sharding spec of the distributed tensor. + """ + assert is_distributed_tensor(dtensor), 'The input tensor is not a distributed tensor.' + return dtensor.dist_layout.sharding_spec diff --git a/colossalai/tensor/d_tensor/layout.py b/colossalai/tensor/d_tensor/layout.py index f15956ea3d52..4185b85860e3 100644 --- a/colossalai/tensor/d_tensor/layout.py +++ b/colossalai/tensor/d_tensor/layout.py @@ -1,12 +1,11 @@ import operator -from dataclasses import dataclass from functools import reduce import torch from colossalai.device.device_mesh import DeviceMesh -from .misc import DuplicatedShardingDimensionError, LayoutException, ShardingNotDivisibleError +from .misc import DuplicatedShardingDimensionError, ShardingNotDivisibleError from .sharding_spec import ShardingSpec diff --git a/colossalai/tensor/d_tensor/layout_converter.py b/colossalai/tensor/d_tensor/layout_converter.py index abc70e19a126..14f9c4561622 100644 --- a/colossalai/tensor/d_tensor/layout_converter.py +++ b/colossalai/tensor/d_tensor/layout_converter.py @@ -28,18 +28,6 @@ class LayoutConverterOptions: pass -def to_global(distributed_tensor: torch.Tensor, layout: Layout) -> torch.Tensor: - layout_converter = LayoutConverter() - global_sharding_spec = ShardingSpec(distributed_tensor.dim(), {}) - global_layout = Layout(device_mesh=layout.device_mesh, - device_type=layout.device_type, - sharding_spec=global_sharding_spec, - entire_shape=layout.entire_shape) - with torch.no_grad(): - global_tensor = layout_converter.apply(distributed_tensor, layout, global_layout) - return global_tensor - - def set_layout_converting_options(options: LayoutConverterOptions): """ Configure the shape consistency manager via function call. @@ -553,4 +541,5 @@ def apply(self, tensor: torch.Tensor, source_layout: Layout, target_layout: Layo _, comm_action_sequence = self.layout_converting(source_layout, target_layout) for comm_spec in comm_action_sequence: tensor = comm_spec.covert_spec_to_action(tensor) + tensor.dist_layout = target_layout return tensor diff --git a/colossalai/tensor/d_tensor/utils.py b/colossalai/tensor/d_tensor/utils.py index 644bb6306b42..fc22b990d879 100644 --- a/colossalai/tensor/d_tensor/utils.py +++ b/colossalai/tensor/d_tensor/utils.py @@ -29,7 +29,7 @@ def get_comm_cost(layout: Layout, comm_spec: CommSpec, forward_only: bool = Fals # the comm size for all gather is the size of the gathered tensor gather_dim = comm_spec.gather_dim all_gather_axis = layout.sharding_spec.dim_partition_dict[gather_dim][-1] - all_gather_size = device_mesh.mesh_shape[all_gather_axis] + all_gather_size = device_mesh.shape[all_gather_axis] comm_size_for_all_gather = comm_size * all_gather_size forward_communication_cost = device_mesh.all_gather_cost(comm_size_for_all_gather, logical_process_axis) # give a tiny cost to shard diff --git a/test.py b/test.py new file mode 100644 index 000000000000..f283e21a1ebd --- /dev/null +++ b/test.py @@ -0,0 +1 @@ +from colossalai.tensor.d_tensor.api import to_distributed_tensor diff --git a/tests/test_lazy/lazy_init_utils.py b/tests/test_lazy/lazy_init_utils.py index 2dd8d1ca3216..3879363bcd1b 100644 --- a/tests/test_lazy/lazy_init_utils.py +++ b/tests/test_lazy/lazy_init_utils.py @@ -7,7 +7,8 @@ from packaging import version from colossalai.lazy.lazy_init import LazyInitContext, LazyTensor, _MyTensor -from colossalai.tensor.d_tensor.layout_converter import to_global +from colossalai.tensor.d_tensor import to_global +from colossalai.tensor.d_tensor.layout import Layout from tests.kit.model_zoo.registry import ModelAttribute SUPPORT_LAZY = version.parse(torch.__version__) >= version.parse('1.12.0') @@ -91,6 +92,8 @@ def assert_dist_model_equal(model: torch.nn.Module, distributed_model: torch.nn. assert n1 == n2 t1 = t1.cuda() t2 = t2.cuda() - if n2 in layout_dict: - t2 = to_global(t2, layout_dict[n2]) + if n2 in sharding_spec_dict: + layout = Layout(device_mesh=device_mesh, sharding_spec=sharding_spec_dict[n2], global_shape=t1.shape) + t2.dist_layout = layout + t2 = to_global(t2) assert torch.equal(t1, t2), f'{n1} {t1} vs {t2}' diff --git a/tests/test_shardformer/test_layer/test_embedding.py b/tests/test_shardformer/test_layer/test_embedding.py index 70500008cfff..8a6aa42a42f2 100644 --- a/tests/test_shardformer/test_layer/test_embedding.py +++ b/tests/test_shardformer/test_layer/test_embedding.py @@ -14,6 +14,10 @@ def check_embedding_1d(): assert embedding_1d.weight.shape == torch.Size([32, 64]) + # ensure state dict is reversibly loadable + embedding.load_state_dict(embedding_1d.state_dict()) + embedding_1d.load_state_dict(embedding.state_dict()) + # check computation correctness x = torch.randint(low=0, high=32, size=(4, 32)).cuda() out = embedding(x) diff --git a/tests/test_shardformer/test_layer/test_linear_1d.py b/tests/test_shardformer/test_layer/test_linear_1d.py index 00ecc37ce2fa..a2b8bf22c0b2 100644 --- a/tests/test_shardformer/test_layer/test_linear_1d.py +++ b/tests/test_shardformer/test_layer/test_linear_1d.py @@ -5,6 +5,7 @@ import colossalai from colossalai.shardformer.layer import Linear1D_Col, Linear1D_Row +from colossalai.tensor.d_tensor import is_distributed_tensor from colossalai.testing import rerun_if_address_is_in_use, spawn @@ -12,9 +13,18 @@ def check_linear_1d_col(): linear = nn.Linear(32, 128).cuda() linear_col = Linear1D_Col.from_native_module(linear, process_group=None, gather_output=True) + # ensure that the parameters are distributed + assert is_distributed_tensor(linear_col.weight) + assert is_distributed_tensor(linear_col.bias) + + # ensure the shape is correct assert linear_col.weight.shape == torch.Size([64, 32]) assert linear_col.bias.shape == torch.Size([64]) + # ensure state dict is reversibly loadable + linear.load_state_dict(linear_col.state_dict()) + linear_col.load_state_dict(linear.state_dict()) + # check computation correctness x = torch.rand(4, 32).cuda() out = linear(x) @@ -55,7 +65,7 @@ def check_linear_1d_row(): def run_dist(rank, world_size, port): colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') check_linear_1d_col() - check_linear_1d_row() + # check_linear_1d_row() @rerun_if_address_is_in_use() diff --git a/tests/test_shardformer/test_layer/test_vocab_parallel_embedding_1d.py b/tests/test_shardformer/test_layer/test_vocab_parallel_embedding_1d.py index bee44a2fb109..8991d9b304f5 100644 --- a/tests/test_shardformer/test_layer/test_vocab_parallel_embedding_1d.py +++ b/tests/test_shardformer/test_layer/test_vocab_parallel_embedding_1d.py @@ -14,7 +14,11 @@ def check_vocab_embedding_1d(): assert dist_embedding_1d.weight.shape == torch.Size([64, 32]) assert dist_embedding_1d.num_embeddings == 64 - assert dist_embedding_1d.embed_dim == 32 + assert dist_embedding_1d.embedding_dim == 32 + + # ensure state dict is reversibly loadable + embedding.load_state_dict(dist_embedding_1d.state_dict()) + dist_embedding_1d.load_state_dict(embedding.state_dict()) # check embedding correctness x = torch.randint(0, 128, (4, 32)).to('cuda') diff --git a/tests/test_tensor/test_dtensor/test_comm_spec.py b/tests/test_tensor/test_dtensor/test_comm_spec.py index d1f5b9299397..958eabb65fac 100644 --- a/tests/test_tensor/test_dtensor/test_comm_spec.py +++ b/tests/test_tensor/test_dtensor/test_comm_spec.py @@ -1,14 +1,11 @@ import pytest import torch -import torch.distributed as dist -from torch.distributed import ReduceOp from colossalai.core import global_context as gpc from colossalai.device.device_mesh import DeviceMesh from colossalai.initialize import launch from colossalai.logging import disable_existing_loggers from colossalai.tensor.d_tensor.comm_spec import CollectiveCommPattern, CommSpec -from colossalai.tensor.d_tensor.sharding_spec import ShardingSpec from colossalai.testing import rerun_if_address_is_in_use, spawn diff --git a/tests/test_tensor/test_dtensor/test_dtensor.py b/tests/test_tensor/test_dtensor/test_dtensor.py index 3ca369acbf87..8350fb3e7fe6 100644 --- a/tests/test_tensor/test_dtensor/test_dtensor.py +++ b/tests/test_tensor/test_dtensor/test_dtensor.py @@ -3,9 +3,7 @@ from colossalai.device.device_mesh import DeviceMesh from colossalai.initialize import launch from colossalai.logging import disable_existing_loggers -from colossalai.tensor.d_tensor.d_tensor import DTensor, distribute_tensor -from colossalai.tensor.d_tensor.layout import Layout -from colossalai.tensor.d_tensor.sharding_spec import ShardingSpec +from colossalai.tensor.d_tensor import ShardingSpec, distribute_tensor, get_global_shape, redistribute, to_global from colossalai.testing import rerun_if_address_is_in_use, spawn @@ -31,22 +29,18 @@ def check_dtensor(rank, world_size, port): device_mesh = DeviceMesh(torch.Tensor([0, 1, 2, 3]), (2, 2), init_process_group=True) target_sharding_spec = ShardingSpec(dim_size=original_tensor.dim(), dim_partition_dict={0: [0]}) - layout = Layout(device_mesh=device_mesh, - device_type=torch.device('cuda'), - sharding_spec=target_sharding_spec, - entire_shape=original_tensor.shape) - d_tensor = DTensor(original_tensor, layout) + d_tensor = distribute_tensor(original_tensor, device_mesh, target_sharding_spec) - assert d_tensor.entire_shape == original_tensor.shape - assert d_tensor.data_type == original_tensor.dtype + assert get_global_shape(d_tensor) == original_tensor.shape + assert d_tensor.dtype == original_tensor.dtype if rank in (0, 1): - assert d_tensor.to_local().equal(original_tensor.narrow(0, 0, 2)) + assert d_tensor.equal(original_tensor.narrow(0, 0, 2)) elif rank in (2, 3): - assert d_tensor.to_local().equal(original_tensor.narrow(0, 2, 2)) + assert d_tensor.equal(original_tensor.narrow(0, 2, 2)) else: raise ValueError(f'rank {rank} is not in the device mesh') - assert d_tensor.to_global().equal(original_tensor) + assert to_global(d_tensor).equal(original_tensor) output = test_model(d_tensor) if rank in (0, 1): @@ -57,34 +51,29 @@ def check_dtensor(rank, world_size, port): raise ValueError(f'rank {rank} is not in the device mesh') new_sharding_spec = ShardingSpec(dim_size=original_tensor.dim(), dim_partition_dict={0: [0, 1]}) - new_layout = Layout(device_mesh=device_mesh, - device_type=torch.device('cuda'), - sharding_spec=new_sharding_spec, - entire_shape=original_tensor.shape) - - d_tensor.layout_convert(new_layout) + d_tensor = redistribute(d_tensor, device_mesh, new_sharding_spec) if rank == 0: - assert d_tensor.local_tensor.equal(original_tensor.narrow(0, 0, 1)) + assert d_tensor.equal(original_tensor.narrow(0, 0, 1)) elif rank == 1: - assert d_tensor.local_tensor.equal(original_tensor.narrow(0, 1, 1)) + assert d_tensor.equal(original_tensor.narrow(0, 1, 1)) elif rank == 2: - assert d_tensor.local_tensor.equal(original_tensor.narrow(0, 2, 1)) + assert d_tensor.equal(original_tensor.narrow(0, 2, 1)) elif rank == 3: - assert d_tensor.local_tensor.equal(original_tensor.narrow(0, 3, 1)) + assert d_tensor.equal(original_tensor.narrow(0, 3, 1)) else: raise ValueError(f'rank {rank} is not in the device mesh') dtensor_from_local = distribute_tensor(original_tensor, new_layout) if rank == 0: - assert dtensor_from_local.local_tensor.equal(original_tensor.narrow(0, 0, 1)) + assert dtensor_from_local.equal(original_tensor.narrow(0, 0, 1)) elif rank == 1: - assert dtensor_from_local.local_tensor.equal(original_tensor.narrow(0, 1, 1)) + assert dtensor_from_local.equal(original_tensor.narrow(0, 1, 1)) elif rank == 2: - assert dtensor_from_local.local_tensor.equal(original_tensor.narrow(0, 2, 1)) + assert dtensor_from_local.equal(original_tensor.narrow(0, 2, 1)) elif rank == 3: - assert dtensor_from_local.local_tensor.equal(original_tensor.narrow(0, 3, 1)) + assert dtensor_from_local.equal(original_tensor.narrow(0, 3, 1)) else: raise ValueError(f'rank {rank} is not in the device mesh') diff --git a/tests/test_tensor/test_dtensor/test_layout_converter.py b/tests/test_tensor/test_dtensor/test_layout_converter.py index 5c3da5f2b9ff..d9dff8af933d 100644 --- a/tests/test_tensor/test_dtensor/test_layout_converter.py +++ b/tests/test_tensor/test_dtensor/test_layout_converter.py @@ -9,7 +9,7 @@ from colossalai.tensor.d_tensor.comm_spec import CollectiveCommPattern from colossalai.tensor.d_tensor.layout import Layout from colossalai.tensor.d_tensor.layout_converter import LayoutConverter -from colossalai.tensor.d_tensor.sharding_spec import DimSpec, ShardingSpec +from colossalai.tensor.d_tensor.sharding_spec import ShardingSpec from colossalai.testing import rerun_if_address_is_in_use, spawn entire_shape = torch.Size((64, 32, 16)) From 8108c35c26c852954c610f7945eff20745be4e1a Mon Sep 17 00:00:00 2001 From: FoolPlayer <45593998+FoolPlayer@users.noreply.github.com> Date: Thu, 22 Jun 2023 14:40:37 +0800 Subject: [PATCH 33/49] [shardformer] add linearconv1d test (#4067) * add linearconv1d test * add linearconv1d test --- colossalai/shardformer/layer/linear_conv.py | 36 +++--- colossalai/shardformer/policies/gpt2.py | 10 +- .../test_layer/test_linearconv_1d.py | 107 ++++++++++++++++++ .../test_model/test_shard_gpt2.py | 3 - 4 files changed, 122 insertions(+), 34 deletions(-) create mode 100644 tests/test_shardformer/test_layer/test_linearconv_1d.py diff --git a/colossalai/shardformer/layer/linear_conv.py b/colossalai/shardformer/layer/linear_conv.py index 2adfc182895e..2d1dacf2cd39 100644 --- a/colossalai/shardformer/layer/linear_conv.py +++ b/colossalai/shardformer/layer/linear_conv.py @@ -103,10 +103,15 @@ def __init__(self, self.reset_parameters(weight_initializer, bias_initializer) @staticmethod - def from_native_module(module: nn.Linear, process_group: Union[ProcessGroup, List[ProcessGroup]], n_cast: int, + def from_native_module(module: nn.Linear, process_group: Union[ProcessGroup, List[ProcessGroup]], n_fused: int, *args, **kwargs) -> ParallelModule: r""" Convert a huggingface layer `Conv1D` in gpt2 to a parallelized linear layer. + + Args: + module (`nn.Linear`): The module to be converted. + process_group (`Union[ProcessGroup, List[ProcessGroup]]`): The process group to be used for weight sharding and communication. + n_fused (int): The number of layers to be fused. In GPT2, Q,K,V are fused in one weight. """ # get the attributes in_features = module.weight.shape[0] @@ -135,20 +140,20 @@ def from_native_module(module: nn.Linear, process_group: Union[ProcessGroup, Lis # first rearange the order of weight and bias world_size = dist.get_world_size(group=process_group) - order = torch.arange(world_size * n_cast) + order = torch.arange(world_size * n_fused) new_order = [] for i in range(world_size): new_order.append(order[i::world_size]) new_order = torch.cat(new_order) - weight_chunks = torch.chunk(module.weight.data, world_size * n_cast, dim=1) + weight_chunks = torch.chunk(module.weight.data, world_size * n_fused, dim=1) rearanged_weight_chunks = [weight_chunks[i] for i in new_order] rearanged_weight = torch.cat(rearanged_weight_chunks, dim=1) sharded_weight = shard_colwise(rearanged_weight, process_group) linear_1d.weight.data.copy_(sharded_weight.T.contiguous()) if bias: - bias_chunks = torch.chunk(module.bias.data, world_size * n_cast, dim=0) + bias_chunks = torch.chunk(module.bias.data, world_size * n_fused, dim=0) rearanged_bias_chunks = [bias_chunks[i] for i in new_order] rearanged_bias = torch.cat(rearanged_bias_chunks, dim=0) sharded_bias = shard_colwise(rearanged_bias, process_group) @@ -260,8 +265,8 @@ def __init__(self, self.reset_parameters(weight_initializer, bias_initializer) @staticmethod - def from_native_module(module: nn.Linear, process_group: Union[ProcessGroup, List[ProcessGroup]], n_cast: int, - *args, **kwargs) -> ParallelModule: + def from_native_module(module: nn.Linear, process_group: Union[ProcessGroup, List[ProcessGroup]], *args, + **kwargs) -> ParallelModule: r""" Convert a native PyTorch linear layer to a parallelized linear layer. """ @@ -289,26 +294,11 @@ def from_native_module(module: nn.Linear, process_group: Union[ProcessGroup, Lis with torch.no_grad(): # the weigh to the linear layer is a transpose # thus shard on col is equal to shard on row - - # first rearange the order of weight and bias - world_size = dist.get_world_size(group=process_group) - order = torch.arange(world_size * n_cast) - new_order = [] - for i in range(world_size): - new_order.append(order[i::world_size]) - new_order = torch.cat(new_order) - - weight_chunks = torch.chunk(module.weight.data, world_size * n_cast, dim=0) - rearanged_weight_chunks = [weight_chunks[i] for i in new_order] - rearanged_weight = torch.cat(rearanged_weight_chunks, dim=0) - sharded_weight = shard_rowwise(rearanged_weight, process_group) + sharded_weight = shard_rowwise(module.weight.data, process_group) linear_1d.weight.data.copy_(sharded_weight.T.contiguous()) if bias: - bias_chunks = torch.chunk(module.bias.data, world_size * n_cast, dim=0) - rearanged_bias_chunks = [bias_chunks[i] for i in new_order] - rearanged_bias = torch.cat(rearanged_bias_chunks, dim=0) - linear_1d.bias.copy_(rearanged_bias.contiguous()) + linear_1d.bias.copy_(module.bias.data) return linear_1d diff --git a/colossalai/shardformer/policies/gpt2.py b/colossalai/shardformer/policies/gpt2.py index 54ea2f6e3279..9d5d7d36aea3 100644 --- a/colossalai/shardformer/policies/gpt2.py +++ b/colossalai/shardformer/policies/gpt2.py @@ -44,29 +44,23 @@ def module_policy(self): suffix="attn.c_attn", target_module=col_nn.LinearConv1D_Col, kwargs={ - "n_cast": 3, + "n_fused": 3, }, ), SubModuleReplacementDescription( suffix="attn.c_proj", target_module=col_nn.LinearConv1D_Row, - kwargs={ - "n_cast": 1, - }, ), SubModuleReplacementDescription( suffix="mlp.c_fc", target_module=col_nn.LinearConv1D_Col, kwargs={ - "n_cast": 1, + "n_fused": 1, }, ), SubModuleReplacementDescription( suffix="mlp.c_proj", target_module=col_nn.LinearConv1D_Row, - kwargs={ - "n_cast": 1, - }, ), SubModuleReplacementDescription( suffix="attn.attn_dropout", diff --git a/tests/test_shardformer/test_layer/test_linearconv_1d.py b/tests/test_shardformer/test_layer/test_linearconv_1d.py new file mode 100644 index 000000000000..e0c97178d901 --- /dev/null +++ b/tests/test_shardformer/test_layer/test_linearconv_1d.py @@ -0,0 +1,107 @@ +import torch +import torch.distributed as dist +import torch.nn as nn +from torch.testing import assert_close + +import colossalai +from colossalai.shardformer.layer import LinearConv1D_Col, LinearConv1D_Row +from colossalai.testing import rerun_if_address_is_in_use, spawn + + +# This code is copied from https://github.com/huggingface/transformers +class Conv1D(nn.Module): + """ + 1D-convolutional layer as defined by Radford et al. for OpenAI GPT (and also used in GPT-2). + + Basically works like a linear layer but the weights are transposed. + + Args: + nf (`int`): The number of output features. + nx (`int`): The number of input features. + """ + + def __init__(self, nf, nx): + super().__init__() + self.nf = nf + self.weight = nn.Parameter(torch.empty(nx, nf)) + self.bias = nn.Parameter(torch.zeros(nf)) + nn.init.normal_(self.weight, std=0.02) + + def forward(self, x): + size_out = x.size()[:-1] + (self.nf,) + x = torch.addmm(self.bias, x.view(-1, x.size(-1)), self.weight) + x = x.view(size_out) + return x + + +def rearrange(tensor: torch.Tensor, dim: int): + tensor = tensor.clone() + world_size = 2 + order = torch.arange(world_size * 3) + new_order = [] + for i in range(world_size): + new_order.append(order[i::world_size]) + new_order = torch.cat(new_order) + + tensor_chunks = torch.chunk(tensor, world_size * 3, dim=dim) + rearanged_tensor_chunks = [tensor_chunks[i] for i in new_order] + rearanged_tensor = torch.cat(rearanged_tensor_chunks, dim=dim) + return rearanged_tensor + + +def check_linear_conv_1d_col(): + linear = Conv1D(192, 48).cuda() + linear_conv_col = LinearConv1D_Col.from_native_module(linear, process_group=None, gather_output=True, n_fused=3) + + assert linear_conv_col.weight.shape == torch.Size([96, 48]) + assert linear_conv_col.bias.shape == torch.Size([96]) + + # check computation correctness + x = torch.rand(4, 48).cuda() + out = linear(x) + gather_out = linear_conv_col(x) + assert_close(rearrange(out, 1), gather_out) + + # check backward correctness + out.sum().backward() + gather_out.sum().backward() + + rank = dist.get_rank() + target_grad = torch.chunk(linear.weight.grad, 2, dim=1)[rank] + assert_close(target_grad.transpose(0, 1).contiguous(), linear_conv_col.weight.grad) + + +def check_linear_1d_row(): + linear = Conv1D(192, 48).cuda() + linear_row = LinearConv1D_Row.from_native_module(linear, process_group=None, parallel_input=False) + + assert linear_row.weight.shape == torch.Size([192, 24]) + assert linear_row.bias.shape == torch.Size([192]) + + # check computation correctness + x = torch.rand(4, 48).cuda() + out = linear(x) + gather_out = linear_row(x) + assert_close(out, gather_out) + + # check backward correctness + out.sum().backward() + gather_out.sum().backward() + + rank = dist.get_rank() + target_grad = torch.chunk(linear.weight.grad, 2, dim=1)[rank] + assert_close(target_grad, linear_row.weight.grad) + + +def run_dist(rank, world_size, port): + colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') + check_linear_conv_1d_col() + + +@rerun_if_address_is_in_use() +def test_linearconv(): + spawn(run_dist, nprocs=2) + + +if __name__ == '__main__': + test_linearconv() diff --git a/tests/test_shardformer/test_model/test_shard_gpt2.py b/tests/test_shardformer/test_model/test_shard_gpt2.py index 0c07f44401c7..9aa02ec34d17 100644 --- a/tests/test_shardformer/test_model/test_shard_gpt2.py +++ b/tests/test_shardformer/test_model/test_shard_gpt2.py @@ -42,9 +42,6 @@ def check_gpt2(rank, world_size, port): sub_model_zoo = model_zoo.get_sub_registry('transformers_gpt') for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items(): - print(name) - # if name == 'transformers_gpt': - # continue org_model, sharded_model = build_model(world_size, model_fn) check_forward_backward(org_model, sharded_model, data_gen_fn, output_transform_fn, loss_fn) From a484c71a74f3a50343fd4d01bc45f2813e58793b Mon Sep 17 00:00:00 2001 From: Frank Lee Date: Fri, 23 Jun 2023 16:07:09 +0800 Subject: [PATCH 34/49] [shardformer] supported fused qkv checkpoint (#4073) --- colossalai/shardformer/layer/_operation.py | 86 +++++++++- colossalai/shardformer/layer/embedding.py | 4 +- colossalai/shardformer/layer/linear.py | 16 +- colossalai/shardformer/layer/linear_conv.py | 162 ++++++++++++------ .../shardformer/layer/parallel_module.py | 8 +- colossalai/tensor/d_tensor/__init__.py | 8 +- colossalai/tensor/d_tensor/api.py | 127 ++++++++++++++ .../test_layer/test_linear_1d.py | 64 ++++++- .../test_layer/test_linearconv_1d.py | 20 ++- .../test_model/test_shard_gpt2.py | 13 +- 10 files changed, 420 insertions(+), 88 deletions(-) diff --git a/colossalai/shardformer/layer/_operation.py b/colossalai/shardformer/layer/_operation.py index 280d5526342b..7e97bee01b33 100644 --- a/colossalai/shardformer/layer/_operation.py +++ b/colossalai/shardformer/layer/_operation.py @@ -1,5 +1,6 @@ import torch import torch.distributed as dist +import torch.nn.functional as F try: import fused_mix_prec_layer_norm_cuda @@ -46,7 +47,7 @@ def backward(ctx, grad_output): return grad_input, grad_weight, grad_bias, None, None -class LinearWithAsyncCommunication(torch.autograd.Function): +class MatmulWithAsyncCommunication(torch.autograd.Function): """ Linear layer execution with asynchronous communication in backprop. """ @@ -58,11 +59,59 @@ def forward(ctx, input_, weight, bias, process_group, async_grad_allreduce): ctx.process_group = process_group ctx.async_grad_allreduce = async_grad_allreduce - output = torch.matmul(input_, weight.t()) + output = torch.matmul(input_, weight) + if bias is not None: output = output + bias return output + @staticmethod + def backward(ctx, grad_output): + input, weight = ctx.saved_tensors + use_bias = ctx.use_bias + + total_input = input + grad_input = grad_output.matmul(weight.T) + grad_output = grad_output.contiguous() + # Convert the tensor shapes to 2D for execution compatibility + if len(grad_output.shape) > 2: + grad_output = grad_output.view(-1, grad_output.shape[-1]) + total_input = total_input.view(-1, total_input.shape[-1]) + + if ctx.async_grad_allreduce: + # Asynchronous all-reduce + handle = dist.all_reduce(grad_input, group=ctx.process_group, async_op=True) + # Delay the start of weight gradient computation shortly (3us) to have + # all-reduce scheduled first and have GPU resources allocated + _ = torch.empty(1, device=grad_output.device) + 1 + + grad_weight = total_input.t().matmul(grad_output) + grad_bias = grad_output.sum(dim=0) if use_bias else None + + if ctx.async_grad_allreduce: + handle.wait() + + return grad_input, grad_weight, grad_bias, None, None, None + + +class LinearWithAsyncCommunication(torch.autograd.Function): + """ + Linear layer execution with asynchronous communication in backprop. + """ + + @staticmethod + def forward(ctx, input_, weight, bias, process_group, async_grad_allreduce): + ctx.save_for_backward(input_, weight) + ctx.use_bias = bias is not None + ctx.process_group = process_group + ctx.async_grad_allreduce = async_grad_allreduce + + if bias is not None: + output = F.linear(input_, weight, bias) + else: + output = F.linear(input_, weight) + return output + @staticmethod def backward(ctx, grad_output): input, weight = ctx.saved_tensors @@ -114,7 +163,7 @@ def backward(ctx, grad_output): return _gather(grad_output, ctx.dim, ctx.process_group), None, None -class _ReduceInput(torch.autograd.Function): +class _ReduceForward(torch.autograd.Function): """ All-reduce the input from the model parallel region. @@ -132,6 +181,25 @@ def backward(ctx, grad_output): return grad_output, None +class _ReduceBackward(torch.autograd.Function): + """ + All-reduce the input from the model parallel region. + + Args: + input_: input matrix. + parallel_mode: parallel mode. + """ + + @staticmethod + def forward(ctx, input_, process_group): + ctx.process_group = process_group + return input_ + + @staticmethod + def backward(ctx, grad_output): + return _reduce(grad_output, ctx.process_group), None + + def _reduce(input_, process_group): # skip if only one rank involved if dist.get_world_size(process_group) == 1: @@ -198,6 +266,10 @@ def backward(ctx, grad_output): return _split(grad_output, ctx.dim, ctx.process_group), None, None +def matmul_with_async_comm(input_, weight, bias, process_group, async_grad_allreduce): + return MatmulWithAsyncCommunication.apply(input_, weight, bias, process_group, async_grad_allreduce) + + def linear_with_async_comm(input_, weight, bias, process_group, async_grad_allreduce): return LinearWithAsyncCommunication.apply(input_, weight, bias, process_group, async_grad_allreduce) @@ -210,5 +282,9 @@ def split_forward_gather_backward(input_, dim, process_group): return _SplitForwardGatherBackward.apply(input_, dim, process_group) -def reduce_input(input_, process_group): - return _ReduceInput.apply(input_, process_group) +def reduce_forward(input_, process_group): + return _ReduceForward.apply(input_, process_group) + + +def reduce_backward(input_, process_group): + return _ReduceBackward.apply(input_, process_group) diff --git a/colossalai/shardformer/layer/embedding.py b/colossalai/shardformer/layer/embedding.py index 23601a04a27b..db39a457b7fd 100644 --- a/colossalai/shardformer/layer/embedding.py +++ b/colossalai/shardformer/layer/embedding.py @@ -15,7 +15,7 @@ from colossalai.nn.layer.utils import divide from colossalai.tensor.d_tensor.api import shard_colwise, shard_rowwise, sharded_tensor_to_param -from ._operation import gather_forward_split_backward, reduce_input +from ._operation import gather_forward_split_backward, reduce_forward from .parallel_module import ParallelModule from .utils import create_randomizer_with_offset @@ -276,5 +276,5 @@ def forward(self, input_: Tensor) -> Tensor: # Mask the output embedding. output_parallel[input_mask, :] = 0. # Reduce across all the model parallel GPUs. - output = reduce_input(output_parallel, self.process_group) + output = reduce_forward(output_parallel, self.process_group) return output diff --git a/colossalai/shardformer/layer/linear.py b/colossalai/shardformer/layer/linear.py index 912be26b99ba..d952d5eecbee 100644 --- a/colossalai/shardformer/layer/linear.py +++ b/colossalai/shardformer/layer/linear.py @@ -15,12 +15,11 @@ from colossalai.nn import init as init from colossalai.nn.layer.utils import divide from colossalai.tensor.d_tensor import shard_colwise, shard_rowwise, sharded_tensor_to_param -from colossalai.utils.cuda import get_current_device from ._operation import ( gather_forward_split_backward, linear_with_async_comm, - reduce_input, + reduce_forward, split_forward_gather_backward, ) from .parallel_module import ParallelModule @@ -148,9 +147,10 @@ def forward(self, input_: Tensor) -> Tuple[Tensor, Tensor]: assert input_.shape[-1] == self.weight.shape[-1], \ 'Invalid shapes in Linear1D_Col forward: input={}, weight={}. Expected last dim of input {}.'.format( input_.shape, self.weight.shape, self.weight.shape[-1]) + # Set up backprop all-reduce. - # input_parallel = reduce_grad(input_, ParallelMode.PARALLEL_1D) input_parallel = input_ + # Matrix multiply. bias = self.bias if not self.skip_bias_add else None output_parallel = linear_with_async_comm(input_parallel, self.weight, bias, self.process_group, True) @@ -209,17 +209,14 @@ def __init__(self, self.parallel_input = parallel_input self.skip_bias_add = skip_bias_add self.process_group = process_group + self.num_partitions = dist.get_world_size(self.process_group) if skip_bias_add and not bias: raise ValueError('cannot skip bias addition if bias is None') # Parameters. # Initialize weight. - if device is None: - device = get_current_device() - factory_kwargs = {'device': device, 'dtype': dtype} - weight = torch.empty(self.out_features, self.in_features, **factory_kwargs) sharded_weight = shard_colwise(weight, self.process_group) self.weight = sharded_tensor_to_param(sharded_weight) @@ -327,8 +324,7 @@ def forward(self, input_: Tensor) -> Tensor: output = torch.cat(output_parallel_list, dim=-1) else: output_parallel = F.linear(input_, self.weight) - # output_parallel = linear_with_async_comm(input_, self.weight, None, ParallelMode.PARALLEL_1D, False) - output = reduce_input(output_parallel, self.process_group) + output = reduce_forward(output_parallel, self.process_group) if not self.skip_bias_add: if self.bias is not None: @@ -336,5 +332,3 @@ def forward(self, input_: Tensor) -> Tensor: return output else: return output, self.bias - return output, self.bias - return output, self.bias diff --git a/colossalai/shardformer/layer/linear_conv.py b/colossalai/shardformer/layer/linear_conv.py index 2d1dacf2cd39..e856abc14be6 100644 --- a/colossalai/shardformer/layer/linear_conv.py +++ b/colossalai/shardformer/layer/linear_conv.py @@ -14,13 +14,18 @@ from colossalai.nn import init as init from colossalai.nn.layer.utils import divide -from colossalai.tensor.d_tensor.api import shard_colwise, shard_rowwise -from colossalai.utils.cuda import get_current_device +from colossalai.tensor.d_tensor.api import ( + customized_distributed_tensor_to_param, + distribute_tensor_with_customization, + shard_rowwise, + sharded_tensor_to_param, +) from ._operation import ( gather_forward_split_backward, - linear_with_async_comm, - reduce_input, + matmul_with_async_comm, + reduce_backward, + reduce_forward, split_forward_gather_backward, ) from .parallel_module import ParallelModule @@ -29,11 +34,69 @@ __all__ = ['LinearConv1D_Col', 'LinearConv1D_Row'] +def split_fused_qkv(qkv: torch.Tensor, n_fused: int, process_group: ProcessGroup): + """ + The fused qkv tensor looks like [Q1, Q2, K1, K2, V1, V2], this function will split them into [Q1, K1, V1] and [Q2, K2, V2]. + """ + # get the number of slice for the fused qkv + rank = dist.get_rank(group=process_group) + world_size = dist.get_world_size(group=process_group) + order = torch.arange(world_size * n_fused) + + # split the fused qkv + # from + # [Q, K, V] + # to + # [Q1, Q2, K1, K2, V1, V2] + weight_chunks = torch.chunk(qkv, world_size * n_fused, dim=-1) + + # rearrange the slice into the final order + # from + # [Q1, Q2, K1, K2, V1, V2] + # to + # [Q1, K1, V1], [Q2, K2, V2] + weight_chunks_of_current_rank = [weight_chunks[i] for i in order[rank::world_size]] + weight_of_current_rank = torch.cat(weight_chunks_of_current_rank, dim=-1) + return weight_of_current_rank + + +def gather_fused_qkv(qkv: torch.Tensor, n_fused: int, process_group: ProcessGroup): + """ + The splitted qkv tensor looks like [Q1, K1, V1] and [Q2, K2, V2], this function will gather them into [Q1, Q2, K1, K2, V1, V2]. + """ + world_size = dist.get_world_size(group=process_group) + + # gather the tensors + # from + # [Q1, K1, V1], [Q2, K2, V2] + # to + # [Q1, K1, V1, Q2, K2, V2] + origin_device = qkv.device + qkv = qkv.cuda() + gather_list = [torch.zeros_like(qkv) for _ in range(world_size)] + dist.all_gather(gather_list, qkv, group=process_group) + gather_weight = torch.cat(gather_list, dim=-1) + gather_weight = gather_weight.to(origin_device) + qkv = qkv.to(origin_device) + + # rearrange the tensor slices + # from + # [Q1, K1, V1, Q2, K2, V2] + # to + # [Q1, Q2, K1, K2, V1, V2] + weight_chunks = torch.chunk(gather_weight, world_size * n_fused, dim=-1) + reordered_chunk_list = [] + for i in range(n_fused): + reordered_chunk_list.extend(weight_chunks[i::n_fused]) + reordered_gather_weight = torch.cat(reordered_chunk_list, dim=-1) + return reordered_gather_weight + + class LinearConv1D_Col(ParallelModule): r"""Linear layer with column parallelism. The linear layer is defined as :math:`Y = XA + b`. A is parallelized along - its second dimension as :math:`A = [A_1, ..., A_p]`. This layer is used to fit `Conv1D` layer in gpt2 of huggingface. + its second dimension as :math:`A = [A_1, ..., A_p]`. This layer is used to fit `Conv1D` layer (Fused QKV) in gpt2 of huggingface. Args: in_features (int): size of each input sample. @@ -41,6 +104,7 @@ class LinearConv1D_Col(ParallelModule): bias (bool, optional): If set to ``False``, the layer will not learn an additive bias, defaults to ``True``. dtype (`torch.dtype`): The dtype of parameters, defaults to None. device (`torch.device`): The device of parameters, defaults to None. + n_fused (int): The number items fused, defaults to 3 (QKV). process_group (`torch.distributed.ProcessGroup`): The process group to be used for weight sharding and communication, defaults to None. gather_output (bool, optional): If true, call all-gather on output and make Y available to all GPUs, otherwise, every GPU will have its output @@ -63,8 +127,10 @@ def __init__(self, dtype: torch.dtype = None, device: torch.device = None, process_group: ProcessGroup = None, + async_communication: bool = False, gather_output: bool = False, skip_bias_add: bool = False, + n_fused: int = 3, weight_initializer: Callable = init.kaiming_uniform_(a=math.sqrt(5)), bias_initializer: Callable = init.xavier_uniform_(a=1, scale=1)): super().__init__() @@ -75,23 +141,34 @@ def __init__(self, self.gather_output = gather_output self.skip_bias_add = skip_bias_add self.device = device + self.n_fused = n_fused self.process_group = process_group - self.num_partitions = dist.get_world_size(self.process_group) + self.async_communication = async_communication if skip_bias_add and not bias: raise ValueError('cannot skip bias addition if bias is None') - self.out_features_per_partition = divide(out_features, self.num_partitions) - # Parameters. # Initialize weight. - if device is None: - device = get_current_device() factory_kwargs = {'device': device, 'dtype': dtype} - self.weight = Parameter(torch.empty(self.out_features_per_partition, self.in_features, **factory_kwargs)) + weight = torch.empty(self.in_features, self.out_features, **factory_kwargs) + + def shard_fn(tensor): + return split_fused_qkv(tensor, self.n_fused, self.process_group) + + def gather_fn(tensor): + return gather_fused_qkv(tensor, 3, self.process_group) + + with torch.no_grad(): + sharded_weight = distribute_tensor_with_customization(weight, shard_fn, gather_fn) + self.weight = customized_distributed_tensor_to_param(sharded_weight) if bias: - self.bias = Parameter(torch.empty(self.out_features_per_partition, **factory_kwargs)) + bias = torch.empty(self.out_features, **factory_kwargs) + + with torch.no_grad(): + sharded_bias = distribute_tensor_with_customization(bias, shard_fn, gather_fn) + self.bias = customized_distributed_tensor_to_param(sharded_bias) else: self.bias = None @@ -103,7 +180,7 @@ def __init__(self, self.reset_parameters(weight_initializer, bias_initializer) @staticmethod - def from_native_module(module: nn.Linear, process_group: Union[ProcessGroup, List[ProcessGroup]], n_fused: int, + def from_native_module(module: nn.Module, process_group: Union[ProcessGroup, List[ProcessGroup]], n_fused: int, *args, **kwargs) -> ParallelModule: r""" Convert a huggingface layer `Conv1D` in gpt2 to a parallelized linear layer. @@ -135,29 +212,12 @@ def from_native_module(module: nn.Linear, process_group: Union[ProcessGroup, Lis # TODO: copy the sharded weights with torch.no_grad(): - # the weigh to the linear layer is a transpose - # thus shard on row is equal to shard on column - - # first rearange the order of weight and bias - world_size = dist.get_world_size(group=process_group) - order = torch.arange(world_size * n_fused) - new_order = [] - for i in range(world_size): - new_order.append(order[i::world_size]) - new_order = torch.cat(new_order) - - weight_chunks = torch.chunk(module.weight.data, world_size * n_fused, dim=1) - rearanged_weight_chunks = [weight_chunks[i] for i in new_order] - rearanged_weight = torch.cat(rearanged_weight_chunks, dim=1) - sharded_weight = shard_colwise(rearanged_weight, process_group) - linear_1d.weight.data.copy_(sharded_weight.T.contiguous()) + sharded_weight = split_fused_qkv(module.weight.data, n_fused=n_fused, process_group=process_group) + linear_1d.weight.data.copy_(sharded_weight.data) if bias: - bias_chunks = torch.chunk(module.bias.data, world_size * n_fused, dim=0) - rearanged_bias_chunks = [bias_chunks[i] for i in new_order] - rearanged_bias = torch.cat(rearanged_bias_chunks, dim=0) - sharded_bias = shard_colwise(rearanged_bias, process_group) - linear_1d.bias.copy_(sharded_bias.contiguous()) + sharded_bias = split_fused_qkv(module.bias.data, n_fused=n_fused, process_group=process_group) + linear_1d.bias.data.copy_(sharded_bias.data) return linear_1d @@ -169,15 +229,18 @@ def reset_parameters(self, weight_initializer, bias_initializer) -> None: bias_initializer(self.bias, fan_in=fan_in) def forward(self, input_: Tensor) -> Tuple[Tensor, Tensor]: - assert input_.shape[-1] == self.weight.shape[-1], \ + assert input_.shape[-1] == self.weight.shape[0], \ 'Invalid shapes in Linear1D_Col forward: input={}, weight={}. Expected last dim of input {}.'.format( input_.shape, self.weight.shape, self.weight.shape[-1]) # Set up backprop all-reduce. - # input_parallel = reduce_grad(input_, ParallelMode.PARALLEL_1D) - input_parallel = input_ + input_parallel = reduce_backward(input_, self.process_group) + # input_parallel = input_ + # Matrix multiply. bias = self.bias if not self.skip_bias_add else None - output_parallel = linear_with_async_comm(input_parallel, self.weight, bias, self.process_group, True) + + output_parallel = matmul_with_async_comm(input_parallel, self.weight, bias, self.process_group, + self.async_communication) if self.gather_output: # All-gather across the partitions. @@ -192,7 +255,8 @@ def forward(self, input_: Tensor) -> Tuple[Tensor, Tensor]: class LinearConv1D_Row(ParallelModule): - r""" Linear layer with row parallelism + r""" Linear layer with row parallelism. + This layer is used to fit `Conv1D` layer (Fused QKV) in gpt2 of huggingface. Args: in_features (int): size of each input sample. @@ -243,11 +307,10 @@ def __init__(self, # Parameters. # Initialize weight. - if device is None: - device = get_current_device() - factory_kwargs = {'device': device, 'dtype': dtype} - self.weight = Parameter(torch.empty(self.out_features, self.input_size_per_partition, **factory_kwargs)) + weight = torch.empty(self.in_features, self.out_features, **factory_kwargs) + sharded_weight = shard_rowwise(weight, self.process_group) + self.weight = sharded_tensor_to_param(sharded_weight) if self.stream_chunk_num > 1: # TODO() work for inference only @@ -295,7 +358,7 @@ def from_native_module(module: nn.Linear, process_group: Union[ProcessGroup, Lis # the weigh to the linear layer is a transpose # thus shard on col is equal to shard on row sharded_weight = shard_rowwise(module.weight.data, process_group) - linear_1d.weight.data.copy_(sharded_weight.T.contiguous()) + linear_1d.weight.data.copy_(sharded_weight.data) if bias: linear_1d.bias.copy_(module.bias.data) @@ -325,12 +388,12 @@ def reset_parameters(self, weight_initializer, bias_initializer) -> None: def forward(self, input_: Tensor) -> Tensor: # Set up backprop all-reduce. if self.parallel_input: - assert input_.shape[-1] == self.weight.shape[-1], \ + assert input_.shape[-1] == self.weight.shape[0], \ 'Invalid shapes in Linear1D_Row forward: input={}, weight={}. Expected last dim of input {}.'.format( input_.shape, self.weight.shape, self.weight.shape[-1]) input_ = input_ else: - assert divide(input_.shape[-1], self.num_partitions) == self.weight.shape[-1], \ + assert divide(input_.shape[-1], self.num_partitions) == self.weight.shape[0], \ 'Invalid shapes in Linear1D_Row forward: input={}, weight={}. Expected last dim of input {}.'.format( input_.shape, self.weight.shape, self.weight.shape[-1] * self.num_partitions) input_ = split_forward_gather_backward(input_, dim=-1, process_group=self.process_group) @@ -342,7 +405,7 @@ def forward(self, input_: Tensor) -> Tensor: output_parallel_list = [None for i in range(self.stream_chunk_num)] handle_list = [] for i in range(self.stream_chunk_num): - output_parallel_list[i] = F.linear(input_, self.weight_list[i]) + output_parallel_list[i] = torch.matmul(input_, self.weight_list[i]) handle = torch.distributed.all_reduce(output_parallel_list[i], group=self.process_group, async_op=True) @@ -352,9 +415,8 @@ def forward(self, input_: Tensor) -> Tensor: handle.wait() output = torch.cat(output_parallel_list, dim=-1) else: - output_parallel = F.linear(input_, self.weight) - # output_parallel = linear_with_async_comm(input_, self.weight, None, ParallelMode.PARALLEL_1D, False) - output = reduce_input(output_parallel, self.process_group) + output_parallel = torch.matmul(input_, self.weight) + output = reduce_forward(output_parallel, self.process_group) if not self.skip_bias_add: if self.bias is not None: diff --git a/colossalai/shardformer/layer/parallel_module.py b/colossalai/shardformer/layer/parallel_module.py index 5edcb9dde748..bda147b121ab 100644 --- a/colossalai/shardformer/layer/parallel_module.py +++ b/colossalai/shardformer/layer/parallel_module.py @@ -12,11 +12,14 @@ from colossalai.tensor.d_tensor import ( distribute_tensor, + distribute_tensor_with_customization, get_device_mesh, get_sharding_spec, + is_customized_distributed_tensor, is_distributed_tensor, sharded_tensor_to_param, to_global, + to_global_for_customized_distributed_tensor, ) __all__ = ['ParallelModule'] @@ -54,9 +57,10 @@ def _save_to_state_dict(self, destination, prefix, keep_vars): for name, param in self._parameters.items(): if param is not None: param_ = param if keep_vars else param.detach() - if is_distributed_tensor(param_): destination[prefix + name] = to_global(param_) + elif is_customized_distributed_tensor(param_): + destination[prefix + name] = to_global_for_customized_distributed_tensor(param_) else: destination[prefix + name] = param_ @@ -124,6 +128,8 @@ def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict, miss sharding_spec = get_sharding_spec(param) sharded_tensor = distribute_tensor(input_param, device_mesh, sharding_spec) input_param = sharded_tensor_to_param(sharded_tensor) + elif is_customized_distributed_tensor(param): + input_param = distribute_tensor_with_customization(input_param, param.shard_fn, param.gather_fn) # This is used to avoid copying uninitialized parameters into # non-lazy modules, since they dont have the hook to do the checks diff --git a/colossalai/tensor/d_tensor/__init__.py b/colossalai/tensor/d_tensor/__init__.py index 52eae0e14877..3ae38a12555b 100644 --- a/colossalai/tensor/d_tensor/__init__.py +++ b/colossalai/tensor/d_tensor/__init__.py @@ -1,10 +1,13 @@ from .api import ( compute_global_numel, + customized_distributed_tensor_to_param, distribute_tensor, + distribute_tensor_with_customization, get_device_mesh, get_global_shape, get_layout, get_sharding_spec, + is_customized_distributed_tensor, is_distributed_tensor, is_sharded, redistribute, @@ -12,6 +15,7 @@ shard_rowwise, sharded_tensor_to_param, to_global, + to_global_for_customized_distributed_tensor, ) from .layout import Layout from .sharding_spec import ShardingSpec @@ -19,6 +23,6 @@ __all__ = [ 'is_distributed_tensor', 'distribute_tensor', 'to_global', 'is_sharded', 'shard_rowwise', 'shard_colwise', 'sharded_tensor_to_param', 'compute_global_numel', 'get_sharding_spec', 'get_global_shape', 'get_device_mesh', - 'redistribute', 'get_layout' - 'Layout', 'ShardingSpec' + 'redistribute', 'get_layout', 'is_customized_distributed_tensor', 'distribute_tensor_with_customization', + 'to_global_for_customized_distributed_tensor', 'customized_distributed_tensor_to_param', 'Layout', 'ShardingSpec' ] diff --git a/colossalai/tensor/d_tensor/api.py b/colossalai/tensor/d_tensor/api.py index a38e5e6b7184..95a44e09e16a 100644 --- a/colossalai/tensor/d_tensor/api.py +++ b/colossalai/tensor/d_tensor/api.py @@ -305,3 +305,130 @@ def get_sharding_spec(dtensor: torch.Tensor) -> ShardingSpec: """ assert is_distributed_tensor(dtensor), 'The input tensor is not a distributed tensor.' return dtensor.dist_layout.sharding_spec + + +# ====================================================== +# Some sharding does not obey the SPMD style +# e.g. Fused QKV layer in GPT2 +# we support customize sharding with the following APIs +# ====================================================== +def is_customized_distributed_tensor(tensor: torch.Tensor): + """ + Check whether the given tensor is a customized distributed tensor. + + Args: + tensor (torch.Tensor): The tensor to be checked. + + Returns: + bool: Whether the given tensor is a customized distributed tensor. + """ + return hasattr(tensor, 'shard_fn') and hasattr(tensor, 'gather_fn') + + +def _hijack_detach_and_clone_for_customized_distributed_tensor(dtensor: torch.Tensor) -> torch.Tensor: + """ + Hijack the detach and clone methods of the tensor to make sure the dist_layout is copied. + + Args: + tensor (torch.Tensor): The tensor to be hijacked. + + Returns: + torch.Tensor: The hijacked tensor. + """ + dtensor._old_detach = dtensor.detach + dtensor._old_clone = dtensor.clone + + def new_detach(self): + t_ = self._old_detach() + t_.shard_fn = self.shard_fn + t_.gather_fn = self.gather_fn + return t_ + + def new_clone(self, *args, **kwargs): + t_ = self._old_clone(*args, **kwargs) + t_.shard_fn = self.shard_fn + t_.gather_fn = self.gather_fn + return t_ + + # bind the new methods to the tensor + dtensor.detach = new_detach.__get__(dtensor) + dtensor.clone = new_clone.__get__(dtensor) + return dtensor + + +def distribute_tensor_with_customization(tensor: torch.Tensor, shard_fn, gather_fn: callable): + """ + Distribute the given tensor with the given shard_fn and gather_fn. + + Example: + + ```python + # define shard and gather functions + def shard_fn(tensor): + rank = torch.distributed.get_rank() + world_size = torch.distributed.get_world_size() + return tensor.chunk(world_size, dim=0)[rank] + + def gather_fn(tensor): + rank = torch.distributed.get_rank() + world_size = torch.distributed.get_world_size() + shard_list = [torch.zeros_like(tensor) for _ in range(world_size)] + torch.distributed.all_gather(shard_list, tensor) + return torch.cat(shard_list, dim=0) + + # create a distributed tensor + tensor = torch.rand(4, 4) + dtensor = distribute_tensor_with_customization(tensor, shard_fn, gather_fn) + ``` + + Args: + tensor (torch.Tensor): The tensor to be distributed. + shard_fn (callable): The function to shard the tensor. + gather_fn (callable): The function to gather the tensor. + + Returns: + torch.Tensor: The distributed tensor. + """ + assert callable(shard_fn), 'The shard_fn must be callable.' + assert callable(gather_fn), 'The gather_fn must be callable.' + assert not is_distributed_tensor(tensor), 'The input tensor is already a distributed tensor.' + + sharded_tensor = shard_fn(tensor) + + # set the shard_fn and gather_fn as attributes of the distributed tensor + sharded_tensor.shard_fn = shard_fn + sharded_tensor.gather_fn = gather_fn + + # set the shard_fn and gather_fn as attributes of the distributed tensor + _hijack_detach_and_clone_for_customized_distributed_tensor(sharded_tensor) + + return sharded_tensor + + +def to_global_for_customized_distributed_tensor(dtensor: torch.Tensor) -> torch.Tensor: + """ + Gather the given tensor to the global tensor. + + Args: + dtensor (torch.Tensor): The distributed tensor. + + Returns: + torch.Tensor: The global tensor. + """ + assert is_customized_distributed_tensor(dtensor), 'The input tensor is not a customized distributed tensor.' + return dtensor.gather_fn(dtensor) + + +def customized_distributed_tensor_to_param(dtensor: torch.Tensor, requires_grad: bool = True): + """ + Convert the given customized distributed tensor to a parameter. + """ + assert is_customized_distributed_tensor(dtensor), 'The input tensor is not a customized distributed tensor.' + + param = torch.nn.Parameter(dtensor, requires_grad=requires_grad) + + # make it distributed as well + param.shard_fn = dtensor.shard_fn + param.gather_fn = dtensor.gather_fn + _hijack_detach_and_clone_for_customized_distributed_tensor(param) + return param diff --git a/tests/test_shardformer/test_layer/test_linear_1d.py b/tests/test_shardformer/test_layer/test_linear_1d.py index a2b8bf22c0b2..da3bdc1d78d3 100644 --- a/tests/test_shardformer/test_layer/test_linear_1d.py +++ b/tests/test_shardformer/test_layer/test_linear_1d.py @@ -27,8 +27,13 @@ def check_linear_1d_col(): # check computation correctness x = torch.rand(4, 32).cuda() - out = linear(x) - gather_out = linear_col(x) + x_for_unshard = x.expand_as(x.clone()) + x_for_unshard.requires_grad_(True) + x_for_shard = x.expand_as(x.clone()) + x_for_shard.requires_grad_(True) + + out = linear(x_for_unshard) + gather_out = linear_col(x_for_shard) assert_close(out, gather_out) # check backward correctness @@ -39,6 +44,11 @@ def check_linear_1d_col(): target_grad = torch.chunk(linear.weight.grad, 2, dim=0)[rank] assert_close(target_grad, linear_col.weight.grad) + # check the input gradients + assert x_for_shard.grad is not None + assert x_for_unshard.grad is not None + assert_close(x_for_unshard.grad, x_for_shard.grad) + def check_linear_1d_row(): linear = nn.Linear(32, 128).cuda() @@ -49,8 +59,14 @@ def check_linear_1d_row(): # check computation correctness x = torch.rand(4, 32).cuda() - out = linear(x) - gather_out = linear_row(x) + x_for_unshard = x.expand_as(x.clone()) + x_for_unshard.requires_grad_(True) + x_for_shard = x.expand_as(x.clone()) + x_for_shard.requires_grad_(True) + + # run forward + out = linear(x_for_unshard) + gather_out = linear_row(x_for_shard) assert_close(out, gather_out) # check backward correctness @@ -61,11 +77,49 @@ def check_linear_1d_row(): target_grad = torch.chunk(linear.weight.grad, 2, dim=1)[rank] assert_close(target_grad, linear_row.weight.grad) + # check the input gradients + assert x_for_shard.grad is not None + assert x_for_unshard.grad is not None + assert_close(x_for_unshard.grad, x_for_shard.grad) + + +def check_linear_col_plus_row(): + linear_1 = nn.Linear(32, 128).cuda() + linear_2 = nn.Linear(128, 32).cuda() + linear_col = Linear1D_Col.from_native_module(linear_1, process_group=None, gather_output=False) + linear_row = Linear1D_Row.from_native_module(linear_2, process_group=None, parallel_input=True) + + # check computation correctness + x = torch.rand(4, 32).cuda() + x_for_unshard = x.expand_as(x.clone()) + x_for_unshard.requires_grad_(True) + x_for_shard = x.expand_as(x.clone()) + x_for_shard.requires_grad_(True) + + # run forward + unshard_out = linear_2(linear_1(x_for_unshard)) + shard_out = linear_row(linear_col(x_for_shard)) + assert_close(unshard_out, shard_out) + + # check backward correctness + unshard_out.sum().backward() + shard_out.sum().backward() + + rank = dist.get_rank() + target_1_grad = torch.chunk(linear_1.weight.grad, 2, dim=0)[rank] + assert_close(target_1_grad, linear_col.weight.grad) + + # check the input gradients + assert x_for_shard.grad is not None + assert x_for_unshard.grad is not None + assert_close(x_for_unshard.grad, x_for_shard.grad) + def run_dist(rank, world_size, port): colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') check_linear_1d_col() - # check_linear_1d_row() + check_linear_1d_row() + check_linear_col_plus_row() @rerun_if_address_is_in_use() diff --git a/tests/test_shardformer/test_layer/test_linearconv_1d.py b/tests/test_shardformer/test_layer/test_linearconv_1d.py index e0c97178d901..efdb88351519 100644 --- a/tests/test_shardformer/test_layer/test_linearconv_1d.py +++ b/tests/test_shardformer/test_layer/test_linearconv_1d.py @@ -5,6 +5,7 @@ import colossalai from colossalai.shardformer.layer import LinearConv1D_Col, LinearConv1D_Row +from colossalai.shardformer.layer.linear_conv import split_fused_qkv from colossalai.testing import rerun_if_address_is_in_use, spawn @@ -53,9 +54,15 @@ def check_linear_conv_1d_col(): linear = Conv1D(192, 48).cuda() linear_conv_col = LinearConv1D_Col.from_native_module(linear, process_group=None, gather_output=True, n_fused=3) - assert linear_conv_col.weight.shape == torch.Size([96, 48]) + assert linear.weight.shape == torch.Size([48, 192]) + assert linear.bias.shape == torch.Size([192]) + assert linear_conv_col.weight.shape == torch.Size([48, 96]) assert linear_conv_col.bias.shape == torch.Size([96]) + # ensure weights are reversibly loadable + linear_conv_col.load_state_dict(linear.state_dict()) + linear.load_state_dict(linear_conv_col.state_dict()) + # check computation correctness x = torch.rand(4, 48).cuda() out = linear(x) @@ -66,16 +73,16 @@ def check_linear_conv_1d_col(): out.sum().backward() gather_out.sum().backward() - rank = dist.get_rank() - target_grad = torch.chunk(linear.weight.grad, 2, dim=1)[rank] - assert_close(target_grad.transpose(0, 1).contiguous(), linear_conv_col.weight.grad) + target_grad = split_fused_qkv(linear.weight.grad, 3, None) + assert_close(target_grad, linear_conv_col.weight.grad) def check_linear_1d_row(): linear = Conv1D(192, 48).cuda() linear_row = LinearConv1D_Row.from_native_module(linear, process_group=None, parallel_input=False) - assert linear_row.weight.shape == torch.Size([192, 24]) + assert linear.weight.shape == torch.Size([48, 192]) + assert linear_row.weight.shape == torch.Size([24, 192]) assert linear_row.bias.shape == torch.Size([192]) # check computation correctness @@ -89,13 +96,14 @@ def check_linear_1d_row(): gather_out.sum().backward() rank = dist.get_rank() - target_grad = torch.chunk(linear.weight.grad, 2, dim=1)[rank] + target_grad = torch.chunk(linear.weight.grad, 2, dim=0)[rank] assert_close(target_grad, linear_row.weight.grad) def run_dist(rank, world_size, port): colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') check_linear_conv_1d_col() + check_linear_1d_row() @rerun_if_address_is_in_use() diff --git a/tests/test_shardformer/test_model/test_shard_gpt2.py b/tests/test_shardformer/test_model/test_shard_gpt2.py index 9aa02ec34d17..676267c2ca2a 100644 --- a/tests/test_shardformer/test_model/test_shard_gpt2.py +++ b/tests/test_shardformer/test_model/test_shard_gpt2.py @@ -20,20 +20,21 @@ def check_forward_backward(org_model, sharded_model, data_gen_fn, output_transfo # check grad equality if org_model.__class__.__name__ == 'GPT2Model': - org_grad = org_model.h[0].attn.c_attn.weight.grad - shard_grad = sharded_model.h[0].attn.c_attn.weight.grad.transpose(0, 1).contiguous() + org_grad = org_model.h[0].mlp.c_fc.weight.grad + shard_grad = sharded_model.h[0].mlp.c_fc.weight.grad else: org_grad = org_model.transformer.h[0].mlp.c_fc.weight.grad - shard_grad = sharded_model.transformer.h[0].mlp.c_fc.weight.grad.transpose(0, 1).contiguous() + shard_grad = sharded_model.transformer.h[0].mlp.c_fc.weight.grad shard_grad_list = [torch.zeros([*shard_grad.shape]).to('cuda') for _ in range(2)] shard_grad = torch.distributed.all_gather(shard_grad_list, shard_grad) all_shard_grad = torch.cat(shard_grad_list, dim=1) assert torch.allclose(org_loss, shard_loss, - atol=1e-5), f"shard model loss is not equal to orgin model loss\n{org_loss}\n{shard_loss}" - assert torch.allclose(org_grad, all_shard_grad, - atol=1e-5), f"shard model grad is not equal to orgin model grad\n{org_grad}\n{all_shard_grad}" + atol=1e-5), f"shard model loss is not equal to origin model loss\n{org_loss}\n{shard_loss}" + assert torch.allclose( + org_grad, all_shard_grad, + atol=1e-5), f"shard model grad is not equal to origin model grad\n{org_grad}\n{all_shard_grad}" def check_gpt2(rank, world_size, port): From 12801e8cbd393333abb30c9784fb5e6c913d69df Mon Sep 17 00:00:00 2001 From: FoolPlayer <45593998+FoolPlayer@users.noreply.github.com> Date: Fri, 23 Jun 2023 18:00:22 +0800 Subject: [PATCH 35/49] [shardformer] Add layernorm (#4072) * add layernorm to bert * add layernorm test * add layernorm test with load state dict * add use_mixedfusedLN in shard config * refactor policy to support fused_layernorm --- colossalai/shardformer/layer/__init__.py | 3 +- colossalai/shardformer/layer/layernorm.py | 89 +++++++++++++ colossalai/shardformer/policies/bert.py | 122 ++++++++++++++++-- colossalai/shardformer/shard/shard_config.py | 4 +- .../test_layer/test_layernorm.py | 45 +++++++ .../test_layer/test_linearconv_1d.py | 4 +- tests/test_shardformer/test_model/_utils.py | 2 +- 7 files changed, 252 insertions(+), 17 deletions(-) create mode 100644 colossalai/shardformer/layer/layernorm.py create mode 100644 tests/test_shardformer/test_layer/test_layernorm.py diff --git a/colossalai/shardformer/layer/__init__.py b/colossalai/shardformer/layer/__init__.py index 808ebbc12aeb..3ce0ef68aa4f 100644 --- a/colossalai/shardformer/layer/__init__.py +++ b/colossalai/shardformer/layer/__init__.py @@ -1,10 +1,11 @@ from .dropout import Dropout1D from .embedding import Embedding1D, VocabParallelEmbedding1D +from .layernorm import LayerNorm1D from .linear import Linear1D_Col, Linear1D_Row from .linear_conv import LinearConv1D_Col, LinearConv1D_Row from .loss import cross_entropy_1d __all__ = [ "Embedding1D", "VocabParallelEmbedding1D", "Linear1D_Col", "Linear1D_Row", "LinearConv1D_Col", "LinearConv1D_Row", - "Dropout1D", "cross_entropy_1d" + "Dropout1D", "cross_entropy_1d", 'LayerNorm1D' ] diff --git a/colossalai/shardformer/layer/layernorm.py b/colossalai/shardformer/layer/layernorm.py new file mode 100644 index 000000000000..a8e1d7a2c082 --- /dev/null +++ b/colossalai/shardformer/layer/layernorm.py @@ -0,0 +1,89 @@ +#!/usr/bin/env python +# -*- encoding: utf-8 -*- + +from typing import List, Union + +import torch +import torch.nn as nn +from torch.distributed import ProcessGroup + +from colossalai.kernel import LayerNorm +from colossalai.nn import init as init + +from .parallel_module import ParallelModule + +__all__ = ['LayerNorm1D'] + +Fast_LN = None +try: + from apex.contrib.layer_norm.layer_norm import FastLayerNorm + Fast_LN = FastLayerNorm +except ImportError: + pass + + +class LayerNorm1D(ParallelModule): + r""" + Layer Normalization for colossalai + + Args: + normalized_shape (int): input shape from an expected input of size. + :math:`[* \times \text{normalized_shape}[0] \times \text{normalized_shape}[1] + \times \ldots \times \text{normalized_shape}[-1]]` + If a single integer is used, it is treated as a singleton list, and this module will + normalize over the last dimension which is expected to be of that specific size. + eps (float): a value added to the denominator for numerical stability, defaults to 1e-05. + bias (bool, optional): Whether to add a bias, defaults to ``True``. + dtype (:class:`torch.dtype`, optional): The dtype of parameters, defaults to None. + """ + + _fast_ln_supported_sizes = [ + 1024, 1536, 2048, 2304, 3072, 3840, 4096, 5120, 6144, 8192, 10240, 12288, 12800, 15360, 16384, 18432, 20480, + 24576, 25600, 30720, 32768, 40960, 49152, 65536 + ] + + def __init__(self, + normalized_shape: int, + eps: int = 1e-05, + bias: bool = True, + dtype: torch.dtype = None, + device: torch.device = None): + super().__init__() + if Fast_LN is not None and normalized_shape in self._fast_ln_supported_sizes: + norm = Fast_LN(normalized_shape, eps=eps).to(dtype) + else: + norm = None + try: + from apex.normalization import FusedLayerNorm + norm = FusedLayerNorm(normalized_shape, eps=eps).to(dtype) + except ImportError: + norm = LayerNorm(normalized_shape, eps=eps, device=device, dtype=dtype) + self.norm = norm + + @staticmethod + def from_native_module(module: nn.LayerNorm, process_group: Union[ProcessGroup, List[ProcessGroup]], *args, + **kwargs) -> ParallelModule: + r""" + Convert a native pytorch layer norm module to colossalai layer norm module + """ + normalized_shape = module.normalized_shape + eps = module.eps + bias = module.bias is not None + dtype = module.weight.dtype + device = module.weight.device + + # ensure only one process group is passed + if isinstance(process_group, (list, tuple)): + assert len(process_group) == 1, \ + f'Expected only one process group, got {len(process_group)}.' + process_group = process_group[0] + + # create layer norm + layer_norm = LayerNorm1D(normalized_shape, eps=eps, bias=bias, device=device, dtype=dtype).norm + + with torch.no_grad(): + # copy weight and bias + layer_norm.weight.copy_(module.weight) + if bias: + layer_norm.bias.copy_(module.bias) + return layer_norm diff --git a/colossalai/shardformer/policies/bert.py b/colossalai/shardformer/policies/bert.py index 8649c0dbeaa6..1baf67ef9c02 100644 --- a/colossalai/shardformer/policies/bert.py +++ b/colossalai/shardformer/policies/bert.py @@ -1,8 +1,14 @@ import torch.nn as nn -from transformers.models.bert.modeling_bert import BertEmbeddings, BertLayer, BertLMPredictionHead +from transformers.models.bert.modeling_bert import ( + BertEmbeddings, + BertForMultipleChoice, + BertForSequenceClassification, + BertForTokenClassification, + BertLayer, + BertLMPredictionHead, +) import colossalai.shardformer.layer as col_nn -from colossalai.shardformer.layer.dropout import Dropout1D from .._utils import getattr_, setattr_ from .basepolicy import ModulePolicyDescription, Policy, SubModuleReplacementDescription @@ -24,7 +30,7 @@ def preprocess(self): return self.model def module_policy(self): - return { + base_policy = { BertLayer: ModulePolicyDescription( attribute_replacement={ @@ -53,10 +59,18 @@ def module_policy(self): suffix="attention.self.value", target_module=col_nn.Linear1D_Col, ), + SubModuleReplacementDescription( + suffix="attention.self.dropout", + target_module=col_nn.Dropout1D, + ), SubModuleReplacementDescription( suffix="attention.output.dense", target_module=col_nn.Linear1D_Row, ), + SubModuleReplacementDescription( + suffix="attention.output.dropout", + target_module=col_nn.Dropout1D, + ), SubModuleReplacementDescription( suffix="intermediate.dense", target_module=col_nn.Linear1D_Col, @@ -66,12 +80,8 @@ def module_policy(self): target_module=col_nn.Linear1D_Row, ), SubModuleReplacementDescription( - suffix="attention.self.dropout", - target_module=Dropout1D, - ), - SubModuleReplacementDescription( - suffix="attention.output.dropout", - target_module=Dropout1D, + suffix="output.dropout", + target_module=col_nn.Dropout1D, ) ]), BertEmbeddings: @@ -81,10 +91,32 @@ def module_policy(self): SubModuleReplacementDescription( suffix="word_embeddings", target_module=col_nn.VocabParallelEmbedding1D, + ), + SubModuleReplacementDescription( + suffix="dropout", + target_module=col_nn.Dropout1D, ) ]) } + if self.shard_config.fused_layernorm: + base_policy[BertLayer].sub_module_replacement.append( + SubModuleReplacementDescription( + suffix="attention.output.LayerNorm", + target_module=col_nn.LayerNorm1D, + )) + base_policy[BertLayer].sub_module_replacement.append( + SubModuleReplacementDescription( + suffix="output.LayerNorm", + target_module=col_nn.LayerNorm1D, + )) + base_policy[BertEmbeddings].sub_module_replacement.append( + SubModuleReplacementDescription( + suffix="LayerNorm", + target_module=col_nn.LayerNorm1D, + ),) + return base_policy + def new_model_class(self): # do nothing return self.model @@ -115,9 +147,15 @@ def module_policy(self): sub_module_replacement=[ SubModuleReplacementDescription(suffix="decoder", target_module=col_nn.Linear1D_Col, - kwargs={"gather_output": True}) + kwargs={"gather_output": True}), ]) } + if self.shard_config.fused_layernorm: + addon_module[BertLMPredictionHead].sub_module_replacement.append( + SubModuleReplacementDescription( + suffix="transform.LayerNorm", + target_module=col_nn.LayerNorm1D, + )) module_policy.update(addon_module) return module_policy @@ -146,9 +184,15 @@ def module_policy(self): sub_module_replacement=[ SubModuleReplacementDescription(suffix="decoder", target_module=col_nn.Linear1D_Col, - kwargs={"gather_output": True}) + kwargs={"gather_output": True}), ]) } + if self.shard_config.fused_layernorm: + addon_module[BertLMPredictionHead].sub_module_replacement.append( + SubModuleReplacementDescription( + suffix="transform.LayerNorm", + target_module=col_nn.LayerNorm1D, + )) module_policy.update(addon_module) return module_policy @@ -177,9 +221,15 @@ def module_policy(self): sub_module_replacement=[ SubModuleReplacementDescription(suffix="decoder", target_module=col_nn.Linear1D_Col, - kwargs={"gather_output": True}) + kwargs={"gather_output": True}), ]) } + if self.shard_config.fused_layernorm: + addon_module[BertLMPredictionHead].sub_module_replacement.append( + SubModuleReplacementDescription( + suffix="transform.LayerNorm", + target_module=col_nn.LayerNorm1D, + )) module_policy.update(addon_module) return module_policy @@ -199,6 +249,22 @@ class BertForSequenceClassificationPolicy(BertPolicy): def __init__(self) -> None: super().__init__() + def module_policy(self): + module_policy = super().module_policy() + addon_module = { + BertForSequenceClassification: + ModulePolicyDescription(attribute_replacement={}, + param_replacement=[], + sub_module_replacement=[ + SubModuleReplacementDescription( + suffix="dropout", + target_module=col_nn.Dropout1D, + ) + ]) + } + module_policy.update(addon_module) + return module_policy + # BertForTokenClassification class BertForTokenClassificationPolicy(BertPolicy): @@ -206,6 +272,22 @@ class BertForTokenClassificationPolicy(BertPolicy): def __init__(self) -> None: super().__init__() + def module_policy(self): + module_policy = super().module_policy() + addon_module = { + BertForTokenClassification: + ModulePolicyDescription(attribute_replacement={}, + param_replacement=[], + sub_module_replacement=[ + SubModuleReplacementDescription( + suffix="dropout", + target_module=col_nn.Dropout1D, + ) + ]) + } + module_policy.update(addon_module) + return module_policy + # BertForNextSentencePrediction class BertForNextSentencePredictionPolicy(BertPolicy): @@ -219,3 +301,19 @@ class BertForMultipleChoicePolicy(BertPolicy): def __init__(self) -> None: super().__init__() + + def module_policy(self): + module_policy = super().module_policy() + addon_module = { + BertForMultipleChoice: + ModulePolicyDescription(attribute_replacement={}, + param_replacement=[], + sub_module_replacement=[ + SubModuleReplacementDescription( + suffix="dropout", + target_module=col_nn.Dropout1D, + ) + ]) + } + module_policy.update(addon_module) + return module_policy diff --git a/colossalai/shardformer/shard/shard_config.py b/colossalai/shardformer/shard/shard_config.py index 7379a8208745..8d3fc225e894 100644 --- a/colossalai/shardformer/shard/shard_config.py +++ b/colossalai/shardformer/shard/shard_config.py @@ -11,8 +11,9 @@ class ShardConfig: The config for sharding the huggingface model Args: - data_parallel_size (int): The size of data parallel tensor_parallel_size (int): The size of tensor parallel + use_mixedfusedLN (bool): Whether to use the `MixedFusedLayerNorm` + data_parallel_size (int): The size of data parallel pipeline_parallel_size (int): The size of pipeline parallel tensor_parallel_mode (List): The mode of tensor parallel, choose from `['1d','2d','2.5d','3d'] inference_only (bool): Whether to use the inference only mode, when setting to `True`, the model @@ -20,6 +21,7 @@ class ShardConfig: gather_output (bool): Whether to gather the output of the model of the last layer """ tensor_parallel_size: int + fused_layernorm: bool = False # TODO: add support for tensor parallel # pipeline_parallel_size: int diff --git a/tests/test_shardformer/test_layer/test_layernorm.py b/tests/test_shardformer/test_layer/test_layernorm.py new file mode 100644 index 000000000000..334ae05bed95 --- /dev/null +++ b/tests/test_shardformer/test_layer/test_layernorm.py @@ -0,0 +1,45 @@ +import torch +import torch.distributed as dist +import torch.nn as nn +from torch.testing import assert_close + +import colossalai +from colossalai.shardformer.layer import LayerNorm1D +from colossalai.testing import rerun_if_address_is_in_use, spawn + + +def check_layernorm_1d(): + norm = nn.LayerNorm(128, 0.00001).cuda() + norm1d = LayerNorm1D.from_native_module(norm, process_group=None) + + assert norm1d.weight.shape == torch.Size([128]) + + # ensure state dict is reversibly loadable + norm.load_state_dict(norm1d.state_dict()) + norm1d.load_state_dict(norm.state_dict()) + + # check computation correctness + x = torch.rand(4, 128).cuda() + out = norm(x) + gather_out = norm1d(x) + assert_close(out, gather_out) + + # check backward correctness + out.sum().backward() + gather_out.sum().backward() + + assert_close(norm.weight.grad, norm1d.weight.grad) + + +def run_dist(rank, world_size, port): + colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') + check_layernorm_1d() + + +@rerun_if_address_is_in_use() +def test_layernorm_1d(): + spawn(run_dist, nprocs=2) + + +if __name__ == '__main__': + test_layernorm_1d() diff --git a/tests/test_shardformer/test_layer/test_linearconv_1d.py b/tests/test_shardformer/test_layer/test_linearconv_1d.py index efdb88351519..774e6340eee6 100644 --- a/tests/test_shardformer/test_layer/test_linearconv_1d.py +++ b/tests/test_shardformer/test_layer/test_linearconv_1d.py @@ -77,7 +77,7 @@ def check_linear_conv_1d_col(): assert_close(target_grad, linear_conv_col.weight.grad) -def check_linear_1d_row(): +def check_linear_conv_1d_row(): linear = Conv1D(192, 48).cuda() linear_row = LinearConv1D_Row.from_native_module(linear, process_group=None, parallel_input=False) @@ -103,7 +103,7 @@ def check_linear_1d_row(): def run_dist(rank, world_size, port): colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') check_linear_conv_1d_col() - check_linear_1d_row() + check_linear_conv_1d_row() @rerun_if_address_is_in_use() diff --git a/tests/test_shardformer/test_model/_utils.py b/tests/test_shardformer/test_model/_utils.py index 52ca7fce895b..a282e0bb919e 100644 --- a/tests/test_shardformer/test_model/_utils.py +++ b/tests/test_shardformer/test_model/_utils.py @@ -8,7 +8,7 @@ def build_model(world_size, model_fn): org_model = model_fn().cuda() # shard model - shard_config = ShardConfig(tensor_parallel_size=world_size) + shard_config = ShardConfig(tensor_parallel_size=world_size, fused_layernorm=True) model_copy = copy.deepcopy(org_model) shard_former = ShardFormer(shard_config=shard_config) shard_former.init_distributed() From d88844c093f1b4fe203edd10effa664deb619970 Mon Sep 17 00:00:00 2001 From: Frank Lee Date: Mon, 26 Jun 2023 15:50:07 +0800 Subject: [PATCH 36/49] [test] fixed tests failed due to dtensor change (#4082) * [test] fixed tests failed due to dtensor change * polish code --- .../tensor_shard/node_handler/node_handler.py | 2 +- .../strategy/matmul_strategy_generator.py | 6 +- .../auto_parallel/tensor_shard/utils/misc.py | 4 +- colossalai/checkpoint_io/utils.py | 4 +- colossalai/lazy/lazy_init.py | 11 ++- colossalai/tensor/comm_spec.py | 97 +++++++++---------- colossalai/tensor/d_tensor/comm_spec.py | 88 ++++++++--------- colossalai/tensor/d_tensor/layout.py | 13 +-- .../tensor/d_tensor/layout_converter.py | 71 +++++++------- colossalai/tensor/shape_consistency.py | 6 +- colossalai/tensor/sharding_spec.py | 6 +- test.py | 1 - .../test_autochunk_unet.py | 11 +-- .../test_gemini_checkpoint_io.py | 4 +- tests/test_device/test_device_mesh.py | 10 +- tests/test_device/test_init_logical_pg.py | 16 ++- .../test_hf_model/hf_tracer_utils.py | 14 ++- .../test_hf_model/test_hf_albert.py | 2 +- .../test_tracer/test_hf_model/test_hf_bert.py | 4 +- .../test_hf_model/test_hf_diffuser.py | 2 +- .../test_tracer/test_hf_model/test_hf_gpt.py | 4 +- .../test_tracer/test_hf_model/test_hf_opt.py | 2 +- .../test_tracer/test_hf_model/test_hf_t5.py | 9 +- .../test_timm_model/test_timm_model.py | 2 +- .../test_torchaudio_model.py | 2 +- .../test_torchrec_model/test_deepfm_model.py | 2 +- .../test_torchrec_model/test_dlrm_model.py | 2 +- .../test_torchvision_model.py | 2 +- tests/test_lazy/lazy_init_utils.py | 4 +- tests/test_lazy/test_distribute.py | 30 +++--- tests/test_lazy/test_models.py | 2 +- .../test_dtensor/test_comm_spec.py | 33 ++----- .../test_tensor/test_dtensor/test_dtensor.py | 2 +- .../test_dtensor/test_layout_converter.py | 43 +++----- tests/test_tensor/test_shape_consistency.py | 7 +- tests/test_tensor/test_sharded_linear.py | 2 +- tests/test_tensor/test_sharding_spec.py | 2 +- 37 files changed, 233 insertions(+), 289 deletions(-) delete mode 100644 test.py diff --git a/colossalai/auto_parallel/tensor_shard/node_handler/node_handler.py b/colossalai/auto_parallel/tensor_shard/node_handler/node_handler.py index 4262d76173e4..b4b7b0e794d1 100644 --- a/colossalai/auto_parallel/tensor_shard/node_handler/node_handler.py +++ b/colossalai/auto_parallel/tensor_shard/node_handler/node_handler.py @@ -188,7 +188,7 @@ def register_strategy(self, compute_resharding_cost: bool = True) -> StrategiesV remove_strategy_list = [] for strategy in self.strategies_vector: shard_axis_list = [] - last_axis = len(self.device_mesh.mesh_shape) - 1 + last_axis = len(self.device_mesh.shape) - 1 for op_data, sharding_spec in strategy.sharding_specs.items(): if op_data.data is not None and isinstance(op_data.data, torch.Tensor): for dim, shard_axes in sharding_spec.dim_partition_dict.items(): diff --git a/colossalai/auto_parallel/tensor_shard/node_handler/strategy/matmul_strategy_generator.py b/colossalai/auto_parallel/tensor_shard/node_handler/strategy/matmul_strategy_generator.py index 1ce5a08f2d6b..aa1581b99e0f 100644 --- a/colossalai/auto_parallel/tensor_shard/node_handler/strategy/matmul_strategy_generator.py +++ b/colossalai/auto_parallel/tensor_shard/node_handler/strategy/matmul_strategy_generator.py @@ -984,7 +984,7 @@ def split_batch_dim_both_contract(self, mesh_dim_0, mesh_dim_1): def collate_strategies(self) -> List[ShardingStrategy]: strategy_list = [] device_mesh_is_1d = True - if len(self.device_mesh.mesh_shape) == 2 and 1 not in self.device_mesh.mesh_shape: + if len(self.device_mesh.shape) == 2 and 1 not in self.device_mesh.shape: device_mesh_is_1d = False if device_mesh_is_1d: @@ -992,10 +992,10 @@ def collate_strategies(self) -> List[ShardingStrategy]: # Sb = Sb x Sb # can be None as it is only for 1D device mesh # only for 1D device mesh - if len(self.device_mesh.mesh_shape) == 1: + if len(self.device_mesh.shape) == 1: mesh_dim = 0 else: - mesh_dim = self.device_mesh.mesh_shape.index(1) + mesh_dim = self.device_mesh.shape.index(1) strategy_list.append(self.split_one_batch_dim(mesh_dim)) else: # for 2D device mesh diff --git a/colossalai/auto_parallel/tensor_shard/utils/misc.py b/colossalai/auto_parallel/tensor_shard/utils/misc.py index 9e402dab7578..475e95fc4326 100644 --- a/colossalai/auto_parallel/tensor_shard/utils/misc.py +++ b/colossalai/auto_parallel/tensor_shard/utils/misc.py @@ -46,8 +46,8 @@ def check_sharding_spec_validity(sharding_spec: ShardingSpec, tensor: torch.Tens # make sure all dims are covered in sharding spec sharding_len = len(sharding_spec.sharding_sequence) tensor_num_dim = tensor.dim() - num_devices_in_col = sharding_spec.device_mesh.mesh_shape[0] - num_devices_in_row = sharding_spec.device_mesh.mesh_shape[1] + num_devices_in_col = sharding_spec.device_mesh.shape[0] + num_devices_in_row = sharding_spec.device_mesh.shape[1] assert sharding_len == tensor_num_dim, \ f'The ShardingSpec ({sharding_spec.sharding_sequence}) is created for {sharding_len}-dimension tensor, but the given tensor is {tensor_num_dim}-dimension ({tensor.shape}).' diff --git a/colossalai/checkpoint_io/utils.py b/colossalai/checkpoint_io/utils.py index 68981dff0d0a..485577b9650c 100644 --- a/colossalai/checkpoint_io/utils.py +++ b/colossalai/checkpoint_io/utils.py @@ -99,7 +99,7 @@ def shard_model_checkpoint(state_dict: torch.Tensor, max_shard_size: int = 1024) for key, weight in state_dict.items(): ret_block = None ret_block_size = 0 - if is_distributed_tensor(weight): + if not is_distributed_tensor(weight): weight_size = calculate_tensor_size(weight) # If this weight is going to tip up over the maximal size, we split. @@ -146,7 +146,7 @@ def shard_optimizer_checkpoint(state_dict: dict, max_shard_size: int = 1024) -> continue # If the states are stored as DTensors, mark isDTensor as true. - if type(state_tensor) == DTensor: + if is_distributed_tensor(state_tensor): isDTensor = True state_size += calculate_tensor_size(state_tensor) diff --git a/colossalai/lazy/lazy_init.py b/colossalai/lazy/lazy_init.py index 1e45eced5f34..8b911407307c 100644 --- a/colossalai/lazy/lazy_init.py +++ b/colossalai/lazy/lazy_init.py @@ -1,5 +1,5 @@ from types import MethodType -from typing import Callable, Optional, Union +from typing import Callable, Dict, Optional, Union import torch import torch.distributed as dist @@ -173,7 +173,7 @@ def materialize(self) -> torch.Tensor: self.clean() return _convert_cls(self, target) - def distribute(self, layout: Layout) -> torch.Tensor: + def distribute(self, device_mesh: DeviceMesh, sharding_spec: ShardingSpec) -> torch.Tensor: """Distribute the ``LazyTensor`` to ``torch.Tensor`` by modifying __class__ (inplace), according to the layout. Args: @@ -537,7 +537,10 @@ def apply_fn(name: str, p: LazyTensor): return _apply_to_lazy_module(module, apply_fn, verbose) @staticmethod - def distribute(module: nn.Module, layout_dict: dict, verbose: bool = False) -> nn.Module: + def distribute(module: nn.Module, + device_mesh: DeviceMesh, + sharding_spec_dict: Dict[str, ShardingSpec], + verbose: bool = False) -> nn.Module: """Distribute all ``nn.Parameter`` from ``LazyTensor``. This function will modify the module in-place. Args: @@ -547,7 +550,7 @@ def distribute(module: nn.Module, layout_dict: dict, verbose: bool = False) -> n """ def apply_fn(name: str, p: LazyTensor): - p.distribute(layout_dict[name]) + p.distribute(device_mesh, sharding_spec_dict[name]) return _apply_to_lazy_module(module, apply_fn, verbose) diff --git a/colossalai/tensor/comm_spec.py b/colossalai/tensor/comm_spec.py index af38d2a502c2..204f81343199 100644 --- a/colossalai/tensor/comm_spec.py +++ b/colossalai/tensor/comm_spec.py @@ -16,69 +16,66 @@ def _all_gather(tensor, comm_spec): ''' Implement all gather operation on device mesh based on information provided by comm_spec. ''' - process_groups_list = comm_spec.device_mesh.process_groups_dict[comm_spec.logical_process_axis] - for rank_list, process_group in process_groups_list: - if dist.get_rank() in rank_list: - tensor_list = [ - torch.zeros(tensor.shape, dtype=tensor.dtype, device=tensor.device) - for _ in range(comm_spec.device_mesh.mesh_shape[comm_spec.logical_process_axis]) - ] - # without this contiguous operation, the all gather may get some unexpected results. - tensor = tensor.contiguous() - dist.all_gather(tensor_list, tensor, group=process_group) - output = torch.cat(tuple(tensor_list), comm_spec.gather_dim).contiguous() - return output + process_groups = comm_spec.device_mesh.get_process_group_for_all_axes() + process_group = process_groups[comm_spec.logical_process_axis] + + tensor_list = [ + torch.zeros(tensor.shape, dtype=tensor.dtype, device=tensor.device) + for _ in range(comm_spec.device_mesh.shape[comm_spec.logical_process_axis]) + ] + # without this contiguous operation, the all gather may get some unexpected results. + tensor = tensor.contiguous() + dist.all_gather(tensor_list, tensor, group=process_group) + output = torch.cat(tuple(tensor_list), comm_spec.gather_dim).contiguous() + return output def _split(tensor, comm_spec): ''' Implement shard operation on device mesh based on information provided by comm_spec. ''' - process_groups_list = comm_spec.device_mesh.process_groups_dict[comm_spec.logical_process_axis] - for rank_list, _ in process_groups_list: - if dist.get_rank() in rank_list: - dim = comm_spec.shard_dim - length = tensor.shape[comm_spec.shard_dim] // len(rank_list) - start = length * rank_list.index(dist.get_rank()) - output = torch.narrow(tensor, dim, start, length).contiguous() - return output + process_groups = comm_spec.device_mesh.get_process_group_for_all_axes() + process_group = process_groups[comm_spec.logical_process_axis] + + dim = comm_spec.shard_dim + length = tensor.shape[comm_spec.shard_dim] // dist.get_world_size(process_group) + start = length * dist.get_rank(process_group) + output = torch.narrow(tensor, dim, start, length).contiguous() + return output def _all_to_all(tensor, comm_spec): ''' Implement all to all operation on device mesh based on information provided by comm_spec. ''' - process_groups_list = comm_spec.device_mesh.process_groups_dict[comm_spec.logical_process_axis] - for rank_list, process_group in process_groups_list: - if dist.get_rank() in rank_list: - new_shape = list(tensor.shape) - new_shape[comm_spec.shard_dim] = new_shape[comm_spec.shard_dim] // len(rank_list) - new_shape = torch.Size(new_shape) - output_tensor_list = [ - torch.zeros(new_shape, dtype=tensor.dtype, device=tensor.device) for _ in range(len(rank_list)) - ] - dim = comm_spec.shard_dim - length = tensor.shape[comm_spec.shard_dim] // len(rank_list) - input_tensor_list = [ - torch.narrow(tensor, dim, length * i, length).contiguous() for i in range(len(rank_list)) - ] - group = process_group - dist.all_to_all(output_tensor_list, input_tensor_list, group) - output = torch.cat(tuple(output_tensor_list), comm_spec.gather_dim).contiguous() - return output + process_groups = comm_spec.device_mesh.get_process_group_for_all_axes() + process_group = process_groups[comm_spec.logical_process_axis] + world_size = dist.get_world_size(process_group) + + new_shape = list(tensor.shape) + new_shape[comm_spec.shard_dim] = new_shape[comm_spec.shard_dim] // world_size + new_shape = torch.Size(new_shape) + output_tensor_list = [torch.zeros(new_shape, dtype=tensor.dtype, device=tensor.device) for _ in range(world_size)] + dim = comm_spec.shard_dim + length = tensor.shape[comm_spec.shard_dim] // world_size + input_tensor_list = [torch.narrow(tensor, dim, length * i, length).contiguous() for i in range(world_size)] + group = process_group + dist.all_to_all(output_tensor_list, input_tensor_list, group) + output = torch.cat(tuple(output_tensor_list), comm_spec.gather_dim).contiguous() + return output def _all_reduce(tensor, comm_spec, async_op=False): ''' Implement all reduce operation on device mesh based on information provided by comm_spec. ''' - process_groups_list = comm_spec.device_mesh.process_groups_dict[comm_spec.logical_process_axis] - for rank_list, process_group in process_groups_list: - if dist.get_rank() in rank_list: - if not tensor.is_contiguous(): - tensor = tensor.contiguous() - dist.all_reduce(tensor, op=ReduceOp.SUM, group=process_group, async_op=async_op) - return tensor + process_groups = comm_spec.device_mesh.get_process_group_for_all_axes() + process_group = process_groups[comm_spec.logical_process_axis] + + if not tensor.is_contiguous(): + tensor = tensor.contiguous() + dist.all_reduce(tensor, op=ReduceOp.SUM, group=process_group, async_op=async_op) + return tensor def _mix_gather(tensor, comm_spec): @@ -128,7 +125,7 @@ def _mix_gather(tensor, comm_spec): process_group = "[0, 1, 2, 3, 4, 5, 6, 7]" tensor_list = [(0,0),(1,0),(2,0),(3,0),(4,0),(5,0),(6,0),(7,0)] ''' - total_slices = comm_spec.device_mesh.mesh_shape[0] + total_slices = comm_spec.device_mesh.shape[0] tensor_list = [torch.zeros(tensor.shape, dtype=tensor.dtype, device=tensor.device) for _ in range(total_slices)] leading_group_dim = comm_spec.logical_process_axes[0] assert len(comm_spec.device_mesh.process_groups_dict) == 1 @@ -149,7 +146,7 @@ def _mix_gather(tensor, comm_spec): if comm_spec.logical_process_axes[0] == comm_spec.logical_process_axes[1]: output = torch.cat(tuple(tensor_list), comm_spec.gather_dim[0]).contiguous() else: - mesh_shape = comm_spec.device_meshes.mesh_shape + mesh_shape = comm_spec.device_meshes.shape cat_slice = [mesh_shape[comm_spec.logical_process_axes[0]], mesh_shape[comm_spec.logical_process_axes[1]]] tmp_tensor_shape = list(tensor.shape) tmp_tensor_shape[comm_spec.gather_dim[0]] *= cat_slice[0] @@ -181,9 +178,9 @@ def _mix_split(tensor, comm_spec): # [4, 5, 6, 7]] # return {0: [0, 4, 1, 5, 2, 6, 3, 7], 1: [0, 1, 2, 3, 4, 5, 6, 7]} ''' - mesh_shape = comm_spec.device_meshes.mesh_shape + mesh_shape = comm_spec.device_meshes.shape dim = comm_spec.gather_dim - total_slices = comm_spec.device_mesh.mesh_shape[0] + total_slices = comm_spec.device_mesh.shape[0] # Get global rank rank = dist.get_rank() @@ -414,7 +411,7 @@ def __init__(self, self.forward_only = forward_only if isinstance(self.logical_process_axis, list): if not mix_gather: - self.device_mesh = self.sharding_spec.device_mesh.flatten_device_mesh + self.device_mesh = self.sharding_spec.device_mesh.flatten() self.logical_process_axis = 0 else: self.device_meshes = self.sharding_spec.device_mesh.flatten_device_meshes diff --git a/colossalai/tensor/d_tensor/comm_spec.py b/colossalai/tensor/d_tensor/comm_spec.py index 159125fa16db..79b2e3ef936a 100644 --- a/colossalai/tensor/d_tensor/comm_spec.py +++ b/colossalai/tensor/d_tensor/comm_spec.py @@ -24,12 +24,12 @@ class CommSpec: ''' Communication spec is used to record the communication action. It converts the communication spec to real action which will be used in runtime. It contains comm_pattern to determine the - communication method, process_groups_dict to determine the process groups, gather_dim and shard_dim + communication method, process_group_dict to determine the process groups, gather_dim and shard_dim to determine the buffer shape, and logical_process_axis Argument: - comm_pattern(CollectiveCommPattern): describe the communication method used in this spec. - process_groups_dict(Dict): A dict which contains the process groups used to apply this CommSpec. + comm_pattern(CollectiveCommPattern): decribe the communication method used in this spec. + process_group_dict(Dict): A dict which contains the process groups used to apply this CommSpec. gather_dim(int, Optional): The gather_dim of the tensor will be gathered. shard_dim(int, Optional): The shard_dim of the tensor will be sharded. logical_process_axis(Union(int, List[int]), Optional): The mesh_dim to implement the communication action. @@ -37,7 +37,7 @@ class CommSpec: def __init__(self, comm_pattern: CollectiveCommPattern, - process_groups_dict: Dict, + process_group_dict: Dict, gather_dim: int = None, shard_dim: int = None, logical_process_axis: int = None): @@ -45,7 +45,7 @@ def __init__(self, self.gather_dim = gather_dim self.shard_dim = shard_dim self.logical_process_axis = logical_process_axis - self.process_groups_dict = process_groups_dict + self.process_group_dict = process_group_dict def __repr__(self): res_list = ["CommSpec:("] @@ -92,68 +92,56 @@ def _all_gather(tensor: torch.Tensor, comm_spec: CommSpec): ''' Implement all gather operation on device mesh based on information provided by comm_spec. ''' - process_groups_list = comm_spec.process_groups_dict[comm_spec.logical_process_axis] - for rank_list, process_group in process_groups_list: - if dist.get_rank() in rank_list: - tensor_list = [ - torch.zeros(tensor.shape, dtype=tensor.dtype, device=tensor.device) for _ in range(len(rank_list)) - ] - # without this contiguous operation, the all gather may get some unexpected results. - tensor = tensor.contiguous() - dist.all_gather(tensor_list, tensor, group=process_group) - output = torch.cat(tuple(tensor_list), comm_spec.gather_dim).contiguous() - return output + process_group = comm_spec.process_group_dict[comm_spec.logical_process_axis] + world_size = dist.get_world_size(process_group) + tensor_list = [torch.zeros(tensor.shape, dtype=tensor.dtype, device=tensor.device) for _ in range(world_size)] + # without this contiguous operation, the all gather may get some unexpected results. + tensor = tensor.contiguous() + dist.all_gather(tensor_list, tensor, group=process_group) + output = torch.cat(tuple(tensor_list), comm_spec.gather_dim).contiguous() + return output def _split(tensor: torch.Tensor, comm_spec: CommSpec): ''' Implement shard operation on device mesh based on information provided by comm_spec. ''' - process_groups_list = comm_spec.process_groups_dict[comm_spec.logical_process_axis] - for rank_list, _ in process_groups_list: - if dist.get_rank() in rank_list: - dim = comm_spec.shard_dim - length = tensor.shape[comm_spec.shard_dim] // len(rank_list) - start = length * rank_list.index(dist.get_rank()) - output = torch.narrow(tensor, dim, start, length).contiguous() - return output + process_group = comm_spec.process_group_dict[comm_spec.logical_process_axis] + dim = comm_spec.shard_dim + length = tensor.shape[comm_spec.shard_dim] // dist.get_world_size(process_group) + start = length * dist.get_rank(process_group) + output = torch.narrow(tensor, dim, start, length).contiguous() + return output def _all_to_all(tensor: torch.Tensor, comm_spec: CommSpec): ''' Implement all to all operation on device mesh based on information provided by comm_spec. ''' - process_groups_list = comm_spec.process_groups_dict[comm_spec.logical_process_axis] - for rank_list, process_group in process_groups_list: - if dist.get_rank() in rank_list: - new_shape = list(tensor.shape) - new_shape[comm_spec.shard_dim] = new_shape[comm_spec.shard_dim] // len(rank_list) - new_shape = torch.Size(new_shape) - output_tensor_list = [ - torch.zeros(new_shape, dtype=tensor.dtype, device=tensor.device) for _ in range(len(rank_list)) - ] - dim = comm_spec.shard_dim - length = tensor.shape[comm_spec.shard_dim] // len(rank_list) - input_tensor_list = [ - torch.narrow(tensor, dim, length * i, length).contiguous() for i in range(len(rank_list)) - ] - group = process_group - dist.all_to_all(output_tensor_list, input_tensor_list, group) - output = torch.cat(tuple(output_tensor_list), comm_spec.gather_dim).contiguous() - return output + process_group = comm_spec.process_group_dict[comm_spec.logical_process_axis] + world_size = dist.get_world_size(process_group) + new_shape = list(tensor.shape) + new_shape[comm_spec.shard_dim] = new_shape[comm_spec.shard_dim] // world_size + new_shape = torch.Size(new_shape) + output_tensor_list = [torch.zeros(new_shape, dtype=tensor.dtype, device=tensor.device) for _ in range(world_size)] + dim = comm_spec.shard_dim + length = tensor.shape[comm_spec.shard_dim] // world_size + input_tensor_list = [torch.narrow(tensor, dim, length * i, length).contiguous() for i in range(world_size)] + group = process_group + dist.all_to_all(output_tensor_list, input_tensor_list, group) + output = torch.cat(tuple(output_tensor_list), comm_spec.gather_dim).contiguous() + return output def _all_reduce(tensor: torch.Tensor, comm_spec: CommSpec, async_op: bool = False): ''' Implement all reduce operation on device mesh based on information provided by comm_spec. ''' - process_groups_list = comm_spec.process_groups_dict[comm_spec.logical_process_axis] - for rank_list, process_group in process_groups_list: - if dist.get_rank() in rank_list: - if not tensor.is_contiguous(): - tensor = tensor.contiguous() - dist.all_reduce(tensor, op=ReduceOp.SUM, group=process_group, async_op=async_op) - return tensor + process_group = comm_spec.process_group_dict[comm_spec.logical_process_axis] + if not tensor.is_contiguous(): + tensor = tensor.contiguous() + dist.all_reduce(tensor, op=ReduceOp.SUM, group=process_group, async_op=async_op) + return tensor class _ReduceGrad(torch.autograd.Function): @@ -269,7 +257,7 @@ def symbolic(graph, input_): def forward(ctx, input_, comm_spec): output = _all_to_all(input_, comm_spec) comm_spec_for_backward = CommSpec(comm_pattern=comm_spec.comm_pattern, - process_groups_dict=comm_spec.process_groups_dict, + process_group_dict=comm_spec.process_group_dict, gather_dim=comm_spec.shard_dim, shard_dim=comm_spec.gather_dim, logical_process_axis=comm_spec.logical_process_axis) diff --git a/colossalai/tensor/d_tensor/layout.py b/colossalai/tensor/d_tensor/layout.py index 4185b85860e3..a35b2f43e44b 100644 --- a/colossalai/tensor/d_tensor/layout.py +++ b/colossalai/tensor/d_tensor/layout.py @@ -14,24 +14,21 @@ class Layout: Attributes: device_mesh: the device mesh to store the tensor distributed. - device_type: the type of the device mesh, e.g. 'cpu' or 'cuda'. sharding_spec: the sharding specification to describe how the tensor is sharded. - entire_shape: the entire shape of the global tensor. + global_shape: the entire shape of the global tensor. """ - def __init__(self, device_mesh: DeviceMesh, device_type: torch.device, sharding_spec: ShardingSpec, - entire_shape: torch.Size): + def __init__(self, device_mesh: DeviceMesh, sharding_spec: ShardingSpec, global_shape: torch.Size): self.device_mesh = device_mesh - self.device_type = device_type self.sharding_spec = sharding_spec - self.entire_shape = entire_shape + self.global_shape = global_shape self._sanity_check() def __hash__(self) -> int: return hash(f'{self.sharding_spec}') def get_sharded_shape_per_device(self): - sharded_shape = list(self.entire_shape) + sharded_shape = list(self.global_shape) for dim, shard_list in self.sharding_spec.dim_partition_dict.items(): mesh_list = [self.device_mesh.shape[mesh_dim] for mesh_dim in shard_list] shard_partitions = reduce(operator.mul, mesh_list, 1) @@ -56,7 +53,7 @@ def _sanity_check(self): # make sure that the sharding for a dimension is divisible by the number of devices for dim, shard_list in sharding_spec.dim_partition_dict.items(): - tensor_dim_size = self.entire_shape[dim] + tensor_dim_size = self.global_shape[dim] num_devices = 1 for element in shard_list: diff --git a/colossalai/tensor/d_tensor/layout_converter.py b/colossalai/tensor/d_tensor/layout_converter.py index 14f9c4561622..528ed7901c4f 100644 --- a/colossalai/tensor/d_tensor/layout_converter.py +++ b/colossalai/tensor/d_tensor/layout_converter.py @@ -3,10 +3,8 @@ from dataclasses import dataclass from typing import Dict, List, Tuple -import numpy as np import torch -from colossalai.auto_parallel.tensor_shard.sharding_strategy import MemoryCost, TrainCycleItem from colossalai.context.singleton_meta import SingletonMeta from colossalai.tensor.d_tensor.comm_spec import * from colossalai.tensor.d_tensor.layout import Layout @@ -37,6 +35,9 @@ def set_layout_converting_options(options: LayoutConverterOptions): class LayoutConverter(metaclass=SingletonMeta): + """ + LayoutConverter is a singleton class which converts the layout of a distributed tensor. + """ def __init__(self): self._options = None @@ -79,15 +80,14 @@ def all_gather_transform_layouts(self, source_layout: Layout) -> Dict[Layout, Co # [[0, 1, # [2, 3]] device_mesh = DeviceMesh(physical_mesh_id, mesh_shape, init_process_group=True) - entire_shape = (4, 4, 4) + global_shape = (4, 4, 4) dim_partition_dict = {0: [0], 1: [1]} # [S0,S1,R] sharding_spec = ShardingSpec(dim_size=3, dim_partition_dict=dim_partition_dict) layout = Layout(device_mesh=device_mesh, - device_type=torch.device('cuda'), sharding_spec=sharding_spec, - entire_shape=entire_shape) + global_shape=global_shape) rst_dict = layout_converter.all_gather_transform_layouts(layout) for layout, comm_spec in rst_dict.items(): @@ -100,7 +100,12 @@ def all_gather_transform_layouts(self, source_layout: Layout) -> Dict[Layout, Co valid_spec_dict = {} comm_pattern = CollectiveCommPattern.GATHER_FWD_SPLIT_BWD source_spec = source_layout.sharding_spec - process_groups_dict = source_layout.device_mesh.process_groups_dict + + # the key of the dict is the axis + # the value is the process group + current_rank = source_layout.device_mesh._global_rank_of_current_process + process_group_dict = source_layout.device_mesh._process_group_dict[current_rank] + for target_pair in source_spec.dim_partition_dict.items(): shard_list = all_gather_simulator(target_pair) index = target_pair[0] @@ -118,7 +123,7 @@ def all_gather_transform_layouts(self, source_layout: Layout) -> Dict[Layout, Co logical_process_axis = target_pair[1][-1] comm_spec = CommSpec( comm_pattern, - process_groups_dict=process_groups_dict, + process_group_dict=process_group_dict, gather_dim=gather_dim, # shard_dim will be used during backward shard_dim=gather_dim, @@ -129,8 +134,7 @@ def all_gather_transform_layouts(self, source_layout: Layout) -> Dict[Layout, Co new_sharding_spec = ShardingSpec(source_spec.dims, dim_partition_dict=new_dim_partition_dict) new_layout = Layout(device_mesh=source_layout.device_mesh, sharding_spec=new_sharding_spec, - device_type=source_layout.device_type, - entire_shape=source_layout.entire_shape) + global_shape=source_layout.global_shape) valid_spec_dict[new_layout] = comm_spec except LayoutException: @@ -155,15 +159,14 @@ def all_to_all_transform_layout(self, source_layout: Layout) -> Dict[Layout, Com # [[0, 1, # [2, 3]] device_mesh = DeviceMesh(physical_mesh_id, mesh_shape, init_process_group=True) - entire_shape = (4, 4, 4) + global_shape = (4, 4, 4) dim_partition_dict = {0: [0], 1: [1]} # [S0,S1,R] sharding_spec = ShardingSpec(dim_size=3, dim_partition_dict=dim_partition_dict) layout = Layout(device_mesh=device_mesh, - device_type=torch.device('cuda'), sharding_spec=sharding_spec, - entire_shape=entire_shape) + global_shape=global_shape) rst_dict = layout_converter.all_to_all_transform_layout(layout) for layout, comm_spec in rst_dict.items(): @@ -176,7 +179,12 @@ def all_to_all_transform_layout(self, source_layout: Layout) -> Dict[Layout, Com ''' valid_spec_dict = {} comm_pattern = CollectiveCommPattern.ALL2ALL_FWD_ALL2ALL_BWD - process_groups_dict = source_layout.device_mesh.process_groups_dict + + # the key of the dict is the axis + # the value is the process group + current_rank = source_layout.device_mesh._global_rank_of_current_process + process_group_dict = source_layout.device_mesh._process_group_dict[current_rank] + source_spec = source_layout.sharding_spec tensor_dims = source_spec.dims for f_index in range(tensor_dims - 1): @@ -217,7 +225,7 @@ def all_to_all_transform_layout(self, source_layout: Layout) -> Dict[Layout, Com shard_dim = f_index logical_process_axis = b_target_pair[1][-1] comm_spec = CommSpec(comm_pattern, - process_groups_dict, + process_group_dict=process_group_dict, gather_dim=gather_dim, shard_dim=shard_dim, logical_process_axis=logical_process_axis) @@ -240,8 +248,7 @@ def all_to_all_transform_layout(self, source_layout: Layout) -> Dict[Layout, Com new_sharding_spec = ShardingSpec(source_spec.dims, dim_partition_dict=new_dim_partition_dict) new_layout = Layout(device_mesh=source_layout.device_mesh, sharding_spec=new_sharding_spec, - device_type=source_layout.device_type, - entire_shape=source_layout.entire_shape) + global_shape=source_layout.global_shape) valid_spec_dict[new_layout] = comm_spec except LayoutException: pass @@ -266,16 +273,15 @@ def shard_transform_layout(self, source_layout: Layout) -> Dict[Layout, CommSpec # [[0, 1, # [2, 3]] device_mesh = DeviceMesh(physical_mesh_id, mesh_shape, init_process_group=True) - entire_shape = (4, 4, 4) + global_shape = (4, 4, 4) dim_partition_dict = {0: [0]} # [S0,R,R] sharding_spec = ShardingSpec(dim_size=3, dim_partition_dict=dim_partition_dict) layout = Layout(device_mesh=device_mesh, - device_type=torch.device('cuda'), sharding_spec=sharding_spec, - entire_shape=entire_shape) + global_shape=global_shape) rst_dict = layout_converter.shard_transform_layout(layout) for layout, comm_spec in rst_dict.items(): @@ -289,7 +295,11 @@ def shard_transform_layout(self, source_layout: Layout) -> Dict[Layout, CommSpec valid_spec_dict = {} comm_pattern = CollectiveCommPattern.SPLIT_FWD_GATHER_BWD source_spec = source_layout.sharding_spec - process_groups_dict = source_layout.device_mesh.process_groups_dict + + # the key of the dict is the axis + # the value is the process group + current_rank = source_layout.device_mesh._global_rank_of_current_process + process_group_dict = source_layout.device_mesh._process_group_dict[current_rank] # legal sharding dims means the mesh_id is still available to use. legal_sharding_dims = [i for i in range(len(source_layout.device_mesh.shape))] @@ -317,7 +327,7 @@ def shard_transform_layout(self, source_layout: Layout) -> Dict[Layout, CommSpec shard_dim = index logical_process_axis = shard_list[-1] comm_spec = CommSpec(comm_pattern, - process_groups_dict, + process_group_dict=process_group_dict, gather_dim=shard_dim, shard_dim=shard_dim, logical_process_axis=logical_process_axis) @@ -328,8 +338,7 @@ def shard_transform_layout(self, source_layout: Layout) -> Dict[Layout, CommSpec dim_partition_dict=new_dim_partition_dict) new_layout = Layout(device_mesh=source_layout.device_mesh, sharding_spec=new_sharding_spec, - device_type=source_layout.device_type, - entire_shape=source_layout.entire_shape) + global_shape=source_layout.global_shape) valid_spec_dict[new_layout] = comm_spec except LayoutException: pass @@ -387,7 +396,7 @@ def layout_converting(self, source_layout: Layout, # [[0, 1, # [2, 3]] device_mesh = DeviceMesh(physical_mesh_id, mesh_shape, init_process_group=True) - entire_shape = (4, 4, 4) + global_shape = (4, 4, 4) dim_partition_source = {1: [0, 1]} dim_partition_target = {0: [0, 1]} @@ -395,16 +404,14 @@ def layout_converting(self, source_layout: Layout, # [R,S01,R] sharding_spec_source = ShardingSpec(dim_size=3, dim_partition_dict=dim_partition_source) source_layout = Layout(device_mesh=device_mesh, - device_type=torch.device('cuda'), sharding_spec=sharding_spec_source, - entire_shape=entire_shape) + global_shape=global_shape) # [S01,R,R] sharding_spec_target = ShardingSpec(dim_size=3, dim_partition_dict=dim_partition_target) target_layout = Layout(device_mesh=device_mesh, - device_type=torch.device('cuda'), sharding_spec=sharding_spec_target, - entire_shape=entire_shape) + global_shape=global_shape) transform_path, comm_action_sequence = layout_converter.layout_converting(source_layout, target_layout) transform_path_str = '->'.join([str(layout.sharding_spec.sharding_sequence) for layout in transform_path]) @@ -493,21 +500,19 @@ def apply(self, tensor: torch.Tensor, source_layout: Layout, target_layout: Layo # [[0, 1, # [2, 3]] device_mesh = DeviceMesh(physical_mesh_id, mesh_shape, init_process_group=True) - entire_shape = (4, 4, 4) + global_shape = (4, 4, 4) # [S0,R,R] sharding_spec_source = ShardingSpec(dim_size=3, dim_partition_dict=dim_partition_source) source_layout = Layout(device_mesh=device_mesh, - device_type=torch.device('cuda'), sharding_spec=sharding_spec_source, - entire_shape=entire_shape) + global_shape=global_shape) # [R,S0,R] sharding_spec_target = ShardingSpec(dim_size=3, dim_partition_dict=dim_partition_target) target_layout = Layout(device_mesh=device_mesh, - device_type=torch.device('cuda'), sharding_spec=sharding_spec_target, - entire_shape=entire_shape) + global_shape=global_shape) if rank in (0, 1): sharded_tensor_0 = torch.zeros(2, 1) diff --git a/colossalai/tensor/shape_consistency.py b/colossalai/tensor/shape_consistency.py index 5bec552d69d5..99d782c3f6e8 100644 --- a/colossalai/tensor/shape_consistency.py +++ b/colossalai/tensor/shape_consistency.py @@ -285,7 +285,7 @@ def get_all_shard_spec(self, source_spec: ShardingSpec, orig_cost_dict): comm_pattern = CollectiveCommPattern.SPLIT_FWD_GATHER_BWD # legal sharding dims means the mesh_id is still available to use. - legal_sharding_dims = [i for i in range(len(source_spec.device_mesh.mesh_shape))] + legal_sharding_dims = [i for i in range(len(source_spec.device_mesh.shape))] for dim, shard_list in source_spec.dim_partition_dict.items(): for element in shard_list: legal_sharding_dims.remove(element) @@ -435,7 +435,7 @@ def gather_analysis(comm_spec: CommSpec, discard_input: bool, alloc_numel: int, """ input_shape = compute_shape(comm_spec.sharding_spec) input_numel = np.prod(input_shape) - output_numel = input_numel * comm_spec.device_mesh.mesh_shape[comm_spec.logical_process_axis] + output_numel = input_numel * comm_spec.device_mesh.shape[comm_spec.logical_process_axis] peak_numel = max(peak_numel, alloc_numel + output_numel * 2) alloc_numel += output_numel if discard_input: @@ -461,7 +461,7 @@ def split_analysis(comm_spec: CommSpec, discard_input: bool, alloc_numel: int, p # generate a new tensor input_shape = compute_shape(comm_spec.sharding_spec) input_numel = np.prod(input_shape) - output_numel = input_numel // comm_spec.device_mesh.mesh_shape[comm_spec.logical_process_axis] + output_numel = input_numel // comm_spec.device_mesh.shape[comm_spec.logical_process_axis] alloc_numel += output_numel peak_numel = max(peak_numel, alloc_numel) if discard_input: diff --git a/colossalai/tensor/sharding_spec.py b/colossalai/tensor/sharding_spec.py index 406ad49097b5..e594fd297dc4 100644 --- a/colossalai/tensor/sharding_spec.py +++ b/colossalai/tensor/sharding_spec.py @@ -195,7 +195,7 @@ def __init__(self, def __repr__(self): res_list = ["DistSpec:"] res_list.append(f"\n\tshard_sequence: " + ",".join(str(dimspec) for dimspec in self.sharding_sequence)) - res_list.append(f"\n\tdevice_mesh_shape: {self.device_mesh.mesh_shape}") + res_list.append(f"\n\tdevice_mesh_shape: {self.device_mesh.shape}") return ' '.join(res_list) def _sanity_check(self): @@ -222,7 +222,7 @@ def _sanity_check(self): num_devices = 1 for element in shard_list: - num_devices *= self.device_mesh.mesh_shape[element] + num_devices *= self.device_mesh.shape[element] if tensor_dim_size % num_devices != 0: raise ShardingNotDivisibleError( @@ -288,7 +288,7 @@ def get_sharded_shape_per_device(self): sharded_shape = list(self.entire_shape) for dim, shard_list in self.dim_partition_dict.items(): - mesh_list = [self.device_mesh.mesh_shape[mesh_dim] for mesh_dim in shard_list] + mesh_list = [self.device_mesh.shape[mesh_dim] for mesh_dim in shard_list] shard_partitions = reduce(operator.mul, mesh_list, 1) assert sharded_shape[ dim] % shard_partitions == 0, f'Cannot shard dimension {dim} into {shard_partitions} partitions.' diff --git a/test.py b/test.py deleted file mode 100644 index f283e21a1ebd..000000000000 --- a/test.py +++ /dev/null @@ -1 +0,0 @@ -from colossalai.tensor.d_tensor.api import to_distributed_tensor diff --git a/tests/test_autochunk/test_autochunk_diffuser/test_autochunk_unet.py b/tests/test_autochunk/test_autochunk_diffuser/test_autochunk_unet.py index fc9d8455ed5c..f0cf2a5fcbca 100644 --- a/tests/test_autochunk/test_autochunk_diffuser/test_autochunk_unet.py +++ b/tests/test_autochunk/test_autochunk_diffuser/test_autochunk_unet.py @@ -58,13 +58,4 @@ def test_evoformer_block(model, shape, max_memory): if __name__ == "__main__": - run_test( - rank=0, - data=get_data(LATENTS_SHAPE), - max_memory=None, - model=UNet2DModel, - print_code=False, - print_mem=True, - print_est_mem=False, - print_progress=False, - ) + test_evoformer_block() diff --git a/tests/test_checkpoint_io/test_gemini_checkpoint_io.py b/tests/test_checkpoint_io/test_gemini_checkpoint_io.py index 14d69cab2176..602cf468c944 100644 --- a/tests/test_checkpoint_io/test_gemini_checkpoint_io.py +++ b/tests/test_checkpoint_io/test_gemini_checkpoint_io.py @@ -22,7 +22,7 @@ @parameterize('use_safetensors', [False, True]) def exam_state_dict_with_origin(placement_policy, model_name, use_safetensors: bool): from transformers import BertForSequenceClassification - (model_fn, data_gen_fn, output_transform_fn, _) = next(iter(model_zoo.get_sub_registry(model_name).values())) + (model_fn, data_gen_fn, output_transform_fn, _, _) = next(iter(model_zoo.get_sub_registry(model_name).values())) bert_model = model_fn() with shared_tempdir() as tempdir: @@ -53,7 +53,7 @@ def exam_state_dict_with_origin(placement_policy, model_name, use_safetensors: b @parameterize('shard', [True, False]) @parameterize('model_name', ['transformers_gpt']) def exam_state_dict(placement_policy, shard: bool, model_name: str): - (model_fn, data_gen_fn, output_transform_fn, _) = next(iter(model_zoo.get_sub_registry(model_name).values())) + (model_fn, data_gen_fn, output_transform_fn, _, _) = next(iter(model_zoo.get_sub_registry(model_name).values())) criterion = lambda x: x.mean() plugin = GeminiPlugin(placement_policy=placement_policy) booster = Booster(plugin=plugin) diff --git a/tests/test_device/test_device_mesh.py b/tests/test_device/test_device_mesh.py index e9f0f9477e4a..590d6966bff6 100644 --- a/tests/test_device/test_device_mesh.py +++ b/tests/test_device/test_device_mesh.py @@ -8,18 +8,16 @@ def test_device_mesh(): - physical_mesh_id = torch.arange(0, 16).reshape(2, 8) + physical_mesh_id = torch.arange(0, 16) mesh_shape = (4, 4) # [[0, 1, 2, 3], # [4, 5, 6, 7], # [8, 9, 10,11], # [12,13,14,15]] device_mesh = DeviceMesh(physical_mesh_id, mesh_shape) - assert device_mesh.convert_map[5] == [1, 1] - assert device_mesh.convert_map[11] == [2, 3] - assert device_mesh.global_rank_to_process_groups_with_logical_rank(0)[0] == [[0, 0], [1, 0], [2, 0], [3, 0]] - assert device_mesh.global_rank_to_process_groups_with_logical_rank(2)[1] == [[0, 0], [0, 1], [0, 2], [0, 3]] - assert device_mesh.global_rank_to_process_groups_with_global_rank(2)[1] == [0, 1, 2, 3] + assert device_mesh.global_rank_to_local_rank(5) == [1, 1] + assert device_mesh.global_rank_to_local_rank(11) == [2, 3] + assert device_mesh.get_ranks_in_process_group(axis=1, global_rank=2) == [0, 1, 2, 3] def check_1d_device_mesh(): diff --git a/tests/test_device/test_init_logical_pg.py b/tests/test_device/test_init_logical_pg.py index 2b7060c4846a..7c6339eff67e 100644 --- a/tests/test_device/test_init_logical_pg.py +++ b/tests/test_device/test_init_logical_pg.py @@ -20,16 +20,12 @@ def check_layer(rank, world_size, port): # [[0, 1, # [2, 3]] device_mesh = DeviceMesh(physical_mesh_id, mesh_shape, init_process_group=True) - logical_pg_dict = {0: [[0, 2], [1, 3]], 1: [[0, 1], [2, 3]]} - logical_process_groups = device_mesh.process_groups_dict - - for mesh_dim, pgs in logical_pg_dict.items(): - for index, pg in enumerate(pgs): - if rank in pg: - tensor = torch.ones(4).cuda() - group = logical_process_groups[mesh_dim][index][1] - dist.all_reduce(tensor, op=ReduceOp.SUM, group=group) - assert tensor.equal(tensor_to_check) + + for axis in range(len(mesh_shape)): + tensor = torch.ones(4).cuda() + pg = device_mesh.get_process_group(axis=axis) + dist.all_reduce(tensor, op=ReduceOp.SUM, group=pg) + assert tensor.equal(tensor_to_check) gpc.destroy() diff --git a/tests/test_fx/test_tracer/test_hf_model/hf_tracer_utils.py b/tests/test_fx/test_tracer/test_hf_model/hf_tracer_utils.py index 7a4bf131ae36..58c8132e1490 100644 --- a/tests/test_fx/test_tracer/test_hf_model/hf_tracer_utils.py +++ b/tests/test_fx/test_tracer/test_hf_model/hf_tracer_utils.py @@ -1,3 +1,5 @@ +from typing import List + import torch from numpy import isin from torch.fx import GraphModule @@ -7,19 +9,23 @@ from colossalai._analyzer.fx import symbolic_trace -def trace_model_and_compare_output(model, data_gen): +def trace_model_and_compare_output(model, data_gen, ignore_data: List[str] = None): # must turn on eval mode to ensure the output is consistent model.eval() + inputs = data_gen() + + if ignore_data is not None: + # drop the ignore_data key + inputs = {k: v for k, v in inputs.items() if k not in ignore_data} + try: - kwargs = data_gen() - meta_args = {k: v.to('meta') for k, v in kwargs.items()} + meta_args = {k: v.to('meta') for k, v in inputs.items()} gm = symbolic_trace(model, meta_args=meta_args) except Exception as e: raise RuntimeError(f"Failed to trace {model.__class__.__name__}, error: {e}") # run forward - inputs = data_gen() non_fx_out = model(**inputs) fx_out = gm(**inputs) diff --git a/tests/test_fx/test_tracer/test_hf_model/test_hf_albert.py b/tests/test_fx/test_tracer/test_hf_model/test_hf_albert.py index f4d681221191..a1470400ad82 100644 --- a/tests/test_fx/test_tracer/test_hf_model/test_hf_albert.py +++ b/tests/test_fx/test_tracer/test_hf_model/test_hf_albert.py @@ -15,7 +15,7 @@ def test_albert(): sub_registry = model_zoo.get_sub_registry('transformers_albert') - for name, (model_fn, data_gen_fn, _, _) in sub_registry.items(): + for name, (model_fn, data_gen_fn, _, _, _) in sub_registry.items(): model = model_fn() trace_model_and_compare_output(model, data_gen_fn) diff --git a/tests/test_fx/test_tracer/test_hf_model/test_hf_bert.py b/tests/test_fx/test_tracer/test_hf_model/test_hf_bert.py index a833bb30c056..632ad366ccc4 100644 --- a/tests/test_fx/test_tracer/test_hf_model/test_hf_bert.py +++ b/tests/test_fx/test_tracer/test_hf_model/test_hf_bert.py @@ -12,9 +12,9 @@ def test_bert(): sub_registry = model_zoo.get_sub_registry('transformers_bert') - for name, (model_fn, data_gen_fn, _, _) in sub_registry.items(): + for name, (model_fn, data_gen_fn, _, _, _) in sub_registry.items(): model = model_fn() - trace_model_and_compare_output(model, data_gen_fn) + trace_model_and_compare_output(model, data_gen_fn, ignore_data=['labels', 'next_sentence_label']) if __name__ == '__main__': diff --git a/tests/test_fx/test_tracer/test_hf_model/test_hf_diffuser.py b/tests/test_fx/test_tracer/test_hf_model/test_hf_diffuser.py index ccbe2da58bf2..ac87a7fcb13b 100644 --- a/tests/test_fx/test_tracer/test_hf_model/test_hf_diffuser.py +++ b/tests/test_fx/test_tracer/test_hf_model/test_hf_diffuser.py @@ -47,7 +47,7 @@ def test_diffusers(): sub_model_zoo = model_zoo.get_sub_registry('diffusers') - for name, (model_fn, data_gen_fn, output_transform_fn, _, attribute) in sub_model_zoo.items(): + for name, (model_fn, data_gen_fn, output_transform_fn, _, _, attribute) in sub_model_zoo.items(): data = data_gen_fn() trace_and_compare(model_fn, data, output_transform_fn) torch.cuda.synchronize() diff --git a/tests/test_fx/test_tracer/test_hf_model/test_hf_gpt.py b/tests/test_fx/test_tracer/test_hf_model/test_hf_gpt.py index 67107469d8bb..31bcb7028e25 100644 --- a/tests/test_fx/test_tracer/test_hf_model/test_hf_gpt.py +++ b/tests/test_fx/test_tracer/test_hf_model/test_hf_gpt.py @@ -12,7 +12,7 @@ def test_gpt(): sub_registry = model_zoo.get_sub_registry('transformers_gpt') - for name, (model_fn, data_gen_fn, _, _) in sub_registry.items(): + for name, (model_fn, data_gen_fn, _, _, _) in sub_registry.items(): model = model_fn() # TODO: support the following models @@ -21,7 +21,7 @@ def test_gpt(): if model.__class__.__name__ in ['GPT2DoubleHeadsModel']: continue - trace_model_and_compare_output(model, data_gen_fn) + trace_model_and_compare_output(model, data_gen_fn, ignore_data=['labels']) if __name__ == '__main__': diff --git a/tests/test_fx/test_tracer/test_hf_model/test_hf_opt.py b/tests/test_fx/test_tracer/test_hf_model/test_hf_opt.py index 369545b03de1..f528db6a64ef 100644 --- a/tests/test_fx/test_tracer/test_hf_model/test_hf_opt.py +++ b/tests/test_fx/test_tracer/test_hf_model/test_hf_opt.py @@ -12,7 +12,7 @@ def test_opt(): sub_registry = model_zoo.get_sub_registry('transformers_opt') - for name, (model_fn, data_gen_fn, _, _) in sub_registry.items(): + for name, (model_fn, data_gen_fn, _, _, _) in sub_registry.items(): model = model_fn() trace_model_and_compare_output(model, data_gen_fn) diff --git a/tests/test_fx/test_tracer/test_hf_model/test_hf_t5.py b/tests/test_fx/test_tracer/test_hf_model/test_hf_t5.py index 811cf3b21430..45e06bc2bbb0 100644 --- a/tests/test_fx/test_tracer/test_hf_model/test_hf_t5.py +++ b/tests/test_fx/test_tracer/test_hf_model/test_hf_t5.py @@ -12,9 +12,14 @@ def test_t5(): sub_registry = model_zoo.get_sub_registry('transformers_t5') - for name, (model_fn, data_gen_fn, _, _) in sub_registry.items(): + for name, (model_fn, data_gen_fn, _, _, _) in sub_registry.items(): + if name == "transformers_t5_for_conditional_generation": + # cannot trace for loss function yet + # so we use a data gen which does not produce labels + data_gen_fn = sub_registry.get('transformers_t5')[1] + model = model_fn() - trace_model_and_compare_output(model, data_gen_fn) + trace_model_and_compare_output(model, data_gen_fn, ignore_data=['labels']) if __name__ == '__main__': diff --git a/tests/test_fx/test_tracer/test_timm_model/test_timm_model.py b/tests/test_fx/test_tracer/test_timm_model/test_timm_model.py index 117c70c84aa8..98433b8f7c3b 100644 --- a/tests/test_fx/test_tracer/test_timm_model/test_timm_model.py +++ b/tests/test_fx/test_tracer/test_timm_model/test_timm_model.py @@ -56,7 +56,7 @@ def test_timm_models(): sub_model_zoo = model_zoo.get_sub_registry('timm') - for name, (model_fn, data_gen_fn, output_transform_fn, _, attribute) in sub_model_zoo.items(): + for name, (model_fn, data_gen_fn, output_transform_fn, _, _, attribute) in sub_model_zoo.items(): data = data_gen_fn() if attribute is not None and attribute.has_control_flow: meta_args = {k: v.to('meta') for k, v in data.items()} diff --git a/tests/test_fx/test_tracer/test_torchaudio_model/test_torchaudio_model.py b/tests/test_fx/test_tracer/test_torchaudio_model/test_torchaudio_model.py index f73c5bb9a590..2b7def5bef85 100644 --- a/tests/test_fx/test_tracer/test_torchaudio_model/test_torchaudio_model.py +++ b/tests/test_fx/test_tracer/test_torchaudio_model/test_torchaudio_model.py @@ -16,7 +16,7 @@ def test_torchaudio_models(): sub_model_zoo = model_zoo.get_sub_registry('torchaudio') - for name, (model_fn, data_gen_fn, output_transform_fn, _, attribute) in sub_model_zoo.items(): + for name, (model_fn, data_gen_fn, output_transform_fn, _, _, attribute) in sub_model_zoo.items(): model = model_fn() trace_and_compare(model, data_gen_fn, diff --git a/tests/test_fx/test_tracer/test_torchrec_model/test_deepfm_model.py b/tests/test_fx/test_tracer/test_torchrec_model/test_deepfm_model.py index df02568c0049..f969c8e6c3da 100644 --- a/tests/test_fx/test_tracer/test_torchrec_model/test_deepfm_model.py +++ b/tests/test_fx/test_tracer/test_torchrec_model/test_deepfm_model.py @@ -53,7 +53,7 @@ def test_torchrec_deepfm_models(): deepfm_models = model_zoo.get_sub_registry('deepfm') torch.backends.cudnn.deterministic = True - for name, (model_fn, data_gen_fn, output_transform_fn, attribute) in deepfm_models.items(): + for name, (model_fn, data_gen_fn, output_transform_fn, _, attribute) in deepfm_models.items(): data = data_gen_fn() if attribute is not None and attribute.has_control_flow: meta_args = {k: v.to('meta') for k, v in data.items()} diff --git a/tests/test_fx/test_tracer/test_torchrec_model/test_dlrm_model.py b/tests/test_fx/test_tracer/test_torchrec_model/test_dlrm_model.py index 9776452be9c8..94fb24f33376 100644 --- a/tests/test_fx/test_tracer/test_torchrec_model/test_dlrm_model.py +++ b/tests/test_fx/test_tracer/test_torchrec_model/test_dlrm_model.py @@ -53,7 +53,7 @@ def test_torchrec_dlrm_models(): torch.backends.cudnn.deterministic = True dlrm_models = model_zoo.get_sub_registry('dlrm') - for name, (model_fn, data_gen_fn, output_transform_fn, attribute) in dlrm_models.items(): + for name, (model_fn, data_gen_fn, output_transform_fn, _, attribute) in dlrm_models.items(): data = data_gen_fn() # dlrm_interactionarch is not supported diff --git a/tests/test_fx/test_tracer/test_torchvision_model/test_torchvision_model.py b/tests/test_fx/test_tracer/test_torchvision_model/test_torchvision_model.py index bd259475ae5a..74cb753e2937 100644 --- a/tests/test_fx/test_tracer/test_torchvision_model/test_torchvision_model.py +++ b/tests/test_fx/test_tracer/test_torchvision_model/test_torchvision_model.py @@ -10,7 +10,7 @@ def test_torchvision_models(): torch.backends.cudnn.deterministic = True tv_sub_registry = model_zoo.get_sub_registry('torchvision') - for name, (model_fn, data_gen_fn, output_transform_fn, model_attribute) in tv_sub_registry.items(): + for name, (model_fn, data_gen_fn, output_transform_fn, _, model_attribute) in tv_sub_registry.items(): data = data_gen_fn() if model_attribute is not None and model_attribute.has_stochastic_depth_prob: diff --git a/tests/test_lazy/lazy_init_utils.py b/tests/test_lazy/lazy_init_utils.py index 3879363bcd1b..73c3c5422d8a 100644 --- a/tests/test_lazy/lazy_init_utils.py +++ b/tests/test_lazy/lazy_init_utils.py @@ -6,6 +6,7 @@ import torch from packaging import version +from colossalai.device.device_mesh import DeviceMesh from colossalai.lazy.lazy_init import LazyInitContext, LazyTensor, _MyTensor from colossalai.tensor.d_tensor import to_global from colossalai.tensor.d_tensor.layout import Layout @@ -82,7 +83,8 @@ def check_lazy_init(entry: TestingEntry, seed: int = 42, verbose: bool = False, print(f'{model.__class__.__name__} pass') -def assert_dist_model_equal(model: torch.nn.Module, distributed_model: torch.nn.Module, layout_dict: dict) -> None: +def assert_dist_model_equal(model: torch.nn.Module, distributed_model: torch.nn.Module, device_mesh: DeviceMesh, + sharding_spec_dict: dict) -> None: state = model.state_dict() distributed_state = distributed_model.state_dict() diff --git a/tests/test_lazy/test_distribute.py b/tests/test_lazy/test_distribute.py index f33c037e3de6..622d9deb601d 100644 --- a/tests/test_lazy/test_distribute.py +++ b/tests/test_lazy/test_distribute.py @@ -26,23 +26,19 @@ def find_shard_dim(shape: torch.Size) -> Optional[int]: return dim -def make_layout(device_mesh: DeviceMesh, original_tensor: torch.Tensor) -> Layout: +def make_sharding_spec(original_tensor: torch.Tensor) -> Layout: shard_dim = find_shard_dim(original_tensor.shape) dim_partition_dict = {shard_dim: [0]} if shard_dim is not None else {} target_sharding_spec = ShardingSpec(dim_size=original_tensor.dim(), dim_partition_dict=dim_partition_dict) - layout = Layout(device_mesh=device_mesh, - device_type=torch.device('cuda'), - sharding_spec=target_sharding_spec, - entire_shape=original_tensor.shape) - return layout + return target_sharding_spec def _get_current_name(prefix: str, name: str) -> str: return f'{prefix}.{name}'.lstrip('.') -def generate_layout_dict(model: nn.Module, device_mesh: DeviceMesh) -> dict: - layout_dict = {} +def generate_sharding_spec_dict(model: nn.Module) -> dict: + sharding_spec_dict = {} @torch.no_grad() def generate_recursively(module: nn.Module, prefix: str = ''): @@ -53,17 +49,17 @@ def generate_recursively(module: nn.Module, prefix: str = ''): # initialize tensors directly attached to the current module for name, param in module.named_parameters(recurse=False): if isinstance(param, LazyTensor): - layout = make_layout(device_mesh, param) - layout_dict[_get_current_name(prefix, name)] = layout + sharding_spec = make_sharding_spec(param) + sharding_spec_dict[_get_current_name(prefix, name)] = sharding_spec for name, buf in module.named_buffers(recurse=False): if isinstance(buf, LazyTensor): - layout = make_layout(device_mesh, buf) - layout_dict[_get_current_name(prefix, name)] = layout + sharding_spec = make_sharding_spec(buf) + sharding_spec_dict[_get_current_name(prefix, name)] = sharding_spec generate_recursively(model) - return layout_dict + return sharding_spec_dict @parameterize('subset', ['torchvision', 'diffusers', 'timm', 'transformers', 'torchaudio', 'deepfm', 'dlrm']) @@ -75,7 +71,7 @@ def run_dist_lazy_init(subset, seed: int = 42): for name, entry in sub_model_zoo.items(): # TODO(ver217): lazy init does not support weight norm, skip these models - if name in ('torchaudio_wav2vec2_base', 'torchaudio_hubert_base'): + if name in ('torchaudio_wav2vec2_base', 'torchaudio_hubert_base') or name.startswith('transformers_llama'): continue print_rank_0(name) model_fn, data_gen_fn, output_transform_fn, _, model_attr = entry @@ -85,9 +81,9 @@ def run_dist_lazy_init(subset, seed: int = 42): ctx = LazyInitContext() with ctx: deferred_model = model_fn() - layout_dict = generate_layout_dict(deferred_model, device_mesh) - ctx.distribute(deferred_model, layout_dict, verbose=True) - assert_dist_model_equal(model, deferred_model, layout_dict) + sharding_spec_dict = generate_sharding_spec_dict(deferred_model) + ctx.distribute(deferred_model, device_mesh, sharding_spec_dict, verbose=True) + assert_dist_model_equal(model, deferred_model, device_mesh, sharding_spec_dict) def run_dist(rank, world_size, port) -> None: diff --git a/tests/test_lazy/test_models.py b/tests/test_lazy/test_models.py index f828b23a94c4..4b7aeed73a69 100644 --- a/tests/test_lazy/test_models.py +++ b/tests/test_lazy/test_models.py @@ -10,7 +10,7 @@ def test_torchvision_models_lazy_init(subset): sub_model_zoo = model_zoo.get_sub_registry(subset) for name, entry in sub_model_zoo.items(): # TODO(ver217): lazy init does not support weight norm, skip these models - if name in ('torchaudio_wav2vec2_base', 'torchaudio_hubert_base'): + if name in ('torchaudio_wav2vec2_base', 'torchaudio_hubert_base') or name.startswith('transformers_llama'): continue check_lazy_init(entry, verbose=True) diff --git a/tests/test_tensor/test_dtensor/test_comm_spec.py b/tests/test_tensor/test_dtensor/test_comm_spec.py index 958eabb65fac..95fcd2aaf8f3 100644 --- a/tests/test_tensor/test_dtensor/test_comm_spec.py +++ b/tests/test_tensor/test_dtensor/test_comm_spec.py @@ -122,23 +122,6 @@ def check_all_reduce_bwd(process_groups_dict, rank): assert tensor_to_comm.equal(tensor_to_check) -def check_all_reduce_in_flatten_device_mesh(process_groups_dict, rank): - # tensor to comm - tensor_to_comm = torch.ones(2, 2).cuda() * rank - - # reduce through logical process axis 0 at flatten device mesh - # tensor to check - # tensor([[6., 6.], - # [6., 6.]]) - tensor_to_check = torch.tensor([[6, 6], [6, 6]], dtype=tensor_to_comm.dtype).cuda() - - # CommSpec:(comm_pattern:all_reduce, logical_process_axis:[0, 1]) - comm_spec = CommSpec(CollectiveCommPattern.ALLREDUCE_FWD_IDENTITY_BWD, process_groups_dict, logical_process_axis=0) - tensor_to_comm = comm_spec.covert_spec_to_action(tensor_to_comm) - - assert tensor_to_comm.equal(tensor_to_check) - - def check_comm(rank, world_size, port): disable_existing_loggers() launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') @@ -150,24 +133,22 @@ def check_comm(rank, world_size, port): # [[0, 1, # [2, 3]] device_mesh = DeviceMesh(physical_mesh_id, mesh_shape, init_process_group=True) - process_groups_dict = device_mesh.process_groups_dict + + process_group_dict = device_mesh._process_group_dict[rank] # test all gather - check_all_gather(process_groups_dict, rank) + check_all_gather(process_group_dict, rank) # test shard - check_shard(process_groups_dict, rank) + check_shard(process_group_dict, rank) # test all to all - check_all_to_all(process_groups_dict, rank) + check_all_to_all(process_group_dict, rank) # test all reduce - check_all_reduce_fwd(process_groups_dict, rank) - check_all_reduce_bwd(process_groups_dict, rank) + check_all_reduce_fwd(process_group_dict, rank) + check_all_reduce_bwd(process_group_dict, rank) - flatten_process_groups_dict = device_mesh.flatten_device_mesh.process_groups_dict - # test all reduce in 1D flatten device mesh - check_all_reduce_in_flatten_device_mesh(flatten_process_groups_dict, rank) gpc.destroy() diff --git a/tests/test_tensor/test_dtensor/test_dtensor.py b/tests/test_tensor/test_dtensor/test_dtensor.py index 8350fb3e7fe6..5a1aef79f332 100644 --- a/tests/test_tensor/test_dtensor/test_dtensor.py +++ b/tests/test_tensor/test_dtensor/test_dtensor.py @@ -64,7 +64,7 @@ def check_dtensor(rank, world_size, port): else: raise ValueError(f'rank {rank} is not in the device mesh') - dtensor_from_local = distribute_tensor(original_tensor, new_layout) + dtensor_from_local = distribute_tensor(original_tensor, device_mesh, new_sharding_spec) if rank == 0: assert dtensor_from_local.equal(original_tensor.narrow(0, 0, 1)) diff --git a/tests/test_tensor/test_dtensor/test_layout_converter.py b/tests/test_tensor/test_dtensor/test_layout_converter.py index d9dff8af933d..5388fd901e09 100644 --- a/tests/test_tensor/test_dtensor/test_layout_converter.py +++ b/tests/test_tensor/test_dtensor/test_layout_converter.py @@ -12,9 +12,9 @@ from colossalai.tensor.d_tensor.sharding_spec import ShardingSpec from colossalai.testing import rerun_if_address_is_in_use, spawn -entire_shape = torch.Size((64, 32, 16)) +global_shape = torch.Size((64, 32, 16)) layout_converter = LayoutConverter() -physical_mesh_id = torch.arange(0, 4).reshape(2, 2) +physical_mesh_id = torch.arange(0, 4) mesh_shape = (2, 2) @@ -30,10 +30,7 @@ def check_one_step_transform(rank, world_size, port): # shard_sequence: S0,S1,R # device_mesh_shape: (2, 2) sharding_spec = ShardingSpec(dim_size=3, dim_partition_dict=dim_partition_dict) - layout = Layout(device_mesh=device_mesh, - device_type=torch.device('cuda'), - sharding_spec=sharding_spec, - entire_shape=entire_shape) + layout = Layout(device_mesh=device_mesh, sharding_spec=sharding_spec, global_shape=global_shape) rst_dict = layout_converter.all_gather_transform_layouts(layout) @@ -49,10 +46,7 @@ def check_one_step_transform(rank, world_size, port): # shard_sequence: S0,S1,R # device_mesh_shape: (4, 4) sharding_spec_all2all = ShardingSpec(dim_size=3, dim_partition_dict=dim_partition_dict_all2all) - layout_all2all = Layout(device_mesh=device_mesh, - device_type=torch.device('cuda'), - sharding_spec=sharding_spec_all2all, - entire_shape=entire_shape) + layout_all2all = Layout(device_mesh=device_mesh, sharding_spec=sharding_spec_all2all, global_shape=global_shape) rst_dict_all2all = layout_converter.all_to_all_transform_layout(layout_all2all) @@ -71,10 +65,7 @@ def check_one_step_transform(rank, world_size, port): # shard_sequence: S0,R,R # device_mesh_shape: (4, 4) sharding_spec_shard = ShardingSpec(dim_size=3, dim_partition_dict=dim_partition_shard) - shard_layout = Layout(device_mesh=device_mesh, - device_type=torch.device('cuda'), - sharding_spec=sharding_spec_shard, - entire_shape=entire_shape) + shard_layout = Layout(device_mesh=device_mesh, sharding_spec=sharding_spec_shard, global_shape=global_shape) rst_dict_shard = layout_converter.shard_transform_layout(shard_layout) @@ -100,19 +91,13 @@ def check_layout_converting(rank, world_size, port): # shard_sequence: R,S01,R # device_mesh_shape: (4, 4) sharding_spec_source = ShardingSpec(dim_size=3, dim_partition_dict=dim_partition_source) - source_layout = Layout(device_mesh=device_mesh, - device_type=torch.device('cuda'), - sharding_spec=sharding_spec_source, - entire_shape=entire_shape) + source_layout = Layout(device_mesh=device_mesh, sharding_spec=sharding_spec_source, global_shape=global_shape) # DistSpec: # shard_sequence: S01,R,R # device_mesh_shape: (4, 4) sharding_spec_target = ShardingSpec(dim_size=3, dim_partition_dict=dim_partition_target) - target_layout = Layout(device_mesh=device_mesh, - device_type=torch.device('cuda'), - sharding_spec=sharding_spec_target, - entire_shape=entire_shape) + target_layout = Layout(device_mesh=device_mesh, sharding_spec=sharding_spec_target, global_shape=global_shape) transform_path, comm_action_sequence = layout_converter.layout_converting(source_layout, target_layout) @@ -137,7 +122,7 @@ def check_layout_converting(rank, world_size, port): assert comm_action_sequence[2].shard_dim == 0 assert comm_action_sequence[2].logical_process_axis == 1 - # checkout cached_spec_pairs_transform_path + # checkout chached_spec_pairs_transform_path assert layout_converter.cached_solution[('[R, S01, R]', '[S01, R, R]')][0] == transform_path assert layout_converter.cached_solution[('[R, S01, R]', '[S01, R, R]')][1] == comm_action_sequence @@ -159,21 +144,15 @@ def check_layout_converting_apply(rank, world_size, port): # shard_sequence: R,S01,R # device_mesh_shape: (4, 4) sharding_spec_source = ShardingSpec(dim_size=3, dim_partition_dict=dim_partition_source) - source_layout = Layout(device_mesh=device_mesh, - device_type=torch.device('cuda'), - sharding_spec=sharding_spec_source, - entire_shape=entire_shape) + source_layout = Layout(device_mesh=device_mesh, sharding_spec=sharding_spec_source, global_shape=global_shape) # DistSpec: # shard_sequence: S01,R,R # device_mesh_shape: (4, 4) sharding_spec_target = ShardingSpec(dim_size=3, dim_partition_dict=dim_partition_target) - target_layout = Layout(device_mesh=device_mesh, - device_type=torch.device('cuda'), - sharding_spec=sharding_spec_target, - entire_shape=entire_shape) + target_layout = Layout(device_mesh=device_mesh, sharding_spec=sharding_spec_target, global_shape=global_shape) - original_tensor = torch.rand(entire_shape).cuda() + original_tensor = torch.rand(global_shape).cuda() # tensor_to_apply: [R, S01, R] tensor_to_apply = original_tensor.narrow(1, rank * 8, 8) diff --git a/tests/test_tensor/test_shape_consistency.py b/tests/test_tensor/test_shape_consistency.py index 6fe9ee292cd0..859eef051256 100644 --- a/tests/test_tensor/test_shape_consistency.py +++ b/tests/test_tensor/test_shape_consistency.py @@ -1,9 +1,10 @@ -from colossalai.tensor.shape_consistency import ShapeConsistencyManager, CollectiveCommPattern import torch -from colossalai.tensor.sharding_spec import _DimSpec, ShardingSpec + from colossalai.device.device_mesh import DeviceMesh +from colossalai.tensor.shape_consistency import CollectiveCommPattern, ShapeConsistencyManager +from colossalai.tensor.sharding_spec import ShardingSpec, _DimSpec -physical_mesh_id = torch.arange(0, 16).reshape(2, 8) +physical_mesh_id = torch.arange(0, 16) mesh_shape = (4, 4) # [[0, 1, 2, 3], # [4, 5, 6, 7], diff --git a/tests/test_tensor/test_sharded_linear.py b/tests/test_tensor/test_sharded_linear.py index d66d4fec14d1..9bd9805e9b8f 100644 --- a/tests/test_tensor/test_sharded_linear.py +++ b/tests/test_tensor/test_sharded_linear.py @@ -26,7 +26,7 @@ def run_dist(rank, world_size, port): # the mesh is in the following topo # [[0, 1], # [2, 3]] - physical_mesh_id = torch.arange(0, 4).reshape(2, 2) + physical_mesh_id = torch.arange(0, 4) mesh_shape = (2, 2) device_mesh = DeviceMesh(physical_mesh_id, mesh_shape) row_id = rank // 2 diff --git a/tests/test_tensor/test_sharding_spec.py b/tests/test_tensor/test_sharding_spec.py index 909c84ef0f0e..5007c4141849 100644 --- a/tests/test_tensor/test_sharding_spec.py +++ b/tests/test_tensor/test_sharding_spec.py @@ -5,7 +5,7 @@ def test_sharding_spec(): - physical_mesh_id = torch.arange(0, 16).reshape(2, 8) + physical_mesh_id = torch.arange(0, 16) mesh_shape = (4, 4) # [[0, 1, 2, 3], # [4, 5, 6, 7], From 4e0db9985e99dde68148093ff84af21c4407322a Mon Sep 17 00:00:00 2001 From: Frank Lee Date: Mon, 26 Jun 2023 18:05:00 +0800 Subject: [PATCH 37/49] [shardformer] refactored layernorm (#4086) --- colossalai/shardformer/layer/__init__.py | 4 +- colossalai/shardformer/layer/layernorm.py | 101 +++++++----------- colossalai/shardformer/policies/bert.py | 12 +-- .../test_layer/test_layernorm.py | 11 +- 4 files changed, 51 insertions(+), 77 deletions(-) diff --git a/colossalai/shardformer/layer/__init__.py b/colossalai/shardformer/layer/__init__.py index 3ce0ef68aa4f..3ece2583132c 100644 --- a/colossalai/shardformer/layer/__init__.py +++ b/colossalai/shardformer/layer/__init__.py @@ -1,11 +1,11 @@ from .dropout import Dropout1D from .embedding import Embedding1D, VocabParallelEmbedding1D -from .layernorm import LayerNorm1D +from .layernorm import FusedLayerNorm from .linear import Linear1D_Col, Linear1D_Row from .linear_conv import LinearConv1D_Col, LinearConv1D_Row from .loss import cross_entropy_1d __all__ = [ "Embedding1D", "VocabParallelEmbedding1D", "Linear1D_Col", "Linear1D_Row", "LinearConv1D_Col", "LinearConv1D_Row", - "Dropout1D", "cross_entropy_1d", 'LayerNorm1D' + "Dropout1D", "cross_entropy_1d", 'FusedLayerNorm' ] diff --git a/colossalai/shardformer/layer/layernorm.py b/colossalai/shardformer/layer/layernorm.py index a8e1d7a2c082..83854239cf90 100644 --- a/colossalai/shardformer/layer/layernorm.py +++ b/colossalai/shardformer/layer/layernorm.py @@ -1,89 +1,64 @@ #!/usr/bin/env python # -*- encoding: utf-8 -*- -from typing import List, Union - import torch import torch.nn as nn -from torch.distributed import ProcessGroup - -from colossalai.kernel import LayerNorm -from colossalai.nn import init as init -from .parallel_module import ParallelModule +__all__ = ['FusedLayerNorm'] -__all__ = ['LayerNorm1D'] +FAST_LAYERNORM_SUPPORTED_SIZE = [ + 1024, 1536, 2048, 2304, 3072, 3840, 4096, 5120, 6144, 8192, 10240, 12288, 12800, 15360, 16384, 18432, 20480, 24576, + 25600, 30720, 32768, 40960, 49152, 65536 +] -Fast_LN = None -try: - from apex.contrib.layer_norm.layer_norm import FastLayerNorm - Fast_LN = FastLayerNorm -except ImportError: - pass - -class LayerNorm1D(ParallelModule): +class FusedLayerNorm(): r""" - Layer Normalization for colossalai - - Args: - normalized_shape (int): input shape from an expected input of size. - :math:`[* \times \text{normalized_shape}[0] \times \text{normalized_shape}[1] - \times \ldots \times \text{normalized_shape}[-1]]` - If a single integer is used, it is treated as a singleton list, and this module will - normalize over the last dimension which is expected to be of that specific size. - eps (float): a value added to the denominator for numerical stability, defaults to 1e-05. - bias (bool, optional): Whether to add a bias, defaults to ``True``. - dtype (:class:`torch.dtype`, optional): The dtype of parameters, defaults to None. + This is a wrapper around the apex fused layernorm implementation. It is meant to be used only with the from_native_module interface. """ - _fast_ln_supported_sizes = [ - 1024, 1536, 2048, 2304, 3072, 3840, 4096, 5120, 6144, 8192, 10240, 12288, 12800, 15360, 16384, 18432, 20480, - 24576, 25600, 30720, 32768, 40960, 49152, 65536 - ] - - def __init__(self, - normalized_shape: int, - eps: int = 1e-05, - bias: bool = True, - dtype: torch.dtype = None, - device: torch.device = None): - super().__init__() - if Fast_LN is not None and normalized_shape in self._fast_ln_supported_sizes: - norm = Fast_LN(normalized_shape, eps=eps).to(dtype) - else: - norm = None - try: - from apex.normalization import FusedLayerNorm - norm = FusedLayerNorm(normalized_shape, eps=eps).to(dtype) - except ImportError: - norm = LayerNorm(normalized_shape, eps=eps, device=device, dtype=dtype) - self.norm = norm + def __init__(self) -> None: + raise NotImplementedError( + 'FusedLayerNorm is not implemented as a physical class. ' + 'It is meant to be used only with the from_native_module interface to wrap the fused layernorm implementation provided by apex.' + ) @staticmethod - def from_native_module(module: nn.LayerNorm, process_group: Union[ProcessGroup, List[ProcessGroup]], *args, - **kwargs) -> ParallelModule: + def from_native_module(module: nn.LayerNorm, *args, **kwargs) -> nn.Module: r""" Convert a native pytorch layer norm module to colossalai layer norm module """ + # check if apex is installed + try: + import apex + except ImportError: + raise ImportError( + 'Please install apex from source (https://github.com/NVIDIA/apex) to use the fused layernorm kernel') + + # get the attributes of the module normalized_shape = module.normalized_shape eps = module.eps - bias = module.bias is not None + elementwise_affine = module.elementwise_affine dtype = module.weight.dtype device = module.weight.device - # ensure only one process group is passed - if isinstance(process_group, (list, tuple)): - assert len(process_group) == 1, \ - f'Expected only one process group, got {len(process_group)}.' - process_group = process_group[0] + # pick the suitable layernorm implementation + use_fast_ln = normalized_shape in FAST_LAYERNORM_SUPPORTED_SIZE + + if use_fast_ln: + try: + from apex.contrib.layer_norm.layer_norm import FastLayerNorm as ApexFusedLayerNorm + except ImportError: + # fall back to the normal fused layernorm is not built + from apex.normalization import FusedLayerNorm as ApexFusedLayerNorm + else: + from apex.normalization import FusedLayerNorm as ApexFusedLayerNorm - # create layer norm - layer_norm = LayerNorm1D(normalized_shape, eps=eps, bias=bias, device=device, dtype=dtype).norm + layernorm = ApexFusedLayerNorm(normalized_shape, eps=eps, + elementwise_affine=elementwise_affine).to(dtype).to(device) with torch.no_grad(): # copy weight and bias - layer_norm.weight.copy_(module.weight) - if bias: - layer_norm.bias.copy_(module.bias) - return layer_norm + layernorm.weight.copy_(module.weight) + layernorm.bias.copy_(module.bias) + return layernorm diff --git a/colossalai/shardformer/policies/bert.py b/colossalai/shardformer/policies/bert.py index 1baf67ef9c02..7b0eaa5d8ab1 100644 --- a/colossalai/shardformer/policies/bert.py +++ b/colossalai/shardformer/policies/bert.py @@ -103,17 +103,17 @@ def module_policy(self): base_policy[BertLayer].sub_module_replacement.append( SubModuleReplacementDescription( suffix="attention.output.LayerNorm", - target_module=col_nn.LayerNorm1D, + target_module=col_nn.FusedLayerNorm, )) base_policy[BertLayer].sub_module_replacement.append( SubModuleReplacementDescription( suffix="output.LayerNorm", - target_module=col_nn.LayerNorm1D, + target_module=col_nn.FusedLayerNorm, )) base_policy[BertEmbeddings].sub_module_replacement.append( SubModuleReplacementDescription( suffix="LayerNorm", - target_module=col_nn.LayerNorm1D, + target_module=col_nn.FusedLayerNorm, ),) return base_policy @@ -154,7 +154,7 @@ def module_policy(self): addon_module[BertLMPredictionHead].sub_module_replacement.append( SubModuleReplacementDescription( suffix="transform.LayerNorm", - target_module=col_nn.LayerNorm1D, + target_module=col_nn.FusedLayerNorm, )) module_policy.update(addon_module) return module_policy @@ -191,7 +191,7 @@ def module_policy(self): addon_module[BertLMPredictionHead].sub_module_replacement.append( SubModuleReplacementDescription( suffix="transform.LayerNorm", - target_module=col_nn.LayerNorm1D, + target_module=col_nn.FusedLayerNorm, )) module_policy.update(addon_module) return module_policy @@ -228,7 +228,7 @@ def module_policy(self): addon_module[BertLMPredictionHead].sub_module_replacement.append( SubModuleReplacementDescription( suffix="transform.LayerNorm", - target_module=col_nn.LayerNorm1D, + target_module=col_nn.FusedLayerNorm, )) module_policy.update(addon_module) return module_policy diff --git a/tests/test_shardformer/test_layer/test_layernorm.py b/tests/test_shardformer/test_layer/test_layernorm.py index 334ae05bed95..a117845545be 100644 --- a/tests/test_shardformer/test_layer/test_layernorm.py +++ b/tests/test_shardformer/test_layer/test_layernorm.py @@ -1,16 +1,15 @@ import torch -import torch.distributed as dist import torch.nn as nn from torch.testing import assert_close import colossalai -from colossalai.shardformer.layer import LayerNorm1D +from colossalai.shardformer.layer import FusedLayerNorm from colossalai.testing import rerun_if_address_is_in_use, spawn -def check_layernorm_1d(): +def check_layernorm(): norm = nn.LayerNorm(128, 0.00001).cuda() - norm1d = LayerNorm1D.from_native_module(norm, process_group=None) + norm1d = FusedLayerNorm.from_native_module(norm, process_group=None) assert norm1d.weight.shape == torch.Size([128]) @@ -33,11 +32,11 @@ def check_layernorm_1d(): def run_dist(rank, world_size, port): colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') - check_layernorm_1d() + check_layernorm() @rerun_if_address_is_in_use() -def test_layernorm_1d(): +def test_layernorm(): spawn(run_dist, nprocs=2) From a7433a099b59f5dbdfa46f53f4a6fe5d2c8ecc93 Mon Sep 17 00:00:00 2001 From: jiangmingyan <1829166702@qq.com> Date: Tue, 27 Jun 2023 17:39:29 +0800 Subject: [PATCH 38/49] [shardformer] shardformer support opt models (#4091) * [shardformer] shardformer support opt models * [shardformer] shardformer support opt models, fix * [shardformer] shardformer support opt models, fix * [shardformer] shardformer support opt models, fix --- colossalai/shardformer/policies/autopolicy.py | 10 ++ colossalai/shardformer/policies/opt.py | 133 ++++++++++++++++++ tests/kit/model_zoo/transformers/opt.py | 57 +++++++- .../test_tracer/test_hf_model/test_hf_opt.py | 3 +- tests/test_shardformer/test_model/_utils.py | 4 +- .../test_model/test_shard_opt.py | 67 +++++++++ 6 files changed, 264 insertions(+), 10 deletions(-) create mode 100644 colossalai/shardformer/policies/opt.py create mode 100644 tests/test_shardformer/test_model/test_shard_opt.py diff --git a/colossalai/shardformer/policies/autopolicy.py b/colossalai/shardformer/policies/autopolicy.py index b1b8c6156f9f..9cc583d58b11 100644 --- a/colossalai/shardformer/policies/autopolicy.py +++ b/colossalai/shardformer/policies/autopolicy.py @@ -68,6 +68,16 @@ class PolicyLocation: PolicyLocation(file_name="gpt2", class_name="GPT2ForTokenClassificationPolicy"), "transformers.models.gpt2.modeling_gpt2.GPT2ForSequenceClassification": PolicyLocation(file_name="gpt2", class_name="GPT2ForSequenceClassificationPolicy"), + + # OPT + "transformers.models.opt.modeling_opt.OPTModel": + PolicyLocation(file_name="opt", class_name="OPTModelPolicy"), + "transformers.models.opt.modeling_opt.OPTForCausalLM": + PolicyLocation(file_name="opt", class_name="OPTForCausalLMPolicy"), + "transformers.models.opt.modeling_opt.OPTForSequenceClassification": + PolicyLocation(file_name="opt", class_name="OPTForSequenceClassificationPolicy"), + "transformers.models.opt.modeling_opt.OPTForQuestionAnswering": + PolicyLocation(file_name="opt", class_name="OPTForQuestionAnsweringPolicy"), } diff --git a/colossalai/shardformer/policies/opt.py b/colossalai/shardformer/policies/opt.py new file mode 100644 index 000000000000..f467726e5580 --- /dev/null +++ b/colossalai/shardformer/policies/opt.py @@ -0,0 +1,133 @@ +from transformers.models.opt.modeling_opt import ( + OPTAttention, + OPTDecoder, + OPTDecoderLayer, + OPTForCausalLM, + OPTForSequenceClassification, +) + +from colossalai.shardformer.layer import Embedding1D, FusedLayerNorm, Linear1D_Col, Linear1D_Row + +from .basepolicy import ModulePolicyDescription, Policy, SubModuleReplacementDescription + + +class OPTPolicy(Policy): + + def preprocess(self): + # reshape the embedding layer + r""" + Reshape the Embedding layer to make the embedding dimension divisible by world_size + """ + vocab_size = self.model.config.vocab_size + world_size = self.shard_config.tensor_parallel_size + if vocab_size % world_size != 0: + new_vocab_size = vocab_size + world_size - vocab_size % world_size + self.model.resize_token_embeddings(new_vocab_size) + return self.model + + def module_policy(self): + base_policy = { + OPTDecoder: + ModulePolicyDescription(attribute_replacement={}, + param_replacement=[], + sub_module_replacement=[ + SubModuleReplacementDescription( + suffix="embed_tokens", + target_module=Embedding1D, + ) + ]), + OPTDecoderLayer: + ModulePolicyDescription(attribute_replacement={}, + param_replacement=[], + sub_module_replacement=[ + SubModuleReplacementDescription( + suffix="fc1", + target_module=Linear1D_Col, + ), + SubModuleReplacementDescription( + suffix="fc2", + target_module=Linear1D_Row, + ) + ]), + OPTAttention: + ModulePolicyDescription(attribute_replacement={ + "embed_dim": self.model.config.hidden_size // self.shard_config.tensor_parallel_size, + "num_heads": self.model.config.num_attention_heads // self.shard_config.tensor_parallel_size + }, + param_replacement=[], + sub_module_replacement=[ + SubModuleReplacementDescription( + suffix="q_proj", + target_module=Linear1D_Col, + ), + SubModuleReplacementDescription( + suffix="k_proj", + target_module=Linear1D_Col, + ), + SubModuleReplacementDescription( + suffix="v_proj", + target_module=Linear1D_Col, + ), + SubModuleReplacementDescription( + suffix="out_proj", + target_module=Linear1D_Row, + ), + ]), + } + if self.shard_config.fused_layernorm: + base_policy[OPTDecoder].sub_module_replacement.append( + SubModuleReplacementDescription(suffix="final_layer_norm", + target_module=FusedLayerNorm, + ignore_if_not_exist=True)) + base_policy[OPTDecoderLayer].sub_module_replacement.extend([ + SubModuleReplacementDescription(suffix="self_attn_layer_norm", + target_module=FusedLayerNorm, + ignore_if_not_exist=True), + SubModuleReplacementDescription(suffix="final_layer_norm", + target_module=FusedLayerNorm, + ignore_if_not_exist=True) + ]) + return base_policy + + def new_model_class(self): + return None + + def postprocess(self): + return self.model + + +class OPTModelPolicy(OPTPolicy): + + def __init__(self) -> None: + super().__init__() + + +class OPTForCausalLMPolicy(OPTPolicy): + + def module_policy(self): + policy = super().module_policy() + new_item = { + OPTForCausalLM: + ModulePolicyDescription(attribute_replacement={}, + param_replacement=[], + sub_module_replacement=[ + SubModuleReplacementDescription(suffix="lm_head", + target_module=Linear1D_Col, + kwargs=dict(gather_output=True)) + ]) + } + + policy.update(new_item) + return policy + + +class OPTForSequenceClassificationPolicy(OPTPolicy): + + def __init__(self) -> None: + super().__init__() + + +class OPTForQuestionAnsweringPolicy(OPTPolicy): + + def __init__(self) -> None: + super().__init__() diff --git a/tests/kit/model_zoo/transformers/opt.py b/tests/kit/model_zoo/transformers/opt.py index d9c4a0b3c23c..4463ae12b901 100644 --- a/tests/kit/model_zoo/transformers/opt.py +++ b/tests/kit/model_zoo/transformers/opt.py @@ -11,14 +11,47 @@ def data_gen(): - input_ids = torch.zeros((BATCH_SIZE, SEQ_LENGTH), dtype=torch.int64) - attention_mask = torch.zeros((BATCH_SIZE, SEQ_LENGTH), dtype=torch.int64) + input_ids = torch.Tensor([[1, 15043, 29892, 590, 11203, 338, 274, 1082]]).long() + attention_mask = torch.Tensor([[1, 1, 1, 1, 1, 1, 1, 1]]).long() return dict(input_ids=input_ids, attention_mask=attention_mask) -output_transform_fn = lambda x: x +def data_gen_for_causal_lm(): + # LM data gen + # the `labels` of LM is the token of the output, cause no padding, use `input_ids` as `labels` + data = data_gen() + labels = data['input_ids'].clone() + data['labels'] = labels + return data + + +def data_gen_for_sequence_classification(): + # LM data gen + # the `labels` of LM is the token of the output, cause no padding, use `input_ids` as `labels` + data = data_gen() + labels = data['input_ids'].clone() + data['labels'] = torch.tensor([1]) + return data + + +def data_gen_for_question_answering(): + # LM data gen + # the `labels` of LM is the token of the output, cause no padding, use `input_ids` as `labels` + data = data_gen() + data['start_positions'] = torch.tensor([0]) + data['end_positions'] = torch.tensor([1]) + return data + -config = transformers.OPTConfig(hidden_size=128, num_hidden_layers=2, num_attention_heads=4) +output_transform_fn = lambda x: x +loss_fn_for_opt_model = lambda x: x.last_hidden_state.mean() +loss_fn_for_lm = lambda x: x.loss +config = transformers.OPTConfig( + hidden_size=128, + num_hidden_layers=2, + num_attention_heads=4, + dropout=0, +) # register the following models # transformers.OPTModel, @@ -27,9 +60,23 @@ def data_gen(): model_fn=lambda: transformers.OPTModel(config), data_gen_fn=data_gen, output_transform_fn=output_transform_fn, + loss_fn=loss_fn_for_opt_model, model_attribute=ModelAttribute(has_control_flow=True)) model_zoo.register(name='transformers_opt_for_causal_lm', model_fn=lambda: transformers.OPTForCausalLM(config), - data_gen_fn=data_gen, + data_gen_fn=data_gen_for_causal_lm, + output_transform_fn=output_transform_fn, + loss_fn=loss_fn_for_lm, + model_attribute=ModelAttribute(has_control_flow=True)) +model_zoo.register(name='transformers_opt_for_question_answering', + model_fn=lambda: transformers.OPTForQuestionAnswering(config), + data_gen_fn=data_gen_for_question_answering, + output_transform_fn=output_transform_fn, + loss_fn=loss_fn_for_lm, + model_attribute=ModelAttribute(has_control_flow=True)) +model_zoo.register(name='transformers_opt_for_sequence_classification', + model_fn=lambda: transformers.OPTForSequenceClassification(config), + data_gen_fn=data_gen_for_sequence_classification, output_transform_fn=output_transform_fn, + loss_fn=loss_fn_for_lm, model_attribute=ModelAttribute(has_control_flow=True)) diff --git a/tests/test_fx/test_tracer/test_hf_model/test_hf_opt.py b/tests/test_fx/test_tracer/test_hf_model/test_hf_opt.py index f528db6a64ef..c68b89e82fbe 100644 --- a/tests/test_fx/test_tracer/test_hf_model/test_hf_opt.py +++ b/tests/test_fx/test_tracer/test_hf_model/test_hf_opt.py @@ -11,10 +11,9 @@ @clear_cache_before_run() def test_opt(): sub_registry = model_zoo.get_sub_registry('transformers_opt') - for name, (model_fn, data_gen_fn, _, _, _) in sub_registry.items(): model = model_fn() - trace_model_and_compare_output(model, data_gen_fn) + trace_model_and_compare_output(model, data_gen_fn, ignore_data=['labels', 'start_positions', 'end_positions']) if __name__ == '__main__': diff --git a/tests/test_shardformer/test_model/_utils.py b/tests/test_shardformer/test_model/_utils.py index a282e0bb919e..ad7c408aeb38 100644 --- a/tests/test_shardformer/test_model/_utils.py +++ b/tests/test_shardformer/test_model/_utils.py @@ -25,7 +25,6 @@ def run_forward(original_model, sharded_model, data_gen_fn, output_transform_fn, # switch to train mode original_model.train() sharded_model.train() - # run forward org_output = original_model(**data) org_output = output_transform_fn(org_output) @@ -34,5 +33,4 @@ def run_forward(original_model, sharded_model, data_gen_fn, output_transform_fn, shard_output = sharded_model(**data) shard_output = output_transform_fn(shard_output) shard_loss = loss_fn(shard_output) - - return org_output, org_loss, shard_output, shard_loss + return org_output, org_loss, shard_output, shard_loss \ No newline at end of file diff --git a/tests/test_shardformer/test_model/test_shard_opt.py b/tests/test_shardformer/test_model/test_shard_opt.py new file mode 100644 index 000000000000..4d4c55770144 --- /dev/null +++ b/tests/test_shardformer/test_model/test_shard_opt.py @@ -0,0 +1,67 @@ +import copy +import os + +import pytest +import torch + +import colossalai +from colossalai.logging import disable_existing_loggers +from colossalai.testing import ( + assert_hf_output_close, + check_state_dict_equal, + clear_cache_before_run, + rerun_if_address_is_in_use, + spawn, +) +from tests.kit.model_zoo import model_zoo +from tests.test_shardformer.test_model._utils import build_model, run_forward + +os.environ['TRANSFORMERS_NO_ADVISORY_WARNINGS'] = 'true' + + +def check_forward_backward(org_model, sharded_model, data_gen_fn, output_transform_fn, loss_fn): + org_output, org_loss, shard_output, shard_loss = run_forward(org_model, sharded_model, data_gen_fn, + output_transform_fn, loss_fn) + assert_hf_output_close(org_output, shard_output, ignore_keys=['past_key_values'], rtol=1e-4) + + # run backward + org_loss.backward() + shard_loss.backward() + + # check grad + if hasattr(org_model, 'model'): + opt_model = org_model.model + shard_opt_model = sharded_model.model + else: + opt_model = org_model + shard_opt_model = sharded_model + + org_grad = opt_model.decoder.layers[0].self_attn.q_proj.weight.grad + shard_grad = shard_opt_model.decoder.layers[0].self_attn.q_proj.weight.grad + + shard_grad_list = [torch.zeros([*shard_grad.shape]).to('cuda') for _ in range(4)] + shard_grad = torch.distributed.all_gather(shard_grad_list, shard_grad) + all_shard_grad = torch.cat(shard_grad_list, dim=0) + assert torch.allclose(org_loss, shard_loss, + atol=1e-5), f"shard model loss is not equal to orgin model loss\n{org_loss}\n{shard_loss}" + assert torch.allclose(org_grad, all_shard_grad, + atol=1e-5), f"shard model grad is not equal to orgin model grad\n{org_grad}\n{shard_grad}" + + +def check_OPTModel(rank, world_size, port): + disable_existing_loggers() + colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') + + sub_model_zoo = model_zoo.get_sub_registry('transformers_opt') + for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items(): + org_model, sharded_model = build_model(world_size, model_fn) + check_forward_backward(org_model, sharded_model, data_gen_fn, output_transform_fn, loss_fn) + + torch.cuda.empty_cache() + + +@pytest.mark.dist +@rerun_if_address_is_in_use() +@clear_cache_before_run() +def test_OPTModel(): + spawn(check_OPTModel, 4) From ad604f7d2ed1b24f7455bd1b6b271e71514142a3 Mon Sep 17 00:00:00 2001 From: Kun Lin <81014421+klhhhhh@users.noreply.github.com> Date: Wed, 28 Jun 2023 13:28:18 +0800 Subject: [PATCH 39/49] [shardformer] support vision transformer (#4096) * first v of vit shardformer * keep vit * update * vit shard add vitattention vitlayer * update num head shard para * finish test for vit * add new_model_class & postprocess * add vit readme * delete old files & fix the conflict * fix sth --- colossalai/shardformer/README.md | 2 +- colossalai/shardformer/layer/_operation.py | 2 +- colossalai/shardformer/layer/layernorm.py | 2 +- colossalai/shardformer/policies/bert.py | 2 +- colossalai/shardformer/policies/t5.py | 2 +- colossalai/shardformer/policies/vit.py | 96 +++++++++++++++++++ tests/test_device/test_device_mesh.py | 2 +- .../test_layer/test_layernorm.py | 2 +- .../test_model/test_shard_t5.py | 2 +- .../test_model/test_shard_vit.py | 55 +++++++++++ 10 files changed, 159 insertions(+), 8 deletions(-) create mode 100644 colossalai/shardformer/policies/vit.py create mode 100644 tests/test_shardformer/test_model/test_shard_vit.py diff --git a/colossalai/shardformer/README.md b/colossalai/shardformer/README.md index fee4cce7a28a..da80a7276b68 100644 --- a/colossalai/shardformer/README.md +++ b/colossalai/shardformer/README.md @@ -91,7 +91,7 @@ We will follow this roadmap to develop Shardformer: - [ ] GPT Neo - [ ] GPT-J - [ ] CV - - [ ] ViT + - [x] ViT - [ ] BEiT - [ ] SwinTransformer - [ ] SwinTransformer V2 diff --git a/colossalai/shardformer/layer/_operation.py b/colossalai/shardformer/layer/_operation.py index 7e97bee01b33..c025daaeccc7 100644 --- a/colossalai/shardformer/layer/_operation.py +++ b/colossalai/shardformer/layer/_operation.py @@ -287,4 +287,4 @@ def reduce_forward(input_, process_group): def reduce_backward(input_, process_group): - return _ReduceBackward.apply(input_, process_group) + return _ReduceBackward.apply(input_, process_group) \ No newline at end of file diff --git a/colossalai/shardformer/layer/layernorm.py b/colossalai/shardformer/layer/layernorm.py index 83854239cf90..6103380fe8a5 100644 --- a/colossalai/shardformer/layer/layernorm.py +++ b/colossalai/shardformer/layer/layernorm.py @@ -61,4 +61,4 @@ def from_native_module(module: nn.LayerNorm, *args, **kwargs) -> nn.Module: # copy weight and bias layernorm.weight.copy_(module.weight) layernorm.bias.copy_(module.bias) - return layernorm + return layernorm \ No newline at end of file diff --git a/colossalai/shardformer/policies/bert.py b/colossalai/shardformer/policies/bert.py index 7b0eaa5d8ab1..fb70cdff8824 100644 --- a/colossalai/shardformer/policies/bert.py +++ b/colossalai/shardformer/policies/bert.py @@ -316,4 +316,4 @@ def module_policy(self): ]) } module_policy.update(addon_module) - return module_policy + return module_policy \ No newline at end of file diff --git a/colossalai/shardformer/policies/t5.py b/colossalai/shardformer/policies/t5.py index 30433f751088..9a1b63e46d2c 100644 --- a/colossalai/shardformer/policies/t5.py +++ b/colossalai/shardformer/policies/t5.py @@ -167,4 +167,4 @@ def module_policy(self): class T5EncoderPolicy(T5ModelPolicy): - pass + pass \ No newline at end of file diff --git a/colossalai/shardformer/policies/vit.py b/colossalai/shardformer/policies/vit.py new file mode 100644 index 000000000000..4a2b72057d05 --- /dev/null +++ b/colossalai/shardformer/policies/vit.py @@ -0,0 +1,96 @@ +from typing import Dict, Union + +import torch.nn as nn + +from transformers.models.vit.modeling_vit import ViTModel, ViTLayer, ViTEmbeddings, ViTAttention + +from colossalai.shardformer.layer import Linear1D_Col, Linear1D_Row, Dropout1D + +from .basepolicy import ModulePolicyDescription, Policy, SubModuleReplacementDescription + +class ViTPolicy(Policy): + + def preprocess(self): + # Resize embedding + vocab_size = self.model.config.vocab_size + world_size = self.shard_config.tensor_parallel_size + + if vocab_size % world_size != 0: + new_vocab_size = vocab_size + world_size - vocab_size % world_size + self.model.resize_token_embeddings(new_vocab_size) + + return self.model + + def module_policy(self) -> Dict[Union[str, nn.Module], ModulePolicyDescription]: + return { + ViTEmbeddings: + ModulePolicyDescription( + attribute_replacement{}, + param_replacement=[], + sub_module_replacement=[ + SubModuleReplacementDescription( + suffix="dropout", + target_module=Dropout1D, + ) + ] + ), + ViTLayer: + ModulePolicyDescription( + attribute_replacement{ + "attention.attention.num_attention_heads": + self.model.config.num_attention_heads//self.shard_config.tensor_parallel_size, + "attention.attention.all_head_size": + self.model.config.hidden_size//self.shard_config.tensor_parallel_size, + }, + param_replacement=[], + sub_module_replacement=[ + SubModuleReplacementDescription( + suffix="attention.attention.query", + target_module=Linear1D_Col, + ), + SubModuleReplacementDescription( + suffix="attention.attention.key", + target_module=Linear1D_Col, + ), + SubModuleReplacementDescription( + suffix="attention.attention.value", + target_module=Linear1D_Col, + ), + SubModuleReplacementDescription( + suffix="attention.attention.dropout", + target_module=Dropout1D, + ), + SubModuleReplacementDescription( + suffix="attention.output.dense", + target_module=Linear1D_Row, + ), + SubModuleReplacementDescription( + suffix="attention.output.dropout", + target_module=Dropout1D, + ), + SubModuleReplacementDescription( + suffix="intermediate.dense", + target_module=Linear1D_Col, + ), + SubModuleReplacementDescription( + suffix="output.dense", + target_module=Linear1D_Row, + ), + SubModuleReplacementDescription( + suffix="output.dropout", + target_module=Dropout1D, + ), + ] + ), + } + + def new_model_class(self): + return None + + def postprocess(self): + return self.model + + + + + diff --git a/tests/test_device/test_device_mesh.py b/tests/test_device/test_device_mesh.py index 590d6966bff6..1f8db99c9236 100644 --- a/tests/test_device/test_device_mesh.py +++ b/tests/test_device/test_device_mesh.py @@ -86,4 +86,4 @@ def test_device_mesh_from_process_group(): if __name__ == '__main__': test_device_mesh() - test_device_mesh_from_process_group() + test_device_mesh_from_process_group() \ No newline at end of file diff --git a/tests/test_shardformer/test_layer/test_layernorm.py b/tests/test_shardformer/test_layer/test_layernorm.py index a117845545be..080fae034956 100644 --- a/tests/test_shardformer/test_layer/test_layernorm.py +++ b/tests/test_shardformer/test_layer/test_layernorm.py @@ -41,4 +41,4 @@ def test_layernorm(): if __name__ == '__main__': - test_layernorm_1d() + test_layernorm_1d() \ No newline at end of file diff --git a/tests/test_shardformer/test_model/test_shard_t5.py b/tests/test_shardformer/test_model/test_shard_t5.py index 2698d7675c8e..6074a902e9b0 100644 --- a/tests/test_shardformer/test_model/test_shard_t5.py +++ b/tests/test_shardformer/test_model/test_shard_t5.py @@ -56,4 +56,4 @@ def test_t5(): if __name__ == "__main__": - test_t5() + test_t5() \ No newline at end of file diff --git a/tests/test_shardformer/test_model/test_shard_vit.py b/tests/test_shardformer/test_model/test_shard_vit.py new file mode 100644 index 000000000000..d5d71d9e29fe --- /dev/null +++ b/tests/test_shardformer/test_model/test_shard_vit.py @@ -0,0 +1,55 @@ +import pytest +import torch + +import colossalai +from colossalai.logging import disable_existing_loggers +from colossalai.testing import assert_hf_output_close, clear_cache_before_run, rerun_if_address_is_in_use, spawn +from tests.kit.model_zoo import model_zoo +from tests.test_shardformer.test_model._utils import build_model, run_forward + + +def check_forward_backward(org_model, sharded_model, data_gen_fn, output_transform_fn, loss_fn): + # check forward + org_output, org_loss, shard_output, shard_loss = run_forward(org_model, sharded_model, data_gen_fn, + output_transform_fn, loss_fn) + assert_hf_output_close(org_output, shard_output) + + # do backward + org_loss.backward() + shard_loss.backward() + + # check grad + org_grad = org_model.encoder.layer[0].attention.attention.query.weight.grad + shard_grad = sharded_model.encoder.layer[0].attention.attention.query.weight.grad + + shard_grad_list = [torch.zeros([*shard_grad.shape]).to('cuda') for _ in range(2)] + shard_grad = torch.distributed.all_gather(shard_grad_list, shard_grad) + all_shard_grad = torch.cat(shard_grad_list, dim=0) + + assert torch.allclose(org_loss, shard_loss, + atol=1e-5), f"shard model loss is not equal to orgin model loss\n{org_loss}\n{shard_loss}" + assert torch.allclose(org_grad, all_shard_grad, + atol=1e-5), f"shard model grad is not equal to orgin model grad\n{org_grad}\n{all_shard_grad}" + + +def check_vit(rank, world_size, port): + disable_existing_loggers() + colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') + + sub_model_zoo = model_zoo.get_sub_registry('transformers_vit') + for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items(): + org_model, sharded_model = build_model(world_size, model_fn) + check_forward_backward(org_model, sharded_model, data_gen_fn, output_transform_fn, loss_fn) + + torch.cuda.empty_cache() + + +@pytest.mark.dist +@rerun_if_address_is_in_use() +@clear_cache_before_run() +def test_vit(): + spawn(check_vit, 4) + + +if __name__ == "__main__": + test_vit() From 8b0930cfc3468cca260ac4eb3e4aab4534a4fc5b Mon Sep 17 00:00:00 2001 From: Frank Lee Date: Wed, 28 Jun 2023 15:04:35 +0800 Subject: [PATCH 40/49] [shardformer] supported bloom model (#4098) --- colossalai/shardformer/README.md | 8 +- colossalai/shardformer/layer/__init__.py | 9 +- colossalai/shardformer/layer/dropout.py | 45 +++- colossalai/shardformer/layer/linear.py | 8 +- .../{linear_conv.py => qkv_fused_linear.py} | 103 ++++++--- colossalai/shardformer/layer/utils.py | 80 ++++++- colossalai/shardformer/policies/autopolicy.py | 12 + colossalai/shardformer/policies/basepolicy.py | 1 + colossalai/shardformer/policies/bert.py | 14 +- colossalai/shardformer/policies/bloom.py | 214 ++++++++++++++++++ colossalai/shardformer/policies/gpt2.py | 14 +- colossalai/shardformer/policies/t5.py | 14 +- colossalai/shardformer/policies/vit.py | 129 +++++------ colossalai/shardformer/shard/sharder.py | 16 +- tests/kit/model_zoo/transformers/__init__.py | 1 + tests/kit/model_zoo/transformers/bloom.py | 107 +++++++++ tests/kit/model_zoo/transformers/gpt.py | 2 - .../test_layer/test_dropout.py | 25 +- ...conv_1d.py => test_qkv_fused_linear_1d.py} | 15 +- .../test_model/test_shard_bloom.py | 59 +++++ 20 files changed, 723 insertions(+), 153 deletions(-) rename colossalai/shardformer/layer/{linear_conv.py => qkv_fused_linear.py} (79%) create mode 100644 colossalai/shardformer/policies/bloom.py create mode 100644 tests/kit/model_zoo/transformers/bloom.py rename tests/test_shardformer/test_layer/{test_linearconv_1d.py => test_qkv_fused_linear_1d.py} (81%) create mode 100644 tests/test_shardformer/test_model/test_shard_bloom.py diff --git a/colossalai/shardformer/README.md b/colossalai/shardformer/README.md index da80a7276b68..8a8ed0f792fd 100644 --- a/colossalai/shardformer/README.md +++ b/colossalai/shardformer/README.md @@ -83,8 +83,10 @@ We will follow this roadmap to develop Shardformer: - [x] BERT - [x] T5 - [x] LlaMa - - [ ] GPT2 - - [ ] BLOOM + - [x] GPT2 + - [x] OPT + - [x] BLOOM + - [ ] GLM - [ ] RoBERTa - [ ] ALBERT - [ ] ERNIE @@ -96,7 +98,7 @@ We will follow this roadmap to develop Shardformer: - [ ] SwinTransformer - [ ] SwinTransformer V2 - [ ] Audio - - [ ] To be added + - [ ] Whisper - [ ] Multi-modal - [ ] To be added diff --git a/colossalai/shardformer/layer/__init__.py b/colossalai/shardformer/layer/__init__.py index 3ece2583132c..2826a8429f00 100644 --- a/colossalai/shardformer/layer/__init__.py +++ b/colossalai/shardformer/layer/__init__.py @@ -1,11 +1,12 @@ -from .dropout import Dropout1D +from .dropout import DropoutForParallelInput, DropoutForReplicatedInput from .embedding import Embedding1D, VocabParallelEmbedding1D from .layernorm import FusedLayerNorm from .linear import Linear1D_Col, Linear1D_Row -from .linear_conv import LinearConv1D_Col, LinearConv1D_Row from .loss import cross_entropy_1d +from .qkv_fused_linear import GPT2FusedLinearConv1D_Col, GPT2FusedLinearConv1D_Row __all__ = [ - "Embedding1D", "VocabParallelEmbedding1D", "Linear1D_Col", "Linear1D_Row", "LinearConv1D_Col", "LinearConv1D_Row", - "Dropout1D", "cross_entropy_1d", 'FusedLayerNorm' + "Embedding1D", "VocabParallelEmbedding1D", "Linear1D_Col", "Linear1D_Row", 'GPT2FusedLinearConv1D_Col', + 'GPT2FusedLinearConv1D_Row', 'DropoutForParallelInput', 'DropoutForReplicatedInput', "cross_entropy_1d", + 'FusedLayerNorm' ] diff --git a/colossalai/shardformer/layer/dropout.py b/colossalai/shardformer/layer/dropout.py index 2c49b49faad6..2625fe97889a 100644 --- a/colossalai/shardformer/layer/dropout.py +++ b/colossalai/shardformer/layer/dropout.py @@ -7,10 +7,10 @@ from .parallel_module import ParallelModule from .utils import create_randomizer_with_offset -__all__ = ['Dropout1D'] +__all__ = ['DropoutForParallelInput', 'DropoutForReplicatedInput'] -class Dropout1D(ParallelModule, nn.Dropout): +class DropoutForParallelInput(ParallelModule, nn.Dropout): """ The Dropout Layer will apply dropout mask to the input tensor. The dropout mask is generated with randomness on different ranks of the given process group. This can avoid the same dropout mask is generated @@ -32,13 +32,50 @@ def __init__(self, p: float = 0.5, inplace: bool = False, process_group: Process @staticmethod def from_native_module(module: nn.Dropout, - process_group: Union[ProcessGroup, List[ProcessGroup]] = None) -> "Dropout1D": + process_group: Union[ProcessGroup, List[ProcessGroup]] = None) -> "DropoutForParallelInput": + """ + Create a DropoutForParallelInput layer from a native dropout layer. + """ + p = module.p + inplace = module.inplace + return DropoutForParallelInput(p=p, inplace=inplace, process_group=process_group) + + def forward(self, input): + with self.randomizer.fork_rng(): + input = super().forward(input) + return input + + +class DropoutForReplicatedInput(ParallelModule, nn.Dropout): + """ + The Dropout Layer will apply dropout mask to the input tensor. The dropout mask is generated with + randomness on different ranks of the given process group. This can avoid the same dropout mask is generated + and applied on the same position of different ranks, leading to poor convergence performance. + + Args: + p (float): probability of an element to be zeroed. Defaults to 0.5. + inplace (bool): If set to True, will do this operation in-place. Defaults to False. + process_group (ProcessGroup): the process group to be used for generating randomness. Defaults to None. + """ + + def __init__(self, p: float = 0.5, inplace: bool = False, process_group: ProcessGroup = None): + # init with nn.Dropout + super(nn.Dropout, self).__init__(p=p, inplace=inplace) + + # offset the seed with randomizer index only + seed = torch.random.initial_seed() + self.randomizer = create_randomizer_with_offset(seed, process_group=process_group, offset_by_rank=False) + + @staticmethod + def from_native_module( + module: nn.Dropout, + process_group: Union[ProcessGroup, List[ProcessGroup]] = None) -> "DropoutForReplicatedInput": """ Create a Dropout1D layer from a native dropout layer. """ p = module.p inplace = module.inplace - return Dropout1D(p=p, inplace=inplace, process_group=process_group) + return DropoutForReplicatedInput(p=p, inplace=inplace, process_group=process_group) def forward(self, input): with self.randomizer.fork_rng(): diff --git a/colossalai/shardformer/layer/linear.py b/colossalai/shardformer/layer/linear.py index d952d5eecbee..26ba5883c64f 100644 --- a/colossalai/shardformer/layer/linear.py +++ b/colossalai/shardformer/layer/linear.py @@ -277,6 +277,7 @@ def from_native_module(module: nn.Linear, process_group: Union[ProcessGroup, Lis def chunk_weight(self): self.weight_list = torch.chunk(self.weight, self.stream_chunk_num, dim=0) + @torch.no_grad() def reset_parameters(self, weight_initializer, bias_initializer) -> None: fan_in, fan_out = self.in_features, self.out_features weight_initializer(self.weight, fan_in=fan_in, fan_out=fan_out) @@ -289,9 +290,10 @@ def reset_parameters(self, weight_initializer, bias_initializer) -> None: src_rank = dist.distributed_c10d._get_global_rank(self.process_group, 0) origin_device = self.bias.device - self.bias = self.bias.cuda() - dist.broadcast(self.bias, src=src_rank, group=self.process_group) - self.bias = self.bias.to(origin_device) + bias = self.bias.cuda() + dist.broadcast(bias, src=src_rank, group=self.process_group) + bias = bias.to(origin_device) + self.bias.copy_(bias) def forward(self, input_: Tensor) -> Tensor: # Set up backprop all-reduce. diff --git a/colossalai/shardformer/layer/linear_conv.py b/colossalai/shardformer/layer/qkv_fused_linear.py similarity index 79% rename from colossalai/shardformer/layer/linear_conv.py rename to colossalai/shardformer/layer/qkv_fused_linear.py index e856abc14be6..9d51670c65dd 100644 --- a/colossalai/shardformer/layer/linear_conv.py +++ b/colossalai/shardformer/layer/qkv_fused_linear.py @@ -31,12 +31,25 @@ from .parallel_module import ParallelModule from .utils import create_randomizer_with_offset -__all__ = ['LinearConv1D_Col', 'LinearConv1D_Row'] +__all__ = ['FusedLinear1D_Col', 'FusedLinear1D_Row'] +# ==================================== +# For GPT Only +# ==================================== -def split_fused_qkv(qkv: torch.Tensor, n_fused: int, process_group: ProcessGroup): + +def split_fused_qkv_in_gpt2_style(qkv: torch.Tensor, + n_fused: int, + process_group: ProcessGroup, + is_transposed: bool = False): """ The fused qkv tensor looks like [Q1, Q2, K1, K2, V1, V2], this function will split them into [Q1, K1, V1] and [Q2, K2, V2]. + + Args: + qkv (torch.Tensor): The fused qkv tensor. + n_fused (int): The number items fused together, defaults to 3 (query, key and value). + process_group (ProcessGroup): The process group for distributed communication. + is_transposed (bool): generally the tensor is the shape of (out_features, in_features). Set this to True if the tensor is in the shape (in_features, out_features). """ # get the number of slice for the fused qkv rank = dist.get_rank(group=process_group) @@ -48,7 +61,10 @@ def split_fused_qkv(qkv: torch.Tensor, n_fused: int, process_group: ProcessGroup # [Q, K, V] # to # [Q1, Q2, K1, K2, V1, V2] - weight_chunks = torch.chunk(qkv, world_size * n_fused, dim=-1) + if is_transposed: + weight_chunks = torch.chunk(qkv, world_size * n_fused, dim=-1) + else: + weight_chunks = torch.chunk(qkv, world_size * n_fused, dim=0) # rearrange the slice into the final order # from @@ -56,13 +72,26 @@ def split_fused_qkv(qkv: torch.Tensor, n_fused: int, process_group: ProcessGroup # to # [Q1, K1, V1], [Q2, K2, V2] weight_chunks_of_current_rank = [weight_chunks[i] for i in order[rank::world_size]] - weight_of_current_rank = torch.cat(weight_chunks_of_current_rank, dim=-1) + + if is_transposed: + weight_of_current_rank = torch.cat(weight_chunks_of_current_rank, dim=-1) + else: + weight_of_current_rank = torch.cat(weight_chunks_of_current_rank, dim=0) return weight_of_current_rank -def gather_fused_qkv(qkv: torch.Tensor, n_fused: int, process_group: ProcessGroup): +def gather_fused_qkv_in_gpt2_style(qkv: torch.Tensor, + n_fused: int, + process_group: ProcessGroup, + is_transposed: bool = False): """ The splitted qkv tensor looks like [Q1, K1, V1] and [Q2, K2, V2], this function will gather them into [Q1, Q2, K1, K2, V1, V2]. + + Args: + qkv (torch.Tensor): The fused qkv tensor. + n_fused (int): The number items fused together, defaults to 3 (query, key and value). + process_group (ProcessGroup): The process group for distributed communication. + is_transposed (bool): generally the tensor is the shape of (out_features, in_features). Set this to True if the tensor is in the shape (in_features, out_features). """ world_size = dist.get_world_size(group=process_group) @@ -75,7 +104,11 @@ def gather_fused_qkv(qkv: torch.Tensor, n_fused: int, process_group: ProcessGrou qkv = qkv.cuda() gather_list = [torch.zeros_like(qkv) for _ in range(world_size)] dist.all_gather(gather_list, qkv, group=process_group) - gather_weight = torch.cat(gather_list, dim=-1) + + if is_transposed: + gather_weight = torch.cat(gather_list, dim=-1) + else: + gather_weight = torch.cat(gather_list, dim=0) gather_weight = gather_weight.to(origin_device) qkv = qkv.to(origin_device) @@ -84,15 +117,23 @@ def gather_fused_qkv(qkv: torch.Tensor, n_fused: int, process_group: ProcessGrou # [Q1, K1, V1, Q2, K2, V2] # to # [Q1, Q2, K1, K2, V1, V2] - weight_chunks = torch.chunk(gather_weight, world_size * n_fused, dim=-1) + if is_transposed: + weight_chunks = torch.chunk(gather_weight, world_size * n_fused, dim=-1) + else: + weight_chunks = torch.chunk(gather_weight, world_size * n_fused, dim=0) + reordered_chunk_list = [] for i in range(n_fused): reordered_chunk_list.extend(weight_chunks[i::n_fused]) - reordered_gather_weight = torch.cat(reordered_chunk_list, dim=-1) + + if is_transposed: + reordered_gather_weight = torch.cat(reordered_chunk_list, dim=-1) + else: + reordered_gather_weight = torch.cat(reordered_chunk_list, dim=0) return reordered_gather_weight -class LinearConv1D_Col(ParallelModule): +class GPT2FusedLinearConv1D_Col(ParallelModule): r"""Linear layer with column parallelism. The linear layer is defined as :math:`Y = XA + b`. A is parallelized along @@ -154,10 +195,10 @@ def __init__(self, weight = torch.empty(self.in_features, self.out_features, **factory_kwargs) def shard_fn(tensor): - return split_fused_qkv(tensor, self.n_fused, self.process_group) + return split_fused_qkv_in_gpt2_style(tensor, self.n_fused, self.process_group, True) def gather_fn(tensor): - return gather_fused_qkv(tensor, 3, self.process_group) + return gather_fused_qkv_in_gpt2_style(tensor, 3, self.process_group, True) with torch.no_grad(): sharded_weight = distribute_tensor_with_customization(weight, shard_fn, gather_fn) @@ -202,21 +243,27 @@ def from_native_module(module: nn.Module, process_group: Union[ProcessGroup, Lis f'Expected only one process group, got {len(process_group)}.' process_group = process_group[0] - linear_1d = LinearConv1D_Col(in_features=in_features, - out_features=out_features, - bias=bias, - device=device, - process_group=process_group, - *args, - **kwargs) + linear_1d = GPT2FusedLinearConv1D_Col(in_features=in_features, + out_features=out_features, + bias=bias, + device=device, + process_group=process_group, + *args, + **kwargs) # TODO: copy the sharded weights with torch.no_grad(): - sharded_weight = split_fused_qkv(module.weight.data, n_fused=n_fused, process_group=process_group) + sharded_weight = split_fused_qkv_in_gpt2_style(module.weight.data, + n_fused=n_fused, + process_group=process_group, + is_transposed=True) linear_1d.weight.data.copy_(sharded_weight.data) if bias: - sharded_bias = split_fused_qkv(module.bias.data, n_fused=n_fused, process_group=process_group) + sharded_bias = split_fused_qkv_in_gpt2_style(module.bias.data, + n_fused=n_fused, + process_group=process_group, + is_transposed=True) linear_1d.bias.data.copy_(sharded_bias.data) return linear_1d @@ -254,7 +301,7 @@ def forward(self, input_: Tensor) -> Tuple[Tensor, Tensor]: return output -class LinearConv1D_Row(ParallelModule): +class GPT2FusedLinearConv1D_Row(ParallelModule): r""" Linear layer with row parallelism. This layer is used to fit `Conv1D` layer (Fused QKV) in gpt2 of huggingface. @@ -345,13 +392,13 @@ def from_native_module(module: nn.Linear, process_group: Union[ProcessGroup, Lis f'Expected only one process group, got {len(process_group)}.' process_group = process_group[0] - linear_1d = LinearConv1D_Row(in_features=in_features, - out_features=out_features, - bias=bias, - device=device, - process_group=process_group, - *args, - **kwargs) + linear_1d = GPT2FusedLinearConv1D_Row(in_features=in_features, + out_features=out_features, + bias=bias, + device=device, + process_group=process_group, + *args, + **kwargs) # TODO: copy the sharded weights with torch.no_grad(): diff --git a/colossalai/shardformer/layer/utils.py b/colossalai/shardformer/layer/utils.py index c3d6ab57e3e9..f2ac6563c46f 100644 --- a/colossalai/shardformer/layer/utils.py +++ b/colossalai/shardformer/layer/utils.py @@ -3,6 +3,7 @@ import torch import torch.distributed as dist from torch.distributed import ProcessGroup +from torch.distributed.distributed_c10d import _get_global_rank class Randomizer: @@ -112,27 +113,90 @@ def index(): """ idx = Randomizer._INDEX - Randomizer._INDEX += 1 return idx + @staticmethod + def increment_index(): + """ + Increment the index of the randomizer by one. + """ + Randomizer._INDEX += 1 + + @staticmethod + def is_randomizer_index_synchronized(process_group: ProcessGroup = None): + """ + Return whether the randomizer index is synchronized across processes. + """ + index = Randomizer.index() + if dist.is_initialized(): + # convert the index to tensor + index_tensor = torch.tensor(index, dtype=torch.int32).cuda() + + # all gather the index + gathered_index = [torch.zeros_like(index_tensor) for _ in range(dist.get_world_size(process_group))] + dist.all_gather(gathered_index, index_tensor, process_group) + + # make sure all the gathered index are the same + for i in range(1, dist.get_world_size(process_group)): + if gathered_index[i] != gathered_index[0]: + return False + + return True -def create_randomizer_with_offset(seed: int, process_group: ProcessGroup = None): + @staticmethod + def synchronize_index(process_group: ProcessGroup = None): + """ + All gather the index and pick the largest value. + """ + index = Randomizer.index() + + if dist.is_initialized(): + # convert the index to tensor + index_tensor = torch.tensor(index, dtype=torch.int32).cuda() + + # all gather the index + gathered_index = [torch.zeros_like(index_tensor) for _ in range(dist.get_world_size(process_group))] + dist.all_gather(gathered_index, index_tensor, process_group) + + # pick the largest index + for i in range(1, dist.get_world_size(process_group)): + if gathered_index[i] > index_tensor: + index_tensor = gathered_index[i] + + # set the index + Randomizer._INDEX = index_tensor.item() + + +def create_randomizer_with_offset(seed: int, + process_group: ProcessGroup = None, + offset_by_rank: bool = True, + offset_by_index: bool = True): """ Create a randomizer with an offset. The offset is equal to the rank of the process and the index of the randomizer. Args: seed (int): The base random seed to set. - enable_cpu (bool): fork the CPU RNG state as well. process_group (ProcessGroup): the process group to get the rank from. + offset_by_rank (bool): whether to offset by the rank of the process, i.e., the rank of the process will be added to the seed. Default: True. + offset_by_index (bool): whether to offset by the index of the randomizer, i.e., the index of the randomizer will be added to the seed. Default: True. Returns: Randomizer: the randomizer with offset. """ - offset = Randomizer.index() + base_seed = seed - if dist.is_initialized(): + if offset_by_rank and dist.is_initialized(): rank = dist.get_rank(process_group) - offset += rank + base_seed += rank + + if offset_by_index: + # check if the randomizer index is synchronized + is_synchronized = Randomizer.is_randomizer_index_synchronized(process_group) + assert is_synchronized, ("We detect that the randomizer index is not synchronized across processes." + "This is not allowed when we want to create a randomizer with offset by index." + "Please call Randomizer.synchronize_index() first.") + + base_seed += Randomizer.index() + Randomizer.increment_index() - seed += offset - return Randomizer(seed=seed) + return Randomizer(seed=base_seed) diff --git a/colossalai/shardformer/policies/autopolicy.py b/colossalai/shardformer/policies/autopolicy.py index 9cc583d58b11..17c063c8d2cf 100644 --- a/colossalai/shardformer/policies/autopolicy.py +++ b/colossalai/shardformer/policies/autopolicy.py @@ -78,6 +78,18 @@ class PolicyLocation: PolicyLocation(file_name="opt", class_name="OPTForSequenceClassificationPolicy"), "transformers.models.opt.modeling_opt.OPTForQuestionAnswering": PolicyLocation(file_name="opt", class_name="OPTForQuestionAnsweringPolicy"), + + # Bloom + "transformers.models.bloom.modeling_bloom.BloomModel": + PolicyLocation(file_name="bloom", class_name="BloomModelPolicy"), + "transformers.models.bloom.modeling_bloom.BloomForCausalLM": + PolicyLocation(file_name="bloom", class_name="BloomForCausalLMPolicy"), + "transformers.models.bloom.modeling_bloom.BloomForSequenceClassification": + PolicyLocation(file_name="bloom", class_name="BloomForSequenceClassificationPolicy"), + "transformers.models.bloom.modeling_bloom.BloomForTokenClassification": + PolicyLocation(file_name="bloom", class_name="BloomForTokenClassificationPolicy"), + "transformers.models.bloom.modeling_bloom.BloomForQuestionAnswering": + PolicyLocation(file_name="bloom", class_name="BloomForQuestionAnsweringPolicy"), } diff --git a/colossalai/shardformer/policies/basepolicy.py b/colossalai/shardformer/policies/basepolicy.py index b5d9cdbd7289..7e9bcf209573 100644 --- a/colossalai/shardformer/policies/basepolicy.py +++ b/colossalai/shardformer/policies/basepolicy.py @@ -52,6 +52,7 @@ def example_replace_weight(module: torch.nn.Module, process_group): attribute_replacement: Dict[str, Any] param_replacement: List[Callable] sub_module_replacement: List[SubModuleReplacementDescription] + method_replacement: List[Callable] = None class Policy(ABC): diff --git a/colossalai/shardformer/policies/bert.py b/colossalai/shardformer/policies/bert.py index fb70cdff8824..49ef53259321 100644 --- a/colossalai/shardformer/policies/bert.py +++ b/colossalai/shardformer/policies/bert.py @@ -61,7 +61,7 @@ def module_policy(self): ), SubModuleReplacementDescription( suffix="attention.self.dropout", - target_module=col_nn.Dropout1D, + target_module=col_nn.DropoutForParallelInput, ), SubModuleReplacementDescription( suffix="attention.output.dense", @@ -69,7 +69,7 @@ def module_policy(self): ), SubModuleReplacementDescription( suffix="attention.output.dropout", - target_module=col_nn.Dropout1D, + target_module=col_nn.DropoutForParallelInput, ), SubModuleReplacementDescription( suffix="intermediate.dense", @@ -81,7 +81,7 @@ def module_policy(self): ), SubModuleReplacementDescription( suffix="output.dropout", - target_module=col_nn.Dropout1D, + target_module=col_nn.DropoutForParallelInput, ) ]), BertEmbeddings: @@ -94,7 +94,7 @@ def module_policy(self): ), SubModuleReplacementDescription( suffix="dropout", - target_module=col_nn.Dropout1D, + target_module=col_nn.DropoutForParallelInput, ) ]) } @@ -258,7 +258,7 @@ def module_policy(self): sub_module_replacement=[ SubModuleReplacementDescription( suffix="dropout", - target_module=col_nn.Dropout1D, + target_module=col_nn.DropoutForParallelInput, ) ]) } @@ -281,7 +281,7 @@ def module_policy(self): sub_module_replacement=[ SubModuleReplacementDescription( suffix="dropout", - target_module=col_nn.Dropout1D, + target_module=col_nn.DropoutForParallelInput, ) ]) } @@ -311,7 +311,7 @@ def module_policy(self): sub_module_replacement=[ SubModuleReplacementDescription( suffix="dropout", - target_module=col_nn.Dropout1D, + target_module=col_nn.DropoutForParallelInput, ) ]) } diff --git a/colossalai/shardformer/policies/bloom.py b/colossalai/shardformer/policies/bloom.py new file mode 100644 index 000000000000..d196bdbd6e4d --- /dev/null +++ b/colossalai/shardformer/policies/bloom.py @@ -0,0 +1,214 @@ +import torch +import torch.distributed as dist + +import colossalai.shardformer.layer as col_nn + +from .basepolicy import ModulePolicyDescription, Policy, SubModuleReplacementDescription + + +def build_bloom_alibi_tensor(self, attention_mask: torch.Tensor, num_heads: int, dtype: torch.dtype) -> torch.Tensor: + """ + Link to paper: https://arxiv.org/abs/2108.12409 Alibi tensor is not causal as the original paper mentions, it + relies on a translation invariance of softmax for quick implementation: with l being a tensor, and a fixed value + `softmax(l+a) = softmax(l)`. Based on + https://github.com/ofirpress/attention_with_linear_biases/blob/a35aaca144e0eb6b789dfcb46784c4b8e31b7983/fairseq/models/transformer.py#L742 + TODO @thomasw21 this doesn't work as nicely due to the masking strategy, and so masking varies slightly. + + Args: + Returns tensor shaped (batch_size * num_heads, 1, max_seq_len) + attention_mask (`torch.Tensor`): + Token-wise attention mask, this should be of shape (batch_size, max_seq_len). + num_heads (`int`, *required*): + number of heads + dtype (`torch.dtype`, *optional*, default=`torch.bfloat16`): + dtype of the output tensor + """ + import math + + if dist.is_initialized(): + world_size = dist.get_world_size() + num_heads = num_heads * world_size + + batch_size, seq_length = attention_mask.shape + closest_power_of_2 = 2**math.floor(math.log2(num_heads)) + base = torch.tensor(2**(-(2**-(math.log2(closest_power_of_2) - 3))), + device=attention_mask.device, + dtype=torch.float32) + powers = torch.arange(1, 1 + closest_power_of_2, device=attention_mask.device, dtype=torch.int32) + slopes = torch.pow(base, powers) + + if closest_power_of_2 != num_heads: + extra_base = torch.tensor(2**(-(2**-(math.log2(2 * closest_power_of_2) - 3))), + device=attention_mask.device, + dtype=torch.float32) + num_remaining_heads = min(closest_power_of_2, num_heads - closest_power_of_2) + extra_powers = torch.arange(1, 1 + 2 * num_remaining_heads, 2, device=attention_mask.device, dtype=torch.int32) + slopes = torch.cat([slopes, torch.pow(extra_base, extra_powers)], dim=0) + + # Note: alibi will added to the attention bias that will be applied to the query, key product of attention + # => therefore alibi will have to be of shape (batch_size, num_heads, query_length, key_length) + # => here we set (batch_size=1, num_heads=num_heads, query_length=1, key_length=max_length) + # => the query_length dimension will then be broadcasted correctly + # This is more or less identical to T5's relative position bias: + # https://github.com/huggingface/transformers/blob/f681437203baa7671de3174b0fa583c349d9d5e1/src/transformers/models/t5/modeling_t5.py#L527 + arange_tensor = ((attention_mask.cumsum(dim=-1) - 1) * attention_mask)[:, None, :] + alibi = slopes[..., None] * arange_tensor + if dist.is_initialized(): + num_heads_per_rank = int(num_heads / dist.get_world_size()) + offset = dist.get_rank() * num_heads_per_rank + alibi = alibi.view(batch_size, num_heads, 1, seq_length) + alibi = alibi[:, offset:num_heads_per_rank + offset, :, :] + return alibi.reshape(batch_size * num_heads_per_rank, 1, seq_length).to(dtype) + else: + return alibi.reshape(batch_size * num_heads, 1, seq_length).to(dtype) + + +class BloomPolicy(Policy): + + def preprocess(self): + # reshape the embedding layer + r""" + Reshape the Embedding layer to make the embedding dimension divisible by world_size + """ + # TODO: + vocab_size = self.model.config.vocab_size + world_size = self.shard_config.tensor_parallel_size + if vocab_size % world_size != 0: + new_vocab_size = vocab_size + world_size - vocab_size % world_size + self.model.resize_token_embeddings(new_vocab_size) + return self.model + + def module_policy(self): + from transformers.models.bloom.modeling_bloom import BloomBlock, BloomModel + + return { + BloomBlock: + ModulePolicyDescription( + attribute_replacement={ + # 1. shard hidden size + "self_attention.hidden_size": + self.model.config.hidden_size // self.shard_config.tensor_parallel_size, + "self_attention.split_size": + self.model.config.hidden_size // self.shard_config.tensor_parallel_size, + # 2. shard number of heads + "self_attention.num_heads": + self.model.config.n_head // self.shard_config.tensor_parallel_size, + }, + param_replacement=[], + sub_module_replacement=[ + SubModuleReplacementDescription( + suffix="self_attention.query_key_value", + target_module=col_nn.Linear1D_Col, + # kwargs={'n_fused': 3} + ), + SubModuleReplacementDescription( + suffix="self_attention.dense", + target_module=col_nn.Linear1D_Row, + ), + SubModuleReplacementDescription( + suffix="self_attention.attention_dropout", + target_module=col_nn.DropoutForParallelInput, + ), + SubModuleReplacementDescription( + suffix="mlp.dense_h_to_4h", + target_module=col_nn.Linear1D_Col, + ), + SubModuleReplacementDescription( + suffix="mlp.dense_4h_to_h", + target_module=col_nn.Linear1D_Row, + ), + ]), + BloomModel: + ModulePolicyDescription(attribute_replacement={ + "num_heads": self.model.config.n_head // self.shard_config.tensor_parallel_size, + }, + param_replacement=[], + method_replacement={"build_alibi_tensor": build_bloom_alibi_tensor}, + sub_module_replacement=[ + SubModuleReplacementDescription( + suffix="word_embeddings", + target_module=col_nn.VocabParallelEmbedding1D, + ) + ]) + } + + def new_model_class(self): + # do nothing + return self.model + + def postprocess(self): + return self.model + + +# BertModel +class BloomModelPolicy(BloomPolicy): + pass + + +class BloomForCausalLMPolicy(BloomPolicy): + + def module_policy(self): + from transformers.models.bloom.modeling_bloom import BloomForCausalLM + policy = super().module_policy() + # add a new item for casual lm + new_item = { + BloomForCausalLM: + ModulePolicyDescription(attribute_replacement={}, + param_replacement=[], + sub_module_replacement=[ + SubModuleReplacementDescription(suffix="lm_head", + target_module=col_nn.Linear1D_Col, + kwargs=dict(gather_output=True)) + ]) + } + policy.update(new_item) + return policy + + +class BloomForSequenceClassificationPolicy(BloomPolicy): + + def module_policy(self): + from transformers.models.bloom.modeling_bloom import BloomForSequenceClassification + policy = super().module_policy() + # add a new item for casual lm + new_item = { + BloomForSequenceClassification: + ModulePolicyDescription(attribute_replacement={}, + param_replacement=[], + sub_module_replacement=[ + SubModuleReplacementDescription(suffix="score", + target_module=col_nn.Linear1D_Col, + kwargs=dict(gather_output=True)) + ]) + } + policy.update(new_item) + return policy + + +class BloomForTokenClassificationPolicy(BloomPolicy): + + def module_policy(self): + from transformers.models.bloom.modeling_bloom import BloomForTokenClassification + policy = super().module_policy() + # add a new item for casual lm + new_item = { + BloomForTokenClassification: + ModulePolicyDescription(attribute_replacement={}, + param_replacement=[], + sub_module_replacement=[ + SubModuleReplacementDescription(suffix="classifier", + target_module=col_nn.Linear1D_Col, + kwargs=dict(gather_output=True)), + SubModuleReplacementDescription( + suffix="dropout", + target_module=col_nn.DropoutForReplicatedInput, + ), + ]) + } + policy.update(new_item) + return policy + + +class BloomForQuestionAnsweringPolicy(BloomPolicy): + # No head sharding as the output features is only 2 + pass diff --git a/colossalai/shardformer/policies/gpt2.py b/colossalai/shardformer/policies/gpt2.py index 9d5d7d36aea3..ebfaf8a8e1c3 100644 --- a/colossalai/shardformer/policies/gpt2.py +++ b/colossalai/shardformer/policies/gpt2.py @@ -42,37 +42,37 @@ def module_policy(self): sub_module_replacement=[ SubModuleReplacementDescription( suffix="attn.c_attn", - target_module=col_nn.LinearConv1D_Col, + target_module=col_nn.GPT2FusedLinearConv1D_Col, kwargs={ "n_fused": 3, }, ), SubModuleReplacementDescription( suffix="attn.c_proj", - target_module=col_nn.LinearConv1D_Row, + target_module=col_nn.GPT2FusedLinearConv1D_Row, ), SubModuleReplacementDescription( suffix="mlp.c_fc", - target_module=col_nn.LinearConv1D_Col, + target_module=col_nn.GPT2FusedLinearConv1D_Col, kwargs={ "n_fused": 1, }, ), SubModuleReplacementDescription( suffix="mlp.c_proj", - target_module=col_nn.LinearConv1D_Row, + target_module=col_nn.GPT2FusedLinearConv1D_Row, ), SubModuleReplacementDescription( suffix="attn.attn_dropout", - target_module=col_nn.Dropout1D, + target_module=col_nn.DropoutForParallelInput, ), SubModuleReplacementDescription( suffix="attn.resid_dropout", - target_module=col_nn.Dropout1D, + target_module=col_nn.DropoutForParallelInput, ), SubModuleReplacementDescription( suffix="mlp.dropout", - target_module=col_nn.Dropout1D, + target_module=col_nn.DropoutForParallelInput, ), ]) } diff --git a/colossalai/shardformer/policies/t5.py b/colossalai/shardformer/policies/t5.py index 9a1b63e46d2c..8d8abc9f7204 100644 --- a/colossalai/shardformer/policies/t5.py +++ b/colossalai/shardformer/policies/t5.py @@ -9,7 +9,7 @@ T5Stack, ) -from colossalai.shardformer.layer import Dropout1D, Embedding1D, Linear1D_Col, Linear1D_Row +from colossalai.shardformer.layer import DropoutForParallelInput, Embedding1D, Linear1D_Col, Linear1D_Row from .basepolicy import ModulePolicyDescription, Policy, SubModuleReplacementDescription @@ -38,7 +38,7 @@ def module_policy(self): sub_module_replacement=[ SubModuleReplacementDescription( suffix="dropout", - target_module=Dropout1D, + target_module=DropoutForParallelInput, ) ]), T5LayerSelfAttention: @@ -47,7 +47,7 @@ def module_policy(self): sub_module_replacement=[ SubModuleReplacementDescription( suffix="dropout", - target_module=Dropout1D, + target_module=DropoutForParallelInput, ), ]), T5LayerCrossAttention: @@ -56,7 +56,7 @@ def module_policy(self): sub_module_replacement=[ SubModuleReplacementDescription( suffix="dropout", - target_module=Dropout1D, + target_module=DropoutForParallelInput, ) ]), T5Attention: @@ -97,7 +97,7 @@ def module_policy(self): sub_module_replacement=[ SubModuleReplacementDescription( suffix="dropout", - target_module=Dropout1D, + target_module=DropoutForParallelInput, ), ]), T5DenseGatedActDense: @@ -117,7 +117,7 @@ def module_policy(self): kwargs=dict(gather_output=True)), SubModuleReplacementDescription( suffix="dropout", - target_module=Dropout1D, + target_module=DropoutForParallelInput, ) ]), T5DenseActDense: @@ -134,7 +134,7 @@ def module_policy(self): ), SubModuleReplacementDescription( suffix="dropout", - target_module=Dropout1D, + target_module=DropoutForParallelInput, ) ]) } diff --git a/colossalai/shardformer/policies/vit.py b/colossalai/shardformer/policies/vit.py index 4a2b72057d05..550f8f997ae1 100644 --- a/colossalai/shardformer/policies/vit.py +++ b/colossalai/shardformer/policies/vit.py @@ -1,15 +1,15 @@ from typing import Dict, Union import torch.nn as nn +from transformers.models.vit.modeling_vit import ViTAttention, ViTEmbeddings, ViTLayer, ViTModel -from transformers.models.vit.modeling_vit import ViTModel, ViTLayer, ViTEmbeddings, ViTAttention - -from colossalai.shardformer.layer import Linear1D_Col, Linear1D_Row, Dropout1D +from colossalai.shardformer.layer import DropoutForReplicatedInput, Linear1D_Col, Linear1D_Row from .basepolicy import ModulePolicyDescription, Policy, SubModuleReplacementDescription + class ViTPolicy(Policy): - + def preprocess(self): # Resize embedding vocab_size = self.model.config.vocab_size @@ -20,77 +20,68 @@ def preprocess(self): self.model.resize_token_embeddings(new_vocab_size) return self.model - + def module_policy(self) -> Dict[Union[str, nn.Module], ModulePolicyDescription]: - return { + return { ViTEmbeddings: - ModulePolicyDescription( - attribute_replacement{}, - param_replacement=[], - sub_module_replacement=[ - SubModuleReplacementDescription( - suffix="dropout", - target_module=Dropout1D, - ) - ] - ), + ModulePolicyDescription(attribute_replacement={}, + param_replacement=[], + sub_module_replacement=[ + SubModuleReplacementDescription( + suffix="dropout", + target_module=DropoutForReplicatedInput, + ) + ]), ViTLayer: - ModulePolicyDescription( - attribute_replacement{ - "attention.attention.num_attention_heads": - self.model.config.num_attention_heads//self.shard_config.tensor_parallel_size, - "attention.attention.all_head_size": - self.model.config.hidden_size//self.shard_config.tensor_parallel_size, - }, - param_replacement=[], - sub_module_replacement=[ - SubModuleReplacementDescription( - suffix="attention.attention.query", - target_module=Linear1D_Col, - ), - SubModuleReplacementDescription( - suffix="attention.attention.key", - target_module=Linear1D_Col, - ), - SubModuleReplacementDescription( - suffix="attention.attention.value", - target_module=Linear1D_Col, - ), - SubModuleReplacementDescription( - suffix="attention.attention.dropout", - target_module=Dropout1D, - ), - SubModuleReplacementDescription( - suffix="attention.output.dense", - target_module=Linear1D_Row, - ), - SubModuleReplacementDescription( - suffix="attention.output.dropout", - target_module=Dropout1D, - ), - SubModuleReplacementDescription( - suffix="intermediate.dense", - target_module=Linear1D_Col, - ), - SubModuleReplacementDescription( - suffix="output.dense", - target_module=Linear1D_Row, - ), - SubModuleReplacementDescription( - suffix="output.dropout", - target_module=Dropout1D, - ), - ] - ), + ModulePolicyDescription(attribute_replacement={ + "attention.attention.num_attention_heads": + self.model.config.num_attention_heads // self.shard_config.tensor_parallel_size, + "attention.attention.all_head_size": + self.model.config.hidden_size // self.shard_config.tensor_parallel_size, + }, + param_replacement=[], + sub_module_replacement=[ + SubModuleReplacementDescription( + suffix="attention.attention.query", + target_module=Linear1D_Col, + ), + SubModuleReplacementDescription( + suffix="attention.attention.key", + target_module=Linear1D_Col, + ), + SubModuleReplacementDescription( + suffix="attention.attention.value", + target_module=Linear1D_Col, + ), + SubModuleReplacementDescription( + suffix="attention.attention.dropout", + target_module=DropoutForParallelInput, + ), + SubModuleReplacementDescription( + suffix="attention.output.dense", + target_module=Linear1D_Row, + ), + SubModuleReplacementDescription( + suffix="attention.output.dropout", + target_module=DropoutForParallelInput, + ), + SubModuleReplacementDescription( + suffix="intermediate.dense", + target_module=Linear1D_Col, + ), + SubModuleReplacementDescription( + suffix="output.dense", + target_module=Linear1D_Row, + ), + SubModuleReplacementDescription( + suffix="output.dropout", + target_module=DropoutForParallelInput, + ), + ]), } - + def new_model_class(self): return None def postprocess(self): return self.model - - - - - diff --git a/colossalai/shardformer/shard/sharder.py b/colossalai/shardformer/shard/sharder.py index 22f5f1c12d26..c2444e1f765c 100644 --- a/colossalai/shardformer/shard/sharder.py +++ b/colossalai/shardformer/shard/sharder.py @@ -95,8 +95,9 @@ def _replace_module(self,) -> None: attr_replacement = module_description[1].attribute_replacement param_replacement = module_description[1].param_replacement sub_module_replacement = module_description[1].sub_module_replacement + method_replacement = module_description[1].method_replacement self._recursive_replace_layer(self.model, origin_layer_cls, attr_replacement, param_replacement, - sub_module_replacement) + method_replacement, sub_module_replacement) def _recursive_replace_layer( self, @@ -104,6 +105,7 @@ def _recursive_replace_layer( origin_cls: nn.Module, attr_replacement: Dict[str, Any], param_replacement: List[Callable], + method_replacement: Dict[str, Callable], sub_module_replacement: List[Callable], ) -> None: r""" @@ -119,9 +121,11 @@ def _recursive_replace_layer( if module.__class__ == origin_cls: self._replace_attr(module, attr_replacement) self._replace_param(module, param_replacement) + self._replace_method(module, method_replacement) self._replace_sub_module(module, sub_module_replacement) + for name, child in module.named_children(): - self._recursive_replace_layer(child, origin_cls, attr_replacement, param_replacement, + self._recursive_replace_layer(child, origin_cls, attr_replacement, param_replacement, method_replacement, sub_module_replacement) def _replace_attr( @@ -154,6 +158,14 @@ def _replace_param( # TODO: support parameter shard pass + def _replace_method(self, module: nn.Module, method_replacement: Dict[str, Callable]): + if method_replacement is None: + return + + for method_name, new_method in method_replacement.items(): + # bind the new method to the module + setattr(module, method_name, new_method.__get__(module, module.__class__)) + def _replace_sub_module( self, org_layer: nn.Module, diff --git a/tests/kit/model_zoo/transformers/__init__.py b/tests/kit/model_zoo/transformers/__init__.py index ffaf4c566df9..4aa01abe13ee 100644 --- a/tests/kit/model_zoo/transformers/__init__.py +++ b/tests/kit/model_zoo/transformers/__init__.py @@ -1,5 +1,6 @@ from .albert import * from .bert import * +from .bloom import * from .gpt import * from .llama import * from .opt import * diff --git a/tests/kit/model_zoo/transformers/bloom.py b/tests/kit/model_zoo/transformers/bloom.py new file mode 100644 index 000000000000..71146c0b9819 --- /dev/null +++ b/tests/kit/model_zoo/transformers/bloom.py @@ -0,0 +1,107 @@ +import torch +import transformers + +from ..registry import ModelAttribute, model_zoo + +# =============================== +# Register Bloom +# =============================== + + +def data_gen(): + # Generated from following code snippet + # + # from transformers import BloomTokenizer + # input = 'Hello, my dog is cute' + # tokenized_input = tokenizer(input, return_tensors='pt') + # input_ids = tokenized_input['input_ids'] + # attention_mask = tokenized_input['attention_mask'] + input_ids = torch.tensor([[59414, 15, 2670, 35433, 632, 207595]], dtype=torch.int64) + attention_mask = torch.tensor([[1, 1, 1, 1, 1, 1]], dtype=torch.int64) + return dict(input_ids=input_ids, attention_mask=attention_mask) + + +def data_gen_for_lm(): + # LM data gen + # the `labels` of LM is the token of the output, cause no padding, use `input_ids` as `labels` + data = data_gen() + data['labels'] = data['input_ids'].clone() + return data + + +def data_gen_for_token_classification(): + # token classification data gen + # `labels` is the type not the token id for token classification, 0 or 1 + data = data_gen() + data['labels'] = torch.tensor([[0, 0, 0, 0, 0, 0]], dtype=torch.int64) + return data + + +def data_gen_for_sequence_classification(): + # sequence classification data gen + data = data_gen() + data['labels'] = torch.tensor([0], dtype=torch.int64) + return data + + +def data_gen_for_question_answering(): + # obtained with the following code + # + # from transformers import AutoTokenizer + # tokenizer = AutoTokenizer.from_pretrained("bigscience/bloom-560m") + # question, text = "Who was Jim Henson?", "Jim Henson was a nice puppet" + # inputs = tokenizer(question, text, return_tensors="pt") + + input_ids = torch.tensor( + [[57647, 1620, 23967, 620, 107373, 34, 91514, 620, 107373, 1620, 267, 35378, 48946, 18161]], dtype=torch.int64) + attention_mask = torch.tensor([[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]], dtype=torch.int64) + return dict(input_ids=input_ids, attention_mask=attention_mask) + + +# define output transform function +output_transform_fn = lambda x: x + +# define loss function +loss_fn_for_bloom_model = lambda x: x.last_hidden_state.mean() +loss_fn_for_causal_lm = lambda x: x.loss +loss_fn_for_classification = lambda x: x.logits.mean() +loss_fn_for_question_answering = lambda x: x.end_logits.mean() + +config = transformers.BloomConfig(n_layer=1, + n_head=4, + vocab_size=250880, + hidden_dropout=0, + attention_dropout=0, + hidden_size=64) + +# register the following models +model_zoo.register(name='transformers_bloom', + model_fn=lambda: transformers.BloomModel(config), + data_gen_fn=data_gen, + output_transform_fn=output_transform_fn, + loss_fn=loss_fn_for_bloom_model, + model_attribute=ModelAttribute(has_control_flow=True)) +model_zoo.register(name='transformers_bloom_for_causal_lm', + model_fn=lambda: transformers.BloomForCausalLM(config), + data_gen_fn=data_gen_for_lm, + output_transform_fn=output_transform_fn, + loss_fn=loss_fn_for_causal_lm, + model_attribute=ModelAttribute(has_control_flow=True)) +model_zoo.register(name='transformers_bloom_for_sequence_classification', + model_fn=lambda: transformers.BloomForSequenceClassification(config), + data_gen_fn=data_gen_for_sequence_classification, + output_transform_fn=output_transform_fn, + loss_fn=loss_fn_for_classification, + model_attribute=ModelAttribute(has_control_flow=True)) +model_zoo.register(name='transformers_bloom_for_token_classification', + model_fn=lambda: transformers.BloomForTokenClassification(config), + data_gen_fn=data_gen_for_token_classification, + output_transform_fn=output_transform_fn, + loss_fn=loss_fn_for_classification, + model_attribute=ModelAttribute(has_control_flow=True)) +model_zoo.register(name='transformers_bloom_for_question_answering', + model_fn=lambda: transformers.BloomForQuestionAnswering(config), + data_gen_fn=data_gen_for_question_answering, + output_transform_fn=output_transform_fn, + loss_fn=loss_fn_for_question_answering, + model_attribute=ModelAttribute(has_control_flow=True)) diff --git a/tests/kit/model_zoo/transformers/gpt.py b/tests/kit/model_zoo/transformers/gpt.py index c598fa8f48e0..b9e0310780af 100644 --- a/tests/kit/model_zoo/transformers/gpt.py +++ b/tests/kit/model_zoo/transformers/gpt.py @@ -6,8 +6,6 @@ # =============================== # Register single-sentence GPT # =============================== -BATCH_SIZE = 1 # it can only be 1 as GPT cannot handle batch sizes > 1 if no padding token is defined. -SEQ_LENGTH = 16 def data_gen(): diff --git a/tests/test_shardformer/test_layer/test_dropout.py b/tests/test_shardformer/test_layer/test_dropout.py index c62d25d94aa4..332e377110a4 100644 --- a/tests/test_shardformer/test_layer/test_dropout.py +++ b/tests/test_shardformer/test_layer/test_dropout.py @@ -3,13 +3,13 @@ import torch.nn as nn import colossalai -from colossalai.shardformer.layer import Dropout1D +from colossalai.shardformer.layer import DropoutForParallelInput, DropoutForReplicatedInput from colossalai.testing import assert_equal, assert_not_equal, rerun_if_address_is_in_use, spawn -def check_dropout(): +def check_dropout_parallel_input(): dropout = nn.Dropout().cuda() - dropout_1d = Dropout1D.from_native_module(dropout, process_group=None) + dropout_1d = DropoutForParallelInput.from_native_module(dropout, process_group=None) # check computation correctness x = torch.rand(4, 128).cuda() @@ -39,9 +39,26 @@ def check_dropout(): assert_not_equal(out_1d_all[i], out_1d_all[0]) +def check_dropout_replicated_input(): + dropout = nn.Dropout().cuda() + dropout_replica = DropoutForReplicatedInput.from_native_module(dropout, process_group=None) + + # check computation correctness + x = torch.rand(4, 128).cuda() + out_1d = dropout_replica(x) + + # ensure out_1d is different across ranks + world_size = dist.get_world_size() + out_1d_all = [torch.zeros_like(out_1d) for _ in range(world_size)] + dist.all_gather(out_1d_all, out_1d) + for i in range(1, world_size): + assert_equal(out_1d_all[i], out_1d_all[0]) + + def run_dist(rank, world_size, port): colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') - check_dropout() + check_dropout_parallel_input() + check_dropout_replicated_input() @rerun_if_address_is_in_use() diff --git a/tests/test_shardformer/test_layer/test_linearconv_1d.py b/tests/test_shardformer/test_layer/test_qkv_fused_linear_1d.py similarity index 81% rename from tests/test_shardformer/test_layer/test_linearconv_1d.py rename to tests/test_shardformer/test_layer/test_qkv_fused_linear_1d.py index 774e6340eee6..681c4f6dd9f1 100644 --- a/tests/test_shardformer/test_layer/test_linearconv_1d.py +++ b/tests/test_shardformer/test_layer/test_qkv_fused_linear_1d.py @@ -4,8 +4,8 @@ from torch.testing import assert_close import colossalai -from colossalai.shardformer.layer import LinearConv1D_Col, LinearConv1D_Row -from colossalai.shardformer.layer.linear_conv import split_fused_qkv +from colossalai.shardformer.layer import GPT2FusedLinearConv1D_Col, GPT2FusedLinearConv1D_Row +from colossalai.shardformer.layer.qkv_fused_linear import split_fused_qkv_in_gpt2_style from colossalai.testing import rerun_if_address_is_in_use, spawn @@ -52,7 +52,10 @@ def rearrange(tensor: torch.Tensor, dim: int): def check_linear_conv_1d_col(): linear = Conv1D(192, 48).cuda() - linear_conv_col = LinearConv1D_Col.from_native_module(linear, process_group=None, gather_output=True, n_fused=3) + linear_conv_col = GPT2FusedLinearConv1D_Col.from_native_module(linear, + process_group=None, + gather_output=True, + n_fused=3) assert linear.weight.shape == torch.Size([48, 192]) assert linear.bias.shape == torch.Size([192]) @@ -73,13 +76,13 @@ def check_linear_conv_1d_col(): out.sum().backward() gather_out.sum().backward() - target_grad = split_fused_qkv(linear.weight.grad, 3, None) + target_grad = split_fused_qkv_in_gpt2_style(linear.weight.grad, 3, None, True) assert_close(target_grad, linear_conv_col.weight.grad) def check_linear_conv_1d_row(): linear = Conv1D(192, 48).cuda() - linear_row = LinearConv1D_Row.from_native_module(linear, process_group=None, parallel_input=False) + linear_row = GPT2FusedLinearConv1D_Row.from_native_module(linear, process_group=None, parallel_input=False) assert linear.weight.shape == torch.Size([48, 192]) assert linear_row.weight.shape == torch.Size([24, 192]) @@ -102,6 +105,8 @@ def check_linear_conv_1d_row(): def run_dist(rank, world_size, port): colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') + + # test for linear conv check_linear_conv_1d_col() check_linear_conv_1d_row() diff --git a/tests/test_shardformer/test_model/test_shard_bloom.py b/tests/test_shardformer/test_model/test_shard_bloom.py new file mode 100644 index 000000000000..7e2e3dfa8f81 --- /dev/null +++ b/tests/test_shardformer/test_model/test_shard_bloom.py @@ -0,0 +1,59 @@ +import pytest +import torch + +import colossalai +from colossalai.logging import disable_existing_loggers +from colossalai.testing import assert_hf_output_close, clear_cache_before_run, rerun_if_address_is_in_use, spawn +from tests.kit.model_zoo import model_zoo +from tests.test_shardformer.test_model._utils import build_model, run_forward + + +def check_forward_backward(org_model, sharded_model, data_gen_fn, output_transform_fn, loss_fn): + # check forward + org_output, org_loss, shard_output, shard_loss = run_forward(org_model, sharded_model, data_gen_fn, + output_transform_fn, loss_fn) + assert_hf_output_close(org_output, shard_output, ignore_keys=['past_key_values']) + + # do backward + org_loss.backward() + shard_loss.backward() + + # check grad equality + if org_model.__class__.__name__ == 'BloomModel': + org_grad = org_model.h[0].self_attention.query_key_value.weight.grad + shard_grad = sharded_model.h[0].self_attention.query_key_value.weight.grad + else: + org_grad = org_model.transformer.h[0].self_attention.query_key_value.weight.grad + shard_grad = sharded_model.transformer.h[0].self_attention.query_key_value.weight.grad + + shard_grad_list = [torch.zeros([*shard_grad.shape]).to('cuda') for _ in range(2)] + torch.distributed.all_gather(shard_grad_list, shard_grad) + all_shard_grad = torch.cat(shard_grad_list, dim=0) + + assert torch.allclose(org_loss, shard_loss, + atol=1e-5), f"shard model loss is not equal to orgin model loss\n{org_loss}\n{shard_loss}" + assert torch.allclose(org_grad, all_shard_grad, + atol=1e-5), f"shard model grad is not equal to orgin model grad\n{org_grad}\n{all_shard_grad}" + + +def check_bloom(rank, world_size, port): + disable_existing_loggers() + colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') + + sub_model_zoo = model_zoo.get_sub_registry('transformers_bloom') + for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items(): + org_model, sharded_model = build_model(world_size, model_fn) + check_forward_backward(org_model, sharded_model, data_gen_fn, output_transform_fn, loss_fn) + + torch.cuda.empty_cache() + + +@pytest.mark.dist +@rerun_if_address_is_in_use() +@clear_cache_before_run() +def test_bloom(): + spawn(check_bloom, 2) + + +if __name__ == "__main__": + test_bloom() From 92e669e7aebc3acb6c3a7822e3c76ee0d9cc415d Mon Sep 17 00:00:00 2001 From: Frank Lee Date: Fri, 30 Jun 2023 09:32:37 +0800 Subject: [PATCH 41/49] [shardformer] supported fused normalization (#4112) --- colossalai/shardformer/layer/__init__.py | 4 +- .../layer/{layernorm.py => normalization.py} | 44 ++++++++++++++++++- colossalai/shardformer/policies/basepolicy.py | 8 ++++ colossalai/shardformer/policies/bert.py | 21 ++++++--- colossalai/shardformer/policies/bloom.py | 31 ++++++++++++- colossalai/shardformer/policies/gpt2.py | 29 +++++++++++- colossalai/shardformer/policies/llama.py | 28 +++++++++++- colossalai/shardformer/policies/opt.py | 8 +++- colossalai/shardformer/policies/t5.py | 22 ++++++++-- colossalai/shardformer/policies/vit.py | 27 +++++++++++- colossalai/shardformer/shard/shard_config.py | 10 +---- tests/test_shardformer/test_model/_utils.py | 6 +-- 12 files changed, 207 insertions(+), 31 deletions(-) rename colossalai/shardformer/layer/{layernorm.py => normalization.py} (59%) diff --git a/colossalai/shardformer/layer/__init__.py b/colossalai/shardformer/layer/__init__.py index 2826a8429f00..7fad4948dfd0 100644 --- a/colossalai/shardformer/layer/__init__.py +++ b/colossalai/shardformer/layer/__init__.py @@ -1,12 +1,12 @@ from .dropout import DropoutForParallelInput, DropoutForReplicatedInput from .embedding import Embedding1D, VocabParallelEmbedding1D -from .layernorm import FusedLayerNorm from .linear import Linear1D_Col, Linear1D_Row from .loss import cross_entropy_1d +from .normalization import FusedLayerNorm, FusedRMSNorm from .qkv_fused_linear import GPT2FusedLinearConv1D_Col, GPT2FusedLinearConv1D_Row __all__ = [ "Embedding1D", "VocabParallelEmbedding1D", "Linear1D_Col", "Linear1D_Row", 'GPT2FusedLinearConv1D_Col', 'GPT2FusedLinearConv1D_Row', 'DropoutForParallelInput', 'DropoutForReplicatedInput', "cross_entropy_1d", - 'FusedLayerNorm' + 'FusedLayerNorm', 'FusedRMSNorm' ] diff --git a/colossalai/shardformer/layer/layernorm.py b/colossalai/shardformer/layer/normalization.py similarity index 59% rename from colossalai/shardformer/layer/layernorm.py rename to colossalai/shardformer/layer/normalization.py index 6103380fe8a5..b27307154a76 100644 --- a/colossalai/shardformer/layer/layernorm.py +++ b/colossalai/shardformer/layer/normalization.py @@ -4,7 +4,7 @@ import torch import torch.nn as nn -__all__ = ['FusedLayerNorm'] +__all__ = ['FusedLayerNorm', 'FusedRMSNorm'] FAST_LAYERNORM_SUPPORTED_SIZE = [ 1024, 1536, 2048, 2304, 3072, 3840, 4096, 5120, 6144, 8192, 10240, 12288, 12800, 15360, 16384, 18432, 20480, 24576, @@ -61,4 +61,44 @@ def from_native_module(module: nn.LayerNorm, *args, **kwargs) -> nn.Module: # copy weight and bias layernorm.weight.copy_(module.weight) layernorm.bias.copy_(module.bias) - return layernorm \ No newline at end of file + return layernorm + + +class FusedRMSNorm(): + """ + This is a wrapper around the apex fused rms norm implementation. It is meant to be used only with the from_native_module interface. + """ + + def __init__(self) -> None: + raise NotImplementedError( + 'FusedRMSNorm is not implemented as a physical class. ' + 'It is meant to be used only with the from_native_module interface to wrap the fused rms norm implementation provided by apex.' + ) + + @staticmethod + def from_native_module(module: nn.Module, *args, **kwargs) -> nn.Module: + try: + from apex.normalization import FusedRMSNorm as ApexFusedRMSNorm + except ImportError: + raise ImportError( + 'Please install apex from source (https://github.com/NVIDIA/apex) to use the fused RMS normalization kernel' + ) + + # to check if it is huggingface LlamaRMSNorm + if module.__class__.__name__ == "LlamaRMSNorm": + normalized_shape = module.weight.shape[0] + eps = module.variance_epsilon + elementwise_affine = True + else: + # get the attributes of the module + normalized_shape = module.normalized_shape + eps = module.eps + elementwise_affine = module.elementwise_affine + + rmsnorm = ApexFusedRMSNorm(normalized_shape=normalized_shape, eps=eps, elementwise_affine=elementwise_affine) + + with torch.no_grad(): + # copy weight and bias + rmsnorm.weight.copy_(module.weight) + + return rmsnorm diff --git a/colossalai/shardformer/policies/basepolicy.py b/colossalai/shardformer/policies/basepolicy.py index 7e9bcf209573..8835e38cbbe4 100644 --- a/colossalai/shardformer/policies/basepolicy.py +++ b/colossalai/shardformer/policies/basepolicy.py @@ -98,6 +98,14 @@ def set_shard_config(self, shard_config: ShardConfig) -> None: shard_config (:class:`ShardConfig`): The shard config to be perform """ self.shard_config = shard_config + self.config_sanity_check() + + @abstractmethod + def config_sanity_check(self): + """ + Check if the shard config is valid for the model. Raise an exception if the config is invalid. + """ + pass @abstractmethod def preprocess(self) -> nn.Module: diff --git a/colossalai/shardformer/policies/bert.py b/colossalai/shardformer/policies/bert.py index 49ef53259321..545669f1f463 100644 --- a/colossalai/shardformer/policies/bert.py +++ b/colossalai/shardformer/policies/bert.py @@ -16,6 +16,9 @@ class BertPolicy(Policy): + def config_sanity_check(self): + pass + def preprocess(self): # reshape the embedding layer r""" @@ -99,7 +102,8 @@ def module_policy(self): ]) } - if self.shard_config.fused_layernorm: + # optimization configuration + if self.shard_config.enable_fused_normalization: base_policy[BertLayer].sub_module_replacement.append( SubModuleReplacementDescription( suffix="attention.output.LayerNorm", @@ -150,12 +154,16 @@ def module_policy(self): kwargs={"gather_output": True}), ]) } - if self.shard_config.fused_layernorm: + + # optimization configuration + if self.shard_config.enable_fused_normalization: addon_module[BertLMPredictionHead].sub_module_replacement.append( SubModuleReplacementDescription( suffix="transform.LayerNorm", target_module=col_nn.FusedLayerNorm, )) + + # append extra policy module_policy.update(addon_module) return module_policy @@ -187,7 +195,7 @@ def module_policy(self): kwargs={"gather_output": True}), ]) } - if self.shard_config.fused_layernorm: + if self.shard_config.enable_fused_normalization: addon_module[BertLMPredictionHead].sub_module_replacement.append( SubModuleReplacementDescription( suffix="transform.LayerNorm", @@ -224,12 +232,15 @@ def module_policy(self): kwargs={"gather_output": True}), ]) } - if self.shard_config.fused_layernorm: + + # optimization configuration + if self.shard_config.enable_fused_normalization: addon_module[BertLMPredictionHead].sub_module_replacement.append( SubModuleReplacementDescription( suffix="transform.LayerNorm", target_module=col_nn.FusedLayerNorm, )) + module_policy.update(addon_module) return module_policy @@ -316,4 +327,4 @@ def module_policy(self): ]) } module_policy.update(addon_module) - return module_policy \ No newline at end of file + return module_policy diff --git a/colossalai/shardformer/policies/bloom.py b/colossalai/shardformer/policies/bloom.py index d196bdbd6e4d..4e34f24643c2 100644 --- a/colossalai/shardformer/policies/bloom.py +++ b/colossalai/shardformer/policies/bloom.py @@ -65,6 +65,9 @@ def build_bloom_alibi_tensor(self, attention_mask: torch.Tensor, num_heads: int, class BloomPolicy(Policy): + def config_sanity_check(self): + pass + def preprocess(self): # reshape the embedding layer r""" @@ -81,7 +84,7 @@ def preprocess(self): def module_policy(self): from transformers.models.bloom.modeling_bloom import BloomBlock, BloomModel - return { + base_policy = { BloomBlock: ModulePolicyDescription( attribute_replacement={ @@ -99,7 +102,6 @@ def module_policy(self): SubModuleReplacementDescription( suffix="self_attention.query_key_value", target_module=col_nn.Linear1D_Col, - # kwargs={'n_fused': 3} ), SubModuleReplacementDescription( suffix="self_attention.dense", @@ -132,6 +134,31 @@ def module_policy(self): ]) } + # optimization configuration + if self.shard_config.enable_fused_normalization: + base_policy[BloomModel].sub_module_replacement.extend([ + SubModuleReplacementDescription( + suffix="ln_f", + target_module=col_nn.FusedLayerNorm, + ), + SubModuleReplacementDescription( + suffix="word_embeddings_layernorm", + target_module=col_nn.FusedLayerNorm, + ) + ]) + base_policy[BloomBlock].sub_module_replacement.extend([ + SubModuleReplacementDescription( + suffix="input_layernorm", + target_module=col_nn.FusedLayerNorm, + ), + SubModuleReplacementDescription( + suffix="post_attention_layernorm", + target_module=col_nn.FusedLayerNorm, + ) + ]) + + return base_policy + def new_model_class(self): # do nothing return self.model diff --git a/colossalai/shardformer/policies/gpt2.py b/colossalai/shardformer/policies/gpt2.py index ebfaf8a8e1c3..3d6d94b8e90d 100644 --- a/colossalai/shardformer/policies/gpt2.py +++ b/colossalai/shardformer/policies/gpt2.py @@ -9,6 +9,9 @@ class GPT2Policy(Policy): + def config_sanity_check(self): + pass + def preprocess(self): # reshape the embedding layer r""" @@ -22,7 +25,7 @@ def preprocess(self): return self.model def module_policy(self): - return { + base_policy = { GPT2Model: ModulePolicyDescription(attribute_replacement={}, param_replacement=[], @@ -77,6 +80,30 @@ def module_policy(self): ]) } + # optimization configuration + if self.shard_config.enable_fused_normalization: + base_policy[GPT2Model].sub_module_replacement.append( + SubModuleReplacementDescription( + suffix="ln_f", + target_module=col_nn.FusedLayerNorm, + )) + + base_policy[GPT2Block].sub_module_replacement.extend([ + SubModuleReplacementDescription( + suffix="ln_1", + target_module=col_nn.FusedLayerNorm, + ), + SubModuleReplacementDescription( + suffix="ln_2", + target_module=col_nn.FusedLayerNorm, + ), + SubModuleReplacementDescription(suffix="ln_cross_attn", + target_module=col_nn.FusedLayerNorm, + ignore_if_not_exist=True) + ]) + + return base_policy + def new_model_class(self): return self.model diff --git a/colossalai/shardformer/policies/llama.py b/colossalai/shardformer/policies/llama.py index a13f5f087da4..b36180ce3188 100644 --- a/colossalai/shardformer/policies/llama.py +++ b/colossalai/shardformer/policies/llama.py @@ -4,13 +4,16 @@ from transformers import LlamaForCausalLM, LlamaForSequenceClassification from transformers.models.llama.modeling_llama import LlamaDecoderLayer, LlamaModel -from colossalai.shardformer.layer import Linear1D_Col, Linear1D_Row, VocabParallelEmbedding1D +from colossalai.shardformer.layer import FusedRMSNorm, Linear1D_Col, Linear1D_Row, VocabParallelEmbedding1D from .basepolicy import ModulePolicyDescription, Policy, SubModuleReplacementDescription class LlamaPolicy(Policy): + def config_sanity_check(self): + pass + def preprocess(self): # Resize embedding vocab_size = self.model.config.vocab_size @@ -23,7 +26,7 @@ def preprocess(self): return self.model def module_policy(self) -> Dict[Union[str, nn.Module], ModulePolicyDescription]: - return { + base_policy = { LlamaDecoderLayer: ModulePolicyDescription( attribute_replacement={ @@ -75,6 +78,27 @@ def module_policy(self) -> Dict[Union[str, nn.Module], ModulePolicyDescription]: ]) } + # optimization configuration + if self.shard_config.enable_fused_normalization: + base_policy[LlamaDecoderLayer].sub_module_replacement.extend([ + SubModuleReplacementDescription( + suffix="input_layernorm", + target_module=FusedRMSNorm, + ), + SubModuleReplacementDescription( + suffix="post_attention_layernorm", + target_module=FusedRMSNorm, + ) + ]) + + base_policy[LlamaModel].sub_module_replacement.append( + SubModuleReplacementDescription( + suffix="norm", + target_module=FusedRMSNorm, + )) + + return base_policy + def new_model_class(self): return None diff --git a/colossalai/shardformer/policies/opt.py b/colossalai/shardformer/policies/opt.py index f467726e5580..ce3873954e15 100644 --- a/colossalai/shardformer/policies/opt.py +++ b/colossalai/shardformer/policies/opt.py @@ -13,6 +13,9 @@ class OPTPolicy(Policy): + def config_sanity_check(self): + pass + def preprocess(self): # reshape the embedding layer r""" @@ -74,7 +77,9 @@ def module_policy(self): ), ]), } - if self.shard_config.fused_layernorm: + + # optimization configuration + if self.shard_config.enable_fused_normalization: base_policy[OPTDecoder].sub_module_replacement.append( SubModuleReplacementDescription(suffix="final_layer_norm", target_module=FusedLayerNorm, @@ -87,6 +92,7 @@ def module_policy(self): target_module=FusedLayerNorm, ignore_if_not_exist=True) ]) + return base_policy def new_model_class(self): diff --git a/colossalai/shardformer/policies/t5.py b/colossalai/shardformer/policies/t5.py index 8d8abc9f7204..d35f688a0b61 100644 --- a/colossalai/shardformer/policies/t5.py +++ b/colossalai/shardformer/policies/t5.py @@ -9,7 +9,7 @@ T5Stack, ) -from colossalai.shardformer.layer import DropoutForParallelInput, Embedding1D, Linear1D_Col, Linear1D_Row +from colossalai.shardformer.layer import DropoutForParallelInput, Embedding1D, FusedRMSNorm, Linear1D_Col, Linear1D_Row from .basepolicy import ModulePolicyDescription, Policy, SubModuleReplacementDescription @@ -18,6 +18,9 @@ class T5ModelPolicy(Policy): + def config_sanity_check(self): + pass + def preprocess(self): # reshape the embedding layer r""" @@ -31,7 +34,7 @@ def preprocess(self): return self.model def module_policy(self): - return { + base_policy = { T5Stack: ModulePolicyDescription(attribute_replacement={}, param_replacement=[], @@ -139,6 +142,19 @@ def module_policy(self): ]) } + # optimization configuration + if self.shard_config.enable_fused_normalization: + base_policy[T5LayerFF].sub_module_replacement.append( + SubModuleReplacementDescription(suffix="layer_norm", target_module=FusedRMSNorm)) + base_policy[T5LayerSelfAttention].sub_module_replacement.append( + SubModuleReplacementDescription(suffix="layer_norm", target_module=FusedRMSNorm)) + base_policy[T5LayerCrossAttention].sub_module_replacement.append( + SubModuleReplacementDescription(suffix="layer_norm", target_module=FusedRMSNorm)) + base_policy[T5Stack].sub_module_replacement.append( + SubModuleReplacementDescription(suffix="final_layer_norm", target_module=FusedRMSNorm)) + + return base_policy + def new_model_class(self): return None @@ -167,4 +183,4 @@ def module_policy(self): class T5EncoderPolicy(T5ModelPolicy): - pass \ No newline at end of file + pass diff --git a/colossalai/shardformer/policies/vit.py b/colossalai/shardformer/policies/vit.py index 550f8f997ae1..5d8a235db7a9 100644 --- a/colossalai/shardformer/policies/vit.py +++ b/colossalai/shardformer/policies/vit.py @@ -3,13 +3,16 @@ import torch.nn as nn from transformers.models.vit.modeling_vit import ViTAttention, ViTEmbeddings, ViTLayer, ViTModel -from colossalai.shardformer.layer import DropoutForReplicatedInput, Linear1D_Col, Linear1D_Row +from colossalai.shardformer.layer import DropoutForReplicatedInput, FusedLayerNorm, Linear1D_Col, Linear1D_Row from .basepolicy import ModulePolicyDescription, Policy, SubModuleReplacementDescription class ViTPolicy(Policy): + def config_sanity_check(self): + pass + def preprocess(self): # Resize embedding vocab_size = self.model.config.vocab_size @@ -22,7 +25,7 @@ def preprocess(self): return self.model def module_policy(self) -> Dict[Union[str, nn.Module], ModulePolicyDescription]: - return { + base_policy = { ViTEmbeddings: ModulePolicyDescription(attribute_replacement={}, param_replacement=[], @@ -80,6 +83,26 @@ def module_policy(self) -> Dict[Union[str, nn.Module], ModulePolicyDescription]: ]), } + # optimization configuration + if self.shard_config.enable_fused_normalization: + base_policy[ViTAttention].sub_module_replacement.extend([ + SubModuleReplacementDescription( + suffix="layernorm_before", + target_module=FusedLayerNorm, + ), + SubModuleReplacementDescription( + suffix="layernorm_after", + target_module=FusedLayerNorm, + ) + ]) + base_policy[ViTModel].sub_module_replacement.append( + SubModuleReplacementDescription( + suffix="layernorm", + target_module=FusedLayerNorm, + )) + + return base_policy + def new_model_class(self): return None diff --git a/colossalai/shardformer/shard/shard_config.py b/colossalai/shardformer/shard/shard_config.py index 8d3fc225e894..428ebc9780ba 100644 --- a/colossalai/shardformer/shard/shard_config.py +++ b/colossalai/shardformer/shard/shard_config.py @@ -12,16 +12,10 @@ class ShardConfig: Args: tensor_parallel_size (int): The size of tensor parallel - use_mixedfusedLN (bool): Whether to use the `MixedFusedLayerNorm` - data_parallel_size (int): The size of data parallel - pipeline_parallel_size (int): The size of pipeline parallel - tensor_parallel_mode (List): The mode of tensor parallel, choose from `['1d','2d','2.5d','3d'] - inference_only (bool): Whether to use the inference only mode, when setting to `True`, the model - will not calculate the loss and just return the output. - gather_output (bool): Whether to gather the output of the model of the last layer + enable_fused_normalization (bool): Whether to use fused layernorm, default is False """ tensor_parallel_size: int - fused_layernorm: bool = False + enable_fused_normalization: bool = False # TODO: add support for tensor parallel # pipeline_parallel_size: int diff --git a/tests/test_shardformer/test_model/_utils.py b/tests/test_shardformer/test_model/_utils.py index ad7c408aeb38..e49b0246ced5 100644 --- a/tests/test_shardformer/test_model/_utils.py +++ b/tests/test_shardformer/test_model/_utils.py @@ -8,11 +8,11 @@ def build_model(world_size, model_fn): org_model = model_fn().cuda() # shard model - shard_config = ShardConfig(tensor_parallel_size=world_size, fused_layernorm=True) + shard_config = ShardConfig(tensor_parallel_size=world_size, enable_fused_normalization=True) model_copy = copy.deepcopy(org_model) shard_former = ShardFormer(shard_config=shard_config) shard_former.init_distributed() - sharded_model = shard_former.shard_model(model_copy) + sharded_model = shard_former.shard_model(model_copy).cuda() return org_model, sharded_model @@ -33,4 +33,4 @@ def run_forward(original_model, sharded_model, data_gen_fn, output_transform_fn, shard_output = sharded_model(**data) shard_output = output_transform_fn(shard_output) shard_loss = loss_fn(shard_output) - return org_output, org_loss, shard_output, shard_loss \ No newline at end of file + return org_output, org_loss, shard_output, shard_loss From 8d3f077086af8d7015cb7e756a1357f343c23b21 Mon Sep 17 00:00:00 2001 From: Frank Lee Date: Fri, 30 Jun 2023 09:58:08 +0800 Subject: [PATCH 42/49] [shardformer] integrate with data parallelism (#4103) --- colossalai/shardformer/shard/shard_config.py | 16 ++-- colossalai/shardformer/shard/sharder.py | 11 +-- colossalai/shardformer/shard/shardformer.py | 25 +----- tests/test_shardformer/test_model/_utils.py | 6 +- .../test_model/test_shard_bert.py | 2 +- .../test_model/test_shard_bloom.py | 2 +- .../test_model/test_shard_gpt2.py | 2 +- .../test_model/test_shard_llama.py | 2 +- .../test_model/test_shard_opt.py | 2 +- .../test_model/test_shard_t5.py | 2 +- tests/test_shardformer/test_with_torch_ddp.py | 77 +++++++++++++++++++ 11 files changed, 97 insertions(+), 50 deletions(-) create mode 100644 tests/test_shardformer/test_with_torch_ddp.py diff --git a/colossalai/shardformer/shard/shard_config.py b/colossalai/shardformer/shard/shard_config.py index 428ebc9780ba..e83191210a15 100644 --- a/colossalai/shardformer/shard/shard_config.py +++ b/colossalai/shardformer/shard/shard_config.py @@ -1,5 +1,8 @@ from dataclasses import dataclass +import torch.distributed as dist +from torch.distributed import ProcessGroup + from colossalai.cluster.dist_coordinator import DistCoordinator __all__ = ['ShardConfig'] @@ -11,10 +14,10 @@ class ShardConfig: The config for sharding the huggingface model Args: - tensor_parallel_size (int): The size of tensor parallel + tensor_parallel_process_group (int): The process group for tensor parallelism, defaults to None, which is the global process group. enable_fused_normalization (bool): Whether to use fused layernorm, default is False """ - tensor_parallel_size: int + tensor_parallel_process_group: int = None enable_fused_normalization: bool = False # TODO: add support for tensor parallel @@ -25,10 +28,5 @@ class ShardConfig: # gather_output: bool = True def __post_init__(self): - coordinator = DistCoordinator() - - # ensure the parallel size can match the world size - world_size = coordinator.world_size - self.data_parallel_size = world_size // self.tensor_parallel_size - assert world_size == self.data_parallel_size * self.tensor_parallel_size, \ - f"The world size ({world_size}) should be divisible by the data parallel size {self.data_parallel_size} and tensor parallel size {self.tensor_parallel_size}" + # get the parallel size + self.tensor_parallel_size = dist.get_world_size(self.tensor_parallel_process_group) diff --git a/colossalai/shardformer/shard/sharder.py b/colossalai/shardformer/shard/sharder.py index c2444e1f765c..e9b27ea45959 100644 --- a/colossalai/shardformer/shard/sharder.py +++ b/colossalai/shardformer/shard/sharder.py @@ -22,16 +22,10 @@ class ModelSharder(object): shard_config: The setting of distributed model """ - def __init__( - self, - model: nn.Module, - policy: Policy, - shard_config: ShardConfig = None, # TODO - pg_manager: ProcessGroupManager = None) -> None: + def __init__(self, model: nn.Module, policy: Policy, shard_config: ShardConfig = None) -> None: self.model = model self.policy = get_autopolicy(self.model) if policy is None else policy self.shard_config = shard_config - self.pg_manager = pg_manager def shard(self) -> None: r""" @@ -198,7 +192,8 @@ def _replace_sub_module( continue try: - replace_layer = target_module.from_native_module(native_sub_module, self.pg_manager.pg_store['tp1d'], + replace_layer = target_module.from_native_module(native_sub_module, + self.shard_config.tensor_parallel_process_group, **kwargs) except Exception as e: raise RuntimeError( diff --git a/colossalai/shardformer/shard/shardformer.py b/colossalai/shardformer/shard/shardformer.py index 1208a9d090fb..7c4220c3a9fb 100644 --- a/colossalai/shardformer/shard/shardformer.py +++ b/colossalai/shardformer/shard/shardformer.py @@ -1,7 +1,6 @@ import torch.nn as nn -from torch.utils.data import Dataset -from colossalai.cluster import DistCoordinator, ProcessGroupManager +from colossalai.cluster import DistCoordinator from ..policies.basepolicy import Policy from .shard_config import ShardConfig @@ -28,7 +27,6 @@ class ShardFormer: tensor_parallel_mode='1d', ) shard_former = ShardFormer(shard_config=shard_config) - shard_former.init_distributed() model = shard_former.shard_model(org_model) ``` """ @@ -41,19 +39,6 @@ def __init__(self, shard_config: ShardConfig): """ self.coordinator = DistCoordinator() self.shard_config = shard_config - self.pg_manager = None - - def init_distributed(self) -> ProcessGroupManager: - """ - Initialize the distributed process group according to the - """ - # create process group manager and 1d process group - # TODO: may need to support other parallel mode when the config has such as field - pg_manager = ProcessGroupManager() - pg_manager.create_process_group(name='tp1d', ranks=range(self.coordinator.world_size)) - self.pg_manager = pg_manager - - return pg_manager def shard_model(self, model: nn.Module, policy: Policy = None): r""" @@ -64,12 +49,6 @@ def shard_model(self, model: nn.Module, policy: Policy = None): shard_config (`ShardConfig`): the config for distribute information policy (`Policy`): the custom policy for sharding """ - sharder = ModelSharder(model=model, shard_config=self.shard_config, policy=policy, pg_manager=self.pg_manager) + sharder = ModelSharder(model=model, shard_config=self.shard_config, policy=policy) sharder.shard() return model - - def shard_dataset(self, dataset: Dataset): - """ - Shard dataset for DP - """ - pass diff --git a/tests/test_shardformer/test_model/_utils.py b/tests/test_shardformer/test_model/_utils.py index e49b0246ced5..a6355bf1c75e 100644 --- a/tests/test_shardformer/test_model/_utils.py +++ b/tests/test_shardformer/test_model/_utils.py @@ -3,17 +3,15 @@ from colossalai.shardformer import ShardConfig, ShardFormer -def build_model(world_size, model_fn): +def build_model(model_fn): # create new model org_model = model_fn().cuda() # shard model - shard_config = ShardConfig(tensor_parallel_size=world_size, enable_fused_normalization=True) + shard_config = ShardConfig(enable_fused_normalization=True) model_copy = copy.deepcopy(org_model) shard_former = ShardFormer(shard_config=shard_config) - shard_former.init_distributed() sharded_model = shard_former.shard_model(model_copy).cuda() - return org_model, sharded_model diff --git a/tests/test_shardformer/test_model/test_shard_bert.py b/tests/test_shardformer/test_model/test_shard_bert.py index ad98e3d073d4..a089a1ab33cc 100644 --- a/tests/test_shardformer/test_model/test_shard_bert.py +++ b/tests/test_shardformer/test_model/test_shard_bert.py @@ -42,7 +42,7 @@ def check_bert(rank, world_size, port): sub_model_zoo = model_zoo.get_sub_registry('transformers_bert') for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items(): - org_model, sharded_model = build_model(world_size, model_fn) + org_model, sharded_model = build_model(model_fn) check_forward_backward(org_model, sharded_model, data_gen_fn, output_transform_fn, loss_fn) torch.cuda.empty_cache() diff --git a/tests/test_shardformer/test_model/test_shard_bloom.py b/tests/test_shardformer/test_model/test_shard_bloom.py index 7e2e3dfa8f81..2e7ae7067467 100644 --- a/tests/test_shardformer/test_model/test_shard_bloom.py +++ b/tests/test_shardformer/test_model/test_shard_bloom.py @@ -42,7 +42,7 @@ def check_bloom(rank, world_size, port): sub_model_zoo = model_zoo.get_sub_registry('transformers_bloom') for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items(): - org_model, sharded_model = build_model(world_size, model_fn) + org_model, sharded_model = build_model(model_fn) check_forward_backward(org_model, sharded_model, data_gen_fn, output_transform_fn, loss_fn) torch.cuda.empty_cache() diff --git a/tests/test_shardformer/test_model/test_shard_gpt2.py b/tests/test_shardformer/test_model/test_shard_gpt2.py index 676267c2ca2a..4d4dc3c1e5b4 100644 --- a/tests/test_shardformer/test_model/test_shard_gpt2.py +++ b/tests/test_shardformer/test_model/test_shard_gpt2.py @@ -43,7 +43,7 @@ def check_gpt2(rank, world_size, port): sub_model_zoo = model_zoo.get_sub_registry('transformers_gpt') for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items(): - org_model, sharded_model = build_model(world_size, model_fn) + org_model, sharded_model = build_model(model_fn) check_forward_backward(org_model, sharded_model, data_gen_fn, output_transform_fn, loss_fn) torch.cuda.empty_cache() diff --git a/tests/test_shardformer/test_model/test_shard_llama.py b/tests/test_shardformer/test_model/test_shard_llama.py index 8b672af500bd..763fb2a6bf20 100644 --- a/tests/test_shardformer/test_model/test_shard_llama.py +++ b/tests/test_shardformer/test_model/test_shard_llama.py @@ -50,7 +50,7 @@ def check_llama(rank, world_size, port): sub_model_zoo = model_zoo.get_sub_registry('transformers_llama') for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items(): - org_model, sharded_model = build_model(world_size, model_fn) + org_model, sharded_model = build_model(model_fn) check_forward_backward(org_model, sharded_model, data_gen_fn, output_transform_fn, loss_fn) torch.cuda.empty_cache() diff --git a/tests/test_shardformer/test_model/test_shard_opt.py b/tests/test_shardformer/test_model/test_shard_opt.py index 4d4c55770144..d70b5d8e57d9 100644 --- a/tests/test_shardformer/test_model/test_shard_opt.py +++ b/tests/test_shardformer/test_model/test_shard_opt.py @@ -54,7 +54,7 @@ def check_OPTModel(rank, world_size, port): sub_model_zoo = model_zoo.get_sub_registry('transformers_opt') for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items(): - org_model, sharded_model = build_model(world_size, model_fn) + org_model, sharded_model = build_model(model_fn) check_forward_backward(org_model, sharded_model, data_gen_fn, output_transform_fn, loss_fn) torch.cuda.empty_cache() diff --git a/tests/test_shardformer/test_model/test_shard_t5.py b/tests/test_shardformer/test_model/test_shard_t5.py index 6074a902e9b0..6f558e237970 100644 --- a/tests/test_shardformer/test_model/test_shard_t5.py +++ b/tests/test_shardformer/test_model/test_shard_t5.py @@ -42,7 +42,7 @@ def check_t5(rank, world_size, port): sub_model_zoo = model_zoo.get_sub_registry('transformers_t5') for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items(): - org_model, sharded_model = build_model(world_size, model_fn) + org_model, sharded_model = build_model(model_fn) check_forward_backward(org_model, sharded_model, data_gen_fn, output_transform_fn, loss_fn) torch.cuda.empty_cache() diff --git a/tests/test_shardformer/test_with_torch_ddp.py b/tests/test_shardformer/test_with_torch_ddp.py new file mode 100644 index 000000000000..61b672650965 --- /dev/null +++ b/tests/test_shardformer/test_with_torch_ddp.py @@ -0,0 +1,77 @@ +import pytest +import torch +import torch.distributed as dist +from torch.nn.parallel import DistributedDataParallel as DDP + +import colossalai +from colossalai.cluster import DistCoordinator +from colossalai.logging import disable_existing_loggers +from colossalai.shardformer import ShardConfig, ShardFormer +from colossalai.testing import clear_cache_before_run, rerun_if_address_is_in_use, spawn +from tests.kit.model_zoo import model_zoo + + +def check_shardformer_with_ddp(rank, world_size, port): + disable_existing_loggers() + colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') + + sub_model_zoo = model_zoo.get_sub_registry('transformers_gpt') + + # create shardformer + # ranks: [0, 1, 2, 3] + # tp ranks = [0, 1], [2, 3] + # dp ranks = [0, 2], [1, 3] + dp_process_group_1 = dist.new_group([0, 2]) + dp_process_group_2 = dist.new_group([1, 3]) + tp_process_group_1 = dist.new_group([0, 1]) + tp_process_group_2 = dist.new_group([2, 3]) + + coordinator = DistCoordinator() + + if coordinator.rank in [0, 1]: + tp_process_group = tp_process_group_1 + else: + tp_process_group = tp_process_group_2 + + if coordinator.rank in [0, 2]: + dp_process_group = dp_process_group_1 + else: + dp_process_group = dp_process_group_2 + + shard_config = ShardConfig(tensor_parallel_process_group=tp_process_group, enable_fused_normalization=True) + shardformer = ShardFormer(shard_config=shard_config) + + for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items(): + # create and shard model + model = model_fn().cuda() + sharded_model = shardformer.shard_model(model) + + # add ddp + sharded_ddp_model = DDP(sharded_model, process_group=dp_process_group) + + # prepare input + data = data_gen_fn() + data = {k: v.cuda() for k, v in data.items()} + + # switch to train mode + sharded_ddp_model.train() + + # run forward + output = sharded_ddp_model(**data) + loss = loss_fn(output) + + # backward + loss.backward() + torch.cuda.empty_cache() + + +@pytest.mark.dist +@rerun_if_address_is_in_use() +@clear_cache_before_run() +def test_gpt2(): + spawn(check_shardformer_with_ddp, 4) + + +if __name__ == "__main__": + test_gpt2() + test_gpt2() From 60d2cadbe4ee37e1ec6141cd6da8d4d4e1a61b2b Mon Sep 17 00:00:00 2001 From: Frank Lee Date: Fri, 30 Jun 2023 10:56:29 +0800 Subject: [PATCH 43/49] [shardformer] import huggingface implicitly (#4101) --- colossalai/shardformer/policies/autopolicy.py | 2 ++ colossalai/shardformer/policies/basepolicy.py | 2 ++ colossalai/shardformer/policies/bert.py | 30 +++++++++++++------ colossalai/shardformer/policies/gpt2.py | 14 +++++++-- colossalai/shardformer/policies/llama.py | 12 ++++++-- colossalai/shardformer/policies/opt.py | 17 ++++++----- colossalai/shardformer/policies/t5.py | 27 +++++++++-------- colossalai/shardformer/policies/vit.py | 7 +++-- colossalai/shardformer/shard/shard_config.py | 18 ++++++++++- 9 files changed, 91 insertions(+), 38 deletions(-) diff --git a/colossalai/shardformer/policies/autopolicy.py b/colossalai/shardformer/policies/autopolicy.py index 17c063c8d2cf..8051433e8d71 100644 --- a/colossalai/shardformer/policies/autopolicy.py +++ b/colossalai/shardformer/policies/autopolicy.py @@ -5,6 +5,8 @@ from .basepolicy import Policy +__all__ = ["PolicyLocation", "get_autopolicy", "import_policy"] + @dataclass class PolicyLocation: diff --git a/colossalai/shardformer/policies/basepolicy.py b/colossalai/shardformer/policies/basepolicy.py index 8835e38cbbe4..2b972606948c 100644 --- a/colossalai/shardformer/policies/basepolicy.py +++ b/colossalai/shardformer/policies/basepolicy.py @@ -8,6 +8,8 @@ from ..shard.shard_config import ShardConfig +__all__ = ["ParallelModule", "SubModuleReplacementDescription", "ModulePolicyDescription", "Policy"] + class ParallelModule(): diff --git a/colossalai/shardformer/policies/bert.py b/colossalai/shardformer/policies/bert.py index 545669f1f463..cec7f0eb2a6d 100644 --- a/colossalai/shardformer/policies/bert.py +++ b/colossalai/shardformer/policies/bert.py @@ -1,18 +1,16 @@ import torch.nn as nn -from transformers.models.bert.modeling_bert import ( - BertEmbeddings, - BertForMultipleChoice, - BertForSequenceClassification, - BertForTokenClassification, - BertLayer, - BertLMPredictionHead, -) import colossalai.shardformer.layer as col_nn from .._utils import getattr_, setattr_ from .basepolicy import ModulePolicyDescription, Policy, SubModuleReplacementDescription +__all__ = [ + 'BertPolicy', 'BertModelPolicy', 'BertForPretrainingPolicy', 'BertLMHeadModelPolicy', 'BertForMaskedLMPolicy', + 'BertForNextSentencePredictionPolicy', 'BertForSequenceClassificationPolicy', 'BertForTokenClassificationPolicy', + 'BertForMultipleChoicePolicy' +] + class BertPolicy(Policy): @@ -33,6 +31,8 @@ def preprocess(self): return self.model def module_policy(self): + from transformers.models.bert.modeling_bert import BertEmbeddings, BertLayer + base_policy = { BertLayer: ModulePolicyDescription( @@ -123,7 +123,7 @@ def module_policy(self): def new_model_class(self): # do nothing - return self.model + return None def postprocess(self): return self.model @@ -143,6 +143,8 @@ def __init__(self) -> None: super().__init__() def module_policy(self): + from transformers.models.bert.modeling_bert import BertLMPredictionHead + module_policy = super().module_policy() addon_module = { BertLMPredictionHead: @@ -184,6 +186,8 @@ def __init__(self) -> None: super().__init__() def module_policy(self): + from transformers.models.bert.modeling_bert import BertLMPredictionHead + module_policy = super().module_policy() addon_module = { BertLMPredictionHead: @@ -221,6 +225,8 @@ def __init__(self) -> None: super().__init__() def module_policy(self): + from transformers.models.bert.modeling_bert import BertLMPredictionHead + module_policy = super().module_policy() addon_module = { BertLMPredictionHead: @@ -261,6 +267,8 @@ def __init__(self) -> None: super().__init__() def module_policy(self): + from transformers.models.bert.modeling_bert import BertForSequenceClassification + module_policy = super().module_policy() addon_module = { BertForSequenceClassification: @@ -284,6 +292,8 @@ def __init__(self) -> None: super().__init__() def module_policy(self): + from transformers.models.bert.modeling_bert import BertForTokenClassification + module_policy = super().module_policy() addon_module = { BertForTokenClassification: @@ -314,6 +324,8 @@ def __init__(self) -> None: super().__init__() def module_policy(self): + from transformers.models.bert.modeling_bert import BertForMultipleChoice + module_policy = super().module_policy() addon_module = { BertForMultipleChoice: diff --git a/colossalai/shardformer/policies/gpt2.py b/colossalai/shardformer/policies/gpt2.py index 3d6d94b8e90d..c6108f5c0e85 100644 --- a/colossalai/shardformer/policies/gpt2.py +++ b/colossalai/shardformer/policies/gpt2.py @@ -1,11 +1,15 @@ import torch.nn as nn -from transformers.models.gpt2.modeling_gpt2 import GPT2Block, GPT2DoubleHeadsModel, GPT2LMHeadModel, GPT2Model import colossalai.shardformer.layer as col_nn from .._utils import getattr_, setattr_ from .basepolicy import ModulePolicyDescription, Policy, SubModuleReplacementDescription +__all__ = [ + 'GPT2Policy', 'GPT2ModelPolicy', 'GPT2LMHeadModelPolicy', 'GPT2DoubleHeadsModelPolicy', + 'GPT2ForTokenClassificationPolicy', 'GPT2ForSequenceClassificationPolicy' +] + class GPT2Policy(Policy): @@ -25,7 +29,9 @@ def preprocess(self): return self.model def module_policy(self): - base_policy = { + from transformers.models.gpt2.modeling_gpt2 import GPT2Block, GPT2Model + + return { GPT2Model: ModulePolicyDescription(attribute_replacement={}, param_replacement=[], @@ -125,6 +131,8 @@ def __init__(self) -> None: super().__init__() def module_policy(self): + from transformers.models.gpt2.modeling_gpt2 import GPT2LMHeadModel + module_policy = super().module_policy() addon_module = { GPT2LMHeadModel: @@ -156,6 +164,8 @@ def __init__(self) -> None: super().__init__() def module_policy(self): + from transformers.models.gpt2.modeling_gpt2 import GPT2DoubleHeadsModel + module_policy = super().module_policy() addon_module = { GPT2DoubleHeadsModel: diff --git a/colossalai/shardformer/policies/llama.py b/colossalai/shardformer/policies/llama.py index b36180ce3188..2fd2bc22303b 100644 --- a/colossalai/shardformer/policies/llama.py +++ b/colossalai/shardformer/policies/llama.py @@ -1,13 +1,13 @@ from typing import Dict, Union import torch.nn as nn -from transformers import LlamaForCausalLM, LlamaForSequenceClassification -from transformers.models.llama.modeling_llama import LlamaDecoderLayer, LlamaModel from colossalai.shardformer.layer import FusedRMSNorm, Linear1D_Col, Linear1D_Row, VocabParallelEmbedding1D from .basepolicy import ModulePolicyDescription, Policy, SubModuleReplacementDescription +__all__ = ['LlamaPolicy', 'LlamaForCausalLMPolicy', 'LlamaForSequenceClassificationPolicy'] + class LlamaPolicy(Policy): @@ -26,7 +26,9 @@ def preprocess(self): return self.model def module_policy(self) -> Dict[Union[str, nn.Module], ModulePolicyDescription]: - base_policy = { + from transformers.models.llama.modeling_llama import LlamaDecoderLayer, LlamaModel + + return { LlamaDecoderLayer: ModulePolicyDescription( attribute_replacement={ @@ -109,6 +111,8 @@ def postprocess(self): class LlamaForCausalLMPolicy(LlamaPolicy): def module_policy(self): + from transformers import LlamaForCausalLM + policy = super().module_policy() # add a new item for casual lm new_item = { @@ -128,6 +132,8 @@ def module_policy(self): class LlamaForSequenceClassificationPolicy(LlamaPolicy): def module_policy(self): + from transformers import LlamaForSequenceClassification + policy = super().module_policy() # add a new item for sequence classification diff --git a/colossalai/shardformer/policies/opt.py b/colossalai/shardformer/policies/opt.py index ce3873954e15..ec1bae20886a 100644 --- a/colossalai/shardformer/policies/opt.py +++ b/colossalai/shardformer/policies/opt.py @@ -1,15 +1,12 @@ -from transformers.models.opt.modeling_opt import ( - OPTAttention, - OPTDecoder, - OPTDecoderLayer, - OPTForCausalLM, - OPTForSequenceClassification, -) - from colossalai.shardformer.layer import Embedding1D, FusedLayerNorm, Linear1D_Col, Linear1D_Row from .basepolicy import ModulePolicyDescription, Policy, SubModuleReplacementDescription +__all__ = [ + 'OPTPolicy', 'OPTModelPolicy', 'OPTForCausalLMPolicy', 'OPTForSequenceClassificationPolicy', + 'OPTForQuestionAnsweringPolicy' +] + class OPTPolicy(Policy): @@ -29,6 +26,8 @@ def preprocess(self): return self.model def module_policy(self): + from transformers.models.opt.modeling_opt import OPTAttention, OPTDecoder, OPTDecoderLayer + base_policy = { OPTDecoder: ModulePolicyDescription(attribute_replacement={}, @@ -111,6 +110,8 @@ def __init__(self) -> None: class OPTForCausalLMPolicy(OPTPolicy): def module_policy(self): + from transformers.models.opt.modeling_opt import OPTForCausalLM + policy = super().module_policy() new_item = { OPTForCausalLM: diff --git a/colossalai/shardformer/policies/t5.py b/colossalai/shardformer/policies/t5.py index d35f688a0b61..845bfe727745 100644 --- a/colossalai/shardformer/policies/t5.py +++ b/colossalai/shardformer/policies/t5.py @@ -1,15 +1,4 @@ -from transformers import T5ForConditionalGeneration -from transformers.models.t5.modeling_t5 import ( - T5Attention, - T5DenseActDense, - T5DenseGatedActDense, - T5LayerCrossAttention, - T5LayerFF, - T5LayerSelfAttention, - T5Stack, -) - -from colossalai.shardformer.layer import DropoutForParallelInput, Embedding1D, FusedRMSNorm, Linear1D_Col, Linear1D_Row +from colossalai.shardformer.layer import DropoutForParallelInput, Embedding1D, Linear1D_Col, Linear1D_Row from .basepolicy import ModulePolicyDescription, Policy, SubModuleReplacementDescription @@ -34,7 +23,17 @@ def preprocess(self): return self.model def module_policy(self): - base_policy = { + from transformers.models.t5.modeling_t5 import ( + T5Attention, + T5DenseActDense, + T5DenseGatedActDense, + T5LayerCrossAttention, + T5LayerFF, + T5LayerSelfAttention, + T5Stack, + ) + + return { T5Stack: ModulePolicyDescription(attribute_replacement={}, param_replacement=[], @@ -165,6 +164,8 @@ def postprocess(self): class T5ForConditionalGenerationPolicy(T5ModelPolicy): def module_policy(self): + from transformers import T5ForConditionalGeneration + policy = super().module_policy() new_item = { diff --git a/colossalai/shardformer/policies/vit.py b/colossalai/shardformer/policies/vit.py index 5d8a235db7a9..6a404c2faf0f 100644 --- a/colossalai/shardformer/policies/vit.py +++ b/colossalai/shardformer/policies/vit.py @@ -1,12 +1,13 @@ from typing import Dict, Union import torch.nn as nn -from transformers.models.vit.modeling_vit import ViTAttention, ViTEmbeddings, ViTLayer, ViTModel from colossalai.shardformer.layer import DropoutForReplicatedInput, FusedLayerNorm, Linear1D_Col, Linear1D_Row from .basepolicy import ModulePolicyDescription, Policy, SubModuleReplacementDescription +__all__ = ['ViTPolicy'] + class ViTPolicy(Policy): @@ -25,7 +26,9 @@ def preprocess(self): return self.model def module_policy(self) -> Dict[Union[str, nn.Module], ModulePolicyDescription]: - base_policy = { + from transformers.models.vit.modeling_vit import ViTEmbeddings, ViTLayer + + return { ViTEmbeddings: ModulePolicyDescription(attribute_replacement={}, param_replacement=[], diff --git a/colossalai/shardformer/shard/shard_config.py b/colossalai/shardformer/shard/shard_config.py index e83191210a15..2116d2e622e2 100644 --- a/colossalai/shardformer/shard/shard_config.py +++ b/colossalai/shardformer/shard/shard_config.py @@ -19,6 +19,7 @@ class ShardConfig: """ tensor_parallel_process_group: int = None enable_fused_normalization: bool = False + enable_all_optimization: bool = False # TODO: add support for tensor parallel # pipeline_parallel_size: int @@ -27,6 +28,21 @@ class ShardConfig: # inference_only: bool = True # gather_output: bool = True + @property + def tensor_parallel_size(self): + return self._tensor_parallel_size + def __post_init__(self): # get the parallel size - self.tensor_parallel_size = dist.get_world_size(self.tensor_parallel_process_group) + self._tensor_parallel_size = dist.get_world_size(self.tensor_parallel_process_group) + + # turn on all optimization if all_optimization is set to True + if self.enable_all_optimization: + self._turn_on_all_optimization() + + def _turn_on_all_optimization(self): + """ + Turn on all optimization. + """ + # you can add all the optimization flag here + self.fused_layernorm = True From 26ecfd7945483e9a1a351ac16d904fc25549ab3a Mon Sep 17 00:00:00 2001 From: Frank Lee Date: Fri, 30 Jun 2023 16:16:44 +0800 Subject: [PATCH 44/49] [shardformer] added embedding gradient check (#4124) --- colossalai/shardformer/_utils.py | 4 +- colossalai/shardformer/policies/bert.py | 2 +- colossalai/shardformer/policies/bloom.py | 19 +++- colossalai/shardformer/policies/opt.py | 17 ++- colossalai/shardformer/policies/t5.py | 105 +++++++++++++++--- colossalai/shardformer/shard/sharder.py | 11 -- tests/kit/model_zoo/registry.py | 2 + .../test_model/test_shard_bert.py | 29 +++-- .../test_model/test_shard_bloom.py | 30 +++-- .../test_model/test_shard_gpt2.py | 30 +++-- .../test_model/test_shard_llama.py | 16 ++- .../test_model/test_shard_opt.py | 24 +++- .../test_model/test_shard_t5.py | 35 +++++- .../test_model/test_shard_vit.py | 1 + 14 files changed, 253 insertions(+), 72 deletions(-) diff --git a/colossalai/shardformer/_utils.py b/colossalai/shardformer/_utils.py index a1c7203a929f..4ad877e72357 100644 --- a/colossalai/shardformer/_utils.py +++ b/colossalai/shardformer/_utils.py @@ -55,7 +55,7 @@ def setattr_(obj, attr: str, value, ignore: bool = False): except AttributeError: if ignore: return - raise AttributeError(f"Object {obj} has no attribute {attr}") + raise AttributeError(f"Object {obj.__class__.__name__} has no attribute {attr}") setattr(obj, attrs[-1], value) @@ -76,5 +76,5 @@ def getattr_(obj, attr: str, ignore: bool = False): except AttributeError: if ignore: return None - raise AttributeError(f"Object {obj} has no attribute {attr}") + raise AttributeError(f"Object {obj.__class__.__name__} has no attribute {attr}") return obj diff --git a/colossalai/shardformer/policies/bert.py b/colossalai/shardformer/policies/bert.py index cec7f0eb2a6d..7cf6caf7ca49 100644 --- a/colossalai/shardformer/policies/bert.py +++ b/colossalai/shardformer/policies/bert.py @@ -97,7 +97,7 @@ def module_policy(self): ), SubModuleReplacementDescription( suffix="dropout", - target_module=col_nn.DropoutForParallelInput, + target_module=col_nn.DropoutForReplicatedInput, ) ]) } diff --git a/colossalai/shardformer/policies/bloom.py b/colossalai/shardformer/policies/bloom.py index 4e34f24643c2..c59cfbb405fc 100644 --- a/colossalai/shardformer/policies/bloom.py +++ b/colossalai/shardformer/policies/bloom.py @@ -1,8 +1,10 @@ import torch import torch.distributed as dist +import torch.nn as nn import colossalai.shardformer.layer as col_nn +from .._utils import getattr_, setattr_ from .basepolicy import ModulePolicyDescription, Policy, SubModuleReplacementDescription @@ -73,7 +75,6 @@ def preprocess(self): r""" Reshape the Embedding layer to make the embedding dimension divisible by world_size """ - # TODO: vocab_size = self.model.config.vocab_size world_size = self.shard_config.tensor_parallel_size if vocab_size % world_size != 0: @@ -161,13 +162,12 @@ def module_policy(self): def new_model_class(self): # do nothing - return self.model + return None def postprocess(self): return self.model -# BertModel class BloomModelPolicy(BloomPolicy): pass @@ -191,6 +191,19 @@ def module_policy(self): policy.update(new_item) return policy + def postprocess(self): + binding_map = {"transformer.word_embeddings.weight": "lm_head.weight"} + for k, v in binding_map.items(): + param = getattr_(self.model, k) + + if not isinstance(param, nn.Parameter): + param = nn.Parameter(param) + + # tie weights + setattr_(self.model, k, param) + setattr_(self.model, v, param) + return self.model + class BloomForSequenceClassificationPolicy(BloomPolicy): diff --git a/colossalai/shardformer/policies/opt.py b/colossalai/shardformer/policies/opt.py index ec1bae20886a..dfbaaf5785ba 100644 --- a/colossalai/shardformer/policies/opt.py +++ b/colossalai/shardformer/policies/opt.py @@ -1,5 +1,6 @@ -from colossalai.shardformer.layer import Embedding1D, FusedLayerNorm, Linear1D_Col, Linear1D_Row +from colossalai.shardformer.layer import FusedLayerNorm, Linear1D_Col, Linear1D_Row, VocabParallelEmbedding1D +from .._utils import getattr_, setattr_ from .basepolicy import ModulePolicyDescription, Policy, SubModuleReplacementDescription __all__ = [ @@ -35,7 +36,7 @@ def module_policy(self): sub_module_replacement=[ SubModuleReplacementDescription( suffix="embed_tokens", - target_module=Embedding1D, + target_module=VocabParallelEmbedding1D, ) ]), OPTDecoderLayer: @@ -127,6 +128,18 @@ def module_policy(self): policy.update(new_item) return policy + def postprocess(self): + binding_map = { + 'model.decoder.embed_tokens': 'lm_head', + } + + for k, v in binding_map.items(): + src_mod = getattr_(self.model, k) + dst_mod = getattr_(self.model, v) + dst_mod.weight = src_mod.weight + + return self.model + class OPTForSequenceClassificationPolicy(OPTPolicy): diff --git a/colossalai/shardformer/policies/t5.py b/colossalai/shardformer/policies/t5.py index 845bfe727745..8853687e7621 100644 --- a/colossalai/shardformer/policies/t5.py +++ b/colossalai/shardformer/policies/t5.py @@ -1,11 +1,20 @@ -from colossalai.shardformer.layer import DropoutForParallelInput, Embedding1D, Linear1D_Col, Linear1D_Row +from colossalai.shardformer.layer import ( + DropoutForParallelInput, + Embedding1D, + FusedRMSNorm, + Linear1D_Col, + Linear1D_Row, + VocabParallelEmbedding1D, +) +from colossalai.shardformer.policies.basepolicy import ModulePolicyDescription +from .._utils import getattr_, setattr_ from .basepolicy import ModulePolicyDescription, Policy, SubModuleReplacementDescription __all__ = ["T5ModelPolicy", "T5ForConditionalGenerationPolicy", "T5EncoderPolicy"] -class T5ModelPolicy(Policy): +class T5BasePolicy(Policy): def config_sanity_check(self): pass @@ -33,7 +42,7 @@ def module_policy(self): T5Stack, ) - return { + base_policy = { T5Stack: ModulePolicyDescription(attribute_replacement={}, param_replacement=[], @@ -41,6 +50,10 @@ def module_policy(self): SubModuleReplacementDescription( suffix="dropout", target_module=DropoutForParallelInput, + ), + SubModuleReplacementDescription( + suffix="embed_tokens", + target_module=Embedding1D, ) ]), T5LayerSelfAttention: @@ -158,30 +171,86 @@ def new_model_class(self): return None def postprocess(self): + binding_map = [["shared", "encoder.embed_tokens"], ["shared", "decoder.embed_tokens"]] + + for k, v in binding_map: + mod = getattr_(self.model, k) + setattr_(self.model, v, mod) return self.model -class T5ForConditionalGenerationPolicy(T5ModelPolicy): +class T5ModelPolicy(T5BasePolicy): + + def module_policy(self): + from transformers import T5Model + + base_policy = super().module_policy() + base_policy[T5Model] = ModulePolicyDescription(attribute_replacement={}, + param_replacement=[], + sub_module_replacement=[ + SubModuleReplacementDescription( + suffix="shared", + target_module=VocabParallelEmbedding1D, + ) + ]) + return base_policy + + +class T5ForConditionalGenerationPolicy(T5BasePolicy): def module_policy(self): from transformers import T5ForConditionalGeneration policy = super().module_policy() + policy[T5ForConditionalGeneration] = ModulePolicyDescription(attribute_replacement={}, + param_replacement=[], + sub_module_replacement=[ + SubModuleReplacementDescription( + suffix="shared", + target_module=VocabParallelEmbedding1D, + ), + SubModuleReplacementDescription( + suffix="lm_head", + target_module=Linear1D_Col, + kwargs=dict(gather_output=True)) + ]) + return policy - new_item = { - T5ForConditionalGeneration: - ModulePolicyDescription(attribute_replacement={}, - param_replacement=[], - sub_module_replacement=[ - SubModuleReplacementDescription(suffix="lm_head", - target_module=Linear1D_Col, - kwargs=dict(gather_output=True)) - ]) - } + def postprocess(self): + super().postprocess() + + binding_map = {"shared": "lm_head"} + + for k, v in binding_map.items(): + src_mod = getattr_(self.model, k) + dst_mod = getattr_(self.model, v) + dst_mod.weight = src_mod.weight + + return self.model - policy.update(new_item) - return policy +class T5EncoderPolicy(T5BasePolicy): -class T5EncoderPolicy(T5ModelPolicy): - pass + def module_policy(self): + from transformers import T5EncoderModel + + base_policy = super().module_policy() + base_policy[T5EncoderModel] = ModulePolicyDescription(attribute_replacement={}, + param_replacement=[], + sub_module_replacement=[ + SubModuleReplacementDescription( + suffix="shared", + target_module=VocabParallelEmbedding1D, + ) + ]) + return base_policy + + def postprocess(self): + binding_map = [ + ["shared", "encoder.embed_tokens"], + ] + + for k, v in binding_map: + mod = getattr_(self.model, k) + setattr_(self.model, v, mod) + return self.model diff --git a/colossalai/shardformer/shard/sharder.py b/colossalai/shardformer/shard/sharder.py index e9b27ea45959..81c032b95f03 100644 --- a/colossalai/shardformer/shard/sharder.py +++ b/colossalai/shardformer/shard/sharder.py @@ -38,17 +38,6 @@ def shard(self) -> None: self._replace_module() self._postprocess() - def reshape_embedding(self) -> None: - r""" - Reshape the Embedding layer to make the embedding dimension divisible by world_size - """ - vocab_size = self.model_config.vocab_size - world_size = self.shard_config.world_size - if vocab_size % world_size != 0: - new_vocab_size = vocab_size + world_size - vocab_size % world_size - self.model.resize_token_embeddings(new_vocab_size) - self.model_config = self.model.config - def _preprocess(self) -> None: self.model = self.policy.preprocess() diff --git a/tests/kit/model_zoo/registry.py b/tests/kit/model_zoo/registry.py index efbf3a4d37b1..1e7ef3b62736 100644 --- a/tests/kit/model_zoo/registry.py +++ b/tests/kit/model_zoo/registry.py @@ -70,6 +70,8 @@ def get_sub_registry(self, keyword: str): for k, v in self.items(): if keyword in k: new_dict[k] = v + + assert len(new_dict) > 0, f'No model found with keyword {keyword}' return new_dict diff --git a/tests/test_shardformer/test_model/test_shard_bert.py b/tests/test_shardformer/test_model/test_shard_bert.py index a089a1ab33cc..87c4ef65bf1a 100644 --- a/tests/test_shardformer/test_model/test_shard_bert.py +++ b/tests/test_shardformer/test_model/test_shard_bert.py @@ -18,20 +18,35 @@ def check_forward_backward(org_model, sharded_model, data_gen_fn, output_transfo org_loss.backward() shard_loss.backward() - # check grad equality + assert torch.allclose(org_loss, shard_loss, + atol=1e-5), f"shard model loss is not equal to orgin model loss\n{org_loss}\n{shard_loss}" + + # check grad + if org_model.__class__.__name__ == 'BertModel': - org_grad = org_model.encoder.layer[0].attention.self.query.weight.grad - shard_grad = sharded_model.encoder.layer[0].attention.self.query.weight.grad + bert = org_model + sharded_bert = sharded_model else: - org_grad = org_model.bert.encoder.layer[0].attention.self.query.weight.grad - shard_grad = sharded_model.bert.encoder.layer[0].attention.self.query.weight.grad + bert = org_model.bert + sharded_bert = sharded_model.bert + + # compare self attention grad + org_grad = bert.encoder.layer[0].attention.self.query.weight.grad + shard_grad = sharded_bert.encoder.layer[0].attention.self.query.weight.grad shard_grad_list = [torch.zeros([*shard_grad.shape]).to('cuda') for _ in range(2)] shard_grad = torch.distributed.all_gather(shard_grad_list, shard_grad) all_shard_grad = torch.cat(shard_grad_list, dim=0) + assert torch.allclose(org_grad, all_shard_grad, + atol=1e-5), f"shard model grad is not equal to orgin model grad\n{org_grad}\n{all_shard_grad}" - assert torch.allclose(org_loss, shard_loss, - atol=1e-5), f"shard model loss is not equal to orgin model loss\n{org_loss}\n{shard_loss}" + # compare embedding grad + org_grad = bert.embeddings.word_embeddings.weight.grad + shard_grad = sharded_bert.embeddings.word_embeddings.weight.grad + + shard_grad_list = [torch.zeros([*shard_grad.shape]).to('cuda') for _ in range(2)] + shard_grad = torch.distributed.all_gather(shard_grad_list, shard_grad) + all_shard_grad = torch.cat(shard_grad_list, dim=0) assert torch.allclose(org_grad, all_shard_grad, atol=1e-5), f"shard model grad is not equal to orgin model grad\n{org_grad}\n{all_shard_grad}" diff --git a/tests/test_shardformer/test_model/test_shard_bloom.py b/tests/test_shardformer/test_model/test_shard_bloom.py index 2e7ae7067467..70d902a04517 100644 --- a/tests/test_shardformer/test_model/test_shard_bloom.py +++ b/tests/test_shardformer/test_model/test_shard_bloom.py @@ -18,20 +18,36 @@ def check_forward_backward(org_model, sharded_model, data_gen_fn, output_transfo org_loss.backward() shard_loss.backward() - # check grad equality + assert torch.allclose(org_loss, shard_loss, + atol=1e-5), f"shard model loss is not equal to orgin model loss\n{org_loss}\n{shard_loss}" + + # unwrap model if org_model.__class__.__name__ == 'BloomModel': - org_grad = org_model.h[0].self_attention.query_key_value.weight.grad - shard_grad = sharded_model.h[0].self_attention.query_key_value.weight.grad + bloom = org_model + sharded_bloom = sharded_model else: - org_grad = org_model.transformer.h[0].self_attention.query_key_value.weight.grad - shard_grad = sharded_model.transformer.h[0].self_attention.query_key_value.weight.grad + bloom = org_model.transformer + sharded_bloom = sharded_model.transformer + + # check attention grad + org_grad = bloom.h[0].self_attention.query_key_value.weight.grad + shard_grad = sharded_bloom.h[0].self_attention.query_key_value.weight.grad + + shard_grad_list = [torch.zeros([*shard_grad.shape]).to('cuda') for _ in range(2)] + torch.distributed.all_gather(shard_grad_list, shard_grad) + all_shard_grad = torch.cat(shard_grad_list, dim=0) + + assert torch.allclose(org_grad, all_shard_grad, + atol=1e-5), f"shard model grad is not equal to orgin model grad\n{org_grad}\n{all_shard_grad}" + + # check embedding weights + org_grad = bloom.word_embeddings.weight.grad + shard_grad = sharded_bloom.word_embeddings.weight.grad shard_grad_list = [torch.zeros([*shard_grad.shape]).to('cuda') for _ in range(2)] torch.distributed.all_gather(shard_grad_list, shard_grad) all_shard_grad = torch.cat(shard_grad_list, dim=0) - assert torch.allclose(org_loss, shard_loss, - atol=1e-5), f"shard model loss is not equal to orgin model loss\n{org_loss}\n{shard_loss}" assert torch.allclose(org_grad, all_shard_grad, atol=1e-5), f"shard model grad is not equal to orgin model grad\n{org_grad}\n{all_shard_grad}" diff --git a/tests/test_shardformer/test_model/test_shard_gpt2.py b/tests/test_shardformer/test_model/test_shard_gpt2.py index 4d4dc3c1e5b4..a4edc14bdbc3 100644 --- a/tests/test_shardformer/test_model/test_shard_gpt2.py +++ b/tests/test_shardformer/test_model/test_shard_gpt2.py @@ -18,20 +18,36 @@ def check_forward_backward(org_model, sharded_model, data_gen_fn, output_transfo org_loss.backward() shard_loss.backward() - # check grad equality + assert torch.allclose(org_loss, shard_loss, + atol=1e-5), f"shard model loss is not equal to origin model loss\n{org_loss}\n{shard_loss}" + + # unwrap model if org_model.__class__.__name__ == 'GPT2Model': - org_grad = org_model.h[0].mlp.c_fc.weight.grad - shard_grad = sharded_model.h[0].mlp.c_fc.weight.grad + org_model = org_model + sharded_model = sharded_model else: - org_grad = org_model.transformer.h[0].mlp.c_fc.weight.grad - shard_grad = sharded_model.transformer.h[0].mlp.c_fc.weight.grad + org_model = org_model.transformer + sharded_model = sharded_model.transformer + + # check mlp grad + org_grad = org_model.h[0].mlp.c_fc.weight.grad + shard_grad = sharded_model.h[0].mlp.c_fc.weight.grad shard_grad_list = [torch.zeros([*shard_grad.shape]).to('cuda') for _ in range(2)] shard_grad = torch.distributed.all_gather(shard_grad_list, shard_grad) all_shard_grad = torch.cat(shard_grad_list, dim=1) - assert torch.allclose(org_loss, shard_loss, - atol=1e-5), f"shard model loss is not equal to origin model loss\n{org_loss}\n{shard_loss}" + assert torch.allclose( + org_grad, all_shard_grad, + atol=1e-5), f"shard model grad is not equal to origin model grad\n{org_grad}\n{all_shard_grad}" + + # check embedding weights + org_grad = org_model.wte.weight.grad + shard_grad = sharded_model.wte.weight.grad + shard_grad_list = [torch.zeros([*shard_grad.shape]).to('cuda') for _ in range(2)] + shard_grad = torch.distributed.all_gather(shard_grad_list, shard_grad) + all_shard_grad = torch.cat(shard_grad_list, dim=0) + assert torch.allclose( org_grad, all_shard_grad, atol=1e-5), f"shard model grad is not equal to origin model grad\n{org_grad}\n{all_shard_grad}" diff --git a/tests/test_shardformer/test_model/test_shard_llama.py b/tests/test_shardformer/test_model/test_shard_llama.py index 763fb2a6bf20..a98743a6143a 100644 --- a/tests/test_shardformer/test_model/test_shard_llama.py +++ b/tests/test_shardformer/test_model/test_shard_llama.py @@ -23,7 +23,10 @@ def check_forward_backward(org_model, sharded_model, data_gen_fn, output_transfo org_loss.backward() shard_loss.backward() - # check grad + assert torch.allclose(org_loss, shard_loss, + atol=1e-5), f"shard model loss is not equal to orgin model loss\n{org_loss}\n{shard_loss}" + + # unwrap model if hasattr(org_model, 'model'): llama_model = org_model.model shard_llama_model = sharded_model.model @@ -31,14 +34,21 @@ def check_forward_backward(org_model, sharded_model, data_gen_fn, output_transfo llama_model = org_model shard_llama_model = sharded_model + # check attention grad org_grad = llama_model.layers[0].self_attn.q_proj.weight.grad shard_grad = shard_llama_model.layers[0].self_attn.q_proj.weight.grad shard_grad_list = [torch.zeros([*shard_grad.shape]).to('cuda') for _ in range(4)] shard_grad = torch.distributed.all_gather(shard_grad_list, shard_grad) all_shard_grad = torch.cat(shard_grad_list, dim=0) + assert torch.allclose(org_grad, all_shard_grad, + atol=1e-5), f"shard model grad is not equal to orgin model grad\n{org_grad}\n{shard_grad}" - assert torch.allclose(org_loss, shard_loss, - atol=1e-5), f"shard model loss is not equal to orgin model loss\n{org_loss}\n{shard_loss}" + # check embedding grad + org_grad = llama_model.embed_tokens.weight.grad + shard_grad = shard_llama_model.embed_tokens.weight.grad + shard_grad_list = [torch.zeros([*shard_grad.shape]).to('cuda') for _ in range(4)] + shard_grad = torch.distributed.all_gather(shard_grad_list, shard_grad) + all_shard_grad = torch.cat(shard_grad_list, dim=0) assert torch.allclose(org_grad, all_shard_grad, atol=1e-5), f"shard model grad is not equal to orgin model grad\n{org_grad}\n{shard_grad}" diff --git a/tests/test_shardformer/test_model/test_shard_opt.py b/tests/test_shardformer/test_model/test_shard_opt.py index d70b5d8e57d9..29cf2f6beed8 100644 --- a/tests/test_shardformer/test_model/test_shard_opt.py +++ b/tests/test_shardformer/test_model/test_shard_opt.py @@ -28,7 +28,10 @@ def check_forward_backward(org_model, sharded_model, data_gen_fn, output_transfo org_loss.backward() shard_loss.backward() - # check grad + assert torch.allclose(org_loss, shard_loss, + atol=1e-5), f"shard model loss is not equal to orgin model loss\n{org_loss}\n{shard_loss}" + + # unwrap model if hasattr(org_model, 'model'): opt_model = org_model.model shard_opt_model = sharded_model.model @@ -36,16 +39,23 @@ def check_forward_backward(org_model, sharded_model, data_gen_fn, output_transfo opt_model = org_model shard_opt_model = sharded_model + # check attention grad org_grad = opt_model.decoder.layers[0].self_attn.q_proj.weight.grad shard_grad = shard_opt_model.decoder.layers[0].self_attn.q_proj.weight.grad + shard_grad_list = [torch.zeros([*shard_grad.shape]).to('cuda') for _ in range(4)] + torch.distributed.all_gather(shard_grad_list, shard_grad) + all_shard_grad = torch.cat(shard_grad_list, dim=0) + assert torch.allclose(org_grad, all_shard_grad, + atol=1e-5), f"shard model grad is not equal to orgin model grad\n{org_grad}\n{all_shard_grad}" + # check embedding grad + org_grad = opt_model.decoder.embed_tokens.weight.grad + shard_grad = shard_opt_model.decoder.embed_tokens.weight.grad shard_grad_list = [torch.zeros([*shard_grad.shape]).to('cuda') for _ in range(4)] - shard_grad = torch.distributed.all_gather(shard_grad_list, shard_grad) + torch.distributed.all_gather(shard_grad_list, shard_grad) all_shard_grad = torch.cat(shard_grad_list, dim=0) - assert torch.allclose(org_loss, shard_loss, - atol=1e-5), f"shard model loss is not equal to orgin model loss\n{org_loss}\n{shard_loss}" assert torch.allclose(org_grad, all_shard_grad, - atol=1e-5), f"shard model grad is not equal to orgin model grad\n{org_grad}\n{shard_grad}" + atol=1e-5), f"shard model grad is not equal to orgin model grad\n{org_grad}\n{all_shard_grad}" def check_OPTModel(rank, world_size, port): @@ -65,3 +75,7 @@ def check_OPTModel(rank, world_size, port): @clear_cache_before_run() def test_OPTModel(): spawn(check_OPTModel, 4) + + +if __name__ == '__main__': + test_OPTModel() diff --git a/tests/test_shardformer/test_model/test_shard_t5.py b/tests/test_shardformer/test_model/test_shard_t5.py index 6f558e237970..91430bce918f 100644 --- a/tests/test_shardformer/test_model/test_shard_t5.py +++ b/tests/test_shardformer/test_model/test_shard_t5.py @@ -21,19 +21,43 @@ def check_forward_backward(org_model, sharded_model, data_gen_fn, output_transfo org_loss.backward() shard_loss.backward() - # check grad equality + assert torch.allclose(org_loss, shard_loss, + atol=1e-5), f"shard model loss is not equal to orgin model loss\n{org_loss}\n{shard_loss}" + + # check attention grad org_grad = org_model.encoder.block[0].layer[0].SelfAttention.q.weight.grad shard_grad = sharded_model.encoder.block[0].layer[0].SelfAttention.q.weight.grad shard_grad_list = [torch.zeros([*shard_grad.shape]).to('cuda') for _ in range(2)] shard_grad = torch.distributed.all_gather(shard_grad_list, shard_grad) all_shard_grad = torch.cat(shard_grad_list, dim=0) - - assert torch.allclose(org_loss, shard_loss, - atol=1e-5), f"shard model loss is not equal to orgin model loss\n{org_loss}\n{shard_loss}" assert torch.allclose(org_grad, all_shard_grad, atol=1e-5), f"shard model grad is not equal to orgin model grad\n{org_grad}\n{shard_grad}" + # check self attention embed + org_grad = org_model.encoder.block[0].layer[0].SelfAttention.relative_attention_bias.weight.grad + shard_grad = sharded_model.encoder.block[0].layer[0].SelfAttention.relative_attention_bias.weight.grad + shard_grad_list = [torch.zeros([*shard_grad.shape]).to('cuda') for _ in range(2)] + torch.distributed.all_gather(shard_grad_list, shard_grad) + all_shard_grad = torch.cat(shard_grad_list, dim=1) + assert torch.allclose(org_grad, all_shard_grad, + atol=1e-5), f"shard model grad is not equal to orgin model grad\n{org_grad}\n{all_shard_grad}" + + # check token embedding grad + org_grad = org_model.shared.weight.grad + + # check weights are tied + if hasattr(org_model, 'lm_head'): + assert org_model.shared.weight.data.data_ptr() == org_model.lm_head.weight.data.data_ptr() + assert sharded_model.shared.weight.data.data_ptr() == sharded_model.lm_head.weight.data.data_ptr() + + shard_grad = sharded_model.shared.weight.grad + shard_grad_list = [torch.zeros([*shard_grad.shape]).to('cuda') for _ in range(2)] + torch.distributed.all_gather(shard_grad_list, shard_grad) + all_shard_grad = torch.cat(shard_grad_list, dim=0) + assert torch.allclose(org_grad, all_shard_grad, + atol=1e-5), f"shard model grad is not equal to orgin model grad\n{org_grad}\n{all_shard_grad}" + def check_t5(rank, world_size, port): disable_existing_loggers() @@ -44,7 +68,6 @@ def check_t5(rank, world_size, port): for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items(): org_model, sharded_model = build_model(model_fn) check_forward_backward(org_model, sharded_model, data_gen_fn, output_transform_fn, loss_fn) - torch.cuda.empty_cache() @@ -56,4 +79,4 @@ def test_t5(): if __name__ == "__main__": - test_t5() \ No newline at end of file + test_t5() diff --git a/tests/test_shardformer/test_model/test_shard_vit.py b/tests/test_shardformer/test_model/test_shard_vit.py index d5d71d9e29fe..af1605b6b659 100644 --- a/tests/test_shardformer/test_model/test_shard_vit.py +++ b/tests/test_shardformer/test_model/test_shard_vit.py @@ -45,6 +45,7 @@ def check_vit(rank, world_size, port): @pytest.mark.dist +@pytest.mark.skip @rerun_if_address_is_in_use() @clear_cache_before_run() def test_vit(): From b6f4e0582edd97afe0d87cf9f3a2d26e0bffb4c0 Mon Sep 17 00:00:00 2001 From: jiangmingyan <1829166702@qq.com> Date: Fri, 30 Jun 2023 16:48:29 +0800 Subject: [PATCH 45/49] [shardformer] write an shardformer example with bert finetuning (#4126) * [shardformer] add benchmark of shardformer * [shardformer] add benchmark of shardformer --- colossalai/shardformer/README.md | 13 ++ colossalai/shardformer/examples/data.py | 146 +++++++++++++++++ .../examples/shardformer_benchmark.py | 154 ++++++++++++++++++ .../examples/shardformer_benchmark.sh | 9 + colossalai/shardformer/shard/shard_config.py | 2 +- 5 files changed, 323 insertions(+), 1 deletion(-) create mode 100644 colossalai/shardformer/examples/data.py create mode 100644 colossalai/shardformer/examples/shardformer_benchmark.py create mode 100644 colossalai/shardformer/examples/shardformer_benchmark.sh diff --git a/colossalai/shardformer/README.md b/colossalai/shardformer/README.md index 8a8ed0f792fd..877e28a2db0e 100644 --- a/colossalai/shardformer/README.md +++ b/colossalai/shardformer/README.md @@ -15,6 +15,7 @@ - [Policy](#policy) - [Model Sharder](#model-sharder) - [User-facing API](#user-facing-api) + - [Shardformer Convergence](#shardformer-convergence) ## 🔗 Introduction @@ -324,3 +325,15 @@ class ShardFormer: """ ... ``` + +### Shardformer Convergence + +To validate that training the model using shardformers does not impact its convergence. We [fine-tuned the BERT model](./examples/shardformer_benchmark.py) using both shardformer and non-shardformer approaches. We compared the accuracy, loss, F1 score of the training results. + +| accuracy | f1 | loss | GPU number | model shard | +| :-----: | :----: | :----: | :----: | :----: | +| 0.82594 | 0.87441 | 0.09913 | 4 | True | +| 0.81884 | 0.87299 | 0.10120 | 2 | True | +| 0.81855 | 0.87124 | 0.10357 | 1 | False | + +Overall, the results demonstrate that using shardformers during model training does not affect the convergence. diff --git a/colossalai/shardformer/examples/data.py b/colossalai/shardformer/examples/data.py new file mode 100644 index 000000000000..6296d4be4eb0 --- /dev/null +++ b/colossalai/shardformer/examples/data.py @@ -0,0 +1,146 @@ +import datasets +from torch.utils.data import DataLoader +from transformers import AutoTokenizer, PreTrainedTokenizer + +from colossalai.booster.plugin.dp_plugin_base import DPPluginBase + + +class GLUEDataBuilder: + + task_text_field_map = { + "cola": ["sentence"], + "sst2": ["sentence"], + "mrpc": ["sentence1", "sentence2"], + "qqp": ["question1", "question2"], + "stsb": ["sentence1", "sentence2"], + "mnli": ["premise", "hypothesis"], + "qnli": ["question", "sentence"], + "rte": ["sentence1", "sentence2"], + "wnli": ["sentence1", "sentence2"], + "ax": ["premise", "hypothesis"], + } + + glue_task_num_labels = { + "cola": 2, + "sst2": 2, + "mrpc": 2, + "qqp": 2, + "stsb": 1, + "mnli": 3, + "qnli": 2, + "rte": 2, + "wnli": 2, + "ax": 3, + } + + loader_columns = [ + "datasets_idx", + "input_ids", + "token_type_ids", + "attention_mask", + "start_positions", + "end_positions", + "labels", + ] + + def __init__( + self, + model_name_or_path: str, + plugin: DPPluginBase = None, + task_name: str = "mrpc", + max_seq_length: int = 128, + train_batch_size: int = 32, + eval_batch_size: int = 32, + **kwargs, + ): + super().__init__() + self.model_name_or_path = model_name_or_path + self.task_name = task_name + self.max_seq_length = max_seq_length + self.train_batch_size = train_batch_size + self.eval_batch_size = eval_batch_size + self.plugin = plugin + + self.text_fields = self.task_text_field_map[task_name] + self.num_labels = self.glue_task_num_labels[task_name] + self.tokenizer: PreTrainedTokenizer = AutoTokenizer.from_pretrained(self.model_name_or_path, use_fast=True) + self.setup() + + def setup(self): + self.dataset = datasets.load_dataset("glue", self.task_name) + + for split in self.dataset.keys(): + self.dataset[split] = self.dataset[split].map( + self.convert_to_features, + batched=True, + remove_columns=["label"], + ) + self.columns = [c for c in self.dataset[split].column_names if c in self.loader_columns] + self.dataset[split].set_format(type="torch", columns=self.columns) + + self.eval_splits = [x for x in self.dataset.keys() if "validation" in x] + + def prepare_data(self): + datasets.load_dataset("glue", self.task_name) + AutoTokenizer.from_pretrained(self.model_name_or_path, use_fast=True) + + def train_dataloader(self): + if self.plugin == None: + return self.native_prepare_dataloader(self.dataset["train"], + batch_size=self.train_batch_size, + shuffle=True, + drop_last=True) + return self.plugin.prepare_dataloader(self.dataset["train"], + batch_size=self.train_batch_size, + shuffle=True, + drop_last=True) + + def val_dataloader(self): + if self.plugin == None: + return self.native_prepare_dataloader(self.dataset["validation"], batch_size=self.eval_batch_size) + if len(self.eval_splits) == 1: + return self.plugin.prepare_dataloader(self.dataset["validation"], batch_size=self.eval_batch_size) + elif len(self.eval_splits) > 1: + return [ + self.plugin.prepare_dataloader(self.dataset[x], batch_size=self.eval_batch_size) + for x in self.eval_splits + ] + + def test_dataloader(self): + if self.plugin == None: + return self.native_prepare_dataloader(self.dataset["test"], batch_size=self.train_batch_size) + if len(self.eval_splits) == 1: + return self.plugin.prepare_dataloader(self.dataset["test"], batch_size=self.eval_batch_size) + elif len(self.eval_splits) > 1: + return [ + self.plugin.prepare_dataloader(self.dataset[x], batch_size=self.eval_batch_size) + for x in self.eval_splits + ] + + def convert_to_features(self, example_batch): + + # Either encode single sentence or sentence pairs + if len(self.text_fields) > 1: + texts_or_text_pairs = list(zip(example_batch[self.text_fields[0]], example_batch[self.text_fields[1]])) + else: + texts_or_text_pairs = example_batch[self.text_fields[0]] + + # Tokenize the text/text pairs + features = self.tokenizer.batch_encode_plus(texts_or_text_pairs, + max_length=self.max_seq_length, + padding='max_length', + truncation=True) + + # Rename label to labels to make it easier to pass to model forward + features["labels"] = example_batch["label"] + + return features + + def native_prepare_dataloader(self, dataset, batch_size, shuffle=False, drop_last=False, pin_memory=False): + + return DataLoader(dataset, + batch_size=batch_size, + sampler=None, + shuffle=shuffle, + drop_last=drop_last, + pin_memory=pin_memory) diff --git a/colossalai/shardformer/examples/shardformer_benchmark.py b/colossalai/shardformer/examples/shardformer_benchmark.py new file mode 100644 index 000000000000..bb3560ee9e21 --- /dev/null +++ b/colossalai/shardformer/examples/shardformer_benchmark.py @@ -0,0 +1,154 @@ +import argparse +import math +from typing import Any, List, Union + +import evaluate +import torch +import torch.distributed as dist +from data import GLUEDataBuilder +from torch import nn +from torch.optim import Adam, AdamW, Optimizer +from torch.utils._pytree import tree_map +from torch.utils.data import DataLoader +from tqdm import tqdm +from transformers import BertConfig, BertForSequenceClassification, get_linear_schedule_with_warmup + +import colossalai +from colossalai.cluster import DistCoordinator +from colossalai.nn.optimizer import HybridAdam +from colossalai.shardformer import ShardConfig, ShardFormer + + +def to_device(x: Any, device: torch.device) -> Any: + + def _to(t: Any): + if isinstance(t, torch.Tensor): + return t.to(device) + return t + + return tree_map(_to, x) + + +def train(args): + colossalai.launch_from_torch(config={}, seed=42) + coordinator = DistCoordinator() + + # prepare for data and dataset + data_builder = GLUEDataBuilder(model_name_or_path=args.pretrain, + task_name=args.task, + train_batch_size=args.batch_size, + eval_batch_size=args.batch_size) + train_dataloader = data_builder.train_dataloader() + test_dataloader = data_builder.test_dataloader() + + if args.model == "bert": + cfg = BertConfig.from_pretrained(args.pretrain, num_labels=data_builder.num_labels) + model = BertForSequenceClassification.from_pretrained(args.pretrain, config=cfg) + + model.to(torch.cuda.current_device()) + + # if multiple GPUs, shard the model + if dist.get_world_size() > 1: + shard_config = ShardConfig(enable_fused_normalization=args.fused_layernorm) + shard_former = ShardFormer(shard_config=shard_config) + model = shard_former.shard_model(model) + + optim = Adam(model.parameters(), lr=args.lr) + num_update_steps_per_epoch = len(train_dataloader) // args.accumulation_steps + max_steps = math.ceil(args.max_epochs * num_update_steps_per_epoch) + lr_scheduler = get_linear_schedule_with_warmup( + optim, + num_warmup_steps=math.ceil(max_steps * args.warmup_fraction), + num_training_steps=max_steps, + ) + fit(model, optim, lr_scheduler, train_dataloader, args.max_epochs, args.accumulation_steps, args.batch_size, + coordinator) + results = evaluate_model(model, test_dataloader, data_builder.num_labels, args.task, data_builder.eval_splits, + coordinator) + if coordinator.is_master(): + print(results) + if args.target_f1 is not None and 'f1' in results: + assert results['f1'] >= args.target_f1, f'f1 score {results["f1"]} is lower than target {args.target_f1}' + + +def fit(model: nn.Module, optimizer: Optimizer, scheduler, train_dataloader, max_epochs, accumulation_steps, batch_size, + coordinator): + step_bar = tqdm(range(len(train_dataloader) // accumulation_steps * max_epochs), + desc=f'steps', + disable=not coordinator.is_master()) + total_loss = 0 + for epoch in range(max_epochs): + model.train() + for batch_id, batch in enumerate(train_dataloader): + batch = to_device(batch, torch.cuda.current_device()) + outputs = model(**batch) + loss = outputs.loss + loss = loss / accumulation_steps + loss.backward() + total_loss += loss.item() + if (batch_id + 1) % accumulation_steps == 0: + optimizer.step() + scheduler.step() + optimizer.zero_grad() + step_bar.set_postfix({ + 'epoch': epoch, + 'loss': total_loss / batch_size, + 'lr': scheduler.get_last_lr()[0] + }) + total_loss = 0 + step_bar.update() + + +# evaluate +@torch.no_grad() +def evaluate_model(model: nn.Module, test_dataloader: Union[DataLoader, List[DataLoader]], num_labels: int, + task_name: str, eval_splits: List[str], coordinator: DistCoordinator): + metric = evaluate.load("glue", task_name, process_id=coordinator.rank, num_process=coordinator.world_size) + model.eval() + + def evaluate_subset(dataloader: DataLoader): + accum_loss = torch.zeros(1, device=torch.cuda.current_device()) + for batch in dataloader: + batch = to_device(batch, torch.cuda.current_device()) + outputs = model(**batch) + val_loss, logits = outputs[:2] + accum_loss.add_(val_loss) + + if num_labels > 1: + preds = torch.argmax(logits, axis=1) + elif num_labels == 1: + preds = logits.squeeze() + + labels = batch["labels"] + metric.add_batch(predictions=preds, references=labels) + + results = metric.compute() + if coordinator.is_master(): + results['loss'] = accum_loss.item() / (len(dataloader) * dataloader.batch_size) + return results + + if isinstance(test_dataloader, DataLoader): + return evaluate_subset(test_dataloader) + else: + assert len(test_dataloader) == len(eval_splits) + final_results = {} + for split, sub_loader in zip(eval_splits, test_dataloader): + results = evaluate_subset(sub_loader) + final_results.update({f'{k}_{split}': v for k, v in results.items()}) + return final_results + + +if __name__ == '__main__': + parser = argparse.ArgumentParser() + parser.add_argument('-t', '--task', default='mrpc', help="GLUE task to run") + parser.add_argument('--model', type=str, default="bert") + parser.add_argument('--pretrain', type=str, default="bert-base-uncased") + parser.add_argument('--max_epochs', type=int, default=1) + parser.add_argument('--batch_size', type=int, default=4) + parser.add_argument('--lr', type=float, default=2.4e-5) + parser.add_argument('--fused_layernorm', type=bool, default=False) + parser.add_argument('--accumulation_steps', type=int, default=8) + parser.add_argument('--warmup_fraction', type=float, default=0.03) + parser.add_argument('--target_f1', type=float, default=None) + args = parser.parse_args() + train(args) diff --git a/colossalai/shardformer/examples/shardformer_benchmark.sh b/colossalai/shardformer/examples/shardformer_benchmark.sh new file mode 100644 index 000000000000..f42b19a32d35 --- /dev/null +++ b/colossalai/shardformer/examples/shardformer_benchmark.sh @@ -0,0 +1,9 @@ +torchrun --standalone --nproc_per_node=4 shardformer_benchmark.py \ + --model "bert" \ + --pretrain "bert-base-uncased" \ + --max_epochs 1 \ + --batch_size 2 \ + --lr 2.4e-5 \ + --fused_layernorm False \ + --accumulation_steps 8 \ + --warmup_fraction 0.03 diff --git a/colossalai/shardformer/shard/shard_config.py b/colossalai/shardformer/shard/shard_config.py index 2116d2e622e2..c2573bc6d4dd 100644 --- a/colossalai/shardformer/shard/shard_config.py +++ b/colossalai/shardformer/shard/shard_config.py @@ -17,7 +17,7 @@ class ShardConfig: tensor_parallel_process_group (int): The process group for tensor parallelism, defaults to None, which is the global process group. enable_fused_normalization (bool): Whether to use fused layernorm, default is False """ - tensor_parallel_process_group: int = None + tensor_parallel_process_group: ProcessGroup = None enable_fused_normalization: bool = False enable_all_optimization: bool = False From 1b4a90137e6963c3ba3f37bc6e78d895d6b23d10 Mon Sep 17 00:00:00 2001 From: Frank Lee Date: Mon, 3 Jul 2023 15:29:11 +0800 Subject: [PATCH 46/49] [shardformer] refactored some doc and api (#4137) * [shardformer] refactored some doc and api * polish code --- colossalai/shardformer/README.md | 144 +++++++++----- .../examples/shardformer_benchmark.py | 2 +- colossalai/shardformer/policies/basepolicy.py | 119 ++++-------- colossalai/shardformer/policies/bert.py | 102 ++++------ colossalai/shardformer/policies/bloom.py | 47 ++--- colossalai/shardformer/policies/gpt2.py | 42 ++-- colossalai/shardformer/policies/llama.py | 42 ++-- colossalai/shardformer/policies/opt.py | 51 ++--- colossalai/shardformer/policies/t5.py | 181 ++++++++---------- colossalai/shardformer/policies/vit.py | 17 +- colossalai/shardformer/shard/shard_config.py | 8 +- colossalai/shardformer/shard/sharder.py | 74 +++---- colossalai/shardformer/shard/shardformer.py | 16 +- tests/test_shardformer/test_model/_utils.py | 2 +- tests/test_shardformer/test_with_torch_ddp.py | 2 +- 15 files changed, 357 insertions(+), 492 deletions(-) diff --git a/colossalai/shardformer/README.md b/colossalai/shardformer/README.md index 877e28a2db0e..f5d8bb35d91d 100644 --- a/colossalai/shardformer/README.md +++ b/colossalai/shardformer/README.md @@ -15,7 +15,12 @@ - [Policy](#policy) - [Model Sharder](#model-sharder) - [User-facing API](#user-facing-api) - - [Shardformer Convergence](#shardformer-convergence) + - [⌨️ Development Notes](#️-development-notes) + - [Add New Policy to Shardformer](#add-new-policy-to-shardformer) + - [Write Your Unit Testing](#write-your-unit-testing) + - [📊 Benchmarking](#-benchmarking) + - [System Performance](#system-performance) + - [Convergence](#convergence) ## 🔗 Introduction @@ -40,12 +45,9 @@ config = BertConfig.from_pretrained('bert-base-uncased') model = BertForMaskedLM.from_pretrained('bert-base-uncased', config=config) # create huggingface model as normal -shard_config = ShardConfig(tensor_parallel_size=2, - data_parallel_size=1, - gather_output=True) +shard_config = ShardConfig() shard_former = ShardFormer(shard_config=shard_config) -shard_former.init_distributed() -sharded_model = shard_former.shard_model(model).to('cuda') +sharded_model = shard_former.optimize(model).to('cuda') # do everything like normal ... @@ -67,10 +69,11 @@ class MyPolicy(Policy): # use customized policy to shard model my_policy = MyPolicy() -shard_former.shard_model(model, my_policy) +shard_former.optimize(model, my_policy) + -``` +``` ## 🗺 Roadmap We will follow this roadmap to develop Shardformer: @@ -112,7 +115,6 @@ Please refer to the code for more details.


- This diagram is deprecated, need to update it

@@ -147,15 +149,13 @@ class ParallelModule(torch.nn.Module): ```python @dataclass class ShardConfig: - data_parallel_size: int - tensor_parallel_size: int + tensor_parallel_process_group: ProcessGroup = None + enable_fused_normalization: bool = False ... # Some possible future config fields - pipeline_parallel_size: int # Support pipeline parallelism tensor_parallel_mode: Choice['1d', '2d', '2.5d', '3d'] # support different tensor parallel mode inference_only: bool # only inject inference-suitable sharding policy - gather_output: bool # gather the model output use_flash_attention: bool # whether to use flash attention to speed up attention ``` @@ -166,42 +166,42 @@ It is merely a description, the actual sharding will be performed by `ModelShard We abstract the policy into four stages: 1. Preprocessing: call `Policy.preprocess` to do some prior work before sharding, for example, resizing the embedding -2. Providing a new class: call `Policy.new_model_class` to get a new class for the model, this class replaces attributes and the forward function -3. Providing `ModulePolicyDescription`: call `Policy.module_policy` to get a bunch of `ModulePolicyDescription` to tell `ModelSharder` how the submodules's attributes, child parameters, and deeper submodules will be substituted. -4. Postprocessing: call `Policy.postprocess` to perform some postprocessing work, for example, binding the embedding and classifier head weights of the BERT model. +2. Providing `ModulePolicyDescription`: call `Policy.module_policy` to get a bunch of `ModulePolicyDescription` to tell `ModelSharder` how the submodules's attributes, child parameters, and deeper submodules will be substituted. +3. Postprocessing: call `Policy.postprocess` to perform some postprocessing work, for example, binding the embedding and classifier head weights of the BERT model. ``` python @dataclass class ModulePolicyDescription: - """ - Describe how the attributes and parameters will be transformed in a policy + r""" + Describe how the attributes and parameters will be transformed in a policy. Args: attribute_replacement (Dict[str, Any]): key is the attribute name, value is the attribute value after sharding - param_replacement (List[Callable]): a list of functions to perform in-place param replacement. The function must receive two arguments: module, process_group. One example is - def example_replace_weight(module: torch.nn.Module, process_group): - weight = module.weight - new_weight = shard_rowwise(weight, process_group) - module.weight = torch.nn.Parameter(new_weight) - sub_module_replacement: each element in the list is a ParamReplacementDescription object which specifies the module to be replaced and the target module used to replacement + param_replacement (List[Callable]): a list of functions to perform in-place param replacement. The function must receive only one arguments: module. + sub_module_replacement (List[SubModuleReplacementDescription]): each element in the list is a ParamReplacementDescription + object which specifies the module to be replaced and the target module used to replacement. + method_replace (Dict[str, Callable]): key is the method name, value is the method for replacement """ - attribute_replacement: Dict[str, Any] - param_replacement: List[Callable] - sub_module_replacement: List[SubModuleReplacementDescription] + attribute_replacement: Dict[str, Any] = None + param_replacement: List[Callable] = None + sub_module_replacement: List[SubModuleReplacementDescription] = None + method_replacement: Dict[str, Callable] = None @dataclass class SubModuleReplacementDescription: - """ + r""" Describe how a submodule will be replaced Args: suffix (str): used to get the submodule object target_module (ParallelModule): specifies the module class used to replace to submodule kwargs (Dict[str, Any]): the dictionary used to pass extra arguments to the `ParallelModule.from_native_module` method. + ignore_if_not_exist (bool): if the submodule does not exist, ignore it or raise an exception """ suffix: str target_module: ParallelModule kwargs: Dict[str, Any] = None + ignore_if_not_exist: bool = False class Policy(ABC): @@ -230,13 +230,6 @@ class Policy(ABC): """ ... - @abstractmethod - def new_model_class(self) -> Union[Type[nn.Module], None]: - """ - replace the class of the model to substitute the forward and attributes - """ - ... - @abstractmethods def postprocess(self) -> nn.Module: """ @@ -253,8 +246,9 @@ class Policy(ABC): ```python class ModelSharder: - def __init__(self, model: torch.nn.Module, shard_config: ShardConfig, Policy: ShardPolicy = None) + def __init__(self, model: torch.nn.Module, shard_config: ShardConfig, Policy: ShardPolicy = None): #TODO: input is a cls or a obj + ... def shard(self) -> None: """ @@ -262,15 +256,6 @@ class ModelSharder: """ ... - def replace_model_class(self) -> None: - """ - Replace the model's methods and attributes with our own defined class. - - E.g. we can replace the forward function of the original BertForMaskedLM object - with the forward function we define in BertForMaskedLM_ class. - """ - ... - def replace_module(self) -> None: """ Replace the layer according to the policy. Call Policy.module_policy() to get the module. Call _replace_module recursively. @@ -291,7 +276,7 @@ class ShardFormer: shard_former = ShardFormer(shard_config=shard_config) shard_former.init_distributed() - model = shard_former.shard_model(model, policy=policy) + model = shard_former.optimize(model, policy=policy) dataloader = shard_former.shard_dataset(dataset) """ @@ -326,14 +311,69 @@ class ShardFormer: ... ``` -### Shardformer Convergence +## ⌨️ Development Notes + +### Add New Policy to Shardformer + +This section serves as the guideline for writing new policies and register them into `shardformer`. + +- Step 1. Write your own model policy + +You can create a new file in the `colossalai/shardformer/policies` folder and name the file with the model name. You can implement your policy in this file. You should not import the any model zoo library at the header section of the file because we do not want to import the library when we do not use the policy. Libraries such as `transformers` should be imported only in the function body when needed. + +- Step 2. Register your policy to the autopolicy + +Next, you need to register your policy in the `colossalai/shardformer/policies/autopolicy.py` file. + +For example, if we register the policy for the BERT model, we just add a key-value in the `_POLICY_LIST` dictionary. The key if the `qualname` of the model object (you can get it by model.__class__.__qualname__). The value is a `PolicyLocation` object, which contains the file name and the class name of the policy. We do not import the policy directly because the policy file may contain libraries (such as `transformers`) which we do not want to import when we do not use the policy. + +```python +_POLICY_LIST = { + # BERT + "transformers.models.bert.modeling_bert.BertModel": + PolicyLocation(file_name="bert", class_name="BertModelPolicy"), +} +``` + +### Write Your Unit Testing + +This section serves as the guideline for testing the `shardformer` module. + +- Step 1. Add your model to the model zoo in the test kits. + +Add your model to the `tests/kit/model_zoo` file. This allows you to define test-related components for this model. You can take `tests/kit/model_zoo/transformers/llama.py` as an example for reference. + +- Step 2. Write your unit testing for the model + +Next, implement your unit test in the `tests/test_shardformer` folder. Please refer to other similar tests for style consistency. + + +- Step 3. Execute your test + +When you run tests locally, you should run tests for both your newly-added test file and the whole `shardformer` module tests. + +```bash +# test for your own test file +pytest tests/test_shardformer/test_model/.py + +# test for the whole shardformer module +pytest tests/test_shardformer +``` + +## 📊 Benchmarking + +### System Performance + +To be added. + +### Convergence To validate that training the model using shardformers does not impact its convergence. We [fine-tuned the BERT model](./examples/shardformer_benchmark.py) using both shardformer and non-shardformer approaches. We compared the accuracy, loss, F1 score of the training results. -| accuracy | f1 | loss | GPU number | model shard | -| :-----: | :----: | :----: | :----: | :----: | -| 0.82594 | 0.87441 | 0.09913 | 4 | True | -| 0.81884 | 0.87299 | 0.10120 | 2 | True | -| 0.81855 | 0.87124 | 0.10357 | 1 | False | +| accuracy | f1 | loss | GPU number | model shard | +| :------: | :-----: | :-----: | :--------: | :---------: | +| 0.82594 | 0.87441 | 0.09913 | 4 | True | +| 0.81884 | 0.87299 | 0.10120 | 2 | True | +| 0.81855 | 0.87124 | 0.10357 | 1 | False | Overall, the results demonstrate that using shardformers during model training does not affect the convergence. diff --git a/colossalai/shardformer/examples/shardformer_benchmark.py b/colossalai/shardformer/examples/shardformer_benchmark.py index bb3560ee9e21..de82305b2547 100644 --- a/colossalai/shardformer/examples/shardformer_benchmark.py +++ b/colossalai/shardformer/examples/shardformer_benchmark.py @@ -51,7 +51,7 @@ def train(args): if dist.get_world_size() > 1: shard_config = ShardConfig(enable_fused_normalization=args.fused_layernorm) shard_former = ShardFormer(shard_config=shard_config) - model = shard_former.shard_model(model) + model = shard_former.optimize(model) optim = Adam(model.parameters(), lr=args.lr) num_update_steps_per_epoch = len(train_dataloader) // args.accumulation_steps diff --git a/colossalai/shardformer/policies/basepolicy.py b/colossalai/shardformer/policies/basepolicy.py index 2b972606948c..9ea3d95de5b2 100644 --- a/colossalai/shardformer/policies/basepolicy.py +++ b/colossalai/shardformer/policies/basepolicy.py @@ -22,9 +22,11 @@ class SubModuleReplacementDescription: r""" Describe how a submodule will be replaced - suffix (str): used to get the submodule object - target_module (ParallelModule): specifies the module class used to replace to submodule - kwargs (Dict[str, Any]): the dictionary used to pass extra arguments to the `ParallelModule.from_native_module` method. + Args: + suffix (str): used to get the submodule object + target_module (ParallelModule): specifies the module class used to replace to submodule + kwargs (Dict[str, Any]): the dictionary used to pass extra arguments to the `ParallelModule.from_native_module` method. + ignore_if_not_exist (bool): if the submodule does not exist, ignore it or raise an exception """ suffix: str target_module: ParallelModule @@ -35,47 +37,37 @@ class SubModuleReplacementDescription: @dataclass class ModulePolicyDescription: r""" - Describe how the attributes and parameters will be transformed in a policy - - attribute_replacement (Dict[str, Any]): key is the attribute name, value is the attribute value after sharding - param_replacement (List[Callable]): a list of functions to perform in-place param replacement. The function - must receive two arguments: module, process_group. One example is - - ```python - def example_replace_weight(module: torch.nn.Module, process_group): - weight = module.weight - new_weight = shard_rowwise(weight, process_group) - module.weight = torch.nn.Parameter(new_weight) - ``` - - sub_module_replacement: each element in the list is a ParamReplacementDescription object which specifies - the module to be replaced and the target module used to replacement + Describe how the attributes and parameters will be transformed in a policy. + + Args: + attribute_replacement (Dict[str, Any]): key is the attribute name, value is the attribute value after sharding + param_replacement (List[Callable]): a list of functions to perform in-place param replacement. The function + must receive only one arguments: module. One example is + + ```python + def example_replace_weight(module: torch.nn.Module): + weight = module.weight + new_weight = shard_rowwise(weight, process_group) + module.weight = torch.nn.Parameter(new_weight) + ``` + sub_module_replacement (List[SubModuleReplacementDescription]): each element in the list is a ParamReplacementDescription + object which specifies the module to be replaced and the target module used to replacement. + method_replace (Dict[str, Callable]): key is the method name, value is the method for replacement """ - attribute_replacement: Dict[str, Any] - param_replacement: List[Callable] - sub_module_replacement: List[SubModuleReplacementDescription] - method_replacement: List[Callable] = None + attribute_replacement: Dict[str, Any] = None + param_replacement: List[Callable] = None + sub_module_replacement: List[SubModuleReplacementDescription] = None + method_replacement: Dict[str, Callable] = None class Policy(ABC): r""" - The base class for all the policies - - For each different model, it should have a different policy class, like BertPolicy for Bert Model - or OPTPolicy for OPT model. - - AutoPolicy: - Shardformer already defined some policies for huggingface model, just set ``custom_policy`` = None - to use the auto policy. In shardformer autopolicy, we define a base policy for one type model, - like BertPolicy, and for each different Bert modle in huggingface like, BertForMaskedLM, - BertForSequenceClassification, etc., for each different Bert model we difine different policy class - and overwrite the method like ``inject_policy`` to modify the forward and backward process. - - CustomPolicy: - If you want to define your own policy, you can set ``custom_policy`` = CustomPolicy, and overwrite - all the methods in ``Policy`` class. You can refer to any policy we defined like the ``BertPolicy`` - class for the example. + The base class for all the policies. For each different model, it should have a different policy class, + like BertPolicy for Bert Model or OPTPolicy for OPT model. + Shardformer has provided many built-in sharding policies for the mainstream models. You can use the + built-in policies by setting `policy = None`, which is already the default arguemnt for `Shardformer.optimize`. + If you want to define your own policy, you can inherit from this class and overwrite the methods you want to modify. """ def __init__(self) -> None: @@ -106,63 +98,24 @@ def set_shard_config(self, shard_config: ShardConfig) -> None: def config_sanity_check(self): """ Check if the shard config is valid for the model. Raise an exception if the config is invalid. + This method is made abstractmethod with no default implementation because we want to the policy writer + to take note of the feature supported by his/her model and policy. """ pass @abstractmethod def preprocess(self) -> nn.Module: r""" - Perform some preprocessing of the model, like reshaping the embedding layer + Perform some preprocessing of the model, like reshaping the embedding layer. """ pass @abstractmethod def module_policy(self) -> Dict[Union[str, nn.Module], ModulePolicyDescription]: r""" - Return the dict for the modify policy, the key is the original layer class and the value is the - argument for the modify layer - - Return: - Dict for the modify policy, - :: - { - origin layer class1 (nn.Module): ModulePolicyDescription( - attribute_replacement = { - "attribute1": value1, - "attribute2": value2, - ... - }, - param_replacement = [ - function1, - function2, - ... - ], - sub_module_replacement = [ - `SubModuleReplacementDescription` description1, - `SubModuleReplacementDescription` description2, - ... - ] - ), - origin layer class2 (nn.Module): ModulePolicyDescription( - ... - ), - ... - } - """ - pass - - @abstractmethod - def new_model_class(self) -> Union[Type[nn.Module], None]: - r""" - Return the new model class for the new model, None means no need to modify the model class - - Return: - New model class - - E.g. - ``` - return BertModel_ - ``` + This method returns the module policy, which is a dictionary. The key is the module name or the module object, + and the value is the ModulePolicyDescription object. The ModulePolicyDescription object describes how the module + will be transformed. """ pass diff --git a/colossalai/shardformer/policies/bert.py b/colossalai/shardformer/policies/bert.py index 7cf6caf7ca49..5ab8fb825244 100644 --- a/colossalai/shardformer/policies/bert.py +++ b/colossalai/shardformer/policies/bert.py @@ -48,7 +48,6 @@ def module_policy(self): "crossattention.self.num_attention_heads": self.model.config.num_attention_heads // self.shard_config.tensor_parallel_size, }, - param_replacement=[], sub_module_replacement=[ SubModuleReplacementDescription( suffix="attention.self.query", @@ -88,18 +87,16 @@ def module_policy(self): ) ]), BertEmbeddings: - ModulePolicyDescription(attribute_replacement={}, - param_replacement=[], - sub_module_replacement=[ - SubModuleReplacementDescription( - suffix="word_embeddings", - target_module=col_nn.VocabParallelEmbedding1D, - ), - SubModuleReplacementDescription( - suffix="dropout", - target_module=col_nn.DropoutForReplicatedInput, - ) - ]) + ModulePolicyDescription(sub_module_replacement=[ + SubModuleReplacementDescription( + suffix="word_embeddings", + target_module=col_nn.VocabParallelEmbedding1D, + ), + SubModuleReplacementDescription( + suffix="dropout", + target_module=col_nn.DropoutForReplicatedInput, + ) + ]) } # optimization configuration @@ -121,10 +118,6 @@ def module_policy(self): ),) return base_policy - def new_model_class(self): - # do nothing - return None - def postprocess(self): return self.model @@ -148,13 +141,10 @@ def module_policy(self): module_policy = super().module_policy() addon_module = { BertLMPredictionHead: - ModulePolicyDescription(attribute_replacement={}, - param_replacement=[], - sub_module_replacement=[ - SubModuleReplacementDescription(suffix="decoder", - target_module=col_nn.Linear1D_Col, - kwargs={"gather_output": True}), - ]) + ModulePolicyDescription(sub_module_replacement=[ + SubModuleReplacementDescription( + suffix="decoder", target_module=col_nn.Linear1D_Col, kwargs={"gather_output": True}), + ]) } # optimization configuration @@ -191,13 +181,10 @@ def module_policy(self): module_policy = super().module_policy() addon_module = { BertLMPredictionHead: - ModulePolicyDescription(attribute_replacement={}, - param_replacement=[], - sub_module_replacement=[ - SubModuleReplacementDescription(suffix="decoder", - target_module=col_nn.Linear1D_Col, - kwargs={"gather_output": True}), - ]) + ModulePolicyDescription(sub_module_replacement=[ + SubModuleReplacementDescription( + suffix="decoder", target_module=col_nn.Linear1D_Col, kwargs={"gather_output": True}), + ]) } if self.shard_config.enable_fused_normalization: addon_module[BertLMPredictionHead].sub_module_replacement.append( @@ -230,13 +217,10 @@ def module_policy(self): module_policy = super().module_policy() addon_module = { BertLMPredictionHead: - ModulePolicyDescription(attribute_replacement={}, - param_replacement=[], - sub_module_replacement=[ - SubModuleReplacementDescription(suffix="decoder", - target_module=col_nn.Linear1D_Col, - kwargs={"gather_output": True}), - ]) + ModulePolicyDescription(sub_module_replacement=[ + SubModuleReplacementDescription( + suffix="decoder", target_module=col_nn.Linear1D_Col, kwargs={"gather_output": True}), + ]) } # optimization configuration @@ -272,14 +256,12 @@ def module_policy(self): module_policy = super().module_policy() addon_module = { BertForSequenceClassification: - ModulePolicyDescription(attribute_replacement={}, - param_replacement=[], - sub_module_replacement=[ - SubModuleReplacementDescription( - suffix="dropout", - target_module=col_nn.DropoutForParallelInput, - ) - ]) + ModulePolicyDescription(sub_module_replacement=[ + SubModuleReplacementDescription( + suffix="dropout", + target_module=col_nn.DropoutForParallelInput, + ) + ]) } module_policy.update(addon_module) return module_policy @@ -297,14 +279,12 @@ def module_policy(self): module_policy = super().module_policy() addon_module = { BertForTokenClassification: - ModulePolicyDescription(attribute_replacement={}, - param_replacement=[], - sub_module_replacement=[ - SubModuleReplacementDescription( - suffix="dropout", - target_module=col_nn.DropoutForParallelInput, - ) - ]) + ModulePolicyDescription(sub_module_replacement=[ + SubModuleReplacementDescription( + suffix="dropout", + target_module=col_nn.DropoutForParallelInput, + ) + ]) } module_policy.update(addon_module) return module_policy @@ -329,14 +309,12 @@ def module_policy(self): module_policy = super().module_policy() addon_module = { BertForMultipleChoice: - ModulePolicyDescription(attribute_replacement={}, - param_replacement=[], - sub_module_replacement=[ - SubModuleReplacementDescription( - suffix="dropout", - target_module=col_nn.DropoutForParallelInput, - ) - ]) + ModulePolicyDescription(sub_module_replacement=[ + SubModuleReplacementDescription( + suffix="dropout", + target_module=col_nn.DropoutForParallelInput, + ) + ]) } module_policy.update(addon_module) return module_policy diff --git a/colossalai/shardformer/policies/bloom.py b/colossalai/shardformer/policies/bloom.py index c59cfbb405fc..00ab9159b0dc 100644 --- a/colossalai/shardformer/policies/bloom.py +++ b/colossalai/shardformer/policies/bloom.py @@ -98,7 +98,6 @@ def module_policy(self): "self_attention.num_heads": self.model.config.n_head // self.shard_config.tensor_parallel_size, }, - param_replacement=[], sub_module_replacement=[ SubModuleReplacementDescription( suffix="self_attention.query_key_value", @@ -125,7 +124,6 @@ def module_policy(self): ModulePolicyDescription(attribute_replacement={ "num_heads": self.model.config.n_head // self.shard_config.tensor_parallel_size, }, - param_replacement=[], method_replacement={"build_alibi_tensor": build_bloom_alibi_tensor}, sub_module_replacement=[ SubModuleReplacementDescription( @@ -160,10 +158,6 @@ def module_policy(self): return base_policy - def new_model_class(self): - # do nothing - return None - def postprocess(self): return self.model @@ -180,13 +174,10 @@ def module_policy(self): # add a new item for casual lm new_item = { BloomForCausalLM: - ModulePolicyDescription(attribute_replacement={}, - param_replacement=[], - sub_module_replacement=[ - SubModuleReplacementDescription(suffix="lm_head", - target_module=col_nn.Linear1D_Col, - kwargs=dict(gather_output=True)) - ]) + ModulePolicyDescription(sub_module_replacement=[ + SubModuleReplacementDescription( + suffix="lm_head", target_module=col_nn.Linear1D_Col, kwargs=dict(gather_output=True)) + ]) } policy.update(new_item) return policy @@ -213,13 +204,10 @@ def module_policy(self): # add a new item for casual lm new_item = { BloomForSequenceClassification: - ModulePolicyDescription(attribute_replacement={}, - param_replacement=[], - sub_module_replacement=[ - SubModuleReplacementDescription(suffix="score", - target_module=col_nn.Linear1D_Col, - kwargs=dict(gather_output=True)) - ]) + ModulePolicyDescription(sub_module_replacement=[ + SubModuleReplacementDescription( + suffix="score", target_module=col_nn.Linear1D_Col, kwargs=dict(gather_output=True)) + ]) } policy.update(new_item) return policy @@ -233,17 +221,14 @@ def module_policy(self): # add a new item for casual lm new_item = { BloomForTokenClassification: - ModulePolicyDescription(attribute_replacement={}, - param_replacement=[], - sub_module_replacement=[ - SubModuleReplacementDescription(suffix="classifier", - target_module=col_nn.Linear1D_Col, - kwargs=dict(gather_output=True)), - SubModuleReplacementDescription( - suffix="dropout", - target_module=col_nn.DropoutForReplicatedInput, - ), - ]) + ModulePolicyDescription(sub_module_replacement=[ + SubModuleReplacementDescription( + suffix="classifier", target_module=col_nn.Linear1D_Col, kwargs=dict(gather_output=True)), + SubModuleReplacementDescription( + suffix="dropout", + target_module=col_nn.DropoutForReplicatedInput, + ), + ]) } policy.update(new_item) return policy diff --git a/colossalai/shardformer/policies/gpt2.py b/colossalai/shardformer/policies/gpt2.py index c6108f5c0e85..ad0b1144a8a5 100644 --- a/colossalai/shardformer/policies/gpt2.py +++ b/colossalai/shardformer/policies/gpt2.py @@ -31,23 +31,20 @@ def preprocess(self): def module_policy(self): from transformers.models.gpt2.modeling_gpt2 import GPT2Block, GPT2Model - return { + base_policy = { GPT2Model: - ModulePolicyDescription(attribute_replacement={}, - param_replacement=[], - sub_module_replacement=[ - SubModuleReplacementDescription( - suffix="wte", - target_module=col_nn.VocabParallelEmbedding1D, - ), - ]), + ModulePolicyDescription(sub_module_replacement=[ + SubModuleReplacementDescription( + suffix="wte", + target_module=col_nn.VocabParallelEmbedding1D, + ), + ]), GPT2Block: ModulePolicyDescription(attribute_replacement={ "attn.embed_dim": self.model.config.hidden_size // self.shard_config.tensor_parallel_size, "attn.split_size": self.model.config.hidden_size // self.shard_config.tensor_parallel_size, "attn.num_heads": self.model.config.num_attention_heads // self.shard_config.tensor_parallel_size, }, - param_replacement=[], sub_module_replacement=[ SubModuleReplacementDescription( suffix="attn.c_attn", @@ -110,9 +107,6 @@ def module_policy(self): return base_policy - def new_model_class(self): - return self.model - def postprocess(self): return self.model @@ -136,13 +130,10 @@ def module_policy(self): module_policy = super().module_policy() addon_module = { GPT2LMHeadModel: - ModulePolicyDescription(attribute_replacement={}, - param_replacement=[], - sub_module_replacement=[ - SubModuleReplacementDescription(suffix="lm_head", - target_module=col_nn.Linear1D_Col, - kwargs={"gather_output": True}) - ]) + ModulePolicyDescription(sub_module_replacement=[ + SubModuleReplacementDescription( + suffix="lm_head", target_module=col_nn.Linear1D_Col, kwargs={"gather_output": True}) + ]) } module_policy.update(addon_module) return module_policy @@ -169,13 +160,10 @@ def module_policy(self): module_policy = super().module_policy() addon_module = { GPT2DoubleHeadsModel: - ModulePolicyDescription(attribute_replacement={}, - param_replacement=[], - sub_module_replacement=[ - SubModuleReplacementDescription(suffix="lm_head", - target_module=col_nn.Linear1D_Col, - kwargs={"gather_output": True}) - ]) + ModulePolicyDescription(sub_module_replacement=[ + SubModuleReplacementDescription( + suffix="lm_head", target_module=col_nn.Linear1D_Col, kwargs={"gather_output": True}) + ]) } module_policy.update(addon_module) return module_policy diff --git a/colossalai/shardformer/policies/llama.py b/colossalai/shardformer/policies/llama.py index 2fd2bc22303b..8f397693745c 100644 --- a/colossalai/shardformer/policies/llama.py +++ b/colossalai/shardformer/policies/llama.py @@ -28,7 +28,7 @@ def preprocess(self): def module_policy(self) -> Dict[Union[str, nn.Module], ModulePolicyDescription]: from transformers.models.llama.modeling_llama import LlamaDecoderLayer, LlamaModel - return { + base_policy = { LlamaDecoderLayer: ModulePolicyDescription( attribute_replacement={ @@ -37,7 +37,6 @@ def module_policy(self) -> Dict[Union[str, nn.Module], ModulePolicyDescription]: "self_attn.num_heads": self.model.config.num_attention_heads // self.shard_config.tensor_parallel_size, }, - param_replacement=[], sub_module_replacement=[ SubModuleReplacementDescription( suffix="self_attn.q_proj", @@ -70,14 +69,12 @@ def module_policy(self) -> Dict[Union[str, nn.Module], ModulePolicyDescription]: ], ), LlamaModel: - ModulePolicyDescription(attribute_replacement={}, - param_replacement=[], - sub_module_replacement=[ - SubModuleReplacementDescription( - suffix="embed_tokens", - target_module=VocabParallelEmbedding1D, - ) - ]) + ModulePolicyDescription(sub_module_replacement=[ + SubModuleReplacementDescription( + suffix="embed_tokens", + target_module=VocabParallelEmbedding1D, + ) + ]) } # optimization configuration @@ -101,9 +98,6 @@ def module_policy(self) -> Dict[Union[str, nn.Module], ModulePolicyDescription]: return base_policy - def new_model_class(self): - return None - def postprocess(self): return self.model @@ -117,13 +111,10 @@ def module_policy(self): # add a new item for casual lm new_item = { LlamaForCausalLM: - ModulePolicyDescription(attribute_replacement={}, - param_replacement=[], - sub_module_replacement=[ - SubModuleReplacementDescription(suffix="lm_head", - target_module=Linear1D_Col, - kwargs=dict(gather_output=True)) - ]) + ModulePolicyDescription(sub_module_replacement=[ + SubModuleReplacementDescription( + suffix="lm_head", target_module=Linear1D_Col, kwargs=dict(gather_output=True)) + ]) } policy.update(new_item) return policy @@ -139,13 +130,10 @@ def module_policy(self): # add a new item for sequence classification new_item = { LlamaForSequenceClassification: - ModulePolicyDescription(attribute_replacement={}, - param_replacement=[], - sub_module_replacement=[ - SubModuleReplacementDescription(suffix="score", - target_module=Linear1D_Col, - kwargs=dict(gather_output=True)) - ]) + ModulePolicyDescription(sub_module_replacement=[ + SubModuleReplacementDescription( + suffix="score", target_module=Linear1D_Col, kwargs=dict(gather_output=True)) + ]) } policy.update(new_item) return policy diff --git a/colossalai/shardformer/policies/opt.py b/colossalai/shardformer/policies/opt.py index dfbaaf5785ba..428ee2c9776c 100644 --- a/colossalai/shardformer/policies/opt.py +++ b/colossalai/shardformer/policies/opt.py @@ -31,33 +31,28 @@ def module_policy(self): base_policy = { OPTDecoder: - ModulePolicyDescription(attribute_replacement={}, - param_replacement=[], - sub_module_replacement=[ - SubModuleReplacementDescription( - suffix="embed_tokens", - target_module=VocabParallelEmbedding1D, - ) - ]), + ModulePolicyDescription(sub_module_replacement=[ + SubModuleReplacementDescription( + suffix="embed_tokens", + target_module=VocabParallelEmbedding1D, + ) + ]), OPTDecoderLayer: - ModulePolicyDescription(attribute_replacement={}, - param_replacement=[], - sub_module_replacement=[ - SubModuleReplacementDescription( - suffix="fc1", - target_module=Linear1D_Col, - ), - SubModuleReplacementDescription( - suffix="fc2", - target_module=Linear1D_Row, - ) - ]), + ModulePolicyDescription(sub_module_replacement=[ + SubModuleReplacementDescription( + suffix="fc1", + target_module=Linear1D_Col, + ), + SubModuleReplacementDescription( + suffix="fc2", + target_module=Linear1D_Row, + ) + ]), OPTAttention: ModulePolicyDescription(attribute_replacement={ "embed_dim": self.model.config.hidden_size // self.shard_config.tensor_parallel_size, "num_heads": self.model.config.num_attention_heads // self.shard_config.tensor_parallel_size }, - param_replacement=[], sub_module_replacement=[ SubModuleReplacementDescription( suffix="q_proj", @@ -95,9 +90,6 @@ def module_policy(self): return base_policy - def new_model_class(self): - return None - def postprocess(self): return self.model @@ -116,13 +108,10 @@ def module_policy(self): policy = super().module_policy() new_item = { OPTForCausalLM: - ModulePolicyDescription(attribute_replacement={}, - param_replacement=[], - sub_module_replacement=[ - SubModuleReplacementDescription(suffix="lm_head", - target_module=Linear1D_Col, - kwargs=dict(gather_output=True)) - ]) + ModulePolicyDescription(sub_module_replacement=[ + SubModuleReplacementDescription( + suffix="lm_head", target_module=Linear1D_Col, kwargs=dict(gather_output=True)) + ]) } policy.update(new_item) diff --git a/colossalai/shardformer/policies/t5.py b/colossalai/shardformer/policies/t5.py index 8853687e7621..37fccaabc457 100644 --- a/colossalai/shardformer/policies/t5.py +++ b/colossalai/shardformer/policies/t5.py @@ -44,36 +44,30 @@ def module_policy(self): base_policy = { T5Stack: - ModulePolicyDescription(attribute_replacement={}, - param_replacement=[], - sub_module_replacement=[ - SubModuleReplacementDescription( - suffix="dropout", - target_module=DropoutForParallelInput, - ), - SubModuleReplacementDescription( - suffix="embed_tokens", - target_module=Embedding1D, - ) - ]), + ModulePolicyDescription(sub_module_replacement=[ + SubModuleReplacementDescription( + suffix="dropout", + target_module=DropoutForParallelInput, + ), + SubModuleReplacementDescription( + suffix="embed_tokens", + target_module=Embedding1D, + ) + ]), T5LayerSelfAttention: - ModulePolicyDescription(attribute_replacement={}, - param_replacement=[], - sub_module_replacement=[ - SubModuleReplacementDescription( - suffix="dropout", - target_module=DropoutForParallelInput, - ), - ]), + ModulePolicyDescription(sub_module_replacement=[ + SubModuleReplacementDescription( + suffix="dropout", + target_module=DropoutForParallelInput, + ), + ]), T5LayerCrossAttention: - ModulePolicyDescription(attribute_replacement={}, - param_replacement=[], - sub_module_replacement=[ - SubModuleReplacementDescription( - suffix="dropout", - target_module=DropoutForParallelInput, - ) - ]), + ModulePolicyDescription(sub_module_replacement=[ + SubModuleReplacementDescription( + suffix="dropout", + target_module=DropoutForParallelInput, + ) + ]), T5Attention: ModulePolicyDescription(attribute_replacement={ "d_model": @@ -83,7 +77,6 @@ def module_policy(self): "inner_dim": self.model.config.num_heads * self.model.config.d_kv // self.shard_config.tensor_parallel_size }, - param_replacement=[], sub_module_replacement=[ SubModuleReplacementDescription( suffix="q", @@ -107,51 +100,44 @@ def module_policy(self): ignore_if_not_exist=True) ]), T5LayerFF: - ModulePolicyDescription(attribute_replacement={}, - param_replacement=[], - sub_module_replacement=[ - SubModuleReplacementDescription( - suffix="dropout", - target_module=DropoutForParallelInput, - ), - ]), + ModulePolicyDescription(sub_module_replacement=[ + SubModuleReplacementDescription( + suffix="dropout", + target_module=DropoutForParallelInput, + ), + ]), T5DenseGatedActDense: - ModulePolicyDescription(attribute_replacement={}, - param_replacement=[], - sub_module_replacement=[ - SubModuleReplacementDescription( - suffix="wi_0", - target_module=Linear1D_Col, - ), - SubModuleReplacementDescription( - suffix="wi_1", - target_module=Linear1D_Row, - ), - SubModuleReplacementDescription(suffix="wo", - target_module=Linear1D_Col, - kwargs=dict(gather_output=True)), - SubModuleReplacementDescription( - suffix="dropout", - target_module=DropoutForParallelInput, - ) - ]), + ModulePolicyDescription(sub_module_replacement=[ + SubModuleReplacementDescription( + suffix="wi_0", + target_module=Linear1D_Col, + ), + SubModuleReplacementDescription( + suffix="wi_1", + target_module=Linear1D_Row, + ), + SubModuleReplacementDescription( + suffix="wo", target_module=Linear1D_Col, kwargs=dict(gather_output=True)), + SubModuleReplacementDescription( + suffix="dropout", + target_module=DropoutForParallelInput, + ) + ]), T5DenseActDense: - ModulePolicyDescription(attribute_replacement={}, - param_replacement=[], - sub_module_replacement=[ - SubModuleReplacementDescription( - suffix="wi", - target_module=Linear1D_Col, - ), - SubModuleReplacementDescription( - suffix="wo", - target_module=Linear1D_Row, - ), - SubModuleReplacementDescription( - suffix="dropout", - target_module=DropoutForParallelInput, - ) - ]) + ModulePolicyDescription(sub_module_replacement=[ + SubModuleReplacementDescription( + suffix="wi", + target_module=Linear1D_Col, + ), + SubModuleReplacementDescription( + suffix="wo", + target_module=Linear1D_Row, + ), + SubModuleReplacementDescription( + suffix="dropout", + target_module=DropoutForParallelInput, + ) + ]) } # optimization configuration @@ -167,9 +153,6 @@ def module_policy(self): return base_policy - def new_model_class(self): - return None - def postprocess(self): binding_map = [["shared", "encoder.embed_tokens"], ["shared", "decoder.embed_tokens"]] @@ -185,14 +168,12 @@ def module_policy(self): from transformers import T5Model base_policy = super().module_policy() - base_policy[T5Model] = ModulePolicyDescription(attribute_replacement={}, - param_replacement=[], - sub_module_replacement=[ - SubModuleReplacementDescription( - suffix="shared", - target_module=VocabParallelEmbedding1D, - ) - ]) + base_policy[T5Model] = ModulePolicyDescription(sub_module_replacement=[ + SubModuleReplacementDescription( + suffix="shared", + target_module=VocabParallelEmbedding1D, + ) + ]) return base_policy @@ -202,18 +183,14 @@ def module_policy(self): from transformers import T5ForConditionalGeneration policy = super().module_policy() - policy[T5ForConditionalGeneration] = ModulePolicyDescription(attribute_replacement={}, - param_replacement=[], - sub_module_replacement=[ - SubModuleReplacementDescription( - suffix="shared", - target_module=VocabParallelEmbedding1D, - ), - SubModuleReplacementDescription( - suffix="lm_head", - target_module=Linear1D_Col, - kwargs=dict(gather_output=True)) - ]) + policy[T5ForConditionalGeneration] = ModulePolicyDescription(sub_module_replacement=[ + SubModuleReplacementDescription( + suffix="shared", + target_module=VocabParallelEmbedding1D, + ), + SubModuleReplacementDescription( + suffix="lm_head", target_module=Linear1D_Col, kwargs=dict(gather_output=True)) + ]) return policy def postprocess(self): @@ -235,14 +212,12 @@ def module_policy(self): from transformers import T5EncoderModel base_policy = super().module_policy() - base_policy[T5EncoderModel] = ModulePolicyDescription(attribute_replacement={}, - param_replacement=[], - sub_module_replacement=[ - SubModuleReplacementDescription( - suffix="shared", - target_module=VocabParallelEmbedding1D, - ) - ]) + base_policy[T5EncoderModel] = ModulePolicyDescription(sub_module_replacement=[ + SubModuleReplacementDescription( + suffix="shared", + target_module=VocabParallelEmbedding1D, + ) + ]) return base_policy def postprocess(self): diff --git a/colossalai/shardformer/policies/vit.py b/colossalai/shardformer/policies/vit.py index 6a404c2faf0f..eaebe2eee0ba 100644 --- a/colossalai/shardformer/policies/vit.py +++ b/colossalai/shardformer/policies/vit.py @@ -28,16 +28,14 @@ def preprocess(self): def module_policy(self) -> Dict[Union[str, nn.Module], ModulePolicyDescription]: from transformers.models.vit.modeling_vit import ViTEmbeddings, ViTLayer - return { + base_policy = { ViTEmbeddings: - ModulePolicyDescription(attribute_replacement={}, - param_replacement=[], - sub_module_replacement=[ - SubModuleReplacementDescription( - suffix="dropout", - target_module=DropoutForReplicatedInput, - ) - ]), + ModulePolicyDescription(sub_module_replacement=[ + SubModuleReplacementDescription( + suffix="dropout", + target_module=DropoutForReplicatedInput, + ) + ]), ViTLayer: ModulePolicyDescription(attribute_replacement={ "attention.attention.num_attention_heads": @@ -45,7 +43,6 @@ def module_policy(self) -> Dict[Union[str, nn.Module], ModulePolicyDescription]: "attention.attention.all_head_size": self.model.config.hidden_size // self.shard_config.tensor_parallel_size, }, - param_replacement=[], sub_module_replacement=[ SubModuleReplacementDescription( suffix="attention.attention.query", diff --git a/colossalai/shardformer/shard/shard_config.py b/colossalai/shardformer/shard/shard_config.py index c2573bc6d4dd..0a5aa4cc4bdc 100644 --- a/colossalai/shardformer/shard/shard_config.py +++ b/colossalai/shardformer/shard/shard_config.py @@ -3,8 +3,6 @@ import torch.distributed as dist from torch.distributed import ProcessGroup -from colossalai.cluster.dist_coordinator import DistCoordinator - __all__ = ['ShardConfig'] @@ -15,7 +13,9 @@ class ShardConfig: Args: tensor_parallel_process_group (int): The process group for tensor parallelism, defaults to None, which is the global process group. - enable_fused_normalization (bool): Whether to use fused layernorm, default is False + enable_fused_normalization (bool): Whether to use fused layernorm, default is False. + enable_tensor_parallelism (bool): Whether to use tensor parallelism, default is True. + enable_all_optimization (bool): Whether to turn on all optimization, default is False. """ tensor_parallel_process_group: ProcessGroup = None enable_fused_normalization: bool = False @@ -45,4 +45,4 @@ def _turn_on_all_optimization(self): Turn on all optimization. """ # you can add all the optimization flag here - self.fused_layernorm = True + self.enable_fused_normalization = True diff --git a/colossalai/shardformer/shard/sharder.py b/colossalai/shardformer/shard/sharder.py index 81c032b95f03..2867a0a4fd77 100644 --- a/colossalai/shardformer/shard/sharder.py +++ b/colossalai/shardformer/shard/sharder.py @@ -1,9 +1,7 @@ -from typing import Any, Callable, Dict, List +from typing import Any, Callable, Dict, List, Union import torch.nn as nn -from colossalai.cluster.process_group_manager import ProcessGroupManager - from .._utils import getattr_, setattr_ from ..policies.autopolicy import get_autopolicy from ..policies.basepolicy import Policy, SubModuleReplacementDescription @@ -34,7 +32,6 @@ def shard(self) -> None: self.policy.set_model(self.model) self.policy.set_shard_config(self.shard_config) self._preprocess() - self._replace_model_class() self._replace_module() self._postprocess() @@ -44,27 +41,6 @@ def _preprocess(self) -> None: def _postprocess(self) -> None: self.model = self.policy.postprocess() - def _replace_model_class(self,) -> None: - r""" - Replace the model to policy defined model - Mainly modify the forward and backward to fit distributed model - - e.g. - :: - BertForMaskedLM.forward -> BertForMaskedLM_.forward - """ - new_model_class = self.policy.new_model_class() - if new_model_class is None: - return - - for key in new_model_class.__dict__.keys(): - if hasattr(self.model.__class__, key): - setattr( - self.model.__class__, - key, - getattr(new_model_class, key), - ) - def _replace_module(self,) -> None: r""" Replace the module according to the policy, and replace the module one by one @@ -73,19 +49,18 @@ def _replace_module(self,) -> None: model (:class:`torch.nn.Module`): The model to shard """ module_descriptions = self.policy.module_policy() - for module_description in module_descriptions.items(): - origin_layer_cls = module_description[0] - attr_replacement = module_description[1].attribute_replacement - param_replacement = module_description[1].param_replacement - sub_module_replacement = module_description[1].sub_module_replacement - method_replacement = module_description[1].method_replacement - self._recursive_replace_layer(self.model, origin_layer_cls, attr_replacement, param_replacement, + for layer_cls, module_description in module_descriptions.items(): + attr_replacement = module_description.attribute_replacement + param_replacement = module_description.param_replacement + sub_module_replacement = module_description.sub_module_replacement + method_replacement = module_description.method_replacement + self._recursive_replace_layer(self.model, layer_cls, attr_replacement, param_replacement, method_replacement, sub_module_replacement) def _recursive_replace_layer( self, module: nn.Module, - origin_cls: nn.Module, + origin_cls: Union[str, nn.Module], attr_replacement: Dict[str, Any], param_replacement: List[Callable], method_replacement: Dict[str, Callable], @@ -95,17 +70,25 @@ def _recursive_replace_layer( Reverse the replace layer operation Args: - layer (:class:`torch.nn.Module`): The object of layer to shard - origin_cls (:class:`transformers.model`): The origin layer class + layer (torch.nn.Module): The object of layer to shard + origin_cls (Union[str, torch.nn.Module]): The origin layer class or a string of layer class name. attr_replacement (Dict): The attribute dict to modify param_replacement (List[Callable]): The function list to get parameter shard information in polic sub_module_replacement (List[Callable]): The function list to get sub module shard information in policy """ - if module.__class__ == origin_cls: - self._replace_attr(module, attr_replacement) - self._replace_param(module, param_replacement) - self._replace_method(module, method_replacement) - self._replace_sub_module(module, sub_module_replacement) + if (isinstance(origin_cls, str) and origin_cls == module.__class__.__name__) or \ + (module.__class__ == origin_cls): + if attr_replacement is not None: + self._replace_attr(module, attr_replacement) + + if param_replacement is not None: + self._replace_param(module, param_replacement) + + if method_replacement is not None: + self._replace_method(module, method_replacement) + + if sub_module_replacement is not None: + self._replace_sub_module(module, sub_module_replacement) for name, child in module.named_children(): self._recursive_replace_layer(child, origin_cls, attr_replacement, param_replacement, method_replacement, @@ -138,13 +121,10 @@ def _replace_param( layer (:class:`torch.nn.Module`): The object of layer to shard param_replacement (List[Callable]): The function list to get parameter shard information in policy """ - # TODO: support parameter shard - pass + for param_func in param_replacement: + param_func(module) def _replace_method(self, module: nn.Module, method_replacement: Dict[str, Callable]): - if method_replacement is None: - return - for method_name, new_method in method_replacement.items(): # bind the new method to the module setattr(module, method_name, new_method.__get__(module, module.__class__)) @@ -158,8 +138,8 @@ def _replace_sub_module( Shard one layer according to the policy, the layer should be the same class as the key in policy's argument_policy return dict Args: - org_layer (:class:`torch.nn.Module`): The origin layer object to shard - param_funcs (:class:`List[typing.Callable]`): The function list to get shard information in policy class + org_layer (torch.nn.Module): The origin layer object to shard + sub_module_replacement (List[SubModuleReplacementDescription]): The sub module replacement description list """ for description in sub_module_replacement: diff --git a/colossalai/shardformer/shard/shardformer.py b/colossalai/shardformer/shard/shardformer.py index 7c4220c3a9fb..3fce12463414 100644 --- a/colossalai/shardformer/shard/shardformer.py +++ b/colossalai/shardformer/shard/shardformer.py @@ -22,27 +22,19 @@ class ShardFormer: colossalai.launch_from_torch(config={}) org_model = BertForMaskedLM.from_pretrained('bert-base-uncased') - shard_config = ShardConfig( - tensor_parallel_size=2, - tensor_parallel_mode='1d', - ) + shard_config = ShardConfig() shard_former = ShardFormer(shard_config=shard_config) - model = shard_former.shard_model(org_model) + model = shard_former.optimize(org_model) ``` """ def __init__(self, shard_config: ShardConfig): - """ - Do two things: - 1. Create a colossalai.cluster.process_group_manager to manage process groups for dp, tp and pp - 2. serve as a store for - """ self.coordinator = DistCoordinator() self.shard_config = shard_config - def shard_model(self, model: nn.Module, policy: Policy = None): + def optimize(self, model: nn.Module, policy: Policy = None): r""" - The function is used to shard the PyTorch model. + This method will optimize the model based on the given policy. Args: model (`torch.nn.Model`): the origin huggingface model diff --git a/tests/test_shardformer/test_model/_utils.py b/tests/test_shardformer/test_model/_utils.py index a6355bf1c75e..aa1424af3289 100644 --- a/tests/test_shardformer/test_model/_utils.py +++ b/tests/test_shardformer/test_model/_utils.py @@ -11,7 +11,7 @@ def build_model(model_fn): shard_config = ShardConfig(enable_fused_normalization=True) model_copy = copy.deepcopy(org_model) shard_former = ShardFormer(shard_config=shard_config) - sharded_model = shard_former.shard_model(model_copy).cuda() + sharded_model = shard_former.optimize(model_copy).cuda() return org_model, sharded_model diff --git a/tests/test_shardformer/test_with_torch_ddp.py b/tests/test_shardformer/test_with_torch_ddp.py index 61b672650965..9f8a5db6c94f 100644 --- a/tests/test_shardformer/test_with_torch_ddp.py +++ b/tests/test_shardformer/test_with_torch_ddp.py @@ -44,7 +44,7 @@ def check_shardformer_with_ddp(rank, world_size, port): for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items(): # create and shard model model = model_fn().cuda() - sharded_model = shardformer.shard_model(model) + sharded_model = shardformer.optimize(model) # add ddp sharded_ddp_model = DDP(sharded_model, process_group=dp_process_group) From f8dcf9d831d14931334309e95a4148b0cafd88e9 Mon Sep 17 00:00:00 2001 From: Frank Lee Date: Tue, 4 Jul 2023 09:57:03 +0800 Subject: [PATCH 47/49] [shardformer] made tensor parallelism configurable (#4144) * [shardformer] made tensor parallelism configurable * polish code --- colossalai/shardformer/policies/basepolicy.py | 25 ++ colossalai/shardformer/policies/bert.py | 299 ++++++++---------- colossalai/shardformer/policies/bloom.py | 166 +++++----- colossalai/shardformer/policies/gpt2.py | 163 +++++----- colossalai/shardformer/policies/llama.py | 150 ++++----- colossalai/shardformer/policies/opt.py | 114 ++++--- colossalai/shardformer/policies/t5.py | 266 ++++++++-------- colossalai/shardformer/shard/shard_config.py | 10 +- tests/test_shardformer/test_model/_utils.py | 5 +- .../test_model/test_shard_bert.py | 45 ++- .../test_model/test_shard_bloom.py | 45 ++- .../test_model/test_shard_gpt2.py | 46 ++- .../test_model/test_shard_llama.py | 48 ++- .../test_model/test_shard_opt.py | 41 ++- .../test_model/test_shard_t5.py | 59 +++- 15 files changed, 814 insertions(+), 668 deletions(-) diff --git a/colossalai/shardformer/policies/basepolicy.py b/colossalai/shardformer/policies/basepolicy.py index 9ea3d95de5b2..85e6d509c81b 100644 --- a/colossalai/shardformer/policies/basepolicy.py +++ b/colossalai/shardformer/policies/basepolicy.py @@ -126,3 +126,28 @@ def postprocess(self) -> nn.Module: the classifier layer """ pass + + def append_or_create_submodule_replacement( + self, description: Union[SubModuleReplacementDescription, + List[SubModuleReplacementDescription]], policy: Dict[Union[str, nn.Module], + ModulePolicyDescription], + target_key: Union[str, nn.Module]) -> Dict[Union[str, nn.Module], ModulePolicyDescription]: + r""" + Append or create a new submodule replacement description to the policy for the given key. + + Args: + submodule_replace_desc (Union[SubModuleReplacementDescription, List[SubModuleReplacementDescription]]): the submodule replacement description to be appended + policy (Dict[Union[str, nn.Module], ModulePolicyDescription]): the policy to be updated + target_key (Union[str, nn.Module]): the key of the policy to be updated + """ + # convert to list + if isinstance(description, SubModuleReplacementDescription): + description = [description] + + # append or create a new description + if target_key in policy: + policy[target_key].sub_module_replacement.extend(description) + else: + policy[target_key] = ModulePolicyDescription(sub_module_replacement=description) + + return policy diff --git a/colossalai/shardformer/policies/bert.py b/colossalai/shardformer/policies/bert.py index 5ab8fb825244..9c2736cc64d3 100644 --- a/colossalai/shardformer/policies/bert.py +++ b/colossalai/shardformer/policies/bert.py @@ -33,89 +33,114 @@ def preprocess(self): def module_policy(self): from transformers.models.bert.modeling_bert import BertEmbeddings, BertLayer - base_policy = { - BertLayer: - ModulePolicyDescription( - attribute_replacement={ - # 1. shard hidden size - "attention.self.all_head_size": - self.model.config.hidden_size // self.shard_config.tensor_parallel_size, - "crossattention.self.all_head_size": - self.model.config.hidden_size // self.shard_config.tensor_parallel_size, - # 2. shard number of heads - "attention.self.num_attention_heads": - self.model.config.num_attention_heads // self.shard_config.tensor_parallel_size, - "crossattention.self.num_attention_heads": - self.model.config.num_attention_heads // self.shard_config.tensor_parallel_size, - }, - sub_module_replacement=[ - SubModuleReplacementDescription( - suffix="attention.self.query", - target_module=col_nn.Linear1D_Col, - ), - SubModuleReplacementDescription( - suffix="attention.self.key", - target_module=col_nn.Linear1D_Col, - ), - SubModuleReplacementDescription( - suffix="attention.self.value", - target_module=col_nn.Linear1D_Col, - ), - SubModuleReplacementDescription( - suffix="attention.self.dropout", - target_module=col_nn.DropoutForParallelInput, - ), - SubModuleReplacementDescription( - suffix="attention.output.dense", - target_module=col_nn.Linear1D_Row, - ), - SubModuleReplacementDescription( - suffix="attention.output.dropout", - target_module=col_nn.DropoutForParallelInput, - ), - SubModuleReplacementDescription( - suffix="intermediate.dense", - target_module=col_nn.Linear1D_Col, - ), - SubModuleReplacementDescription( - suffix="output.dense", - target_module=col_nn.Linear1D_Row, - ), - SubModuleReplacementDescription( - suffix="output.dropout", - target_module=col_nn.DropoutForParallelInput, - ) - ]), - BertEmbeddings: - ModulePolicyDescription(sub_module_replacement=[ - SubModuleReplacementDescription( - suffix="word_embeddings", - target_module=col_nn.VocabParallelEmbedding1D, - ), - SubModuleReplacementDescription( - suffix="dropout", - target_module=col_nn.DropoutForReplicatedInput, - ) - ]) - } + policy = {} + + if self.shard_config.enable_tensor_parallelism: + policy[BertLayer] = ModulePolicyDescription(attribute_replacement={ + "attention.self.all_head_size": + self.model.config.hidden_size // self.shard_config.tensor_parallel_size, + "crossattention.self.all_head_size": + self.model.config.hidden_size // self.shard_config.tensor_parallel_size, + "attention.self.num_attention_heads": + self.model.config.num_attention_heads // self.shard_config.tensor_parallel_size, + "crossattention.self.num_attention_heads": + self.model.config.num_attention_heads // self.shard_config.tensor_parallel_size, + }, + sub_module_replacement=[ + SubModuleReplacementDescription( + suffix="attention.self.query", + target_module=col_nn.Linear1D_Col, + ), + SubModuleReplacementDescription( + suffix="attention.self.key", + target_module=col_nn.Linear1D_Col, + ), + SubModuleReplacementDescription( + suffix="attention.self.value", + target_module=col_nn.Linear1D_Col, + ), + SubModuleReplacementDescription( + suffix="attention.self.dropout", + target_module=col_nn.DropoutForParallelInput, + ), + SubModuleReplacementDescription( + suffix="attention.output.dense", + target_module=col_nn.Linear1D_Row, + ), + SubModuleReplacementDescription( + suffix="attention.output.dropout", + target_module=col_nn.DropoutForParallelInput, + ), + SubModuleReplacementDescription( + suffix="intermediate.dense", + target_module=col_nn.Linear1D_Col, + ), + SubModuleReplacementDescription( + suffix="output.dense", + target_module=col_nn.Linear1D_Row, + ), + SubModuleReplacementDescription( + suffix="output.dropout", + target_module=col_nn.DropoutForParallelInput, + ) + ]) + + policy[BertEmbeddings] = ModulePolicyDescription(sub_module_replacement=[ + SubModuleReplacementDescription( + suffix="word_embeddings", + target_module=col_nn.VocabParallelEmbedding1D, + ), + SubModuleReplacementDescription( + suffix="dropout", + target_module=col_nn.DropoutForReplicatedInput, + ) + ]) # optimization configuration if self.shard_config.enable_fused_normalization: - base_policy[BertLayer].sub_module_replacement.append( + # Handle bert layer + self.append_or_create_submodule_replacement(description=[ SubModuleReplacementDescription( suffix="attention.output.LayerNorm", target_module=col_nn.FusedLayerNorm, - )) - base_policy[BertLayer].sub_module_replacement.append( + ), SubModuleReplacementDescription( suffix="output.LayerNorm", target_module=col_nn.FusedLayerNorm, - )) - base_policy[BertEmbeddings].sub_module_replacement.append( - SubModuleReplacementDescription( + ) + ], + policy=policy, + target_key=BertLayer) + + # handle embedding layer + self.append_or_create_submodule_replacement( + description=[SubModuleReplacementDescription( suffix="LayerNorm", target_module=col_nn.FusedLayerNorm, - ),) + )], + policy=policy, + target_key=BertEmbeddings) + return policy + + def add_lm_head_policy(self, base_policy): + from transformers.models.bert.modeling_bert import BertLMPredictionHead + + # optimize for tensor parallelism + if self.shard_config.enable_tensor_parallelism: + self.append_or_create_submodule_replacement(description=SubModuleReplacementDescription( + suffix="decoder", target_module=col_nn.Linear1D_Col, kwargs={"gather_output": True}), + policy=base_policy, + target_key=BertLMPredictionHead) + + # optimize with fused normalization + if self.shard_config.enable_fused_normalization: + # Handle bert lm prediction head + self.append_or_create_submodule_replacement(description=SubModuleReplacementDescription( + suffix="transform.LayerNorm", + target_module=col_nn.FusedLayerNorm, + ), + policy=base_policy, + target_key=BertLMPredictionHead) return base_policy def postprocess(self): @@ -136,35 +161,14 @@ def __init__(self) -> None: super().__init__() def module_policy(self): - from transformers.models.bert.modeling_bert import BertLMPredictionHead - module_policy = super().module_policy() - addon_module = { - BertLMPredictionHead: - ModulePolicyDescription(sub_module_replacement=[ - SubModuleReplacementDescription( - suffix="decoder", target_module=col_nn.Linear1D_Col, kwargs={"gather_output": True}), - ]) - } - - # optimization configuration - if self.shard_config.enable_fused_normalization: - addon_module[BertLMPredictionHead].sub_module_replacement.append( - SubModuleReplacementDescription( - suffix="transform.LayerNorm", - target_module=col_nn.FusedLayerNorm, - )) - - # append extra policy - module_policy.update(addon_module) + module_policy = self.add_lm_head_policy(module_policy) return module_policy def postprocess(self): binding_map = {"bert.embeddings.word_embeddings.weight": "cls.predictions.decoder.weight"} for k, v in binding_map.items(): param = getattr_(self.model, k) - param = nn.Parameter(param) - setattr_(self.model, k, param) setattr_(self.model, v, param) return self.model @@ -176,31 +180,14 @@ def __init__(self) -> None: super().__init__() def module_policy(self): - from transformers.models.bert.modeling_bert import BertLMPredictionHead - module_policy = super().module_policy() - addon_module = { - BertLMPredictionHead: - ModulePolicyDescription(sub_module_replacement=[ - SubModuleReplacementDescription( - suffix="decoder", target_module=col_nn.Linear1D_Col, kwargs={"gather_output": True}), - ]) - } - if self.shard_config.enable_fused_normalization: - addon_module[BertLMPredictionHead].sub_module_replacement.append( - SubModuleReplacementDescription( - suffix="transform.LayerNorm", - target_module=col_nn.FusedLayerNorm, - )) - module_policy.update(addon_module) + module_policy = self.add_lm_head_policy(module_policy) return module_policy def postprocess(self): binding_map = {"bert.embeddings.word_embeddings.weight": "cls.predictions.decoder.weight"} for k, v in binding_map.items(): param = getattr_(self.model, k) - param = nn.Parameter(param) - setattr_(self.model, k, param) setattr_(self.model, v, param) return self.model @@ -212,34 +199,14 @@ def __init__(self) -> None: super().__init__() def module_policy(self): - from transformers.models.bert.modeling_bert import BertLMPredictionHead - module_policy = super().module_policy() - addon_module = { - BertLMPredictionHead: - ModulePolicyDescription(sub_module_replacement=[ - SubModuleReplacementDescription( - suffix="decoder", target_module=col_nn.Linear1D_Col, kwargs={"gather_output": True}), - ]) - } - - # optimization configuration - if self.shard_config.enable_fused_normalization: - addon_module[BertLMPredictionHead].sub_module_replacement.append( - SubModuleReplacementDescription( - suffix="transform.LayerNorm", - target_module=col_nn.FusedLayerNorm, - )) - - module_policy.update(addon_module) + module_policy = self.add_lm_head_policy(module_policy) return module_policy def postprocess(self): binding_map = {"bert.embeddings.word_embeddings.weight": "cls.predictions.decoder.weight"} for k, v in binding_map.items(): param = getattr_(self.model, k) - param = nn.Parameter(param) - setattr_(self.model, k, param) setattr_(self.model, v, param) return self.model @@ -254,16 +221,18 @@ def module_policy(self): from transformers.models.bert.modeling_bert import BertForSequenceClassification module_policy = super().module_policy() - addon_module = { - BertForSequenceClassification: - ModulePolicyDescription(sub_module_replacement=[ - SubModuleReplacementDescription( - suffix="dropout", - target_module=col_nn.DropoutForParallelInput, - ) - ]) - } - module_policy.update(addon_module) + + if self.shard_config.enable_tensor_parallelism: + addon_module = { + BertForSequenceClassification: + ModulePolicyDescription(sub_module_replacement=[ + SubModuleReplacementDescription( + suffix="dropout", + target_module=col_nn.DropoutForParallelInput, + ) + ]) + } + module_policy.update(addon_module) return module_policy @@ -277,16 +246,18 @@ def module_policy(self): from transformers.models.bert.modeling_bert import BertForTokenClassification module_policy = super().module_policy() - addon_module = { - BertForTokenClassification: - ModulePolicyDescription(sub_module_replacement=[ - SubModuleReplacementDescription( - suffix="dropout", - target_module=col_nn.DropoutForParallelInput, - ) - ]) - } - module_policy.update(addon_module) + + if self.shard_config.enable_tensor_parallelism: + addon_module = { + BertForTokenClassification: + ModulePolicyDescription(sub_module_replacement=[ + SubModuleReplacementDescription( + suffix="dropout", + target_module=col_nn.DropoutForParallelInput, + ) + ]) + } + module_policy.update(addon_module) return module_policy @@ -307,14 +278,16 @@ def module_policy(self): from transformers.models.bert.modeling_bert import BertForMultipleChoice module_policy = super().module_policy() - addon_module = { - BertForMultipleChoice: - ModulePolicyDescription(sub_module_replacement=[ - SubModuleReplacementDescription( - suffix="dropout", - target_module=col_nn.DropoutForParallelInput, - ) - ]) - } - module_policy.update(addon_module) + + if self.shard_config.enable_tensor_parallelism: + addon_module = { + BertForMultipleChoice: + ModulePolicyDescription(sub_module_replacement=[ + SubModuleReplacementDescription( + suffix="dropout", + target_module=col_nn.DropoutForParallelInput, + ) + ]) + } + module_policy.update(addon_module) return module_policy diff --git a/colossalai/shardformer/policies/bloom.py b/colossalai/shardformer/policies/bloom.py index 00ab9159b0dc..030774a919d7 100644 --- a/colossalai/shardformer/policies/bloom.py +++ b/colossalai/shardformer/policies/bloom.py @@ -85,57 +85,53 @@ def preprocess(self): def module_policy(self): from transformers.models.bloom.modeling_bloom import BloomBlock, BloomModel - base_policy = { - BloomBlock: - ModulePolicyDescription( - attribute_replacement={ - # 1. shard hidden size - "self_attention.hidden_size": - self.model.config.hidden_size // self.shard_config.tensor_parallel_size, - "self_attention.split_size": - self.model.config.hidden_size // self.shard_config.tensor_parallel_size, - # 2. shard number of heads - "self_attention.num_heads": - self.model.config.n_head // self.shard_config.tensor_parallel_size, - }, - sub_module_replacement=[ - SubModuleReplacementDescription( - suffix="self_attention.query_key_value", - target_module=col_nn.Linear1D_Col, - ), - SubModuleReplacementDescription( - suffix="self_attention.dense", - target_module=col_nn.Linear1D_Row, - ), - SubModuleReplacementDescription( - suffix="self_attention.attention_dropout", - target_module=col_nn.DropoutForParallelInput, - ), - SubModuleReplacementDescription( - suffix="mlp.dense_h_to_4h", - target_module=col_nn.Linear1D_Col, - ), - SubModuleReplacementDescription( - suffix="mlp.dense_4h_to_h", - target_module=col_nn.Linear1D_Row, - ), - ]), - BloomModel: - ModulePolicyDescription(attribute_replacement={ + policy = {} + + if self.shard_config.enable_tensor_parallelism: + policy[BloomBlock] = ModulePolicyDescription(attribute_replacement={ + "self_attention.hidden_size": self.model.config.hidden_size // self.shard_config.tensor_parallel_size, + "self_attention.split_size": self.model.config.hidden_size // self.shard_config.tensor_parallel_size, + "self_attention.num_heads": self.model.config.n_head // self.shard_config.tensor_parallel_size, + }, + sub_module_replacement=[ + SubModuleReplacementDescription( + suffix="self_attention.query_key_value", + target_module=col_nn.Linear1D_Col, + ), + SubModuleReplacementDescription( + suffix="self_attention.dense", + target_module=col_nn.Linear1D_Row, + ), + SubModuleReplacementDescription( + suffix="self_attention.attention_dropout", + target_module=col_nn.DropoutForParallelInput, + ), + SubModuleReplacementDescription( + suffix="mlp.dense_h_to_4h", + target_module=col_nn.Linear1D_Col, + ), + SubModuleReplacementDescription( + suffix="mlp.dense_4h_to_h", + target_module=col_nn.Linear1D_Row, + ), + ]) + + policy[BloomModel] = ModulePolicyDescription( + attribute_replacement={ "num_heads": self.model.config.n_head // self.shard_config.tensor_parallel_size, }, - method_replacement={"build_alibi_tensor": build_bloom_alibi_tensor}, - sub_module_replacement=[ - SubModuleReplacementDescription( - suffix="word_embeddings", - target_module=col_nn.VocabParallelEmbedding1D, - ) - ]) - } + method_replacement={"build_alibi_tensor": build_bloom_alibi_tensor}, + sub_module_replacement=[ + SubModuleReplacementDescription( + suffix="word_embeddings", + target_module=col_nn.VocabParallelEmbedding1D, + ) + ]) # optimization configuration if self.shard_config.enable_fused_normalization: - base_policy[BloomModel].sub_module_replacement.extend([ + # handle bloom model + self.append_or_create_submodule_replacement(description=[ SubModuleReplacementDescription( suffix="ln_f", target_module=col_nn.FusedLayerNorm, @@ -144,8 +140,12 @@ def module_policy(self): suffix="word_embeddings_layernorm", target_module=col_nn.FusedLayerNorm, ) - ]) - base_policy[BloomBlock].sub_module_replacement.extend([ + ], + policy=policy, + target_key=BloomModel) + + # handle bloom block + self.append_or_create_submodule_replacement(description=[ SubModuleReplacementDescription( suffix="input_layernorm", target_module=col_nn.FusedLayerNorm, @@ -154,9 +154,11 @@ def module_policy(self): suffix="post_attention_layernorm", target_module=col_nn.FusedLayerNorm, ) - ]) + ], + policy=policy, + target_key=BloomBlock) - return base_policy + return policy def postprocess(self): return self.model @@ -171,19 +173,19 @@ class BloomForCausalLMPolicy(BloomPolicy): def module_policy(self): from transformers.models.bloom.modeling_bloom import BloomForCausalLM policy = super().module_policy() - # add a new item for casual lm - new_item = { - BloomForCausalLM: - ModulePolicyDescription(sub_module_replacement=[ - SubModuleReplacementDescription( - suffix="lm_head", target_module=col_nn.Linear1D_Col, kwargs=dict(gather_output=True)) - ]) - } - policy.update(new_item) + + # handle tensor parallelism + if self.shard_config.enable_tensor_parallelism: + self.append_or_create_submodule_replacement(description=SubModuleReplacementDescription( + suffix="lm_head", target_module=col_nn.Linear1D_Col, kwargs=dict(gather_output=True)), + policy=policy, + target_key=BloomForCausalLM) + return policy def postprocess(self): binding_map = {"transformer.word_embeddings.weight": "lm_head.weight"} + for k, v in binding_map.items(): param = getattr_(self.model, k) @@ -191,7 +193,6 @@ def postprocess(self): param = nn.Parameter(param) # tie weights - setattr_(self.model, k, param) setattr_(self.model, v, param) return self.model @@ -201,15 +202,14 @@ class BloomForSequenceClassificationPolicy(BloomPolicy): def module_policy(self): from transformers.models.bloom.modeling_bloom import BloomForSequenceClassification policy = super().module_policy() - # add a new item for casual lm - new_item = { - BloomForSequenceClassification: - ModulePolicyDescription(sub_module_replacement=[ - SubModuleReplacementDescription( - suffix="score", target_module=col_nn.Linear1D_Col, kwargs=dict(gather_output=True)) - ]) - } - policy.update(new_item) + + # handle tensor parallelism + if self.shard_config.enable_tensor_parallelism: + self.append_or_create_submodule_replacement(description=SubModuleReplacementDescription( + suffix="score", target_module=col_nn.Linear1D_Col, kwargs=dict(gather_output=True)), + policy=policy, + target_key=BloomForSequenceClassification) + return policy @@ -218,19 +218,21 @@ class BloomForTokenClassificationPolicy(BloomPolicy): def module_policy(self): from transformers.models.bloom.modeling_bloom import BloomForTokenClassification policy = super().module_policy() - # add a new item for casual lm - new_item = { - BloomForTokenClassification: - ModulePolicyDescription(sub_module_replacement=[ - SubModuleReplacementDescription( - suffix="classifier", target_module=col_nn.Linear1D_Col, kwargs=dict(gather_output=True)), - SubModuleReplacementDescription( - suffix="dropout", - target_module=col_nn.DropoutForReplicatedInput, - ), - ]) - } - policy.update(new_item) + + # handle tensor parallelism + if self.shard_config.enable_tensor_parallelism: + self.append_or_create_submodule_replacement(description=[ + SubModuleReplacementDescription(suffix="classifier", + target_module=col_nn.Linear1D_Col, + kwargs=dict(gather_output=True)), + SubModuleReplacementDescription( + suffix="dropout", + target_module=col_nn.DropoutForReplicatedInput, + ), + ], + policy=policy, + target_key=BloomForTokenClassification) + return policy diff --git a/colossalai/shardformer/policies/gpt2.py b/colossalai/shardformer/policies/gpt2.py index ad0b1144a8a5..549cdbf87a80 100644 --- a/colossalai/shardformer/policies/gpt2.py +++ b/colossalai/shardformer/policies/gpt2.py @@ -31,67 +31,67 @@ def preprocess(self): def module_policy(self): from transformers.models.gpt2.modeling_gpt2 import GPT2Block, GPT2Model - base_policy = { - GPT2Model: - ModulePolicyDescription(sub_module_replacement=[ - SubModuleReplacementDescription( - suffix="wte", - target_module=col_nn.VocabParallelEmbedding1D, - ), - ]), - GPT2Block: - ModulePolicyDescription(attribute_replacement={ - "attn.embed_dim": self.model.config.hidden_size // self.shard_config.tensor_parallel_size, - "attn.split_size": self.model.config.hidden_size // self.shard_config.tensor_parallel_size, - "attn.num_heads": self.model.config.num_attention_heads // self.shard_config.tensor_parallel_size, - }, - sub_module_replacement=[ - SubModuleReplacementDescription( - suffix="attn.c_attn", - target_module=col_nn.GPT2FusedLinearConv1D_Col, - kwargs={ - "n_fused": 3, - }, - ), - SubModuleReplacementDescription( - suffix="attn.c_proj", - target_module=col_nn.GPT2FusedLinearConv1D_Row, - ), - SubModuleReplacementDescription( - suffix="mlp.c_fc", - target_module=col_nn.GPT2FusedLinearConv1D_Col, - kwargs={ - "n_fused": 1, - }, - ), - SubModuleReplacementDescription( - suffix="mlp.c_proj", - target_module=col_nn.GPT2FusedLinearConv1D_Row, - ), - SubModuleReplacementDescription( - suffix="attn.attn_dropout", - target_module=col_nn.DropoutForParallelInput, - ), - SubModuleReplacementDescription( - suffix="attn.resid_dropout", - target_module=col_nn.DropoutForParallelInput, - ), - SubModuleReplacementDescription( - suffix="mlp.dropout", - target_module=col_nn.DropoutForParallelInput, - ), - ]) - } + policy = {} - # optimization configuration - if self.shard_config.enable_fused_normalization: - base_policy[GPT2Model].sub_module_replacement.append( + if self.shard_config.enable_tensor_parallelism: + policy[GPT2Model] = ModulePolicyDescription(sub_module_replacement=[ SubModuleReplacementDescription( - suffix="ln_f", - target_module=col_nn.FusedLayerNorm, - )) + suffix="wte", + target_module=col_nn.VocabParallelEmbedding1D, + ), + ]) + policy[GPT2Block] = ModulePolicyDescription(attribute_replacement={ + "attn.embed_dim": self.model.config.hidden_size // self.shard_config.tensor_parallel_size, + "attn.split_size": self.model.config.hidden_size // self.shard_config.tensor_parallel_size, + "attn.num_heads": self.model.config.num_attention_heads // self.shard_config.tensor_parallel_size, + }, + sub_module_replacement=[ + SubModuleReplacementDescription( + suffix="attn.c_attn", + target_module=col_nn.GPT2FusedLinearConv1D_Col, + kwargs={ + "n_fused": 3, + }, + ), + SubModuleReplacementDescription( + suffix="attn.c_proj", + target_module=col_nn.GPT2FusedLinearConv1D_Row, + ), + SubModuleReplacementDescription( + suffix="mlp.c_fc", + target_module=col_nn.GPT2FusedLinearConv1D_Col, + kwargs={ + "n_fused": 1, + }, + ), + SubModuleReplacementDescription( + suffix="mlp.c_proj", + target_module=col_nn.GPT2FusedLinearConv1D_Row, + ), + SubModuleReplacementDescription( + suffix="attn.attn_dropout", + target_module=col_nn.DropoutForParallelInput, + ), + SubModuleReplacementDescription( + suffix="attn.resid_dropout", + target_module=col_nn.DropoutForParallelInput, + ), + SubModuleReplacementDescription( + suffix="mlp.dropout", + target_module=col_nn.DropoutForParallelInput, + ), + ]) - base_policy[GPT2Block].sub_module_replacement.extend([ + # optimization configuration + if self.shard_config.enable_fused_normalization: + self.append_or_create_submodule_replacement(description=SubModuleReplacementDescription( + suffix="ln_f", + target_module=col_nn.FusedLayerNorm, + ), + policy=policy, + target_key=GPT2Model) + + self.append_or_create_submodule_replacement(description=[ SubModuleReplacementDescription( suffix="ln_1", target_module=col_nn.FusedLayerNorm, @@ -103,9 +103,10 @@ def module_policy(self): SubModuleReplacementDescription(suffix="ln_cross_attn", target_module=col_nn.FusedLayerNorm, ignore_if_not_exist=True) - ]) - - return base_policy + ], + policy=policy, + target_key=GPT2Block) + return policy def postprocess(self): return self.model @@ -128,22 +129,22 @@ def module_policy(self): from transformers.models.gpt2.modeling_gpt2 import GPT2LMHeadModel module_policy = super().module_policy() - addon_module = { - GPT2LMHeadModel: - ModulePolicyDescription(sub_module_replacement=[ - SubModuleReplacementDescription( - suffix="lm_head", target_module=col_nn.Linear1D_Col, kwargs={"gather_output": True}) - ]) - } - module_policy.update(addon_module) + + if self.shard_config.enable_tensor_parallelism: + addon_module = { + GPT2LMHeadModel: + ModulePolicyDescription(sub_module_replacement=[ + SubModuleReplacementDescription( + suffix="lm_head", target_module=col_nn.Linear1D_Col, kwargs={"gather_output": True}) + ]) + } + module_policy.update(addon_module) return module_policy def postprocess(self): binding_map = {"transformer.wte.weight": "lm_head.weight"} for k, v in binding_map.items(): param = getattr_(self.model, k) - param = nn.Parameter(param) - setattr_(self.model, k, param) setattr_(self.model, v, param) return self.model @@ -158,22 +159,22 @@ def module_policy(self): from transformers.models.gpt2.modeling_gpt2 import GPT2DoubleHeadsModel module_policy = super().module_policy() - addon_module = { - GPT2DoubleHeadsModel: - ModulePolicyDescription(sub_module_replacement=[ - SubModuleReplacementDescription( - suffix="lm_head", target_module=col_nn.Linear1D_Col, kwargs={"gather_output": True}) - ]) - } - module_policy.update(addon_module) + + if self.shard_config.enable_tensor_parallelism: + addon_module = { + GPT2DoubleHeadsModel: + ModulePolicyDescription(sub_module_replacement=[ + SubModuleReplacementDescription( + suffix="lm_head", target_module=col_nn.Linear1D_Col, kwargs={"gather_output": True}) + ]) + } + module_policy.update(addon_module) return module_policy def postprocess(self): binding_map = {"transformer.wte.weight": "lm_head.weight"} for k, v in binding_map.items(): param = getattr_(self.model, k) - param = nn.Parameter(param) - setattr_(self.model, k, param) setattr_(self.model, v, param) return self.model diff --git a/colossalai/shardformer/policies/llama.py b/colossalai/shardformer/policies/llama.py index 8f397693745c..157785bdcf13 100644 --- a/colossalai/shardformer/policies/llama.py +++ b/colossalai/shardformer/policies/llama.py @@ -28,58 +28,58 @@ def preprocess(self): def module_policy(self) -> Dict[Union[str, nn.Module], ModulePolicyDescription]: from transformers.models.llama.modeling_llama import LlamaDecoderLayer, LlamaModel - base_policy = { - LlamaDecoderLayer: - ModulePolicyDescription( - attribute_replacement={ - "self_attn.hidden_size": - self.model.config.hidden_size // self.shard_config.tensor_parallel_size, - "self_attn.num_heads": - self.model.config.num_attention_heads // self.shard_config.tensor_parallel_size, - }, - sub_module_replacement=[ - SubModuleReplacementDescription( - suffix="self_attn.q_proj", - target_module=Linear1D_Col, - ), - SubModuleReplacementDescription( - suffix="self_attn.k_proj", - target_module=Linear1D_Col, - ), - SubModuleReplacementDescription( - suffix="self_attn.v_proj", - target_module=Linear1D_Col, - ), - SubModuleReplacementDescription( - suffix="self_attn.o_proj", - target_module=Linear1D_Row, - ), - SubModuleReplacementDescription( - suffix="mlp.gate_proj", - target_module=Linear1D_Col, - ), - SubModuleReplacementDescription( - suffix="mlp.up_proj", - target_module=Linear1D_Col, - ), - SubModuleReplacementDescription( - suffix="mlp.down_proj", - target_module=Linear1D_Row, - ) - ], - ), - LlamaModel: - ModulePolicyDescription(sub_module_replacement=[ + policy = {} + + if self.shard_config.enable_tensor_parallelism: + policy[LlamaDecoderLayer] = ModulePolicyDescription( + attribute_replacement={ + "self_attn.hidden_size": + self.model.config.hidden_size // self.shard_config.tensor_parallel_size, + "self_attn.num_heads": + self.model.config.num_attention_heads // self.shard_config.tensor_parallel_size, + }, + sub_module_replacement=[ + SubModuleReplacementDescription( + suffix="self_attn.q_proj", + target_module=Linear1D_Col, + ), + SubModuleReplacementDescription( + suffix="self_attn.k_proj", + target_module=Linear1D_Col, + ), + SubModuleReplacementDescription( + suffix="self_attn.v_proj", + target_module=Linear1D_Col, + ), + SubModuleReplacementDescription( + suffix="self_attn.o_proj", + target_module=Linear1D_Row, + ), SubModuleReplacementDescription( - suffix="embed_tokens", - target_module=VocabParallelEmbedding1D, + suffix="mlp.gate_proj", + target_module=Linear1D_Col, + ), + SubModuleReplacementDescription( + suffix="mlp.up_proj", + target_module=Linear1D_Col, + ), + SubModuleReplacementDescription( + suffix="mlp.down_proj", + target_module=Linear1D_Row, ) - ]) - } + ], + ) + + self.append_or_create_submodule_replacement(description=SubModuleReplacementDescription( + suffix="embed_tokens", + target_module=VocabParallelEmbedding1D, + ), + policy=policy, + target_key=LlamaModel) # optimization configuration if self.shard_config.enable_fused_normalization: - base_policy[LlamaDecoderLayer].sub_module_replacement.extend([ + self.append_or_create_submodule_replacement(description=[ SubModuleReplacementDescription( suffix="input_layernorm", target_module=FusedRMSNorm, @@ -88,15 +88,18 @@ def module_policy(self) -> Dict[Union[str, nn.Module], ModulePolicyDescription]: suffix="post_attention_layernorm", target_module=FusedRMSNorm, ) - ]) + ], + policy=policy, + target_key=LlamaDecoderLayer) - base_policy[LlamaModel].sub_module_replacement.append( - SubModuleReplacementDescription( - suffix="norm", - target_module=FusedRMSNorm, - )) + self.append_or_create_submodule_replacement(description=SubModuleReplacementDescription( + suffix="norm", + target_module=FusedRMSNorm, + ), + policy=policy, + target_key=LlamaModel) - return base_policy + return policy def postprocess(self): return self.model @@ -108,15 +111,17 @@ def module_policy(self): from transformers import LlamaForCausalLM policy = super().module_policy() - # add a new item for casual lm - new_item = { - LlamaForCausalLM: - ModulePolicyDescription(sub_module_replacement=[ - SubModuleReplacementDescription( - suffix="lm_head", target_module=Linear1D_Col, kwargs=dict(gather_output=True)) - ]) - } - policy.update(new_item) + + if self.shard_config.enable_tensor_parallelism: + # add a new item for casual lm + new_item = { + LlamaForCausalLM: + ModulePolicyDescription(sub_module_replacement=[ + SubModuleReplacementDescription( + suffix="lm_head", target_module=Linear1D_Col, kwargs=dict(gather_output=True)) + ]) + } + policy.update(new_item) return policy @@ -127,13 +132,14 @@ def module_policy(self): policy = super().module_policy() - # add a new item for sequence classification - new_item = { - LlamaForSequenceClassification: - ModulePolicyDescription(sub_module_replacement=[ - SubModuleReplacementDescription( - suffix="score", target_module=Linear1D_Col, kwargs=dict(gather_output=True)) - ]) - } - policy.update(new_item) + if self.shard_config.enable_tensor_parallelism: + # add a new item for sequence classification + new_item = { + LlamaForSequenceClassification: + ModulePolicyDescription(sub_module_replacement=[ + SubModuleReplacementDescription( + suffix="score", target_module=Linear1D_Col, kwargs=dict(gather_output=True)) + ]) + } + policy.update(new_item) return policy diff --git a/colossalai/shardformer/policies/opt.py b/colossalai/shardformer/policies/opt.py index 428ee2c9776c..b87db53f45f1 100644 --- a/colossalai/shardformer/policies/opt.py +++ b/colossalai/shardformer/policies/opt.py @@ -29,66 +29,67 @@ def preprocess(self): def module_policy(self): from transformers.models.opt.modeling_opt import OPTAttention, OPTDecoder, OPTDecoderLayer - base_policy = { - OPTDecoder: - ModulePolicyDescription(sub_module_replacement=[ - SubModuleReplacementDescription( - suffix="embed_tokens", - target_module=VocabParallelEmbedding1D, - ) - ]), - OPTDecoderLayer: - ModulePolicyDescription(sub_module_replacement=[ - SubModuleReplacementDescription( - suffix="fc1", - target_module=Linear1D_Col, - ), - SubModuleReplacementDescription( - suffix="fc2", - target_module=Linear1D_Row, - ) - ]), - OPTAttention: - ModulePolicyDescription(attribute_replacement={ - "embed_dim": self.model.config.hidden_size // self.shard_config.tensor_parallel_size, - "num_heads": self.model.config.num_attention_heads // self.shard_config.tensor_parallel_size - }, - sub_module_replacement=[ - SubModuleReplacementDescription( - suffix="q_proj", - target_module=Linear1D_Col, - ), - SubModuleReplacementDescription( - suffix="k_proj", - target_module=Linear1D_Col, - ), - SubModuleReplacementDescription( - suffix="v_proj", - target_module=Linear1D_Col, - ), - SubModuleReplacementDescription( - suffix="out_proj", - target_module=Linear1D_Row, - ), - ]), - } + policy = {} + + if self.shard_config.enable_tensor_parallelism: + policy[OPTDecoder] = ModulePolicyDescription(sub_module_replacement=[ + SubModuleReplacementDescription( + suffix="embed_tokens", + target_module=VocabParallelEmbedding1D, + ) + ]) + policy[OPTDecoderLayer] = ModulePolicyDescription(sub_module_replacement=[ + SubModuleReplacementDescription( + suffix="fc1", + target_module=Linear1D_Col, + ), + SubModuleReplacementDescription( + suffix="fc2", + target_module=Linear1D_Row, + ) + ]) + + policy[OPTAttention] = ModulePolicyDescription(attribute_replacement={ + "embed_dim": self.model.config.hidden_size // self.shard_config.tensor_parallel_size, + "num_heads": self.model.config.num_attention_heads // self.shard_config.tensor_parallel_size + }, + sub_module_replacement=[ + SubModuleReplacementDescription( + suffix="q_proj", + target_module=Linear1D_Col, + ), + SubModuleReplacementDescription( + suffix="k_proj", + target_module=Linear1D_Col, + ), + SubModuleReplacementDescription( + suffix="v_proj", + target_module=Linear1D_Col, + ), + SubModuleReplacementDescription( + suffix="out_proj", + target_module=Linear1D_Row, + ), + ]) # optimization configuration if self.shard_config.enable_fused_normalization: - base_policy[OPTDecoder].sub_module_replacement.append( - SubModuleReplacementDescription(suffix="final_layer_norm", - target_module=FusedLayerNorm, - ignore_if_not_exist=True)) - base_policy[OPTDecoderLayer].sub_module_replacement.extend([ + self.append_or_create_submodule_replacement(description=SubModuleReplacementDescription( + suffix="final_layer_norm", target_module=FusedLayerNorm, ignore_if_not_exist=True), + policy=policy, + target_key=OPTDecoder) + self.append_or_create_submodule_replacement(description=[ SubModuleReplacementDescription(suffix="self_attn_layer_norm", target_module=FusedLayerNorm, ignore_if_not_exist=True), SubModuleReplacementDescription(suffix="final_layer_norm", target_module=FusedLayerNorm, ignore_if_not_exist=True) - ]) + ], + policy=policy, + target_key=OPTDecoderLayer) - return base_policy + return policy def postprocess(self): return self.model @@ -106,15 +107,12 @@ def module_policy(self): from transformers.models.opt.modeling_opt import OPTForCausalLM policy = super().module_policy() - new_item = { - OPTForCausalLM: - ModulePolicyDescription(sub_module_replacement=[ - SubModuleReplacementDescription( - suffix="lm_head", target_module=Linear1D_Col, kwargs=dict(gather_output=True)) - ]) - } - policy.update(new_item) + if self.shard_config.enable_tensor_parallelism: + self.append_or_create_submodule_replacement(description=SubModuleReplacementDescription( + suffix="lm_head", target_module=Linear1D_Col, kwargs=dict(gather_output=True)), + policy=policy, + target_key=OPTForCausalLM) return policy def postprocess(self): diff --git a/colossalai/shardformer/policies/t5.py b/colossalai/shardformer/policies/t5.py index 37fccaabc457..cde59ab77042 100644 --- a/colossalai/shardformer/policies/t5.py +++ b/colossalai/shardformer/policies/t5.py @@ -42,116 +42,126 @@ def module_policy(self): T5Stack, ) - base_policy = { - T5Stack: - ModulePolicyDescription(sub_module_replacement=[ - SubModuleReplacementDescription( - suffix="dropout", - target_module=DropoutForParallelInput, - ), - SubModuleReplacementDescription( - suffix="embed_tokens", - target_module=Embedding1D, - ) - ]), - T5LayerSelfAttention: - ModulePolicyDescription(sub_module_replacement=[ - SubModuleReplacementDescription( - suffix="dropout", - target_module=DropoutForParallelInput, - ), - ]), - T5LayerCrossAttention: - ModulePolicyDescription(sub_module_replacement=[ - SubModuleReplacementDescription( - suffix="dropout", - target_module=DropoutForParallelInput, - ) - ]), - T5Attention: - ModulePolicyDescription(attribute_replacement={ - "d_model": - self.model.config.d_model // self.shard_config.tensor_parallel_size, - "n_heads": - self.model.config.num_heads // self.shard_config.tensor_parallel_size, - "inner_dim": - self.model.config.num_heads * self.model.config.d_kv // self.shard_config.tensor_parallel_size - }, - sub_module_replacement=[ - SubModuleReplacementDescription( - suffix="q", - target_module=Linear1D_Col, - ), - SubModuleReplacementDescription( - suffix="k", - target_module=Linear1D_Col, - ), - SubModuleReplacementDescription( - suffix="v", - target_module=Linear1D_Col, - ), - SubModuleReplacementDescription( - suffix="o", - target_module=Linear1D_Row, - ), - SubModuleReplacementDescription(suffix="relative_attention_bias", - target_module=Embedding1D, - kwargs=dict(gather_output=False), - ignore_if_not_exist=True) - ]), - T5LayerFF: - ModulePolicyDescription(sub_module_replacement=[ - SubModuleReplacementDescription( - suffix="dropout", - target_module=DropoutForParallelInput, - ), - ]), - T5DenseGatedActDense: - ModulePolicyDescription(sub_module_replacement=[ - SubModuleReplacementDescription( - suffix="wi_0", - target_module=Linear1D_Col, - ), - SubModuleReplacementDescription( - suffix="wi_1", - target_module=Linear1D_Row, - ), - SubModuleReplacementDescription( - suffix="wo", target_module=Linear1D_Col, kwargs=dict(gather_output=True)), - SubModuleReplacementDescription( - suffix="dropout", - target_module=DropoutForParallelInput, - ) - ]), - T5DenseActDense: - ModulePolicyDescription(sub_module_replacement=[ - SubModuleReplacementDescription( - suffix="wi", - target_module=Linear1D_Col, - ), - SubModuleReplacementDescription( - suffix="wo", - target_module=Linear1D_Row, - ), - SubModuleReplacementDescription( - suffix="dropout", - target_module=DropoutForParallelInput, - ) - ]) - } + policy = {} + + if self.shard_config.enable_tensor_parallelism: + policy[T5Stack] = ModulePolicyDescription(sub_module_replacement=[ + SubModuleReplacementDescription( + suffix="dropout", + target_module=DropoutForParallelInput, + ), + SubModuleReplacementDescription( + suffix="embed_tokens", + target_module=Embedding1D, + ) + ]) + policy[T5LayerSelfAttention] = ModulePolicyDescription(sub_module_replacement=[ + SubModuleReplacementDescription( + suffix="dropout", + target_module=DropoutForParallelInput, + ), + ]) + policy[T5LayerCrossAttention] = ModulePolicyDescription(sub_module_replacement=[ + SubModuleReplacementDescription( + suffix="dropout", + target_module=DropoutForParallelInput, + ) + ]) + policy[T5Attention] = ModulePolicyDescription(attribute_replacement={ + "d_model": + self.model.config.d_model // self.shard_config.tensor_parallel_size, + "n_heads": + self.model.config.num_heads // self.shard_config.tensor_parallel_size, + "inner_dim": + self.model.config.num_heads * self.model.config.d_kv // self.shard_config.tensor_parallel_size + }, + sub_module_replacement=[ + SubModuleReplacementDescription( + suffix="q", + target_module=Linear1D_Col, + ), + SubModuleReplacementDescription( + suffix="k", + target_module=Linear1D_Col, + ), + SubModuleReplacementDescription( + suffix="v", + target_module=Linear1D_Col, + ), + SubModuleReplacementDescription( + suffix="o", + target_module=Linear1D_Row, + ), + SubModuleReplacementDescription( + suffix="relative_attention_bias", + target_module=Embedding1D, + kwargs=dict(gather_output=False), + ignore_if_not_exist=True) + ]) + policy[T5LayerFF] = ModulePolicyDescription(sub_module_replacement=[ + SubModuleReplacementDescription( + suffix="dropout", + target_module=DropoutForParallelInput, + ), + ]) + policy[T5DenseGatedActDense] = ModulePolicyDescription(sub_module_replacement=[ + SubModuleReplacementDescription( + suffix="wi_0", + target_module=Linear1D_Col, + ), + SubModuleReplacementDescription( + suffix="wi_1", + target_module=Linear1D_Row, + ), + SubModuleReplacementDescription( + suffix="wo", target_module=Linear1D_Col, kwargs=dict(gather_output=True)), + SubModuleReplacementDescription( + suffix="dropout", + target_module=DropoutForParallelInput, + ) + ]) + policy[T5DenseActDense] = ModulePolicyDescription(sub_module_replacement=[ + SubModuleReplacementDescription( + suffix="wi", + target_module=Linear1D_Col, + ), + SubModuleReplacementDescription( + suffix="wo", + target_module=Linear1D_Row, + ), + SubModuleReplacementDescription( + suffix="dropout", + target_module=DropoutForParallelInput, + ) + ]) # optimization configuration if self.shard_config.enable_fused_normalization: - base_policy[T5LayerFF].sub_module_replacement.append( - SubModuleReplacementDescription(suffix="layer_norm", target_module=FusedRMSNorm)) - base_policy[T5LayerSelfAttention].sub_module_replacement.append( - SubModuleReplacementDescription(suffix="layer_norm", target_module=FusedRMSNorm)) - base_policy[T5LayerCrossAttention].sub_module_replacement.append( - SubModuleReplacementDescription(suffix="layer_norm", target_module=FusedRMSNorm)) - base_policy[T5Stack].sub_module_replacement.append( - SubModuleReplacementDescription(suffix="final_layer_norm", target_module=FusedRMSNorm)) - - return base_policy + self.append_or_create_submodule_replacement(description=SubModuleReplacementDescription( + suffix="layer_norm", + target_module=FusedRMSNorm, + ), + policy=policy, + target_key=T5LayerFF) + self.append_or_create_submodule_replacement(description=SubModuleReplacementDescription( + suffix="layer_norm", + target_module=FusedRMSNorm, + ), + policy=policy, + target_key=T5LayerFF) + self.append_or_create_submodule_replacement(description=SubModuleReplacementDescription( + suffix="layer_norm", target_module=FusedRMSNorm), + policy=policy, + target_key=T5LayerSelfAttention) + self.append_or_create_submodule_replacement(description=SubModuleReplacementDescription( + suffix="layer_norm", target_module=FusedRMSNorm), + policy=policy, + target_key=T5LayerCrossAttention) + self.append_or_create_submodule_replacement(description=SubModuleReplacementDescription( + suffix="final_layer_norm", target_module=FusedRMSNorm), + policy=policy, + target_key=T5Stack) + return policy def postprocess(self): binding_map = [["shared", "encoder.embed_tokens"], ["shared", "decoder.embed_tokens"]] @@ -166,14 +176,15 @@ class T5ModelPolicy(T5BasePolicy): def module_policy(self): from transformers import T5Model - base_policy = super().module_policy() - base_policy[T5Model] = ModulePolicyDescription(sub_module_replacement=[ - SubModuleReplacementDescription( + + if self.shard_config.enable_tensor_parallelism: + self.append_or_create_submodule_replacement(description=SubModuleReplacementDescription( suffix="shared", target_module=VocabParallelEmbedding1D, - ) - ]) + ), + policy=base_policy, + target_key=T5Model) return base_policy @@ -183,14 +194,19 @@ def module_policy(self): from transformers import T5ForConditionalGeneration policy = super().module_policy() - policy[T5ForConditionalGeneration] = ModulePolicyDescription(sub_module_replacement=[ - SubModuleReplacementDescription( - suffix="shared", - target_module=VocabParallelEmbedding1D, - ), - SubModuleReplacementDescription( - suffix="lm_head", target_module=Linear1D_Col, kwargs=dict(gather_output=True)) - ]) + + if self.shard_config.enable_tensor_parallelism: + self.append_or_create_submodule_replacement(description=[ + SubModuleReplacementDescription( + suffix="shared", + target_module=VocabParallelEmbedding1D, + ), + SubModuleReplacementDescription(suffix="lm_head", + target_module=Linear1D_Col, + kwargs=dict(gather_output=True)) + ], + policy=policy, + target_key=T5ForConditionalGeneration) return policy def postprocess(self): @@ -212,12 +228,14 @@ def module_policy(self): from transformers import T5EncoderModel base_policy = super().module_policy() - base_policy[T5EncoderModel] = ModulePolicyDescription(sub_module_replacement=[ - SubModuleReplacementDescription( + + if self.shard_config.enable_tensor_parallelism: + self.append_or_create_submodule_replacement(description=SubModuleReplacementDescription( suffix="shared", target_module=VocabParallelEmbedding1D, - ) - ]) + ), + policy=base_policy, + target_key=T5EncoderModel) return base_policy def postprocess(self): diff --git a/colossalai/shardformer/shard/shard_config.py b/colossalai/shardformer/shard/shard_config.py index 0a5aa4cc4bdc..83c08d275df3 100644 --- a/colossalai/shardformer/shard/shard_config.py +++ b/colossalai/shardformer/shard/shard_config.py @@ -13,11 +13,12 @@ class ShardConfig: Args: tensor_parallel_process_group (int): The process group for tensor parallelism, defaults to None, which is the global process group. + enable_tensor_parallelism (bool): Whether to turn on tensor parallelism, default is True. enable_fused_normalization (bool): Whether to use fused layernorm, default is False. - enable_tensor_parallelism (bool): Whether to use tensor parallelism, default is True. enable_all_optimization (bool): Whether to turn on all optimization, default is False. """ tensor_parallel_process_group: ProcessGroup = None + enable_tensor_parallelism: bool = True enable_fused_normalization: bool = False enable_all_optimization: bool = False @@ -33,8 +34,11 @@ def tensor_parallel_size(self): return self._tensor_parallel_size def __post_init__(self): - # get the parallel size - self._tensor_parallel_size = dist.get_world_size(self.tensor_parallel_process_group) + if not self.enable_tensor_parallelism: + self._tensor_parallel_size = 1 + else: + # get the parallel size + self._tensor_parallel_size = dist.get_world_size(self.tensor_parallel_process_group) # turn on all optimization if all_optimization is set to True if self.enable_all_optimization: diff --git a/tests/test_shardformer/test_model/_utils.py b/tests/test_shardformer/test_model/_utils.py index aa1424af3289..d83d9ecd39e0 100644 --- a/tests/test_shardformer/test_model/_utils.py +++ b/tests/test_shardformer/test_model/_utils.py @@ -3,12 +3,13 @@ from colossalai.shardformer import ShardConfig, ShardFormer -def build_model(model_fn): +def build_model(model_fn, enable_fused_normalization=True, enable_tensor_parallelism=True): # create new model org_model = model_fn().cuda() # shard model - shard_config = ShardConfig(enable_fused_normalization=True) + shard_config = ShardConfig(enable_fused_normalization=enable_fused_normalization, + enable_tensor_parallelism=enable_tensor_parallelism) model_copy = copy.deepcopy(org_model) shard_former = ShardFormer(shard_config=shard_config) sharded_model = shard_former.optimize(model_copy).cuda() diff --git a/tests/test_shardformer/test_model/test_shard_bert.py b/tests/test_shardformer/test_model/test_shard_bert.py index 87c4ef65bf1a..1afedb7079ea 100644 --- a/tests/test_shardformer/test_model/test_shard_bert.py +++ b/tests/test_shardformer/test_model/test_shard_bert.py @@ -3,7 +3,14 @@ import colossalai from colossalai.logging import disable_existing_loggers -from colossalai.testing import assert_hf_output_close, clear_cache_before_run, rerun_if_address_is_in_use, spawn +from colossalai.tensor.d_tensor.api import is_customized_distributed_tensor, is_distributed_tensor +from colossalai.testing import ( + assert_hf_output_close, + clear_cache_before_run, + parameterize, + rerun_if_address_is_in_use, + spawn, +) from tests.kit.model_zoo import model_zoo from tests.test_shardformer.test_model._utils import build_model, run_forward @@ -33,36 +40,50 @@ def check_forward_backward(org_model, sharded_model, data_gen_fn, output_transfo # compare self attention grad org_grad = bert.encoder.layer[0].attention.self.query.weight.grad shard_grad = sharded_bert.encoder.layer[0].attention.self.query.weight.grad + shard_weight = sharded_bert.encoder.layer[0].attention.self.query.weight - shard_grad_list = [torch.zeros([*shard_grad.shape]).to('cuda') for _ in range(2)] - shard_grad = torch.distributed.all_gather(shard_grad_list, shard_grad) - all_shard_grad = torch.cat(shard_grad_list, dim=0) + if is_distributed_tensor(shard_weight) or is_customized_distributed_tensor(shard_weight): + shard_grad_list = [torch.zeros([*shard_grad.shape]).to('cuda') for _ in range(2)] + shard_grad = torch.distributed.all_gather(shard_grad_list, shard_grad) + all_shard_grad = torch.cat(shard_grad_list, dim=0) + else: + all_shard_grad = shard_grad assert torch.allclose(org_grad, all_shard_grad, atol=1e-5), f"shard model grad is not equal to orgin model grad\n{org_grad}\n{all_shard_grad}" # compare embedding grad org_grad = bert.embeddings.word_embeddings.weight.grad shard_grad = sharded_bert.embeddings.word_embeddings.weight.grad + shard_weight = sharded_bert.embeddings.word_embeddings.weight + + if is_distributed_tensor(shard_weight) or is_customized_distributed_tensor(shard_weight): + shard_grad_list = [torch.zeros([*shard_grad.shape]).to('cuda') for _ in range(2)] + shard_grad = torch.distributed.all_gather(shard_grad_list, shard_grad) + all_shard_grad = torch.cat(shard_grad_list, dim=0) + else: + all_shard_grad = shard_grad - shard_grad_list = [torch.zeros([*shard_grad.shape]).to('cuda') for _ in range(2)] - shard_grad = torch.distributed.all_gather(shard_grad_list, shard_grad) - all_shard_grad = torch.cat(shard_grad_list, dim=0) assert torch.allclose(org_grad, all_shard_grad, atol=1e-5), f"shard model grad is not equal to orgin model grad\n{org_grad}\n{all_shard_grad}" -def check_bert(rank, world_size, port): - disable_existing_loggers() - colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') - +@parameterize('enable_fused_normalization', [True, False]) +@parameterize('enable_tensor_parallelism', [True, False]) +def run_bert_test(enable_fused_normalization, enable_tensor_parallelism): sub_model_zoo = model_zoo.get_sub_registry('transformers_bert') for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items(): - org_model, sharded_model = build_model(model_fn) + org_model, sharded_model = build_model(model_fn, enable_fused_normalization, enable_tensor_parallelism) check_forward_backward(org_model, sharded_model, data_gen_fn, output_transform_fn, loss_fn) torch.cuda.empty_cache() +def check_bert(rank, world_size, port): + disable_existing_loggers() + colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') + run_bert_test() + + @pytest.mark.dist @rerun_if_address_is_in_use() @clear_cache_before_run() diff --git a/tests/test_shardformer/test_model/test_shard_bloom.py b/tests/test_shardformer/test_model/test_shard_bloom.py index 70d902a04517..a3389652269c 100644 --- a/tests/test_shardformer/test_model/test_shard_bloom.py +++ b/tests/test_shardformer/test_model/test_shard_bloom.py @@ -3,7 +3,14 @@ import colossalai from colossalai.logging import disable_existing_loggers -from colossalai.testing import assert_hf_output_close, clear_cache_before_run, rerun_if_address_is_in_use, spawn +from colossalai.tensor.d_tensor.api import is_customized_distributed_tensor, is_distributed_tensor +from colossalai.testing import ( + assert_hf_output_close, + clear_cache_before_run, + parameterize, + rerun_if_address_is_in_use, + spawn, +) from tests.kit.model_zoo import model_zoo from tests.test_shardformer.test_model._utils import build_model, run_forward @@ -32,10 +39,14 @@ def check_forward_backward(org_model, sharded_model, data_gen_fn, output_transfo # check attention grad org_grad = bloom.h[0].self_attention.query_key_value.weight.grad shard_grad = sharded_bloom.h[0].self_attention.query_key_value.weight.grad + shard_weight = sharded_bloom.h[0].self_attention.query_key_value.weight - shard_grad_list = [torch.zeros([*shard_grad.shape]).to('cuda') for _ in range(2)] - torch.distributed.all_gather(shard_grad_list, shard_grad) - all_shard_grad = torch.cat(shard_grad_list, dim=0) + if is_distributed_tensor(shard_weight) or is_customized_distributed_tensor(shard_weight): + shard_grad_list = [torch.zeros([*shard_grad.shape]).to('cuda') for _ in range(2)] + torch.distributed.all_gather(shard_grad_list, shard_grad) + all_shard_grad = torch.cat(shard_grad_list, dim=0) + else: + all_shard_grad = shard_grad assert torch.allclose(org_grad, all_shard_grad, atol=1e-5), f"shard model grad is not equal to orgin model grad\n{org_grad}\n{all_shard_grad}" @@ -43,27 +54,35 @@ def check_forward_backward(org_model, sharded_model, data_gen_fn, output_transfo # check embedding weights org_grad = bloom.word_embeddings.weight.grad shard_grad = sharded_bloom.word_embeddings.weight.grad + shard_weight = sharded_bloom.word_embeddings.weight - shard_grad_list = [torch.zeros([*shard_grad.shape]).to('cuda') for _ in range(2)] - torch.distributed.all_gather(shard_grad_list, shard_grad) - all_shard_grad = torch.cat(shard_grad_list, dim=0) + if is_distributed_tensor(shard_weight) or is_customized_distributed_tensor(shard_weight): + shard_grad_list = [torch.zeros([*shard_grad.shape]).to('cuda') for _ in range(2)] + torch.distributed.all_gather(shard_grad_list, shard_grad) + all_shard_grad = torch.cat(shard_grad_list, dim=0) + else: + all_shard_grad = shard_grad assert torch.allclose(org_grad, all_shard_grad, atol=1e-5), f"shard model grad is not equal to orgin model grad\n{org_grad}\n{all_shard_grad}" -def check_bloom(rank, world_size, port): - disable_existing_loggers() - colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') - +@parameterize('enable_fused_normalization', [True, False]) +@parameterize('enable_tensor_parallelism', [True, False]) +def run_bloom_test(enable_fused_normalization, enable_tensor_parallelism): sub_model_zoo = model_zoo.get_sub_registry('transformers_bloom') for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items(): - org_model, sharded_model = build_model(model_fn) + org_model, sharded_model = build_model(model_fn, enable_fused_normalization, enable_tensor_parallelism) check_forward_backward(org_model, sharded_model, data_gen_fn, output_transform_fn, loss_fn) - torch.cuda.empty_cache() +def check_bloom(rank, world_size, port): + disable_existing_loggers() + colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') + run_bloom_test() + + @pytest.mark.dist @rerun_if_address_is_in_use() @clear_cache_before_run() diff --git a/tests/test_shardformer/test_model/test_shard_gpt2.py b/tests/test_shardformer/test_model/test_shard_gpt2.py index a4edc14bdbc3..ee7737687d99 100644 --- a/tests/test_shardformer/test_model/test_shard_gpt2.py +++ b/tests/test_shardformer/test_model/test_shard_gpt2.py @@ -3,7 +3,14 @@ import colossalai from colossalai.logging import disable_existing_loggers -from colossalai.testing import assert_hf_output_close, clear_cache_before_run, rerun_if_address_is_in_use, spawn +from colossalai.tensor.d_tensor.api import is_customized_distributed_tensor, is_distributed_tensor +from colossalai.testing import ( + assert_hf_output_close, + clear_cache_before_run, + parameterize, + rerun_if_address_is_in_use, + spawn, +) from tests.kit.model_zoo import model_zoo from tests.test_shardformer.test_model._utils import build_model, run_forward @@ -32,11 +39,14 @@ def check_forward_backward(org_model, sharded_model, data_gen_fn, output_transfo # check mlp grad org_grad = org_model.h[0].mlp.c_fc.weight.grad shard_grad = sharded_model.h[0].mlp.c_fc.weight.grad + shard_weight = sharded_model.h[0].mlp.c_fc.weight - shard_grad_list = [torch.zeros([*shard_grad.shape]).to('cuda') for _ in range(2)] - shard_grad = torch.distributed.all_gather(shard_grad_list, shard_grad) - all_shard_grad = torch.cat(shard_grad_list, dim=1) - + if is_distributed_tensor(shard_weight) or is_customized_distributed_tensor(shard_weight): + shard_grad_list = [torch.zeros([*shard_grad.shape]).to('cuda') for _ in range(2)] + shard_grad = torch.distributed.all_gather(shard_grad_list, shard_grad) + all_shard_grad = torch.cat(shard_grad_list, dim=1) + else: + all_shard_grad = shard_grad assert torch.allclose( org_grad, all_shard_grad, atol=1e-5), f"shard model grad is not equal to origin model grad\n{org_grad}\n{all_shard_grad}" @@ -44,27 +54,35 @@ def check_forward_backward(org_model, sharded_model, data_gen_fn, output_transfo # check embedding weights org_grad = org_model.wte.weight.grad shard_grad = sharded_model.wte.weight.grad - shard_grad_list = [torch.zeros([*shard_grad.shape]).to('cuda') for _ in range(2)] - shard_grad = torch.distributed.all_gather(shard_grad_list, shard_grad) - all_shard_grad = torch.cat(shard_grad_list, dim=0) + shard_weight = sharded_model.wte.weight + if is_distributed_tensor(shard_weight) or is_customized_distributed_tensor(shard_weight): + shard_grad_list = [torch.zeros([*shard_grad.shape]).to('cuda') for _ in range(2)] + shard_grad = torch.distributed.all_gather(shard_grad_list, shard_grad) + all_shard_grad = torch.cat(shard_grad_list, dim=0) + else: + all_shard_grad = shard_grad assert torch.allclose( org_grad, all_shard_grad, atol=1e-5), f"shard model grad is not equal to origin model grad\n{org_grad}\n{all_shard_grad}" -def check_gpt2(rank, world_size, port): - disable_existing_loggers() - colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') - +@parameterize('enable_fused_normalization', [True, False]) +@parameterize('enable_tensor_parallelism', [True, False]) +def run_gpt2_test(enable_fused_normalization, enable_tensor_parallelism): sub_model_zoo = model_zoo.get_sub_registry('transformers_gpt') for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items(): - org_model, sharded_model = build_model(model_fn) + org_model, sharded_model = build_model(model_fn, enable_fused_normalization, enable_tensor_parallelism) check_forward_backward(org_model, sharded_model, data_gen_fn, output_transform_fn, loss_fn) - torch.cuda.empty_cache() +def check_gpt2(rank, world_size, port): + disable_existing_loggers() + colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') + run_gpt2_test() + + @pytest.mark.dist @rerun_if_address_is_in_use() @clear_cache_before_run() diff --git a/tests/test_shardformer/test_model/test_shard_llama.py b/tests/test_shardformer/test_model/test_shard_llama.py index a98743a6143a..74b5fdd18af8 100644 --- a/tests/test_shardformer/test_model/test_shard_llama.py +++ b/tests/test_shardformer/test_model/test_shard_llama.py @@ -5,7 +5,14 @@ import colossalai from colossalai.logging import disable_existing_loggers -from colossalai.testing import assert_hf_output_close, clear_cache_before_run, rerun_if_address_is_in_use, spawn +from colossalai.tensor.d_tensor.api import is_customized_distributed_tensor, is_distributed_tensor +from colossalai.testing import ( + assert_hf_output_close, + clear_cache_before_run, + parameterize, + rerun_if_address_is_in_use, + spawn, +) from tests.kit.model_zoo import model_zoo from tests.test_shardformer.test_model._utils import build_model, run_forward @@ -37,35 +44,48 @@ def check_forward_backward(org_model, sharded_model, data_gen_fn, output_transfo # check attention grad org_grad = llama_model.layers[0].self_attn.q_proj.weight.grad shard_grad = shard_llama_model.layers[0].self_attn.q_proj.weight.grad - shard_grad_list = [torch.zeros([*shard_grad.shape]).to('cuda') for _ in range(4)] - shard_grad = torch.distributed.all_gather(shard_grad_list, shard_grad) - all_shard_grad = torch.cat(shard_grad_list, dim=0) + shard_weight = shard_llama_model.layers[0].self_attn.q_proj.weight + + if is_distributed_tensor(shard_weight) or is_customized_distributed_tensor(shard_weight): + shard_grad_list = [torch.zeros([*shard_grad.shape]).to('cuda') for _ in range(4)] + shard_grad = torch.distributed.all_gather(shard_grad_list, shard_grad) + all_shard_grad = torch.cat(shard_grad_list, dim=0) + else: + all_shard_grad = shard_grad assert torch.allclose(org_grad, all_shard_grad, atol=1e-5), f"shard model grad is not equal to orgin model grad\n{org_grad}\n{shard_grad}" # check embedding grad org_grad = llama_model.embed_tokens.weight.grad shard_grad = shard_llama_model.embed_tokens.weight.grad - shard_grad_list = [torch.zeros([*shard_grad.shape]).to('cuda') for _ in range(4)] - shard_grad = torch.distributed.all_gather(shard_grad_list, shard_grad) - all_shard_grad = torch.cat(shard_grad_list, dim=0) + shard_weight = shard_llama_model.embed_tokens.weight + + if is_distributed_tensor(shard_weight) or is_customized_distributed_tensor(shard_weight): + shard_grad_list = [torch.zeros([*shard_grad.shape]).to('cuda') for _ in range(4)] + shard_grad = torch.distributed.all_gather(shard_grad_list, shard_grad) + all_shard_grad = torch.cat(shard_grad_list, dim=0) + else: + all_shard_grad = shard_grad assert torch.allclose(org_grad, all_shard_grad, atol=1e-5), f"shard model grad is not equal to orgin model grad\n{org_grad}\n{shard_grad}" -def check_llama(rank, world_size, port): - disable_existing_loggers() - colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') - +@parameterize('enable_fused_normalization', [True, False]) +@parameterize('enable_tensor_parallelism', [True, False]) +def run_gpt2_llama(enable_fused_normalization, enable_tensor_parallelism): sub_model_zoo = model_zoo.get_sub_registry('transformers_llama') - for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items(): - org_model, sharded_model = build_model(model_fn) + org_model, sharded_model = build_model(model_fn, enable_fused_normalization, enable_tensor_parallelism) check_forward_backward(org_model, sharded_model, data_gen_fn, output_transform_fn, loss_fn) - torch.cuda.empty_cache() +def check_llama(rank, world_size, port): + disable_existing_loggers() + colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') + run_gpt2_llama() + + @pytest.mark.dist @rerun_if_address_is_in_use() @clear_cache_before_run() diff --git a/tests/test_shardformer/test_model/test_shard_opt.py b/tests/test_shardformer/test_model/test_shard_opt.py index 29cf2f6beed8..25bccb13b1a8 100644 --- a/tests/test_shardformer/test_model/test_shard_opt.py +++ b/tests/test_shardformer/test_model/test_shard_opt.py @@ -6,10 +6,11 @@ import colossalai from colossalai.logging import disable_existing_loggers +from colossalai.tensor.d_tensor.api import is_customized_distributed_tensor, is_distributed_tensor from colossalai.testing import ( assert_hf_output_close, - check_state_dict_equal, clear_cache_before_run, + parameterize, rerun_if_address_is_in_use, spawn, ) @@ -42,34 +43,48 @@ def check_forward_backward(org_model, sharded_model, data_gen_fn, output_transfo # check attention grad org_grad = opt_model.decoder.layers[0].self_attn.q_proj.weight.grad shard_grad = shard_opt_model.decoder.layers[0].self_attn.q_proj.weight.grad - shard_grad_list = [torch.zeros([*shard_grad.shape]).to('cuda') for _ in range(4)] - torch.distributed.all_gather(shard_grad_list, shard_grad) - all_shard_grad = torch.cat(shard_grad_list, dim=0) + shard_weight = shard_opt_model.decoder.layers[0].self_attn.q_proj.weight + + if is_distributed_tensor(shard_weight) or is_customized_distributed_tensor(shard_weight): + shard_grad_list = [torch.zeros([*shard_grad.shape]).to('cuda') for _ in range(4)] + torch.distributed.all_gather(shard_grad_list, shard_grad) + all_shard_grad = torch.cat(shard_grad_list, dim=0) + else: + all_shard_grad = shard_grad assert torch.allclose(org_grad, all_shard_grad, atol=1e-5), f"shard model grad is not equal to orgin model grad\n{org_grad}\n{all_shard_grad}" # check embedding grad org_grad = opt_model.decoder.embed_tokens.weight.grad shard_grad = shard_opt_model.decoder.embed_tokens.weight.grad - shard_grad_list = [torch.zeros([*shard_grad.shape]).to('cuda') for _ in range(4)] - torch.distributed.all_gather(shard_grad_list, shard_grad) - all_shard_grad = torch.cat(shard_grad_list, dim=0) + shard_weight = shard_opt_model.decoder.embed_tokens.weight + + if is_distributed_tensor(shard_weight) or is_customized_distributed_tensor(shard_weight): + shard_grad_list = [torch.zeros([*shard_grad.shape]).to('cuda') for _ in range(4)] + torch.distributed.all_gather(shard_grad_list, shard_grad) + all_shard_grad = torch.cat(shard_grad_list, dim=0) + else: + all_shard_grad = shard_grad assert torch.allclose(org_grad, all_shard_grad, atol=1e-5), f"shard model grad is not equal to orgin model grad\n{org_grad}\n{all_shard_grad}" -def check_OPTModel(rank, world_size, port): - disable_existing_loggers() - colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') - +@parameterize('enable_fused_normalization', [True, False]) +@parameterize('enable_tensor_parallelism', [True, False]) +def run_t5_test(enable_fused_normalization, enable_tensor_parallelism): sub_model_zoo = model_zoo.get_sub_registry('transformers_opt') for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items(): - org_model, sharded_model = build_model(model_fn) + org_model, sharded_model = build_model(model_fn, enable_fused_normalization, enable_tensor_parallelism) check_forward_backward(org_model, sharded_model, data_gen_fn, output_transform_fn, loss_fn) - torch.cuda.empty_cache() +def check_OPTModel(rank, world_size, port): + disable_existing_loggers() + colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') + run_t5_test() + + @pytest.mark.dist @rerun_if_address_is_in_use() @clear_cache_before_run() diff --git a/tests/test_shardformer/test_model/test_shard_t5.py b/tests/test_shardformer/test_model/test_shard_t5.py index 91430bce918f..0762dc09e5af 100644 --- a/tests/test_shardformer/test_model/test_shard_t5.py +++ b/tests/test_shardformer/test_model/test_shard_t5.py @@ -5,7 +5,14 @@ import colossalai from colossalai.logging import disable_existing_loggers -from colossalai.testing import assert_hf_output_close, clear_cache_before_run, rerun_if_address_is_in_use, spawn +from colossalai.tensor.d_tensor.api import is_customized_distributed_tensor, is_distributed_tensor +from colossalai.testing import ( + assert_hf_output_close, + clear_cache_before_run, + parameterize, + rerun_if_address_is_in_use, + spawn, +) from tests.kit.model_zoo import model_zoo from tests.test_shardformer.test_model._utils import build_model, run_forward @@ -27,19 +34,28 @@ def check_forward_backward(org_model, sharded_model, data_gen_fn, output_transfo # check attention grad org_grad = org_model.encoder.block[0].layer[0].SelfAttention.q.weight.grad shard_grad = sharded_model.encoder.block[0].layer[0].SelfAttention.q.weight.grad - - shard_grad_list = [torch.zeros([*shard_grad.shape]).to('cuda') for _ in range(2)] - shard_grad = torch.distributed.all_gather(shard_grad_list, shard_grad) - all_shard_grad = torch.cat(shard_grad_list, dim=0) + shard_weight = sharded_model.encoder.block[0].layer[0].SelfAttention.q.weight + + if is_distributed_tensor(shard_weight) or is_customized_distributed_tensor(shard_weight): + shard_grad_list = [torch.zeros([*shard_grad.shape]).to('cuda') for _ in range(2)] + shard_grad = torch.distributed.all_gather(shard_grad_list, shard_grad) + all_shard_grad = torch.cat(shard_grad_list, dim=0) + else: + all_shard_grad = shard_grad assert torch.allclose(org_grad, all_shard_grad, atol=1e-5), f"shard model grad is not equal to orgin model grad\n{org_grad}\n{shard_grad}" # check self attention embed org_grad = org_model.encoder.block[0].layer[0].SelfAttention.relative_attention_bias.weight.grad shard_grad = sharded_model.encoder.block[0].layer[0].SelfAttention.relative_attention_bias.weight.grad - shard_grad_list = [torch.zeros([*shard_grad.shape]).to('cuda') for _ in range(2)] - torch.distributed.all_gather(shard_grad_list, shard_grad) - all_shard_grad = torch.cat(shard_grad_list, dim=1) + shard_weight = sharded_model.encoder.block[0].layer[0].SelfAttention.relative_attention_bias.weight + + if is_distributed_tensor(shard_weight) or is_customized_distributed_tensor(shard_weight): + shard_grad_list = [torch.zeros([*shard_grad.shape]).to('cuda') for _ in range(2)] + torch.distributed.all_gather(shard_grad_list, shard_grad) + all_shard_grad = torch.cat(shard_grad_list, dim=1) + else: + all_shard_grad = shard_grad assert torch.allclose(org_grad, all_shard_grad, atol=1e-5), f"shard model grad is not equal to orgin model grad\n{org_grad}\n{all_shard_grad}" @@ -52,25 +68,34 @@ def check_forward_backward(org_model, sharded_model, data_gen_fn, output_transfo assert sharded_model.shared.weight.data.data_ptr() == sharded_model.lm_head.weight.data.data_ptr() shard_grad = sharded_model.shared.weight.grad - shard_grad_list = [torch.zeros([*shard_grad.shape]).to('cuda') for _ in range(2)] - torch.distributed.all_gather(shard_grad_list, shard_grad) - all_shard_grad = torch.cat(shard_grad_list, dim=0) + shard_weight = sharded_model.shared.weight + + if is_distributed_tensor(shard_weight) or is_customized_distributed_tensor(shard_weight): + shard_grad_list = [torch.zeros([*shard_grad.shape]).to('cuda') for _ in range(2)] + torch.distributed.all_gather(shard_grad_list, shard_grad) + all_shard_grad = torch.cat(shard_grad_list, dim=0) + else: + all_shard_grad = shard_grad assert torch.allclose(org_grad, all_shard_grad, atol=1e-5), f"shard model grad is not equal to orgin model grad\n{org_grad}\n{all_shard_grad}" -def check_t5(rank, world_size, port): - disable_existing_loggers() - colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') - +@parameterize('enable_fused_normalization', [True, False]) +@parameterize('enable_tensor_parallelism', [True, False]) +def run_t5_test(enable_fused_normalization, enable_tensor_parallelism): sub_model_zoo = model_zoo.get_sub_registry('transformers_t5') - for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items(): - org_model, sharded_model = build_model(model_fn) + org_model, sharded_model = build_model(model_fn, enable_fused_normalization, enable_tensor_parallelism) check_forward_backward(org_model, sharded_model, data_gen_fn, output_transform_fn, loss_fn) torch.cuda.empty_cache() +def check_t5(rank, world_size, port): + disable_existing_loggers() + colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') + run_t5_test() + + @pytest.mark.dist @rerun_if_address_is_in_use() @clear_cache_before_run() From d1db0439a0bcd79b218548af057e8d6c007ef576 Mon Sep 17 00:00:00 2001 From: Frank Lee Date: Tue, 4 Jul 2023 10:28:31 +0800 Subject: [PATCH 48/49] [shardformer] added development protocol for standardization (#4149) --- colossalai/shardformer/README.md | 13 ++++ colossalai/shardformer/model/modeling_bert.py | 67 ------------------ .../{model => modeling}/__init__.py | 0 colossalai/shardformer/modeling/bloom.py | 69 +++++++++++++++++++ colossalai/shardformer/policies/bloom.py | 64 ++--------------- 5 files changed, 86 insertions(+), 127 deletions(-) delete mode 100644 colossalai/shardformer/model/modeling_bert.py rename colossalai/shardformer/{model => modeling}/__init__.py (100%) create mode 100644 colossalai/shardformer/modeling/bloom.py diff --git a/colossalai/shardformer/README.md b/colossalai/shardformer/README.md index f5d8bb35d91d..fca401562be6 100644 --- a/colossalai/shardformer/README.md +++ b/colossalai/shardformer/README.md @@ -321,6 +321,19 @@ This section serves as the guideline for writing new policies and register them You can create a new file in the `colossalai/shardformer/policies` folder and name the file with the model name. You can implement your policy in this file. You should not import the any model zoo library at the header section of the file because we do not want to import the library when we do not use the policy. Libraries such as `transformers` should be imported only in the function body when needed. +Please follow the following protocols when writing your policy: + +- You have to make a clear decision what you want to replace exactly in the original PyTorch module + - Use `ModulePolicyDescription.attribute_replacement` to replace the module attributes + - Use `ModulePolicyDescription.param_replacement` to replace the module parameters + - Use `ModulePolicyDescription.sub_module_replacement` to replace the submodules completely. The target module should implement the `from_native_module` for the . + - Use `ModulePolicyDescription.method_replacement` to replace the module methods. **These replacement methods should be put in the `shardformer/modeling/.py`**. +- You can implement the `ParallelModule` for primitive modules in the `shardformer/layer/.py` file. Primitive modules refer to modules which are not composed of other modules. For example, the `torch.nn.Linear` module is a primitive module while modules such as `BertEncoder` module in the `transformers` library is a composite module. Primitive modules do not nested inner `nn.Module` members. For composite modules, you should consider using `ModulePolicyDescription` to implement your replacement. +- `ParallelModule` is meant to be used in two ways: `ParallelModule.from_native_module` to convert native PyTorch module to the `ParallelModule` and `ParallelModule(...)` to instantiate the module directly just like a normal PyTorch module. `ParallelModule` should be only implemented for modules whose weights are sharded. If you want to make your module compatible with the `ModulePolicyDescription.sub_module_replacement` and there is no weight sharding in your module, you can just implement the `from_native_module` method without inheriting the `ParallelModule` like `colossalai/shardformer/layer/normalization.py`. +- **Do not import any file in the `colossalai/shardformer/policies` and `colossalai/shardformer/modeling` to avoid unwanted import error**. For example, a file in these folders accidentally imports `transformers` library at the top of the file, then the user will have to install `transformers` library even if they do not use this file. Any file in the `modeling` folder should be only imported by the policy file. A policy implementation should be only imported dynamically via the autopolicy or manually via the `ShardFormer` module. +- Try to keep your import statement on third-party libraries such as `transformers` within the function body instead of the header section of the file. This is because we do not want to import the library when we do not use the policy. + + - Step 2. Register your policy to the autopolicy Next, you need to register your policy in the `colossalai/shardformer/policies/autopolicy.py` file. diff --git a/colossalai/shardformer/model/modeling_bert.py b/colossalai/shardformer/model/modeling_bert.py deleted file mode 100644 index bd07ab80c00d..000000000000 --- a/colossalai/shardformer/model/modeling_bert.py +++ /dev/null @@ -1,67 +0,0 @@ -from typing import Any, Dict, List, Type - -import torch -import torch.nn as nn -from torch.nn import CrossEntropyLoss -from transformers import BertForMaskedLM -from transformers.models.bert.modeling_bert import MaskedLMOutput - -from ..layer.dist_crossentropy import applyDistCrossEntropy - - -class BertForMaskedLM_(BertForMaskedLM): - - def forward( - self, - input_ids=None, - attention_mask=None, - token_type_ids=None, - position_ids=None, - head_mask=None, - inputs_embeds=None, - encoder_hidden_states=None, - encoder_attention_mask=None, - labels=None, - output_attentions=None, - output_hidden_states=None, - return_dict=None, - **kwargs, - ): - # print("[Inject OK] Injected forward method") - return_dict = return_dict if return_dict is not None else self.config.use_return_dict - - outputs = self.bert( - input_ids, - attention_mask=attention_mask, - token_type_ids=token_type_ids, - position_ids=position_ids, - head_mask=head_mask, - inputs_embeds=inputs_embeds, - encoder_hidden_states=encoder_hidden_states, - encoder_attention_mask=encoder_attention_mask, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - ) - - sequence_output = outputs[0] - prediction_scores = self.cls(sequence_output) - - masked_lm_loss = None - - if labels is not None: - masked_lm_loss = applyDistCrossEntropy(prediction_scores, labels) - # if labels is not None: - # loss_fct = CrossEntropyLoss() # -100 index = padding token - # masked_lm_loss = loss_fct(prediction_scores.view(-1, self.config.vocab_size), labels.view(-1)) - - if not return_dict: - output = (prediction_scores,) + outputs[2:] - return ((masked_lm_loss,) + output) if masked_lm_loss is not None else output - - return MaskedLMOutput( - loss=masked_lm_loss, - logits=prediction_scores, - hidden_states=outputs.hidden_states, - attentions=outputs.attentions, - ) diff --git a/colossalai/shardformer/model/__init__.py b/colossalai/shardformer/modeling/__init__.py similarity index 100% rename from colossalai/shardformer/model/__init__.py rename to colossalai/shardformer/modeling/__init__.py diff --git a/colossalai/shardformer/modeling/bloom.py b/colossalai/shardformer/modeling/bloom.py new file mode 100644 index 000000000000..a3d774ff2abb --- /dev/null +++ b/colossalai/shardformer/modeling/bloom.py @@ -0,0 +1,69 @@ +import torch +import torch.distributed as dist +from torch.distributed import ProcessGroup + + +def build_bloom_alibi_tensor_fn(process_group: ProcessGroup) -> torch.Tensor: + + def build_bloom_alibi_tensor(self, attention_mask: torch.Tensor, num_heads: int, + dtype: torch.dtype) -> torch.Tensor: + """ + Link to paper: https://arxiv.org/abs/2108.12409 Alibi tensor is not causal as the original paper mentions, it + relies on a translation invariance of softmax for quick implementation: with l being a tensor, and a fixed value + `softmax(l+a) = softmax(l)`. Based on + https://github.com/ofirpress/attention_with_linear_biases/blob/a35aaca144e0eb6b789dfcb46784c4b8e31b7983/fairseq/models/transformer.py#L742 + TODO @thomasw21 this doesn't work as nicely due to the masking strategy, and so masking varies slightly. + + Args: + Returns tensor shaped (batch_size * num_heads, 1, max_seq_len) + attention_mask (`torch.Tensor`): + Token-wise attention mask, this should be of shape (batch_size, max_seq_len). + num_heads (`int`, *required*): + number of heads + dtype (`torch.dtype`, *optional*, default=`torch.bfloat16`): + dtype of the output tensor + """ + import math + + if dist.is_initialized(): + world_size = dist.get_world_size(process_group) + num_heads = num_heads * world_size + + batch_size, seq_length = attention_mask.shape + closest_power_of_2 = 2**math.floor(math.log2(num_heads)) + base = torch.tensor(2**(-(2**-(math.log2(closest_power_of_2) - 3))), + device=attention_mask.device, + dtype=torch.float32) + powers = torch.arange(1, 1 + closest_power_of_2, device=attention_mask.device, dtype=torch.int32) + slopes = torch.pow(base, powers) + + if closest_power_of_2 != num_heads: + extra_base = torch.tensor(2**(-(2**-(math.log2(2 * closest_power_of_2) - 3))), + device=attention_mask.device, + dtype=torch.float32) + num_remaining_heads = min(closest_power_of_2, num_heads - closest_power_of_2) + extra_powers = torch.arange(1, + 1 + 2 * num_remaining_heads, + 2, + device=attention_mask.device, + dtype=torch.int32) + slopes = torch.cat([slopes, torch.pow(extra_base, extra_powers)], dim=0) + + # Note: alibi will added to the attention bias that will be applied to the query, key product of attention + # => therefore alibi will have to be of shape (batch_size, num_heads, query_length, key_length) + # => here we set (batch_size=1, num_heads=num_heads, query_length=1, key_length=max_length) + # => the query_length dimension will then be broadcasted correctly + # This is more or less identical to T5's relative position bias: + # https://github.com/huggingface/transformers/blob/f681437203baa7671de3174b0fa583c349d9d5e1/src/transformers/models/t5/modeling_t5.py#L527 + arange_tensor = ((attention_mask.cumsum(dim=-1) - 1) * attention_mask)[:, None, :] + alibi = slopes[..., None] * arange_tensor + if dist.is_initialized(): + num_heads_per_rank = int(num_heads / dist.get_world_size(process_group)) + offset = dist.get_rank(process_group) * num_heads_per_rank + alibi = alibi.view(batch_size, num_heads, 1, seq_length) + alibi = alibi[:, offset:num_heads_per_rank + offset, :, :] + return alibi.reshape(batch_size * num_heads_per_rank, 1, seq_length).to(dtype) + else: + return alibi.reshape(batch_size * num_heads, 1, seq_length).to(dtype) + + return build_bloom_alibi_tensor diff --git a/colossalai/shardformer/policies/bloom.py b/colossalai/shardformer/policies/bloom.py index 030774a919d7..a0b5340f72bc 100644 --- a/colossalai/shardformer/policies/bloom.py +++ b/colossalai/shardformer/policies/bloom.py @@ -1,70 +1,12 @@ -import torch -import torch.distributed as dist import torch.nn as nn import colossalai.shardformer.layer as col_nn from .._utils import getattr_, setattr_ +from ..modeling.bloom import build_bloom_alibi_tensor_fn from .basepolicy import ModulePolicyDescription, Policy, SubModuleReplacementDescription -def build_bloom_alibi_tensor(self, attention_mask: torch.Tensor, num_heads: int, dtype: torch.dtype) -> torch.Tensor: - """ - Link to paper: https://arxiv.org/abs/2108.12409 Alibi tensor is not causal as the original paper mentions, it - relies on a translation invariance of softmax for quick implementation: with l being a tensor, and a fixed value - `softmax(l+a) = softmax(l)`. Based on - https://github.com/ofirpress/attention_with_linear_biases/blob/a35aaca144e0eb6b789dfcb46784c4b8e31b7983/fairseq/models/transformer.py#L742 - TODO @thomasw21 this doesn't work as nicely due to the masking strategy, and so masking varies slightly. - - Args: - Returns tensor shaped (batch_size * num_heads, 1, max_seq_len) - attention_mask (`torch.Tensor`): - Token-wise attention mask, this should be of shape (batch_size, max_seq_len). - num_heads (`int`, *required*): - number of heads - dtype (`torch.dtype`, *optional*, default=`torch.bfloat16`): - dtype of the output tensor - """ - import math - - if dist.is_initialized(): - world_size = dist.get_world_size() - num_heads = num_heads * world_size - - batch_size, seq_length = attention_mask.shape - closest_power_of_2 = 2**math.floor(math.log2(num_heads)) - base = torch.tensor(2**(-(2**-(math.log2(closest_power_of_2) - 3))), - device=attention_mask.device, - dtype=torch.float32) - powers = torch.arange(1, 1 + closest_power_of_2, device=attention_mask.device, dtype=torch.int32) - slopes = torch.pow(base, powers) - - if closest_power_of_2 != num_heads: - extra_base = torch.tensor(2**(-(2**-(math.log2(2 * closest_power_of_2) - 3))), - device=attention_mask.device, - dtype=torch.float32) - num_remaining_heads = min(closest_power_of_2, num_heads - closest_power_of_2) - extra_powers = torch.arange(1, 1 + 2 * num_remaining_heads, 2, device=attention_mask.device, dtype=torch.int32) - slopes = torch.cat([slopes, torch.pow(extra_base, extra_powers)], dim=0) - - # Note: alibi will added to the attention bias that will be applied to the query, key product of attention - # => therefore alibi will have to be of shape (batch_size, num_heads, query_length, key_length) - # => here we set (batch_size=1, num_heads=num_heads, query_length=1, key_length=max_length) - # => the query_length dimension will then be broadcasted correctly - # This is more or less identical to T5's relative position bias: - # https://github.com/huggingface/transformers/blob/f681437203baa7671de3174b0fa583c349d9d5e1/src/transformers/models/t5/modeling_t5.py#L527 - arange_tensor = ((attention_mask.cumsum(dim=-1) - 1) * attention_mask)[:, None, :] - alibi = slopes[..., None] * arange_tensor - if dist.is_initialized(): - num_heads_per_rank = int(num_heads / dist.get_world_size()) - offset = dist.get_rank() * num_heads_per_rank - alibi = alibi.view(batch_size, num_heads, 1, seq_length) - alibi = alibi[:, offset:num_heads_per_rank + offset, :, :] - return alibi.reshape(batch_size * num_heads_per_rank, 1, seq_length).to(dtype) - else: - return alibi.reshape(batch_size * num_heads, 1, seq_length).to(dtype) - - class BloomPolicy(Policy): def config_sanity_check(self): @@ -120,7 +62,9 @@ def module_policy(self): attribute_replacement={ "num_heads": self.model.config.n_head // self.shard_config.tensor_parallel_size, }, - method_replacement={"build_alibi_tensor": build_bloom_alibi_tensor}, + method_replacement={ + "build_alibi_tensor": build_bloom_alibi_tensor_fn(self.shard_config.tensor_parallel_process_group) + }, sub_module_replacement=[ SubModuleReplacementDescription( suffix="word_embeddings", From dd9fe396db75ead09ac9c24ee598be44f6a34f66 Mon Sep 17 00:00:00 2001 From: Frank Lee Date: Tue, 4 Jul 2023 14:47:53 +0800 Subject: [PATCH 49/49] [chat] removed cache file (#4155) --- applications/Chat/coati/trainer/.sft.py.swp | Bin 20480 -> 0 bytes 1 file changed, 0 insertions(+), 0 deletions(-) delete mode 100644 applications/Chat/coati/trainer/.sft.py.swp diff --git a/applications/Chat/coati/trainer/.sft.py.swp b/applications/Chat/coati/trainer/.sft.py.swp deleted file mode 100644 index 302cf2a775338fb4fcd6b9b12c1a8e80f3969a01..0000000000000000000000000000000000000000 GIT binary patch literal 0 HcmV?d00001 literal 20480 zcmeHOU5q4E6)psY74VP5@FaJ89H$z3t7mp0aohBo-E~(9yMp_JSthfc>h9aq#Z*@{ zRn;uJ!vHZRh#HOZ#2YFJFB0QVSObsxWI{~Lk{BNh;)6b@tf-4W;QF2WQ+2DldsbdF zA-9unZdKiL?>+as=box_s;0Yq@0~~4UT@jK?*oo=(;FxI-#GnI=k?bd=l&oa`tc}; zsBW8IF4yNr58UI0{+tGnwG2>WBO9Z2IGI81(!sh@-T5aOh{vUW~J5 zmiR%$lV&=|yz$v>bg)&H0n5OZ7`V~dvwKBs^@Z9_cGHJWUa5htw+vVYECZGS%YbFT zGGH073|I!UFsSOzQumI2Ga|A7I&?>KiO*ng4|0C@keHBLF9S~i4+E<} z3%D7$h(y4zfBv-t%5*a`ZrO?|Z0V174>laU{aAzEycM%t>?or)eT z7)T}dwb^d(4(3OF7NniTk2X5XZoM_Fg3>&mWaCMedUQH1q-rrkF||?L=bgLq$?QqGObFGlw zsur=BXQ9tuXN* zT09OpAri(tKNK?cdNBm{(In*^F^V*UJ|D0FMx$9b)UIt9LaZCdA#DAyAEq3p115z* z{>`AYG)r^FmE&kK>WKXTk05G(MX;eL>PT*BWlX2cfHuQQt2;ICh7{ju+ zyrA2V(pmqIh0}G%TZ!m0! zuzJ@o<|mwylTIh&ahzlv_n``0Jis|OV;YhZYFtuXv)eOs7m2C0u1LIu6La(7W#S>G1>6z+*f z#61zckXD6kb!|=ILP}U%i`_kAn<|c4OtV$A62@4Zk{!!4@_f0?{%%}z zh}c@KCQ3RMT6H$ylGan^_JdwVH;sWm3A2=CG22`Ulcm-Q-*x*|y~VO2XG#`z@Wt%R zFbFxg4YI%wgL80=vP*JA`5^hSF}J*I<*`=R)DVR{MxhhdpCr_^V44oplT1EBv6jrm zLYMp9t-6In>cS1e&XpfS@1n=N8~JRPLiriptf+2;5EDXj_FAjFn-5K6S(muQS)X-u*-Q38O%Hk?52`wot8x1|xR_{*cYCBZT^NQ_SOZVw! zanc)lQRE$s>8_)Ckd-CA_rv0hN8c0j>twB9nOT-_SZH6mW=}9hM%)X1JmDWdd_pyL zRcYMl|6v-xGf3hQ^CnpkrYsnd3oM>j@#GPA`uxG5$CaF#cEEGtX}{LU_pEAE*sdDj zpz=|wzUza2Ri)H*WL1Xe-4qKm;z{b^gg(~RlBFB>Eb(}QN5MH6*NFMy$*>sco`lP> z`X{6yPY$3ucN_mSSki9eU&b}T8^sG+A39cQi1ijjdSs(oqVSidoG?T!l3S^x_z(|v zxl(UJ0sR=R1f+4{IX_E+?u71NeoE0trh5wgVA3!DOO z27ZE={t@69@Ht>7@G5lu4e%WB7*GRt0yKvuD(%lQU>UFsSOzQumVs>I)s>jiJTo%Io9T8B?Yr36Urjv+!=8{+FB~71M~<8-s$A*#%|0$B8Agnj(`}mV2&{9PSOKJJ^azp3@v(PnE z!oFvzi?B{4y2|MzYJiQ2?KN%6rMDy7Y}o~s2*;LnH)}M%TFW%rt+IY{$nyRcq1em~ zF7pyw>;PriS@Eig9mCdTo0XkxIsvt1$PMT)=8rgVmBZ54{$zwJq?-9#ATos!(2Ju* zOlloD$&4rDj-=CYOnH5*UxO5+=^C@#=FNgBHiJI1I)4ub(zU9NDNsGG2}&U?%K2M~ z+SS4Yg2J5gB)M>$NHb)e5Lw#Q_QrZG26NCeC8`7U!w>MbEwzq59?&SKNOkO z%6pM&i^`;E$toL3w-AzfzQ{w)NN-I?iWtKD+hn9b%7