Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 15 additions & 5 deletions colossalai/shardformer/policies/autopolicy.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,17 +25,19 @@ class PolicyLocation:
_POLICY_LIST = {
# BERT
"transformers.models.bert.modeling_bert.BertModel":
PolicyLocation(file_name="bert", class_name="BertPolicy"),
PolicyLocation(file_name="bert", class_name="BertModelPolicy"),
"transformers.models.bert.modeling_bert.BertForPreTraining":
PolicyLocation(file_name="bert", class_name="BertForPretrainingPolicy"),
"transformers.models.bert.modeling_bert.BertForMaskedLM":
PolicyLocation(file_name="bert", class_name="BertForMaskedLMPolicy"),
"transformers.models.bert.modeling_bert.BertLMHeadModel":
PolicyLocation(file_name="bert", class_name="BertLMHeadModelPolicy"),
"transformers.models.bert.modeling_bert.BertForNextSentencePrediction":
PolicyLocation(file_name="bert", class_name="BertForNextSentencePredictionPolicy"),
"transformers.models.bert.modeling_bert.BertForMaskedLM":
PolicyLocation(file_name="bert", class_name="BertForMaskedLMPolicy"),
"transformers.models.bert.modeling_bert.BertForSequenceClassification":
PolicyLocation(file_name="bert", class_name="BertForSequenceClassificationPolicy"),
"transformers.models.bert.modeling_bert.BertForTokenClassification":
PolicyLocation(file_name="bert", class_name="BertForTokenClassificationPolicy"),
"transformers.models.bert.modeling_bert.BertForNextSentencePrediction":
PolicyLocation(file_name="bert", class_name="BertForNextSentencePredictionPolicy"),
"transformers.models.bert.modeling_bert.BertForMultipleChoice":
PolicyLocation(file_name="bert", class_name="BertForMultipleChoicePolicy"),

Expand All @@ -58,6 +60,14 @@ class PolicyLocation:
# GPT2
"transformers.models.gpt2.modeling_gpt2.GPT2Model":
PolicyLocation(file_name="gpt2", class_name="GPT2ModelPolicy"),
"transformers.models.gpt2.modeling_gpt2.GPT2LMHeadModel":
PolicyLocation(file_name="gpt2", class_name="GPT2LMHeadModelPolicy"),
"transformers.models.gpt2.modeling_gpt2.GPT2DoubleHeadsModel":
PolicyLocation(file_name="gpt2", class_name="GPT2DoubleHeadsModelPolicy"),
"transformers.models.gpt2.modeling_gpt2.GPT2ForTokenClassification":
PolicyLocation(file_name="gpt2", class_name="GPT2ForTokenClassificationPolicy"),
"transformers.models.gpt2.modeling_gpt2.GPT2ForSequenceClassification":
PolicyLocation(file_name="gpt2", class_name="GPT2ForSequenceClassificationPolicy"),
}


Expand Down
23 changes: 15 additions & 8 deletions colossalai/shardformer/policies/bert.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,8 +131,8 @@ def postprocess(self):
return self.model


# BertForMaskedLM
class BertForMaskedLMPolicy(BertPolicy):
# BertLMHeadModel
class BertLMHeadModelPolicy(BertPolicy):

def __init__(self) -> None:
super().__init__()
Expand Down Expand Up @@ -162,8 +162,8 @@ def postprocess(self):
return self.model


# BertLMHeadModel
class BertLMHeadModelPolicy(BertPolicy):
# BertForMaskedLM
class BertForMaskedLMPolicy(BertPolicy):

def __init__(self) -> None:
super().__init__()
Expand Down Expand Up @@ -193,15 +193,22 @@ def postprocess(self):
return self.model


# BertForNextSentencePrediction
class BertForNextSentencePredictionPolicy(BertPolicy):
# BertForSequenceClassification
class BertForSequenceClassificationPolicy(BertPolicy):

def __init__(self) -> None:
super().__init__()


# BertForSequenceClassification
class BertForSequenceClassificationPolicy(BertPolicy):
# BertForTokenClassification
class BertForTokenClassificationPolicy(BertPolicy):

def __init__(self) -> None:
super().__init__()


# BertForNextSentencePrediction
class BertForNextSentencePredictionPolicy(BertPolicy):

def __init__(self) -> None:
super().__init__()
Expand Down
81 changes: 79 additions & 2 deletions colossalai/shardformer/policies/gpt2.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
from transformers.models.gpt2.modeling_gpt2 import GPT2Block, GPT2Model
import torch.nn as nn
from transformers.models.gpt2.modeling_gpt2 import GPT2Block, GPT2DoubleHeadsModel, GPT2LMHeadModel, GPT2Model

import colossalai.shardformer.layer as col_nn

from .._utils import getattr_, setattr_
from .basepolicy import ModulePolicyDescription, Policy, SubModuleReplacementDescription


Expand Down Expand Up @@ -82,7 +84,6 @@ def module_policy(self):
}

def new_model_class(self):

return self.model

def postprocess(self):
Expand All @@ -94,3 +95,79 @@ class GPT2ModelPolicy(GPT2Policy):

def __init__(self) -> None:
super().__init__()


# GPT2LMHeadModel
class GPT2LMHeadModelPolicy(GPT2Policy):

def __init__(self) -> None:
super().__init__()

def module_policy(self):
module_policy = super().module_policy()
addon_module = {
GPT2LMHeadModel:
ModulePolicyDescription(attribute_replacement={},
param_replacement=[],
sub_module_replacement=[
SubModuleReplacementDescription(suffix="lm_head",
target_module=col_nn.Linear1D_Col,
kwargs={"gather_output": True})
])
}
module_policy.update(addon_module)
return module_policy

def postprocess(self):
binding_map = {"transformer.wte.weight": "lm_head.weight"}
for k, v in binding_map.items():
param = getattr_(self.model, k)
param = nn.Parameter(param)
setattr_(self.model, k, param)
setattr_(self.model, v, param)
return self.model


# GPT22DoubleHeadsModel
class GPT2DoubleHeadsModelPolicy(GPT2Policy):

def __init__(self) -> None:
super().__init__()

def module_policy(self):
module_policy = super().module_policy()
addon_module = {
GPT2DoubleHeadsModel:
ModulePolicyDescription(attribute_replacement={},
param_replacement=[],
sub_module_replacement=[
SubModuleReplacementDescription(suffix="lm_head",
target_module=col_nn.Linear1D_Col,
kwargs={"gather_output": True})
])
}
module_policy.update(addon_module)
return module_policy

def postprocess(self):
binding_map = {"transformer.wte.weight": "lm_head.weight"}
for k, v in binding_map.items():
param = getattr_(self.model, k)
param = nn.Parameter(param)
setattr_(self.model, k, param)
setattr_(self.model, v, param)
return self.model


# GPT2ForTokenClassification
class GPT2ForTokenClassificationPolicy(GPT2Policy):

def __init__(self) -> None:
super().__init__()


# GPT2ForSequenceClassification
class GPT2ForSequenceClassificationPolicy(GPT2Policy):

def __init__(self) -> None:
super().__init__()
140 changes: 102 additions & 38 deletions tests/kit/model_zoo/transformers/bert.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,83 +6,147 @@
# ===============================
# Register single-sentence BERT
# ===============================
BATCH_SIZE = 2
SEQ_LENGTH = 16


def data_gen_fn():
input_ids = torch.zeros((BATCH_SIZE, SEQ_LENGTH), dtype=torch.int64)
token_type_ids = torch.zeros((BATCH_SIZE, SEQ_LENGTH), dtype=torch.int64)
attention_mask = torch.zeros((BATCH_SIZE, SEQ_LENGTH), dtype=torch.int64)
# define data gen function
def data_gen():
# Generated from following code snippet
#
# from transformers import BertTokenizer
# input = 'Hello, my dog is cute'
# tokenized_input = tokenizer(input, return_tensors='pt')
# input_ids = tokenized_input['input_ids']
# attention_mask = tokenized_input['attention_mask']
# token_type_ids = tokenized_input['token_type_ids']
input_ids = torch.tensor([[101, 7592, 1010, 2026, 3899, 2003, 10140, 102]], dtype=torch.int64)
token_type_ids = torch.tensor([[0, 0, 0, 0, 0, 0, 0, 0]], dtype=torch.int64)
attention_mask = torch.tensor([[1, 1, 1, 1, 1, 1, 1, 1]], dtype=torch.int64)
return dict(input_ids=input_ids, token_type_ids=token_type_ids, attention_mask=attention_mask)


def data_gen_for_lm():
# LM data gen
# the `labels` of LM is the token of the output, cause no padding, use `input_ids` as `labels`
data = data_gen()
data['labels'] = data['input_ids'].clone()
return data


def data_gen_for_pretraining():
# pretraining data gen
# `next_sentence_label` is the label for next sentence prediction, 0 or 1
data = data_gen_for_lm()
data['next_sentence_label'] = torch.tensor([1], dtype=torch.int64)
return data


def data_gen_for_sequence_classification():
# sequence classification data gen
# `labels` is the label for sequence classification, 0 or 1
data = data_gen()
data['labels'] = torch.tensor([1], dtype=torch.int64)
return data


def data_gen_for_token_classification():
# token classification data gen
# `labels` is the type not the token id for token classification, 0 or 1
data = data_gen()
data['labels'] = torch.tensor([[1, 0, 0, 0, 0, 0, 0, 0]], dtype=torch.int64)
return data


def data_gen_for_mcq():
# multiple choice question data gen
# Generated from following code snippet
#
# tokenizer = transformers.BertTokenizer.from_pretrained("bert-base-uncased")
# prompt = "In Italy, pizza served in formal settings, such as at a restaurant, is presented unsliced."
# choice0 = "It is eaten with a fork and a knife."
# choice1 = "It is eaten while held in the hand."
# data = tokenizer([prompt, prompt], [choice0, choice1], return_tensors="pt", padding=True)
# data = {k: v.unsqueeze(0) for k, v in encoding.items()}
# data['labels'] = torch.tensor([0], dtype=torch.int64)
input_ids = torch.tensor([[[
101, 1999, 3304, 1010, 10733, 2366, 1999, 5337, 10906, 1010, 2107, 2004, 2012, 1037, 4825, 1010, 2003, 3591,
4895, 14540, 6610, 2094, 1012, 102, 2009, 2003, 8828, 2007, 1037, 9292, 1998, 1037, 5442, 1012, 102
],
[
101, 1999, 3304, 1010, 10733, 2366, 1999, 5337, 10906, 1010, 2107, 2004, 2012, 1037,
4825, 1010, 2003, 3591, 4895, 14540, 6610, 2094, 1012, 102, 2009, 2003, 8828, 2096,
2218, 1999, 1996, 2192, 1012, 102, 0
]]])
token_type_ids = torch.tensor(
[[[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1],
[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0]]])
attention_mask = torch.tensor(
[[[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1],
[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0]]])
labels = torch.tensor([0], dtype=torch.int64)

return dict(input_ids=input_ids, token_type_ids=token_type_ids, attention_mask=attention_mask, labels=labels)


# define output transform function
output_transform_fn = lambda x: x

config = transformers.BertConfig(hidden_size=128, num_hidden_layers=2, num_attention_heads=4, intermediate_size=256)
# define loss funciton
loss_fn_for_bert_model = lambda x: x.pooler_output.mean()
loss_fn = lambda x: x.loss

config = transformers.BertConfig(hidden_size=128,
num_hidden_layers=2,
num_attention_heads=4,
intermediate_size=256,
hidden_dropout_prob=0,
attention_probs_dropout_prob=0)

# register the BERT variants
model_zoo.register(name='transformers_bert',
model_fn=lambda: transformers.BertModel(config),
data_gen_fn=data_gen_fn,
data_gen_fn=data_gen,
output_transform_fn=output_transform_fn,
loss_fn=loss_fn_for_bert_model,
model_attribute=ModelAttribute(has_control_flow=True))
model_zoo.register(name='transformers_bert_for_pretraining',
model_fn=lambda: transformers.BertForPreTraining(config),
data_gen_fn=data_gen_fn,
data_gen_fn=data_gen_for_pretraining,
output_transform_fn=output_transform_fn,
loss_fn=loss_fn,
model_attribute=ModelAttribute(has_control_flow=True))
model_zoo.register(name='transformers_bert_lm_head_model',
model_fn=lambda: transformers.BertLMHeadModel(config),
data_gen_fn=data_gen_fn,
data_gen_fn=data_gen_for_lm,
output_transform_fn=output_transform_fn,
loss_fn=loss_fn,
model_attribute=ModelAttribute(has_control_flow=True))
model_zoo.register(name='transformers_bert_for_masked_lm',
model_fn=lambda: transformers.BertForMaskedLM(config),
data_gen_fn=data_gen_fn,
data_gen_fn=data_gen_for_lm,
output_transform_fn=output_transform_fn,
loss_fn=loss_fn,
model_attribute=ModelAttribute(has_control_flow=True))
model_zoo.register(name='transformers_bert_for_sequence_classification',
model_fn=lambda: transformers.BertForSequenceClassification(config),
data_gen_fn=data_gen_fn,
data_gen_fn=data_gen_for_sequence_classification,
output_transform_fn=output_transform_fn,
loss_fn=loss_fn,
model_attribute=ModelAttribute(has_control_flow=True))
model_zoo.register(name='transformers_bert_for_token_classification',
model_fn=lambda: transformers.BertForTokenClassification(config),
data_gen_fn=data_gen_fn,
data_gen_fn=data_gen_for_token_classification,
output_transform_fn=output_transform_fn,
loss_fn=loss_fn,
model_attribute=ModelAttribute(has_control_flow=True))


# ===============================
# Register multi-sentence BERT
# ===============================
def data_gen_for_next_sentence():
tokenizer = transformers.BertTokenizer.from_pretrained("bert-base-uncased")
prompt = "In Italy, pizza served in formal settings, such as at a restaurant, is presented unsliced."
next_sentence = "The sky is blue due to the shorter wavelength of blue light."
encoding = tokenizer(prompt, next_sentence, return_tensors="pt")
return encoding


def data_gen_for_mcq():
tokenizer = transformers.BertTokenizer.from_pretrained("bert-base-uncased")
prompt = "In Italy, pizza served in formal settings, such as at a restaurant, is presented unsliced."
choice0 = "It is eaten with a fork and a knife."
choice1 = "It is eaten while held in the hand."
encoding = tokenizer([prompt, prompt], [choice0, choice1], return_tensors="pt", padding=True)
encoding = {k: v.unsqueeze(0) for k, v in encoding.items()}
return encoding


# register the following models
model_zoo.register(name='transformers_bert_for_next_sentence',
model_fn=lambda: transformers.BertForNextSentencePrediction(config),
data_gen_fn=data_gen_for_next_sentence,
data_gen_fn=data_gen_for_sequence_classification,
output_transform_fn=output_transform_fn,
loss_fn=loss_fn,
model_attribute=ModelAttribute(has_control_flow=True))
model_zoo.register(name='transformers_bert_for_mcq',
model_fn=lambda: transformers.BertForMultipleChoice(config),
data_gen_fn=data_gen_for_mcq,
output_transform_fn=output_transform_fn,
loss_fn=loss_fn,
model_attribute=ModelAttribute(has_control_flow=True))
Loading