Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 10 additions & 0 deletions colossalai/shardformer/policies/autopolicy.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,16 @@ class PolicyLocation:
PolicyLocation(file_name="gpt2", class_name="GPT2ForTokenClassificationPolicy"),
"transformers.models.gpt2.modeling_gpt2.GPT2ForSequenceClassification":
PolicyLocation(file_name="gpt2", class_name="GPT2ForSequenceClassificationPolicy"),

# OPT
"transformers.models.opt.modeling_opt.OPTModel":
PolicyLocation(file_name="opt", class_name="OPTModelPolicy"),
"transformers.models.opt.modeling_opt.OPTForCausalLM":
PolicyLocation(file_name="opt", class_name="OPTForCausalLMPolicy"),
"transformers.models.opt.modeling_opt.OPTForSequenceClassification":
PolicyLocation(file_name="opt", class_name="OPTForSequenceClassificationPolicy"),
Comment thread
FrankLeeeee marked this conversation as resolved.
"transformers.models.opt.modeling_opt.OPTForQuestionAnswering":
PolicyLocation(file_name="opt", class_name="OPTForQuestionAnsweringPolicy"),
}


Expand Down
133 changes: 133 additions & 0 deletions colossalai/shardformer/policies/opt.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,133 @@
from transformers.models.opt.modeling_opt import (
OPTAttention,
OPTDecoder,
OPTDecoderLayer,
OPTForCausalLM,
OPTForSequenceClassification,
)

from colossalai.shardformer.layer import Embedding1D, FusedLayerNorm, Linear1D_Col, Linear1D_Row

from .basepolicy import ModulePolicyDescription, Policy, SubModuleReplacementDescription


class OPTPolicy(Policy):

def preprocess(self):
# reshape the embedding layer
r"""
Reshape the Embedding layer to make the embedding dimension divisible by world_size
"""
vocab_size = self.model.config.vocab_size
world_size = self.shard_config.tensor_parallel_size
if vocab_size % world_size != 0:
new_vocab_size = vocab_size + world_size - vocab_size % world_size
self.model.resize_token_embeddings(new_vocab_size)
return self.model

def module_policy(self):
base_policy = {
OPTDecoder:
ModulePolicyDescription(attribute_replacement={},
param_replacement=[],
sub_module_replacement=[
SubModuleReplacementDescription(
suffix="embed_tokens",
target_module=Embedding1D,
)
]),
OPTDecoderLayer:
ModulePolicyDescription(attribute_replacement={},
param_replacement=[],
sub_module_replacement=[
SubModuleReplacementDescription(
suffix="fc1",
target_module=Linear1D_Col,
),
SubModuleReplacementDescription(
suffix="fc2",
target_module=Linear1D_Row,
)
]),
OPTAttention:
ModulePolicyDescription(attribute_replacement={
"embed_dim": self.model.config.hidden_size // self.shard_config.tensor_parallel_size,
"num_heads": self.model.config.num_attention_heads // self.shard_config.tensor_parallel_size
},
param_replacement=[],
sub_module_replacement=[
SubModuleReplacementDescription(
suffix="q_proj",
target_module=Linear1D_Col,
),
SubModuleReplacementDescription(
suffix="k_proj",
target_module=Linear1D_Col,
),
SubModuleReplacementDescription(
suffix="v_proj",
target_module=Linear1D_Col,
),
SubModuleReplacementDescription(
suffix="out_proj",
target_module=Linear1D_Row,
),
]),
}
if self.shard_config.fused_layernorm:
base_policy[OPTDecoder].sub_module_replacement.append(
SubModuleReplacementDescription(suffix="final_layer_norm",
target_module=FusedLayerNorm,
ignore_if_not_exist=True))
base_policy[OPTDecoderLayer].sub_module_replacement.extend([
SubModuleReplacementDescription(suffix="self_attn_layer_norm",
target_module=FusedLayerNorm,
ignore_if_not_exist=True),
SubModuleReplacementDescription(suffix="final_layer_norm",
target_module=FusedLayerNorm,
ignore_if_not_exist=True)
])
return base_policy

def new_model_class(self):
return None

def postprocess(self):
return self.model


class OPTModelPolicy(OPTPolicy):

def __init__(self) -> None:
super().__init__()


class OPTForCausalLMPolicy(OPTPolicy):

def module_policy(self):
policy = super().module_policy()
new_item = {
OPTForCausalLM:
ModulePolicyDescription(attribute_replacement={},
param_replacement=[],
sub_module_replacement=[
SubModuleReplacementDescription(suffix="lm_head",
target_module=Linear1D_Col,
kwargs=dict(gather_output=True))
])
}

policy.update(new_item)
return policy


class OPTForSequenceClassificationPolicy(OPTPolicy):

def __init__(self) -> None:
super().__init__()


class OPTForQuestionAnsweringPolicy(OPTPolicy):

def __init__(self) -> None:
super().__init__()
57 changes: 52 additions & 5 deletions tests/kit/model_zoo/transformers/opt.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,14 +11,47 @@


def data_gen():
input_ids = torch.zeros((BATCH_SIZE, SEQ_LENGTH), dtype=torch.int64)
attention_mask = torch.zeros((BATCH_SIZE, SEQ_LENGTH), dtype=torch.int64)
input_ids = torch.Tensor([[1, 15043, 29892, 590, 11203, 338, 274, 1082]]).long()
attention_mask = torch.Tensor([[1, 1, 1, 1, 1, 1, 1, 1]]).long()
return dict(input_ids=input_ids, attention_mask=attention_mask)


output_transform_fn = lambda x: x
def data_gen_for_causal_lm():
# LM data gen
# the `labels` of LM is the token of the output, cause no padding, use `input_ids` as `labels`
data = data_gen()
labels = data['input_ids'].clone()
data['labels'] = labels
return data
Comment thread
FrankLeeeee marked this conversation as resolved.


def data_gen_for_sequence_classification():
# LM data gen
# the `labels` of LM is the token of the output, cause no padding, use `input_ids` as `labels`
data = data_gen()
labels = data['input_ids'].clone()
data['labels'] = torch.tensor([1])
return data


def data_gen_for_question_answering():
# LM data gen
# the `labels` of LM is the token of the output, cause no padding, use `input_ids` as `labels`
data = data_gen()
data['start_positions'] = torch.tensor([0])
data['end_positions'] = torch.tensor([1])
return data


config = transformers.OPTConfig(hidden_size=128, num_hidden_layers=2, num_attention_heads=4)
output_transform_fn = lambda x: x
loss_fn_for_opt_model = lambda x: x.last_hidden_state.mean()
loss_fn_for_lm = lambda x: x.loss
config = transformers.OPTConfig(
hidden_size=128,
num_hidden_layers=2,
num_attention_heads=4,
dropout=0,
)

# register the following models
# transformers.OPTModel,
Expand All @@ -27,9 +60,23 @@ def data_gen():
model_fn=lambda: transformers.OPTModel(config),
data_gen_fn=data_gen,
output_transform_fn=output_transform_fn,
loss_fn=loss_fn_for_opt_model,
model_attribute=ModelAttribute(has_control_flow=True))
model_zoo.register(name='transformers_opt_for_causal_lm',
model_fn=lambda: transformers.OPTForCausalLM(config),
data_gen_fn=data_gen,
data_gen_fn=data_gen_for_causal_lm,
output_transform_fn=output_transform_fn,
loss_fn=loss_fn_for_lm,
model_attribute=ModelAttribute(has_control_flow=True))
model_zoo.register(name='transformers_opt_for_question_answering',
model_fn=lambda: transformers.OPTForQuestionAnswering(config),
data_gen_fn=data_gen_for_question_answering,
output_transform_fn=output_transform_fn,
loss_fn=loss_fn_for_lm,
model_attribute=ModelAttribute(has_control_flow=True))
model_zoo.register(name='transformers_opt_for_sequence_classification',
model_fn=lambda: transformers.OPTForSequenceClassification(config),
data_gen_fn=data_gen_for_sequence_classification,
output_transform_fn=output_transform_fn,
loss_fn=loss_fn_for_lm,
model_attribute=ModelAttribute(has_control_flow=True))
3 changes: 1 addition & 2 deletions tests/test_fx/test_tracer/test_hf_model/test_hf_opt.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,10 +11,9 @@
@clear_cache_before_run()
def test_opt():
sub_registry = model_zoo.get_sub_registry('transformers_opt')

for name, (model_fn, data_gen_fn, _, _, _) in sub_registry.items():
model = model_fn()
trace_model_and_compare_output(model, data_gen_fn)
trace_model_and_compare_output(model, data_gen_fn, ignore_data=['labels', 'start_positions', 'end_positions'])


if __name__ == '__main__':
Expand Down
4 changes: 1 addition & 3 deletions tests/test_shardformer/test_model/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,6 @@ def run_forward(original_model, sharded_model, data_gen_fn, output_transform_fn,
# switch to train mode
original_model.train()
sharded_model.train()

# run forward
org_output = original_model(**data)
org_output = output_transform_fn(org_output)
Expand All @@ -34,5 +33,4 @@ def run_forward(original_model, sharded_model, data_gen_fn, output_transform_fn,
shard_output = sharded_model(**data)
shard_output = output_transform_fn(shard_output)
shard_loss = loss_fn(shard_output)

return org_output, org_loss, shard_output, shard_loss
return org_output, org_loss, shard_output, shard_loss
67 changes: 67 additions & 0 deletions tests/test_shardformer/test_model/test_shard_opt.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
import copy
import os

import pytest
import torch

import colossalai
from colossalai.logging import disable_existing_loggers
from colossalai.testing import (
assert_hf_output_close,
check_state_dict_equal,
clear_cache_before_run,
rerun_if_address_is_in_use,
spawn,
)
from tests.kit.model_zoo import model_zoo
from tests.test_shardformer.test_model._utils import build_model, run_forward

os.environ['TRANSFORMERS_NO_ADVISORY_WARNINGS'] = 'true'


def check_forward_backward(org_model, sharded_model, data_gen_fn, output_transform_fn, loss_fn):
org_output, org_loss, shard_output, shard_loss = run_forward(org_model, sharded_model, data_gen_fn,
output_transform_fn, loss_fn)
assert_hf_output_close(org_output, shard_output, ignore_keys=['past_key_values'], rtol=1e-4)

# run backward
org_loss.backward()
shard_loss.backward()

# check grad
if hasattr(org_model, 'model'):
opt_model = org_model.model
shard_opt_model = sharded_model.model
else:
opt_model = org_model
shard_opt_model = sharded_model

org_grad = opt_model.decoder.layers[0].self_attn.q_proj.weight.grad
shard_grad = shard_opt_model.decoder.layers[0].self_attn.q_proj.weight.grad

shard_grad_list = [torch.zeros([*shard_grad.shape]).to('cuda') for _ in range(4)]
shard_grad = torch.distributed.all_gather(shard_grad_list, shard_grad)
all_shard_grad = torch.cat(shard_grad_list, dim=0)
assert torch.allclose(org_loss, shard_loss,
atol=1e-5), f"shard model loss is not equal to orgin model loss\n{org_loss}\n{shard_loss}"
assert torch.allclose(org_grad, all_shard_grad,
atol=1e-5), f"shard model grad is not equal to orgin model grad\n{org_grad}\n{shard_grad}"


def check_OPTModel(rank, world_size, port):
disable_existing_loggers()
colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')

sub_model_zoo = model_zoo.get_sub_registry('transformers_opt')
for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items():
org_model, sharded_model = build_model(world_size, model_fn)
check_forward_backward(org_model, sharded_model, data_gen_fn, output_transform_fn, loss_fn)

torch.cuda.empty_cache()


@pytest.mark.dist
@rerun_if_address_is_in_use()
@clear_cache_before_run()
def test_OPTModel():
spawn(check_OPTModel, 4)