Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Empty file.
Empty file.
63 changes: 63 additions & 0 deletions colossalai/shardformer/model/modeling_bert.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
import torch
import torch.nn as nn
from torch.nn import CrossEntropyLoss
from typing import Any, Dict, List, Type


from transformers import BertForMaskedLM
from transformers.models.bert.modeling_bert import MaskedLMOutput
class BertForMaskedLM_(BertForMaskedLM):
def forward(
self,
input_ids=None,
attention_mask=None,
token_type_ids=None,
position_ids=None,
head_mask=None,
inputs_embeds=None,
encoder_hidden_states=None,
encoder_attention_mask=None,
labels=None,
output_attentions=None,
output_hidden_states=None,
return_dict=None,
**kwargs,
):
print("[Inject OK] Injected forward method")
return_dict = return_dict if return_dict is not None else self.config.use_return_dict

outputs = self.bert(
input_ids,
attention_mask=attention_mask,
token_type_ids=token_type_ids,
position_ids=position_ids,
head_mask=head_mask,
inputs_embeds=inputs_embeds,
encoder_hidden_states=encoder_hidden_states,
encoder_attention_mask=encoder_attention_mask,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)

sequence_output = outputs[0]
prediction_scores = self.cls(sequence_output)

masked_lm_loss = None

# if input_ids is not None:
# masked_lm_loss = applyDistCrossEntropy(prediction_scores, input_ids, self.config.vocab_size)
if labels is not None:
loss_fct = CrossEntropyLoss() # -100 index = padding token
masked_lm_loss = loss_fct(prediction_scores.view(-1, self.config.vocab_size), labels.view(-1))

if not return_dict:
output = (prediction_scores,) + outputs[2:]
return ((masked_lm_loss,) + output) if masked_lm_loss is not None else output

return MaskedLMOutput(
loss=masked_lm_loss,
logits=prediction_scores,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
)
Empty file.
41 changes: 41 additions & 0 deletions colossalai/shardformer/policies/autopolicy.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
import torch.nn as nn

def build_policies():
"""
Build the policies for the model

Return:
The dict for the policies
"""
auto_policy_dict = {}

from transformers.models.bert.modeling_bert import BertForMaskedLM
from .bert import BertForMaskedLMPolicy
auto_policy_dict[BertForMaskedLM] = BertForMaskedLMPolicy

from transformers.models.bert.modeling_bert import BertForSequenceClassification
from .bert import BertForSequenceClassificationPolicy
auto_policy_dict[BertForSequenceClassification] = BertForSequenceClassificationPolicy

return auto_policy_dict

def get_autopolicy(model:nn.Module):
"""
Return the auto policy for the model

Args:
model: The model to be used

Return:
The auto policy for the model
"""
auto_policy_dict = build_policies()
policy = auto_policy_dict.get(model.__class__, None)
if policy is None:
raise NotImplementedError(f"Auto policy for {model.__class__.__qualname__} is not implemented\n Supported models are {[i.__qualname__ for i in auto_policy_dict.keys()]}")
return policy

# from transformers.models.bert.modeling_bert import BertForMaskedLM, BertForPreTraining
# model = BertForPreTraining
# policy = get_autopolicy(model)
# print(policy)
182 changes: 182 additions & 0 deletions colossalai/shardformer/policies/basepolicy.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,182 @@
# part of code modified from https://github.com/tunib-ai/parallelformers

import torch
import torch.nn as nn
import colossalai.nn as col_nn
from typing import Any, Dict, List, Type, Tuple, Callable
from transformers import AutoConfig
from dataclasses import dataclass, field

@dataclass
class Argument:
attr_dict : Dict[str, Any]
param_funcs : List[Callable]
binding_layers : List[nn.Module] = field(default_factory=list)

@dataclass
class Layer:
"""
The layer object for the policy

Args:
weight: The weight name of the layer
bias: The bias name of the layer
replace_layer: The layer to replace the original layer
ignore: Whether to ignore this layer if it is not in the model
"""
weight: str = None
bias: str = None
replace_layer: Any = None
ignore: bool = False


@dataclass
class Col_Layer(Layer):
"""
Class for col shard layer in MegatronLM
"""
gather_output: bool = False


@dataclass
class Row_Layer(Layer):
"""
Class for col shard layer in MegatronLM
"""
pass


class Policy():
"""
The base class for all the policies
For each different model, it should have a different policy class, like BertPolicy for Bert Model
or OPTPolicy for OPT model.
AutoPolicy:
shardformer already defined some policies for huggingface model, just set custom_policy = None
to use the auto policy. In shardformer autopolicy, we define a base policy for one type model,
like BertPolicy, and for each different Bert modle in huggingface like, BertForMaskedLM,
BertForSequenceClassification, etc., for each different Bert model we difine different policy class
and overwrite the method inject_policy

CustomPolicy:
"""
@staticmethod
def argument_policy(model_config, shard_config: int) -> Dict[nn.Module,Argument]:
"""
Return a dict, the key is layer will be modified and the value is the Argument class with param setting and param functions

Args:
model_config: The config of transformer model
shard_setting: The config of distributed model

Return:
Dict for the modify policy,
{
origin layer class1 (nn.Module): Argument(
attr_dict = {
argument1: value1,
argument2: value2,
...
},
param_funcs = [
staticmethod1,
staticmethod2,
...
]
),
origin layer class2 (nn.Module): Argument(
attr_dict = {
argument1: value1,
argument2: value2,
...
},
param_funcs = [
staticmethod1,
staticmethod2,
...
]
),
...
}

"""
raise NotImplementedError


@staticmethod
def inject_policy() -> Tuple[nn.Module, nn.Module]:
"""
Return the dict for the inject model

Return:
The injected model, key is the original model and value is the new shardmodel
"""
return ()


@staticmethod
def attn_in() -> List:
"""
Attention qkv layer

Returns:
List[Layer]: List of layer object, each layer is the new
"""
return NotImplementedError


@staticmethod
def attn_out() -> List:
"""
Attention output projection layer

Returns:
List[Layer]: List of layer object
"""
return NotImplementedError


@staticmethod
def mlp_in() -> List:
"""
h -> 4h mlp layer

Returns:
List[Layer]: List of layer object
"""
return NotImplementedError


@staticmethod
def mlp_out() -> List:
"""
4h -> h mlp layer

Returns:
List[Layer]: List of layer object
"""
return NotImplementedError


@staticmethod
def embedding()->List:
"""
Partially slice the embedding layer
vocab_size->vocab_size//gpu_nums

Return:
List[Layer]: List of layer object
"""
return NotImplementedError


@staticmethod
def unembedding()->List:
"""
Partially slice the embedding layer
vocab_size->vocab_size//gpu_nums

Return:
List[Layer]: List of layer object
"""
return NotImplementedError
Loading