From f3a8a7612e0a6e3842ec745b8583b0b80e3760cd Mon Sep 17 00:00:00 2001 From: FoolPlayer <498107402@qq.com> Date: Thu, 8 Jun 2023 18:01:33 +0800 Subject: [PATCH 1/8] add dist dropout in model --- colossalai/shardformer/policies/basepolicy.py | 8 ++++++++ colossalai/shardformer/policies/bert.py | 8 ++++++-- colossalai/shardformer/shard/sharder.py | 17 ++++++++++++----- colossalai/shardformer/shard/slicer.py | 4 ++-- 4 files changed, 28 insertions(+), 9 deletions(-) diff --git a/colossalai/shardformer/policies/basepolicy.py b/colossalai/shardformer/policies/basepolicy.py index 644d115a270e..c64b8b09509b 100644 --- a/colossalai/shardformer/policies/basepolicy.py +++ b/colossalai/shardformer/policies/basepolicy.py @@ -62,6 +62,14 @@ class Row_Layer(Layer): pass +@dataclass +class Dropout_Layer(Layer): + r""" + Class for dropout layer in MegatronLM + """ + p: int = None + + class Policy(): r""" The base class for all the policies diff --git a/colossalai/shardformer/policies/bert.py b/colossalai/shardformer/policies/bert.py index 89b32f065c27..803faf7956a6 100644 --- a/colossalai/shardformer/policies/bert.py +++ b/colossalai/shardformer/policies/bert.py @@ -5,7 +5,7 @@ import colossalai.shardformer.layer.layers as col_nn -from .basepolicy import Argument, Col_Layer, Layer, Policy, Row_Layer +from .basepolicy import Argument, Col_Layer, Dropout_Layer, Policy, Row_Layer class BertPolicy(Policy): @@ -71,6 +71,10 @@ def attn_in() -> List: bias="attention.self.value.bias", replace_layer=col_nn.Linear1D_Col, ), + Dropout_Layer( + p="attention.self.dropout.p", + replace_layer=col_nn.Dropout1D, + ), Col_Layer( weight="crossattention.self.query.weight", bias="crossattention.self.query.bias", @@ -141,7 +145,7 @@ def unembedding() -> List: weight="decoder.weight", bias="decoder.bias", replace_layer=col_nn.Linear1D_Col, - # gather_output=True, + gather_output=True, ) ] diff --git a/colossalai/shardformer/shard/sharder.py b/colossalai/shardformer/shard/sharder.py index 1ada75e06b67..6633b27e7938 100644 --- a/colossalai/shardformer/shard/sharder.py +++ b/colossalai/shardformer/shard/sharder.py @@ -145,10 +145,12 @@ def shard_one_layer( bias_attr = policy_layer.bias replace_layer_cls = policy_layer.replace_layer ignore = policy_layer.ignore - n_cast = policy_layer.n_cast reversed = policy_layer.reversed - if policy_layer.__class__.__name__ == "Col_Layer": - gather_output = policy_layer.gather_output + n_cast = policy_layer.n_cast + if policy_layer.__class__.__name__ == 'Dropout_Layer': + dropout_p_attr = policy_layer.p + # if policy_layer.__class__.__name__ == "Col_Layer": + # gather_output = policy_layer.gather_output if weight_attr is not None: if hasattr_(org_layer, weight_attr): @@ -167,8 +169,8 @@ def shard_one_layer( continue # set the sliced weight and bias to the new nn_col layer - assert weight is not None or bias is not None - layer_attr = (lambda x: x[:x.rfind(".")])(weight_attr or bias_attr) + assert weight is not None or bias is not None or dropout_p_attr is not None + layer_attr = (lambda x: x[:x.rfind(".")])(weight_attr or bias_attr or dropout_p_attr) # slice weight and bias weight, bias = self.slicer.slice_weight_bias(weight, bias, policy_layer.__class__, n_cast, reversed) @@ -181,6 +183,7 @@ def shard_one_layer( weight.shape[0], bias=False if bias is None else True) elif replace_layer_cls.__name__ == "Linear1D_Col": + gather_output = policy_layer.gather_output replace_layer = replace_layer_cls(weight.shape[0], weight.shape[1], bias=False if bias is None else True, @@ -192,6 +195,10 @@ def shard_one_layer( getattr_(org_layer, f"{layer_attr}.padding_idx", ignore=True)) setattr_(org_layer, layer_attr, replace_layer, ignore=ignore) self.set_param(replace_layer, weight, bias) + elif isinstance(getattr_(org_layer, layer_attr), nn.Dropout): + p = getattr_(org_layer, dropout_p_attr, ignore=True) + replace_layer = replace_layer_cls(p=p) + setattr_(org_layer, layer_attr, replace_layer, ignore=ignore) else: raise NotImplementedError( f"Replacing {getattr_(org_layer, layer_attr).__class__} is not implemented so far") diff --git a/colossalai/shardformer/shard/slicer.py b/colossalai/shardformer/shard/slicer.py index 6d35bd193fed..5f7f94055fa0 100644 --- a/colossalai/shardformer/shard/slicer.py +++ b/colossalai/shardformer/shard/slicer.py @@ -1,6 +1,6 @@ import torch -from ..policies.basepolicy import Col_Layer, Layer, Row_Layer +from ..policies.basepolicy import Col_Layer, Dropout_Layer, Layer, Row_Layer from .shard_config import ShardConfig dim_mapping = {Col_Layer: 1, Row_Layer: 0} @@ -33,7 +33,7 @@ def slice_weight_bias( bias: (:class:`torch.nn.Module`): The bias of the layer policy_layer_class (:class:`Policy`): The class represent how to slice the tensor """ - if policy_layer_cls == Layer: + if policy_layer_cls in [Layer, Dropout_Layer]: return weight, bias dim = dim_mapping[policy_layer_cls] if not reversed else (1 - dim_mapping[policy_layer_cls]) From 4cab8227ff28be2df09ca2f630c098ba761e6fd8 Mon Sep 17 00:00:00 2001 From: FoolPlayer <498107402@qq.com> Date: Fri, 9 Jun 2023 11:07:37 +0800 Subject: [PATCH 2/8] update docstring and bert policy with dropout --- colossalai/shardformer/policies/basepolicy.py | 6 +++--- colossalai/shardformer/policies/bert.py | 8 ++++++++ 2 files changed, 11 insertions(+), 3 deletions(-) diff --git a/colossalai/shardformer/policies/basepolicy.py b/colossalai/shardformer/policies/basepolicy.py index c64b8b09509b..e596ef20cd93 100644 --- a/colossalai/shardformer/policies/basepolicy.py +++ b/colossalai/shardformer/policies/basepolicy.py @@ -46,7 +46,7 @@ class Layer: @dataclass class Col_Layer(Layer): r""" - Class for col shard layer in MegatronLM + Class for col shard layer in tensor parrallel Args: gather_output (bool): Whether to gather the output of the layer @@ -57,7 +57,7 @@ class Col_Layer(Layer): @dataclass class Row_Layer(Layer): r""" - Class for col shard layer in MegatronLM + Class for col shard layer in tensor parrallel """ pass @@ -65,7 +65,7 @@ class Row_Layer(Layer): @dataclass class Dropout_Layer(Layer): r""" - Class for dropout layer in MegatronLM + Class for dropout layer in tensor parrallel """ p: int = None diff --git a/colossalai/shardformer/policies/bert.py b/colossalai/shardformer/policies/bert.py index 803faf7956a6..0b149bd33b96 100644 --- a/colossalai/shardformer/policies/bert.py +++ b/colossalai/shardformer/policies/bert.py @@ -103,6 +103,10 @@ def attn_out() -> List: bias="attention.output.dense.bias", replace_layer=col_nn.Linear1D_Row, ), + Dropout_Layer( + p="attention.output.dropout.p", + replace_layer=col_nn.Dropout1D, + ), Row_Layer( weight="crossattention.output.dense.weight", bias="crossattention.output.dense.bias", @@ -129,6 +133,10 @@ def mlp_out() -> List: bias="output.dense.bias", replace_layer=col_nn.Linear1D_Row, ), + Dropout_Layer( + p="output.dropout.p", + replace_layer=col_nn.Dropout1D, + ) ] @staticmethod From e14ea613aaa2552fec4470298ac86de274df917e Mon Sep 17 00:00:00 2001 From: FoolPlayer <498107402@qq.com> Date: Fri, 9 Jun 2023 17:11:20 +0800 Subject: [PATCH 3/8] refactor basepolicy and sharded, update bert --- colossalai/shardformer/policies/basepolicy.py | 22 +++- colossalai/shardformer/policies/bert.py | 81 ++++++------ colossalai/shardformer/shard/sharder.py | 121 ++++++++++-------- colossalai/shardformer/utils/utils.py | 2 +- 4 files changed, 126 insertions(+), 100 deletions(-) diff --git a/colossalai/shardformer/policies/basepolicy.py b/colossalai/shardformer/policies/basepolicy.py index e596ef20cd93..501b810c2ff6 100644 --- a/colossalai/shardformer/policies/basepolicy.py +++ b/colossalai/shardformer/policies/basepolicy.py @@ -25,8 +25,7 @@ class Layer: The layer object for the policy Args: - weight (str): The weight suffix of the layer - bias (str): The bias suffix of the layer + suffix: (str): the suffix of the layer. replace_layer (:class:`colosalai.nn`): The layer to replace the original layer ignore (bool): Whether to ignore this layer if it is not in the model reversed (bool): Whether the weight in layer is reversed, commonly the weight in `torch.nn.Linear` is [out, in], @@ -35,8 +34,7 @@ class Layer: but in multi-head attention, we need to chunk the weight with the number of devices * n_head, and each device should have a part of Q, K and V weight. """ - weight: str = None - bias: str = None + suffix: str = None replace_layer: Any = None ignore: bool = False reversed: bool = False @@ -49,8 +47,12 @@ class Col_Layer(Layer): Class for col shard layer in tensor parrallel Args: + weight (str): The weight suffix of the layer + bias (str): The bias suffix of the layer gather_output (bool): Whether to gather the output of the layer """ + weight: str = None + bias: str = None gather_output: bool = False @@ -58,16 +60,24 @@ class Col_Layer(Layer): class Row_Layer(Layer): r""" Class for col shard layer in tensor parrallel + + Args: + weight (str): The weight suffix of the layer + bias (str): The bias suffix of the layer """ - pass + weight: str = None + bias: str = None @dataclass class Dropout_Layer(Layer): r""" Class for dropout layer in tensor parrallel + + Args: + p (str): The dropout rate suffix of the layer """ - p: int = None + p: str = None class Policy(): diff --git a/colossalai/shardformer/policies/bert.py b/colossalai/shardformer/policies/bert.py index e75e4ca5b446..4ae4d812197f 100644 --- a/colossalai/shardformer/policies/bert.py +++ b/colossalai/shardformer/policies/bert.py @@ -28,23 +28,15 @@ def argument_policy(config, world_size: int) -> Dict[nn.Module, Argument]: Argument( attr_dict={ # 1. shard vocab size - # "word_embeddings.num_embeddings": config.vocab_size // world_size, - # 2. add the size of the sliced embedding layer excluding the last slice "word_embeddings.dim_size": (config.vocab_size + world_size - 1) // world_size, }, param_funcs=[ BertPolicy.embedding, ]), BertLMPredictionHead: - Argument( - attr_dict={ - # 1. shard vocab size - # "word_embeddings.num_embeddings": config.vocab_size // world_size, - # 2. add the size of the sliced embedding layer excluding the last slice - }, - param_funcs=[ - BertPolicy.unembedding, - ]) + Argument(attr_dict={}, param_funcs=[ + BertPolicy.unembedding, + ]) } @staticmethod @@ -57,39 +49,46 @@ def binding_policy() -> Dict: def attn_in() -> List: return [ Col_Layer( - weight="attention.self.query.weight", - bias="attention.self.query.bias", + suffix="attention.self.query", + weight="weight", + bias="bias", replace_layer=col_nn.Linear1D_Col, ), Col_Layer( - weight="attention.self.key.weight", - bias="attention.self.key.bias", + suffix="attention.self.key", + weight="weight", + bias="bias", replace_layer=col_nn.Linear1D_Col, ), Col_Layer( - weight="attention.self.value.weight", - bias="attention.self.value.bias", + suffix="attention.self.value", + weight="weight", + bias="bias", replace_layer=col_nn.Linear1D_Col, ), Dropout_Layer( - p="attention.self.dropout.p", + suffix="attention.self.dropout", + p="p", replace_layer=col_nn.Dropout1D, ), Col_Layer( - weight="crossattention.self.query.weight", - bias="crossattention.self.query.bias", + suffix="crossattention.self.query", + weight="weight", + bias="bias", replace_layer=col_nn.Linear1D_Col, ignore=True, ), Col_Layer( - weight="crossattention.self.key.weight", - bias="crossattention.self.key.bias", + suffix="crossattention.self.key", + weight="weight", + bias="bias", replace_layer=col_nn.Linear1D_Col, ignore=True, ), Col_Layer( - weight="crossattention.self.value.weight", - bias="crossattention.self.value.bias", + suffix="crossattention.self.value", + weight="weight", + bias="bias", replace_layer=col_nn.Linear1D_Col, ignore=True, ), @@ -99,17 +98,20 @@ def attn_in() -> List: def attn_out() -> List: return [ Row_Layer( - weight="attention.output.dense.weight", - bias="attention.output.dense.bias", + suffix="attention.output.dense", + weight="weight", + bias="bias", replace_layer=col_nn.Linear1D_Row, ), Dropout_Layer( - p="attention.output.dropout.p", + suffix="attention.output.dropout", + p="p", replace_layer=col_nn.Dropout1D, ), Row_Layer( - weight="crossattention.output.dense.weight", - bias="crossattention.output.dense.bias", + suffix="crossattention.output.dense", + weight="weight", + bias="bias", replace_layer=col_nn.Linear1D_Row, ignore=True, ), @@ -119,8 +121,9 @@ def attn_out() -> List: def mlp_in() -> List: return [ Col_Layer( - weight="intermediate.dense.weight", - bias="intermediate.dense.bias", + suffix="intermediate.dense", + weight="weight", + bias="bias", replace_layer=col_nn.Linear1D_Col, ), ] @@ -129,12 +132,14 @@ def mlp_in() -> List: def mlp_out() -> List: return [ Row_Layer( - weight="output.dense.weight", - bias="output.dense.bias", + suffix="output.dense", + weight="weight", + bias="bias", replace_layer=col_nn.Linear1D_Row, ), Dropout_Layer( - p="output.dropout.p", + suffix="output.dropout", + p="p", replace_layer=col_nn.Dropout1D, ) ] @@ -142,7 +147,8 @@ def mlp_out() -> List: @staticmethod def embedding() -> List: return [Col_Layer( - weight="word_embeddings.weight", + suffix="word_embeddings", + weight="weight", replace_layer=col_nn.VocabParallelEmbedding1D, )] @@ -150,8 +156,9 @@ def embedding() -> List: def unembedding() -> List: return [ Col_Layer( - weight="decoder.weight", - bias="decoder.bias", + suffix="decoder", + weight="weight", + bias="bias", replace_layer=col_nn.Linear1D_Col, gather_output=True, ) diff --git a/colossalai/shardformer/shard/sharder.py b/colossalai/shardformer/shard/sharder.py index 03d0b718c0c3..3e224a8dfd8c 100644 --- a/colossalai/shardformer/shard/sharder.py +++ b/colossalai/shardformer/shard/sharder.py @@ -5,7 +5,7 @@ from transformers.pytorch_utils import Conv1D from ..policies.autopolicy import get_autopolicy -from ..policies.basepolicy import Policy +from ..policies.basepolicy import Col_Layer, Dropout_Layer, Policy, Row_Layer from ..utils.utils import getattr_, hasattr_, setattr_ from .shard_config import ShardConfig from .slicer import Slicer @@ -141,72 +141,81 @@ def shard_one_layer( for func in param_funcs: policy_layers = func() for policy_layer in policy_layers: - weight = None - bias = None - weight_attr = policy_layer.weight - bias_attr = policy_layer.bias + suffix = policy_layer.suffix replace_layer_cls = policy_layer.replace_layer ignore = policy_layer.ignore reversed = policy_layer.reversed n_cast = policy_layer.n_cast - if policy_layer.__class__.__name__ == 'Dropout_Layer': - dropout_p_attr = policy_layer.p - if policy_layer.__class__.__name__ == "Col_Layer": - gather_output = policy_layer.gather_output and self.shard_config.gather_output - - if weight_attr is not None: - if hasattr_(org_layer, weight_attr): - weight = getattr_(org_layer, weight_attr) - elif not ignore: - raise ValueError(f"Layer {org_layer.__class__.__qualname__} has no attribute {weight_attr}") - - if bias_attr is not None: - if hasattr_(org_layer, bias_attr): - bias = getattr_(org_layer, bias_attr) - elif not ignore: - raise ValueError(f"Layer {org_layer.__class__.__qualname__} has no attribute {bias_attr}") - - # dont have the attribute in policy, and ignore is true - if weight is None and bias is None and ignore: - continue - - # set the sliced weight and bias to the new nn_col layer - assert weight is not None or bias is not None or dropout_p_attr is not None - layer_attr = (lambda x: x[:x.rfind(".")])(weight_attr or bias_attr or dropout_p_attr) - # slice weight and bias - weight, bias = self.slicer.slice_weight_bias(weight, bias, policy_layer.__class__, n_cast, reversed) + assert replace_layer_cls is not None, 'replace_layer should not be None' # create new object to replace the origin layer - if replace_layer_cls is not None: - if isinstance(getattr_(org_layer, layer_attr), (nn.Linear, Conv1D)): - if replace_layer_cls.__name__ == "Linear1D_Row": - replace_layer = replace_layer_cls(weight.shape[1], - weight.shape[0], - bias=False if bias is None else True) - elif replace_layer_cls.__name__ == "Linear1D_Col": - gather_output = policy_layer.gather_output - replace_layer = replace_layer_cls(weight.shape[0], - weight.shape[1], - bias=False if bias is None else True, - gather_output=gather_output) - setattr_(org_layer, layer_attr, replace_layer, ignore=ignore) - self.set_param(replace_layer, weight, bias) - elif isinstance(getattr_(org_layer, layer_attr), nn.Embedding): + # Linear + suffix_layer = getattr_(org_layer, suffix, ignore=True) + assert suffix_layer is not None or ignore, f"Layer {org_layer.__class__.__qualname__} has no attribute {suffix}" + if suffix_layer is None and ignore: + continue + if isinstance(policy_layer, (Col_Layer, Row_Layer)): + weight = None + bias = None + weight_attr = suffix + '.' + policy_layer.weight if policy_layer.weight is not None else None + bias_attr = suffix + '.' + policy_layer.bias if policy_layer.bias is not None else None + + if weight_attr is not None: + if hasattr_(org_layer, weight_attr): + weight = getattr_(org_layer, weight_attr) + else: + raise ValueError(f"Layer {org_layer.__class__.__qualname__} has no attribute {weight_attr}") + + if bias_attr is not None: + if hasattr_(org_layer, bias_attr): + bias = getattr_(org_layer, bias_attr) + else: + raise ValueError(f"Layer {org_layer.__class__.__qualname__} has no attribute {bias_attr}") + + # set the sliced weight and bias to the new nn_col layer + assert weight is not None or bias is not None + + # slice weight and bias + weight, bias = self.slicer.slice_weight_bias(weight, bias, policy_layer.__class__, n_cast, reversed) + + if replace_layer_cls.__name__ == "Linear1D_Row": + replace_layer = replace_layer_cls(weight.shape[1], + weight.shape[0], + bias=False if bias is None else True) + elif replace_layer_cls.__name__ == "Linear1D_Col": + gather_output = policy_layer.gather_output and self.shard_config.gather_output + replace_layer = replace_layer_cls(weight.shape[0], + weight.shape[1], + bias=False if bias is None else True, + gather_output=gather_output) + elif replace_layer_cls.__name__ == "VocabParallelEmbedding1D": replace_layer = replace_layer_cls(weight.shape[0], weight.shape[1], - getattr_(org_layer, f"{layer_attr}.padding_idx", ignore=True)) - setattr_(org_layer, layer_attr, replace_layer, ignore=ignore) - self.set_param(replace_layer, weight, bias) - elif isinstance(getattr_(org_layer, layer_attr), nn.Dropout): - p = getattr_(org_layer, dropout_p_attr, ignore=True) - replace_layer = replace_layer_cls(p=p) - setattr_(org_layer, layer_attr, replace_layer, ignore=ignore) + getattr_(org_layer, f"{suffix}.padding_idx", ignore=True)) + # setattr_(org_layer, suffix, replace_layer, ignore=ignore) + # self.set_param(replace_layer, weight, bias) else: raise NotImplementedError( - f"Replacing {getattr_(org_layer, layer_attr).__class__} is not implemented so far") - # do not replace the layer object, just replace the weight and bias + f"Replacing to {replace_layer_cls.__name__} is not implemented so far") + setattr_(org_layer, suffix, replace_layer, ignore=ignore) + self.set_param(replace_layer, weight, bias) + # dropout + elif isinstance(policy_layer, Dropout_Layer): + p_attr = suffix + '.' + policy_layer.p + p = getattr_(org_layer, p_attr, ignore=True) + replace_layer = replace_layer_cls(p) + setattr_(org_layer, suffix, replace_layer, ignore=ignore) else: - self.set_param(org_layer, layer_attr, weight, bias) + raise NotImplementedError( + f"Replacing {getattr_(org_layer, suffix).__class__} is not implemented so far") + # do not replace the layer object, just replace the weight and bias + # else: + # self.set_param(org_layer, layer_attr, weight, bias) + + # if policy_layer.__class__.__name__ == 'Dropout_Layer': + # dropout_p_attr = policy_layer.p + # if policy_layer.__class__.__name__ == "Col_Layer": + # gather_output = policy_layer.gather_output and self.shard_config.gather_output def set_param(self, layer: Any, diff --git a/colossalai/shardformer/utils/utils.py b/colossalai/shardformer/utils/utils.py index eb84edd88404..2c02b6f69a3e 100644 --- a/colossalai/shardformer/utils/utils.py +++ b/colossalai/shardformer/utils/utils.py @@ -37,7 +37,7 @@ def setattr_(obj, attr: str, value, ignore: bool = False): setattr(obj, attrs[-1], value) -def getattr_(obj, attr: str, ignore: bool = None): +def getattr_(obj, attr: str, ignore: bool = False): r""" Get the object's multi sublevel attr From a6f6673484f24fc164bad9f4638a592d12682348 Mon Sep 17 00:00:00 2001 From: FoolPlayer <498107402@qq.com> Date: Fri, 9 Jun 2023 17:20:34 +0800 Subject: [PATCH 4/8] update format --- colossalai/shardformer/policies/basepolicy.py | 32 +++++++++---------- colossalai/shardformer/policies/bert.py | 25 ++++++--------- 2 files changed, 26 insertions(+), 31 deletions(-) diff --git a/colossalai/shardformer/policies/basepolicy.py b/colossalai/shardformer/policies/basepolicy.py index 501b810c2ff6..26109c859f28 100644 --- a/colossalai/shardformer/policies/basepolicy.py +++ b/colossalai/shardformer/policies/basepolicy.py @@ -1,7 +1,7 @@ # part of code modified from https://github.com/tunib-ai/parallelformers from dataclasses import dataclass -from typing import Any, Callable, Dict, List, Tuple, Type +from typing import Any, Callable, Dict, List, Tuple, Union import torch.nn as nn @@ -144,7 +144,7 @@ def argument_policy(model_config, shard_config: int) -> Dict[nn.Module, Argument raise NotImplementedError @staticmethod - def inject_policy() -> Tuple[nn.Module, nn.Module]: + def inject_policy() -> Union[Tuple[nn.Module, nn.Module], None]: r""" Return the dict for the inject model @@ -157,9 +157,9 @@ def inject_policy() -> Tuple[nn.Module, nn.Module]: return None @staticmethod - def binding_policy() -> Dict: + def binding_policy() -> Union[Dict[str, str], None]: r""" - Return the dict for the binding model + Return the dict for the binding model, None means no need to bind Return: This method should return the binding relationship for some layers share the weight or bias, @@ -172,7 +172,7 @@ def binding_policy() -> Dict: return None @staticmethod - def attn_in() -> List: + def attn_in() -> Union[List, None]: r""" Attention qkv layer In this kind of method, we should return the list of ``Layer`` object, each ``Layer`` object should be @@ -182,52 +182,52 @@ def attn_in() -> List: Returns: List[Layer]: List of layer object, each layer is the new """ - return NotImplementedError + return None @staticmethod - def attn_out() -> List: + def attn_out() -> Union[List, None]: r""" Attention output projection layer Returns: List[Layer]: List of layer object """ - return NotImplementedError + return None @staticmethod - def mlp_in() -> List: + def mlp_in() -> Union[List, None]: r""" h -> 4h mlp layer Returns: List[Layer]: List of layer object """ - return NotImplementedError + return None @staticmethod - def mlp_out() -> List: + def mlp_out() -> Union[List, None]: r""" 4h -> h mlp layer Returns: List[Layer]: List of layer object """ - return NotImplementedError + return None @staticmethod - def embedding() -> List: + def embedding() -> Union[List, None]: r""" Partially slice the embedding layer Return: List[Layer]: List of layer object """ - return NotImplementedError + return None @staticmethod - def unembedding() -> List: + def unembedding() -> Union[List, None]: r""" - Partially slice the embedding layer + Partially slice the embedding layer, None means there is no unembedding layer Return: List[Layer]: List of layer object diff --git a/colossalai/shardformer/policies/bert.py b/colossalai/shardformer/policies/bert.py index 4ae4d812197f..b5561010aaa1 100644 --- a/colossalai/shardformer/policies/bert.py +++ b/colossalai/shardformer/policies/bert.py @@ -40,13 +40,13 @@ def argument_policy(config, world_size: int) -> Dict[nn.Module, Argument]: } @staticmethod - def binding_policy() -> Dict: + def binding_policy(): return { "bert.embeddings.word_embeddings.weight": "cls.predictions.decoder.weight", } @staticmethod - def attn_in() -> List: + def attn_in(): return [ Col_Layer( suffix="attention.self.query", @@ -95,7 +95,7 @@ def attn_in() -> List: ] @staticmethod - def attn_out() -> List: + def attn_out(): return [ Row_Layer( suffix="attention.output.dense", @@ -118,7 +118,7 @@ def attn_out() -> List: ] @staticmethod - def mlp_in() -> List: + def mlp_in(): return [ Col_Layer( suffix="intermediate.dense", @@ -129,7 +129,7 @@ def mlp_in() -> List: ] @staticmethod - def mlp_out() -> List: + def mlp_out(): return [ Row_Layer( suffix="output.dense", @@ -145,7 +145,7 @@ def mlp_out() -> List: ] @staticmethod - def embedding() -> List: + def embedding(): return [Col_Layer( suffix="word_embeddings", weight="weight", @@ -153,7 +153,7 @@ def embedding() -> List: )] @staticmethod - def unembedding() -> List: + def unembedding(): return [ Col_Layer( suffix="decoder", @@ -173,7 +173,7 @@ def unembedding() -> List: class BertForMaskedLMPolicy(BertPolicy): @staticmethod - def inject_policy() -> Tuple[nn.Module, nn.Module]: + def inject_policy(): # return (BertForMaskedLM, BertForMaskedLM_) return None @@ -181,10 +181,5 @@ def inject_policy() -> Tuple[nn.Module, nn.Module]: class BertForSequenceClassificationPolicy(BertPolicy): @staticmethod - def inject_policy() -> Dict: - return {} - - -# model = BertForMaskedLM.from_pretrained("bert-base-uncased") -# _ = BertForMaskedLMPolicy(model) -# print(isinstance(model,list(_.inject_policy().keys())[0])) + def inject_policy(): + return None From 9622f025b72c5122d5873bd389bc800060772bee Mon Sep 17 00:00:00 2001 From: FoolPlayer <498107402@qq.com> Date: Fri, 9 Jun 2023 17:33:06 +0800 Subject: [PATCH 5/8] update gpt2 policy --- colossalai/shardformer/policies/gpt2.py | 40 +++++++++++++++---------- 1 file changed, 24 insertions(+), 16 deletions(-) diff --git a/colossalai/shardformer/policies/gpt2.py b/colossalai/shardformer/policies/gpt2.py index 44dc9c72f986..0d4342e75783 100644 --- a/colossalai/shardformer/policies/gpt2.py +++ b/colossalai/shardformer/policies/gpt2.py @@ -40,19 +40,22 @@ def argument_policy(config, world_size): @staticmethod def attn_in() -> List: return [ - Col_Layer(weight="attn.c_attn.weight", - bias="attn.c_attn.bias", + Col_Layer(suffix="attn.c_attn", + weight="weight", + bias="bias", n_cast=3, reversed=True, replace_layer=col_nn.Linear1D_Col), - Col_Layer(weight="crossattention.c_attn.weight", - bias="crossattention.c_attn.bias", + Col_Layer(suffix="crossattention.c_attn", + weight="weight", + bias="bias", n_cast=2, reversed=True, ignore=True, replace_layer=col_nn.Linear1D_Col), - Col_Layer(weight="crossattention.q_attn.weight", - bias="crossattention.q_attn.bias", + Col_Layer(suffix="crossattention.q_attn", + weight="weight", + bias="bias", reversed=True, ignore=True, replace_layer=col_nn.Linear1D_Col) @@ -61,12 +64,14 @@ def attn_in() -> List: @staticmethod def attn_out() -> List: return [ - Row_Layer(weight="attn.c_proj.weight", - bias="attn.c_proj.bias", + Row_Layer(suffix="attn.c_proj", + weight="weight", + bias="bias", reversed=True, replace_layer=col_nn.Linear1D_Row), - Row_Layer(weight="crossattention.c_proj.weight", - bias="crossattention.c_proj.bias", + Row_Layer(suffix="crossattention.c_proj", + weight="weight", + bias="bias", reversed=True, ignore=True, replace_layer=col_nn.Linear1D_Row) @@ -75,21 +80,23 @@ def attn_out() -> List: @staticmethod def mlp_in() -> List: return [ - Col_Layer(weight="mlp.c_fc.weight", bias="mlp.c_fc.bias", reversed=True, replace_layer=col_nn.Linear1D_Col), + Col_Layer(suffix="mlp.c_fc", weight="weight", bias="bias", reversed=True, + replace_layer=col_nn.Linear1D_Col), ] @staticmethod def mlp_out() -> List: return [ - Row_Layer(weight="mlp.c_proj.weight", - bias="mlp.c_proj.bias", + Row_Layer(suffix="mlp.c_proj", + weight="weight", + bias="bias", reversed=True, replace_layer=col_nn.Linear1D_Row) ] @staticmethod def embedding() -> List: - return [Col_Layer(weight="wte.weight", replace_layer=col_nn.VocabParallelEmbedding1D)] + return [Col_Layer(suffix="wte", weight="weight", replace_layer=col_nn.VocabParallelEmbedding1D)] from transformers import GPT2LMHeadModel @@ -111,8 +118,9 @@ def argument_policy(config, world_size): @staticmethod def unembedding() -> List: return [ - Col_Layer(weight="lm_head.weight", - bias="lm_head.bias", + Col_Layer(suffix="lm_head", + weight="weight", + bias="bias", replace_layer=col_nn.Linear1D_Col, gather_output=True) ] From dfa7c590cb7de4e643fa53a34d2f81eab04cb2d5 Mon Sep 17 00:00:00 2001 From: FoolPlayer <498107402@qq.com> Date: Fri, 9 Jun 2023 17:55:51 +0800 Subject: [PATCH 6/8] update bert policy --- colossalai/shardformer/policies/basepolicy.py | 4 +- colossalai/shardformer/policies/bert.py | 39 +++++++++++-------- 2 files changed, 25 insertions(+), 18 deletions(-) diff --git a/colossalai/shardformer/policies/basepolicy.py b/colossalai/shardformer/policies/basepolicy.py index 26109c859f28..d55df59fdc8b 100644 --- a/colossalai/shardformer/policies/basepolicy.py +++ b/colossalai/shardformer/policies/basepolicy.py @@ -100,14 +100,14 @@ class for the example. """ @staticmethod - def argument_policy(model_config, shard_config: int) -> Dict[nn.Module, Argument]: + def argument_policy(model_config, world_size: int) -> Dict[nn.Module, Argument]: r""" Return the dict for the modify policy, the key is the original layer class and the value is the argument for the modify layer Args: model_config (:class:`tansformer.Config`): The config of transformer model - shard_config (:class:`ShardConfig`): The config for sharding model + world_size (int)): The world size of sharding model Return: Dict for the modify policy, diff --git a/colossalai/shardformer/policies/bert.py b/colossalai/shardformer/policies/bert.py index b5561010aaa1..67e910d521e9 100644 --- a/colossalai/shardformer/policies/bert.py +++ b/colossalai/shardformer/policies/bert.py @@ -33,10 +33,6 @@ def argument_policy(config, world_size: int) -> Dict[nn.Module, Argument]: param_funcs=[ BertPolicy.embedding, ]), - BertLMPredictionHead: - Argument(attr_dict={}, param_funcs=[ - BertPolicy.unembedding, - ]) } @staticmethod @@ -152,18 +148,6 @@ def embedding(): replace_layer=col_nn.VocabParallelEmbedding1D, )] - @staticmethod - def unembedding(): - return [ - Col_Layer( - suffix="decoder", - weight="weight", - bias="bias", - replace_layer=col_nn.Linear1D_Col, - gather_output=True, - ) - ] - from transformers import BertForMaskedLM @@ -172,11 +156,34 @@ def unembedding(): class BertForMaskedLMPolicy(BertPolicy): + @staticmethod + def argument_policy(config, world_size): + base_argument = BertPolicy.argument_policy(config, world_size) + argument = { + BertLMPredictionHead: Argument(attr_dict={}, param_funcs=[ + BertForMaskedLMPolicy.unembedding, + ]), + } + argument.update(base_argument) + return argument + @staticmethod def inject_policy(): # return (BertForMaskedLM, BertForMaskedLM_) return None + @staticmethod + def unembedding(): + return [ + Col_Layer( + suffix="decoder", + weight="weight", + bias="bias", + replace_layer=col_nn.Linear1D_Col, + gather_output=True, + ) + ] + class BertForSequenceClassificationPolicy(BertPolicy): From 8d1a3e4f8bf71fbd49688c6583e192c5a671dd8f Mon Sep 17 00:00:00 2001 From: FoolPlayer <498107402@qq.com> Date: Fri, 9 Jun 2023 18:07:22 +0800 Subject: [PATCH 7/8] remove unused code --- colossalai/shardformer/shard/sharder.py | 8 -------- 1 file changed, 8 deletions(-) diff --git a/colossalai/shardformer/shard/sharder.py b/colossalai/shardformer/shard/sharder.py index 3e224a8dfd8c..95184cfe6929 100644 --- a/colossalai/shardformer/shard/sharder.py +++ b/colossalai/shardformer/shard/sharder.py @@ -208,14 +208,6 @@ def shard_one_layer( else: raise NotImplementedError( f"Replacing {getattr_(org_layer, suffix).__class__} is not implemented so far") - # do not replace the layer object, just replace the weight and bias - # else: - # self.set_param(org_layer, layer_attr, weight, bias) - - # if policy_layer.__class__.__name__ == 'Dropout_Layer': - # dropout_p_attr = policy_layer.p - # if policy_layer.__class__.__name__ == "Col_Layer": - # gather_output = policy_layer.gather_output and self.shard_config.gather_output def set_param(self, layer: Any, From a1ae48cca587783e8319dc882e2f9cde0e7335c7 Mon Sep 17 00:00:00 2001 From: FoolPlayer <498107402@qq.com> Date: Mon, 12 Jun 2023 14:35:57 +0800 Subject: [PATCH 8/8] update readme for new policy usage --- colossalai/shardformer/README.md | 80 +++++++++++++++++++------------- 1 file changed, 47 insertions(+), 33 deletions(-) diff --git a/colossalai/shardformer/README.md b/colossalai/shardformer/README.md index 222626db3e9d..b8357c203939 100644 --- a/colossalai/shardformer/README.md +++ b/colossalai/shardformer/README.md @@ -55,7 +55,7 @@ colossalai run --nproc_per_node 2 --master_port 29500 test.py --config config.py ## 💡 Policy -If you wanna parallel the model in a custom way, just overwrite the policy class for the Hugging Face model. +If you wanna parallel the model in a custom way, just overwrite the policy class for the Hugging Face model. Please refer to any policy that we have pre-established, like [bert policy](./policies/bert.py) or [gpt2 policy](./policies/gpt2.py). You should do: @@ -68,7 +68,7 @@ You should do: - Shardformer will inject the model according to this method. If you need to modify the forward or backward progress (like distributed corssentropy loss in Bert) you need to overwrite this method. 4. Overwrite or add the param functions - These functions use a suffix to record the path of weight or bias for the layer. - - The return is a list contains some `Col_Layer` or `Row_Layer` objects, which means slice along col and row respectively. + - The return is a list contains some `Col_Layer`, `Row_Layer` or `Dropout_Layer` objects, which means slice along col and row respectively or as dropout layer, refer to CLASS `Layer` for more details. 5. Overwrite `binding_policy` (Optional) - Overwrite to specify Shardformer will bind some weight between layers, like embedding and unembedding layers. - This function will return a dict, the key and value are the suffix of weight need to be binded. @@ -123,7 +123,7 @@ class CustomPolicy(Policy): raise NotImplementedError @staticmethod - def inject_policy() -> Tuple[nn.Module, nn.Module]: + def inject_policy() -> Union[Tuple[nn.Module, nn.Module], None]: r""" Return the dict for the inject model @@ -133,12 +133,12 @@ class CustomPolicy(Policy): (OrignModel, CustomModel) in `CustomModel`, we can overwrite the forward and backward process """ - return () + return None @staticmethod - def binding_policy() -> Dict: + def binding_policy() -> Union[Dict[str, str], None]: r""" - Return the dict for the binding model + Return the dict for the binding model, None means no need to bind Return: This method should return the binding relationship for some layers share the weight or bias, @@ -148,69 +148,70 @@ class CustomPolicy(Policy): "bert.embeddings.word_embeddings.weight": "cls.predictions.decoder.weight", } """ - return NotImplementedError + return None @staticmethod - def attn_in() -> List: - """ + def attn_in() -> Union[List, None]: + r""" Attention qkv layer + In this kind of method, we should return the list of ``Layer`` object, each ``Layer`` object should be + ``Layer`` for no slicing, ``Col_Layer`` for col slicing, ``Row_Layer`` for row slicing. And the parameters + in ``Layer`` object can refer to the ``Layer`` class. Returns: List[Layer]: List of layer object, each layer is the new """ - return NotImplementedError + return None @staticmethod - def attn_out() -> List: - """ + def attn_out() -> Union[List, None]: + r""" Attention output projection layer Returns: List[Layer]: List of layer object """ - return NotImplementedError + return None @staticmethod - def mlp_in() -> List: - """ + def mlp_in() -> Union[List, None]: + r""" h -> 4h mlp layer Returns: List[Layer]: List of layer object """ - return NotImplementedError + return None @staticmethod - def mlp_out() -> List: - """ + def mlp_out() -> Union[List, None]: + r""" 4h -> h mlp layer Returns: List[Layer]: List of layer object """ - return NotImplementedError + return None @staticmethod - def embedding() -> List: - """ + def embedding() -> Union[List, None]: + r""" Partially slice the embedding layer - vocab_size->vocab_size//gpu_nums Return: List[Layer]: List of layer object """ - return NotImplementedError + return None @staticmethod - def unembedding() -> List: - """ - Partially slice the embedding layer - vocab_size->vocab_size//gpu_nums + def unembedding() -> Union[List, None]: + r""" + Partially slice the embedding layer, None means there is no unembedding layer Return: List[Layer]: List of layer object """ - return NotImplementedError + return None ``` @@ -232,21 +233,26 @@ class CustomPolicy(Policy): - CLASS `Layer`: Parameters: - - weight (str): The weight suffix of the layer - - bias (str): The bias suffix of the layer + - suffix: (str): the suffix of the layer to indicate the attribute of the layer. - replace_layer (:class:`colosalai.nn`): The layer to replace the original layer - ignore (bool): Whether to ignore this layer if it is not in the model + - reversed (bool): Whether the weight in layer is reversed, commonly the weight in `torch.nn.Linear` is [out, in], but in GPT2 `Conv1D` layer is [in, out] which is reversed. + - n_cast (int): The number of weight will cast to, like q, k, v in attention layer, n_cast should be 3. commonly in TP, we just chunk the weight with the number of devices, but in multi-head attention, we need to chunk the weight with the number of $ devices * n\_head $, and each device should have a part of Q, K and V weight. - This class is used to specify the replacement policy for a particular layer. If `replace_layer` is None, only parameter partitioning will be performed without replacing the layer class. + This class is a base class used to specify the replacement policy and the suffix the layer for a particular layer. CLASS `Col_Layer(Layer)`: + - weight (str): The weight suffix of the layer + - bias (str): The bias suffix of the layer - gather_output (bool): Whether the output of this layer can be gathered, like the last layer can be gathered, but most of the time, the intermediate layers of the model do not need to be gathered. - This class inherited from `Layer`, representing the layer will be sliced along column. + This class inherited from `Layer`, representing the layer will be sliced along colum and indicate the attributes of weight and bias. Setting `bias` to `None` means ignoring bias, regardless of whether or not it originally exists. CLASS `Row_Layer(Layer)`: + - weight (str): The weight suffix of the layer + - bias (str): The bias suffix of the layer - This class inherited from `Layer`, representing the layer will be sliced along row. + This class inherited from `Layer`, representing the layer will be sliced along row. Just like `Col_Layer` but in tensor parrallel, there is no need to gather the output of layer sliced by row. - CLASS `Policy`: @@ -254,29 +260,37 @@ class CustomPolicy(Policy): - `Policy.attn_in()/attn_out()/mlp_in()/mlp_out()/embedding()/unembedding()`...... These functions define the partitioning methods of the parameters at different locations in the model. Each function returns a list of objects of Layer class that specify the replacement approach for these parameters. Shardformer also supports user-defined functions for modifying their models, in addition to the listed functions. + - `Policy.argument_policy()` In this function, the user should use multiple dict to define which class of layers will require replacement. This includes the attributes and parameters that need to be modified or replaced. Attributes are stored in the form of a "suffix-string: value" dict, while parameters are stored via multiple static methods that return the replacement approach. + - `Policy.inject_policy()` This function will return the injected model to replace the original model. The new model should be a nn.Module class which includes modified forward or backward functions or anything else. + - `Policy.binding_policy()` This function will return the weight sharing information in the model in some dict. The key and value are both the suffixes of the shared parameters. + - CLASS `ModelSharder(model, policy)`: This class helps shard the model, the parameter is the created transformers model and the custom policy. If custom policy is None, shardformer will automatically get already defined policy for the model. + - `ModelShard.inject_model()` This function is used to inject the model to modify the forward and backward progress. + - `ModelShard.replace_layer()` This function is used to replace the original layers with colossalai layer to make them paralleled and can do distributed communication. + - `ModelShard.bind_layer()` This function is used to help different layers share weight or bias. + - CLASS `Slicer`: This class is used to slice tensor according to policy.