From c187562fa3283bfdd25a16db03c6bc34787851f5 Mon Sep 17 00:00:00 2001 From: Mingyan Jiang <1829166702@qq.com> Date: Tue, 27 Jun 2023 11:23:10 +0800 Subject: [PATCH 1/4] [shardformer] shardformer support opt models --- colossalai/shardformer/policies/autopolicy.py | 8 + colossalai/shardformer/policies/opt.py | 146 ++++++++++++++++++ tests/kit/model_zoo/transformers/opt.py | 19 ++- tests/test_shardformer/test_model/_utils.py | 4 +- .../test_model/test_shard_opt.py | 60 +++++++ 5 files changed, 230 insertions(+), 7 deletions(-) create mode 100644 colossalai/shardformer/policies/opt.py create mode 100644 tests/test_shardformer/test_model/test_shard_opt.py diff --git a/colossalai/shardformer/policies/autopolicy.py b/colossalai/shardformer/policies/autopolicy.py index b1b8c6156f9f..b8d2e5627fab 100644 --- a/colossalai/shardformer/policies/autopolicy.py +++ b/colossalai/shardformer/policies/autopolicy.py @@ -68,6 +68,14 @@ class PolicyLocation: PolicyLocation(file_name="gpt2", class_name="GPT2ForTokenClassificationPolicy"), "transformers.models.gpt2.modeling_gpt2.GPT2ForSequenceClassification": PolicyLocation(file_name="gpt2", class_name="GPT2ForSequenceClassificationPolicy"), + + # OPT + "transformers.models.opt.modeling_opt.OPTModel": + PolicyLocation(file_name="opt", class_name="OPTModelPolicy"), + "transformers.models.opt.modeling_opt.OPTForCausalLM": + PolicyLocation(file_name="opt", class_name="OPTForCausalLMPolicy"), + "transformers.models.opt.modeling_opt.OPTForSequenceClassification": + PolicyLocation(file_name="opt", class_name="OPTForSequenceClassificationPolicy"), } diff --git a/colossalai/shardformer/policies/opt.py b/colossalai/shardformer/policies/opt.py new file mode 100644 index 000000000000..fcad60a5f32c --- /dev/null +++ b/colossalai/shardformer/policies/opt.py @@ -0,0 +1,146 @@ +from transformers.models.opt.modeling_opt import ( + OPTDecoder, + OPTDecoderLayer, + OPTAttention, + OPTForCausalLM, + OPTForSequenceClassification +) +from colossalai.shardformer.layer import Linear1D_Col, Linear1D_Row, FusedLayerNorm, Embedding1D +from .basepolicy import ModulePolicyDescription, Policy, SubModuleReplacementDescription + +class OPTPolicy(Policy): + + def preprocess(self): + # reshape the embedding layer + r""" + Reshape the Embedding layer to make the embedding dimension divisible by world_size + """ + vocab_size = self.model.config.vocab_size + world_size = self.shard_config.tensor_parallel_size + if vocab_size % world_size != 0: + new_vocab_size = vocab_size + world_size - vocab_size % world_size + self.model.resize_token_embeddings(new_vocab_size) + return self.model + + def module_policy(self): + base_policy = { + OPTDecoder: + ModulePolicyDescription( + attribute_replacement={}, + param_replacement=[], + sub_module_replacement=[ + SubModuleReplacementDescription( + suffix="embed_tokens", + target_module=Embedding1D, + ) + ]), + OPTDecoderLayer: + ModulePolicyDescription( + attribute_replacement={}, + param_replacement=[], + sub_module_replacement=[ + SubModuleReplacementDescription( + suffix="fc1", + target_module=Linear1D_Col, + ), + SubModuleReplacementDescription( + suffix="fc2", + target_module=Linear1D_Row, + ) + ]), + OPTAttention: + ModulePolicyDescription( + attribute_replacement={ + "embed_dim": + self.model.config.hidden_size // self.shard_config.tensor_parallel_size, + "num_heads": + self.model.config.num_attention_heads // self.shard_config.tensor_parallel_size + }, + param_replacement=[], + sub_module_replacement=[ + SubModuleReplacementDescription( + suffix="q_proj", + target_module=Linear1D_Col, + ), + SubModuleReplacementDescription( + suffix="k_proj", + target_module=Linear1D_Col, + ), + SubModuleReplacementDescription( + suffix="v_proj", + target_module=Linear1D_Col, + ), + SubModuleReplacementDescription( + suffix="out_proj", + target_module=Linear1D_Row, + ), + ]), + } + if self.shard_config.fused_layernorm: + base_policy[OPTDecoder].sub_module_replacement.append( + SubModuleReplacementDescription( + suffix="final_layer_norm", + target_module=FusedLayerNorm, + ignore_if_not_exist=True + )) + base_policy[OPTDecoderLayer].sub_module_replacement.extend([ + SubModuleReplacementDescription( + suffix="self_attn_layer_norm", + target_module=FusedLayerNorm, + ignore_if_not_exist=True + ), + SubModuleReplacementDescription( + suffix="final_layer_norm", + target_module=FusedLayerNorm, + ignore_if_not_exist=True + )]) + return base_policy + + def new_model_class(self): + return None + + def postprocess(self): + return self.model + +class OPTModelPolicy(OPTPolicy): + def __init__(self) -> None: + super().__init__() + +class OPTForCausalLMPolicy(OPTPolicy): + def module_policy(self): + policy = super().module_policy() + new_item = { + OPTForCausalLM: + ModulePolicyDescription( + attribute_replacement={}, + param_replacement=[], + sub_module_replacement=[ + SubModuleReplacementDescription( + suffix="lm_head", + target_module=Linear1D_Col, + kwargs=dict(gather_output=True)) + ]) + } + + policy.update(new_item) + return policy + +class OPTForSequenceClassificationPolicy(OPTPolicy): + + def module_policy(self): + policy = super().module_policy() + new_item = { + OPTForSequenceClassification: + ModulePolicyDescription( + attribute_replacement={}, + param_replacement=[], + sub_module_replacement=[ + SubModuleReplacementDescription( + suffix="score", + target_module=Linear1D_Col, + kwargs=dict(gather_output=True)) + ]) + } + + policy.update(new_item) + return policy \ No newline at end of file diff --git a/tests/kit/model_zoo/transformers/opt.py b/tests/kit/model_zoo/transformers/opt.py index d9c4a0b3c23c..79dc457601f3 100644 --- a/tests/kit/model_zoo/transformers/opt.py +++ b/tests/kit/model_zoo/transformers/opt.py @@ -15,10 +15,19 @@ def data_gen(): attention_mask = torch.zeros((BATCH_SIZE, SEQ_LENGTH), dtype=torch.int64) return dict(input_ids=input_ids, attention_mask=attention_mask) +def data_gen_for_causal_lm(): + # LM data gen + # the `labels` of LM is the token of the output, cause no padding, use `input_ids` as `labels` + data = data_gen() + labels = data['input_ids'].clone() + data['labels'] = labels + return data output_transform_fn = lambda x: x - -config = transformers.OPTConfig(hidden_size=128, num_hidden_layers=2, num_attention_heads=4) +loss_fn_for_opt_model = lambda x: x.last_hidden_state.mean() +loss_fn_for_causal_lm = lambda x: x.loss +config = transformers.OPTConfig(hidden_size=128, num_hidden_layers=2, num_attention_heads=4, + dropout=0,) # register the following models # transformers.OPTModel, @@ -27,9 +36,11 @@ def data_gen(): model_fn=lambda: transformers.OPTModel(config), data_gen_fn=data_gen, output_transform_fn=output_transform_fn, + loss_fn=loss_fn_for_opt_model, model_attribute=ModelAttribute(has_control_flow=True)) model_zoo.register(name='transformers_opt_for_causal_lm', model_fn=lambda: transformers.OPTForCausalLM(config), - data_gen_fn=data_gen, + data_gen_fn=data_gen_for_causal_lm, output_transform_fn=output_transform_fn, - model_attribute=ModelAttribute(has_control_flow=True)) + loss_fn=loss_fn_for_causal_lm, + model_attribute=ModelAttribute(has_control_flow=True)) \ No newline at end of file diff --git a/tests/test_shardformer/test_model/_utils.py b/tests/test_shardformer/test_model/_utils.py index a282e0bb919e..ad7c408aeb38 100644 --- a/tests/test_shardformer/test_model/_utils.py +++ b/tests/test_shardformer/test_model/_utils.py @@ -25,7 +25,6 @@ def run_forward(original_model, sharded_model, data_gen_fn, output_transform_fn, # switch to train mode original_model.train() sharded_model.train() - # run forward org_output = original_model(**data) org_output = output_transform_fn(org_output) @@ -34,5 +33,4 @@ def run_forward(original_model, sharded_model, data_gen_fn, output_transform_fn, shard_output = sharded_model(**data) shard_output = output_transform_fn(shard_output) shard_loss = loss_fn(shard_output) - - return org_output, org_loss, shard_output, shard_loss + return org_output, org_loss, shard_output, shard_loss \ No newline at end of file diff --git a/tests/test_shardformer/test_model/test_shard_opt.py b/tests/test_shardformer/test_model/test_shard_opt.py new file mode 100644 index 000000000000..2d9b1fc67aa6 --- /dev/null +++ b/tests/test_shardformer/test_model/test_shard_opt.py @@ -0,0 +1,60 @@ +import os + +import pytest +import torch +import copy + +import colossalai +from colossalai.logging import disable_existing_loggers +from colossalai.testing import assert_hf_output_close, clear_cache_before_run, rerun_if_address_is_in_use, spawn, check_state_dict_equal +from tests.kit.model_zoo import model_zoo +from tests.test_shardformer.test_model._utils import build_model, run_forward + +os.environ['TRANSFORMERS_NO_ADVISORY_WARNINGS'] = 'true' + + +def check_forward_backward(org_model, sharded_model, data_gen_fn, output_transform_fn, loss_fn): + org_output, org_loss, shard_output, shard_loss = run_forward(org_model, sharded_model, data_gen_fn, + output_transform_fn, loss_fn) + assert_hf_output_close(org_output, shard_output, ignore_keys=['past_key_values'], rtol=1e-4) + + # run backward + org_loss.backward() + shard_loss.backward() + + # check grad + if hasattr(org_model, 'model'): + opt_model = org_model.model + shard_opt_model = sharded_model.model + else: + opt_model = org_model + shard_opt_model = sharded_model + + org_grad = opt_model.decoder.layers[0].self_attn.q_proj.weight.grad + shard_grad = shard_opt_model.decoder.layers[0].self_attn.q_proj.weight.grad + + shard_grad_list = [torch.zeros([*shard_grad.shape]).to('cuda') for _ in range(4)] + shard_grad = torch.distributed.all_gather(shard_grad_list, shard_grad) + all_shard_grad = torch.cat(shard_grad_list, dim=0) + assert torch.allclose(org_loss, shard_loss, + atol=1e-5), f"shard model loss is not equal to orgin model loss\n{org_loss}\n{shard_loss}" + assert torch.allclose(org_grad, all_shard_grad, + atol=1e-5), f"shard model grad is not equal to orgin model grad\n{org_grad}\n{shard_grad}" + + +def check_OPTModel(rank, world_size, port): + disable_existing_loggers() + colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') + + sub_model_zoo = model_zoo.get_sub_registry('transformers_opt') + for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items(): + org_model, sharded_model = build_model(world_size, model_fn) + check_forward_backward(org_model, sharded_model, data_gen_fn, output_transform_fn, loss_fn) + + torch.cuda.empty_cache() + +@pytest.mark.dist +@rerun_if_address_is_in_use() +@clear_cache_before_run() +def test_OPTModel(): + spawn(check_OPTModel, 4) \ No newline at end of file From 684cb96c22990d43506fd47144759e99d4ba9fe5 Mon Sep 17 00:00:00 2001 From: Mingyan Jiang <1829166702@qq.com> Date: Tue, 27 Jun 2023 12:53:37 +0800 Subject: [PATCH 2/4] [shardformer] shardformer support opt models, fix --- tests/test_fx/test_tracer/test_hf_model/hf_tracer_utils.py | 1 + tests/test_fx/test_tracer/test_hf_model/test_hf_opt.py | 3 +-- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/test_fx/test_tracer/test_hf_model/hf_tracer_utils.py b/tests/test_fx/test_tracer/test_hf_model/hf_tracer_utils.py index 58c8132e1490..0c4aad30254a 100644 --- a/tests/test_fx/test_tracer/test_hf_model/hf_tracer_utils.py +++ b/tests/test_fx/test_tracer/test_hf_model/hf_tracer_utils.py @@ -22,6 +22,7 @@ def trace_model_and_compare_output(model, data_gen, ignore_data: List[str] = Non try: meta_args = {k: v.to('meta') for k, v in inputs.items()} gm = symbolic_trace(model, meta_args=meta_args) + # gm = None except Exception as e: raise RuntimeError(f"Failed to trace {model.__class__.__name__}, error: {e}") diff --git a/tests/test_fx/test_tracer/test_hf_model/test_hf_opt.py b/tests/test_fx/test_tracer/test_hf_model/test_hf_opt.py index f528db6a64ef..ead535d06f51 100644 --- a/tests/test_fx/test_tracer/test_hf_model/test_hf_opt.py +++ b/tests/test_fx/test_tracer/test_hf_model/test_hf_opt.py @@ -11,10 +11,9 @@ @clear_cache_before_run() def test_opt(): sub_registry = model_zoo.get_sub_registry('transformers_opt') - for name, (model_fn, data_gen_fn, _, _, _) in sub_registry.items(): model = model_fn() - trace_model_and_compare_output(model, data_gen_fn) + trace_model_and_compare_output(model, data_gen_fn, ignore_data=['labels']) if __name__ == '__main__': From e3036a60e3b26bfbe664af2435de20ca6e5696ed Mon Sep 17 00:00:00 2001 From: Mingyan Jiang <1829166702@qq.com> Date: Tue, 27 Jun 2023 12:55:17 +0800 Subject: [PATCH 3/4] [shardformer] shardformer support opt models, fix --- tests/test_fx/test_tracer/test_hf_model/hf_tracer_utils.py | 1 - 1 file changed, 1 deletion(-) diff --git a/tests/test_fx/test_tracer/test_hf_model/hf_tracer_utils.py b/tests/test_fx/test_tracer/test_hf_model/hf_tracer_utils.py index 0c4aad30254a..58c8132e1490 100644 --- a/tests/test_fx/test_tracer/test_hf_model/hf_tracer_utils.py +++ b/tests/test_fx/test_tracer/test_hf_model/hf_tracer_utils.py @@ -22,7 +22,6 @@ def trace_model_and_compare_output(model, data_gen, ignore_data: List[str] = Non try: meta_args = {k: v.to('meta') for k, v in inputs.items()} gm = symbolic_trace(model, meta_args=meta_args) - # gm = None except Exception as e: raise RuntimeError(f"Failed to trace {model.__class__.__name__}, error: {e}") From af122c5c440f25c847178c046728fb6397bac043 Mon Sep 17 00:00:00 2001 From: Mingyan Jiang <1829166702@qq.com> Date: Tue, 27 Jun 2023 16:59:18 +0800 Subject: [PATCH 4/4] [shardformer] shardformer support opt models, fix --- colossalai/shardformer/policies/autopolicy.py | 2 + colossalai/shardformer/policies/opt.py | 173 ++++++++---------- tests/kit/model_zoo/transformers/opt.py | 50 ++++- .../test_tracer/test_hf_model/test_hf_opt.py | 2 +- .../test_model/test_shard_opt.py | 13 +- 5 files changed, 136 insertions(+), 104 deletions(-) diff --git a/colossalai/shardformer/policies/autopolicy.py b/colossalai/shardformer/policies/autopolicy.py index b8d2e5627fab..9cc583d58b11 100644 --- a/colossalai/shardformer/policies/autopolicy.py +++ b/colossalai/shardformer/policies/autopolicy.py @@ -76,6 +76,8 @@ class PolicyLocation: PolicyLocation(file_name="opt", class_name="OPTForCausalLMPolicy"), "transformers.models.opt.modeling_opt.OPTForSequenceClassification": PolicyLocation(file_name="opt", class_name="OPTForSequenceClassificationPolicy"), + "transformers.models.opt.modeling_opt.OPTForQuestionAnswering": + PolicyLocation(file_name="opt", class_name="OPTForQuestionAnsweringPolicy"), } diff --git a/colossalai/shardformer/policies/opt.py b/colossalai/shardformer/policies/opt.py index fcad60a5f32c..f467726e5580 100644 --- a/colossalai/shardformer/policies/opt.py +++ b/colossalai/shardformer/policies/opt.py @@ -1,13 +1,16 @@ from transformers.models.opt.modeling_opt import ( + OPTAttention, OPTDecoder, OPTDecoderLayer, - OPTAttention, OPTForCausalLM, - OPTForSequenceClassification + OPTForSequenceClassification, ) -from colossalai.shardformer.layer import Linear1D_Col, Linear1D_Row, FusedLayerNorm, Embedding1D + +from colossalai.shardformer.layer import Embedding1D, FusedLayerNorm, Linear1D_Col, Linear1D_Row + from .basepolicy import ModulePolicyDescription, Policy, SubModuleReplacementDescription + class OPTPolicy(Policy): def preprocess(self): @@ -25,75 +28,65 @@ def preprocess(self): def module_policy(self): base_policy = { OPTDecoder: - ModulePolicyDescription( - attribute_replacement={}, - param_replacement=[], - sub_module_replacement=[ - SubModuleReplacementDescription( - suffix="embed_tokens", - target_module=Embedding1D, - ) - ]), + ModulePolicyDescription(attribute_replacement={}, + param_replacement=[], + sub_module_replacement=[ + SubModuleReplacementDescription( + suffix="embed_tokens", + target_module=Embedding1D, + ) + ]), OPTDecoderLayer: - ModulePolicyDescription( - attribute_replacement={}, - param_replacement=[], - sub_module_replacement=[ - SubModuleReplacementDescription( - suffix="fc1", - target_module=Linear1D_Col, - ), - SubModuleReplacementDescription( - suffix="fc2", - target_module=Linear1D_Row, - ) - ]), + ModulePolicyDescription(attribute_replacement={}, + param_replacement=[], + sub_module_replacement=[ + SubModuleReplacementDescription( + suffix="fc1", + target_module=Linear1D_Col, + ), + SubModuleReplacementDescription( + suffix="fc2", + target_module=Linear1D_Row, + ) + ]), OPTAttention: - ModulePolicyDescription( - attribute_replacement={ - "embed_dim": - self.model.config.hidden_size // self.shard_config.tensor_parallel_size, - "num_heads": - self.model.config.num_attention_heads // self.shard_config.tensor_parallel_size - }, - param_replacement=[], - sub_module_replacement=[ - SubModuleReplacementDescription( - suffix="q_proj", - target_module=Linear1D_Col, - ), - SubModuleReplacementDescription( - suffix="k_proj", - target_module=Linear1D_Col, - ), - SubModuleReplacementDescription( - suffix="v_proj", - target_module=Linear1D_Col, - ), - SubModuleReplacementDescription( - suffix="out_proj", - target_module=Linear1D_Row, - ), - ]), + ModulePolicyDescription(attribute_replacement={ + "embed_dim": self.model.config.hidden_size // self.shard_config.tensor_parallel_size, + "num_heads": self.model.config.num_attention_heads // self.shard_config.tensor_parallel_size + }, + param_replacement=[], + sub_module_replacement=[ + SubModuleReplacementDescription( + suffix="q_proj", + target_module=Linear1D_Col, + ), + SubModuleReplacementDescription( + suffix="k_proj", + target_module=Linear1D_Col, + ), + SubModuleReplacementDescription( + suffix="v_proj", + target_module=Linear1D_Col, + ), + SubModuleReplacementDescription( + suffix="out_proj", + target_module=Linear1D_Row, + ), + ]), } if self.shard_config.fused_layernorm: base_policy[OPTDecoder].sub_module_replacement.append( - SubModuleReplacementDescription( - suffix="final_layer_norm", - target_module=FusedLayerNorm, - ignore_if_not_exist=True - )) + SubModuleReplacementDescription(suffix="final_layer_norm", + target_module=FusedLayerNorm, + ignore_if_not_exist=True)) base_policy[OPTDecoderLayer].sub_module_replacement.extend([ - SubModuleReplacementDescription( - suffix="self_attn_layer_norm", - target_module=FusedLayerNorm, - ignore_if_not_exist=True - ), - SubModuleReplacementDescription( - suffix="final_layer_norm", - target_module=FusedLayerNorm, - ignore_if_not_exist=True - )]) + SubModuleReplacementDescription(suffix="self_attn_layer_norm", + target_module=FusedLayerNorm, + ignore_if_not_exist=True), + SubModuleReplacementDescription(suffix="final_layer_norm", + target_module=FusedLayerNorm, + ignore_if_not_exist=True) + ]) return base_policy def new_model_class(self): @@ -101,46 +94,40 @@ def new_model_class(self): def postprocess(self): return self.model - + + class OPTModelPolicy(OPTPolicy): + def __init__(self) -> None: super().__init__() + class OPTForCausalLMPolicy(OPTPolicy): + def module_policy(self): policy = super().module_policy() new_item = { OPTForCausalLM: - ModulePolicyDescription( - attribute_replacement={}, - param_replacement=[], - sub_module_replacement=[ - SubModuleReplacementDescription( - suffix="lm_head", - target_module=Linear1D_Col, - kwargs=dict(gather_output=True)) - ]) + ModulePolicyDescription(attribute_replacement={}, + param_replacement=[], + sub_module_replacement=[ + SubModuleReplacementDescription(suffix="lm_head", + target_module=Linear1D_Col, + kwargs=dict(gather_output=True)) + ]) } policy.update(new_item) return policy - + + class OPTForSequenceClassificationPolicy(OPTPolicy): - def module_policy(self): - policy = super().module_policy() - new_item = { - OPTForSequenceClassification: - ModulePolicyDescription( - attribute_replacement={}, - param_replacement=[], - sub_module_replacement=[ - SubModuleReplacementDescription( - suffix="score", - target_module=Linear1D_Col, - kwargs=dict(gather_output=True)) - ]) - } + def __init__(self) -> None: + super().__init__() - policy.update(new_item) - return policy \ No newline at end of file + +class OPTForQuestionAnsweringPolicy(OPTPolicy): + + def __init__(self) -> None: + super().__init__() diff --git a/tests/kit/model_zoo/transformers/opt.py b/tests/kit/model_zoo/transformers/opt.py index 79dc457601f3..4463ae12b901 100644 --- a/tests/kit/model_zoo/transformers/opt.py +++ b/tests/kit/model_zoo/transformers/opt.py @@ -11,10 +11,11 @@ def data_gen(): - input_ids = torch.zeros((BATCH_SIZE, SEQ_LENGTH), dtype=torch.int64) - attention_mask = torch.zeros((BATCH_SIZE, SEQ_LENGTH), dtype=torch.int64) + input_ids = torch.Tensor([[1, 15043, 29892, 590, 11203, 338, 274, 1082]]).long() + attention_mask = torch.Tensor([[1, 1, 1, 1, 1, 1, 1, 1]]).long() return dict(input_ids=input_ids, attention_mask=attention_mask) + def data_gen_for_causal_lm(): # LM data gen # the `labels` of LM is the token of the output, cause no padding, use `input_ids` as `labels` @@ -23,11 +24,34 @@ def data_gen_for_causal_lm(): data['labels'] = labels return data + +def data_gen_for_sequence_classification(): + # LM data gen + # the `labels` of LM is the token of the output, cause no padding, use `input_ids` as `labels` + data = data_gen() + labels = data['input_ids'].clone() + data['labels'] = torch.tensor([1]) + return data + + +def data_gen_for_question_answering(): + # LM data gen + # the `labels` of LM is the token of the output, cause no padding, use `input_ids` as `labels` + data = data_gen() + data['start_positions'] = torch.tensor([0]) + data['end_positions'] = torch.tensor([1]) + return data + + output_transform_fn = lambda x: x loss_fn_for_opt_model = lambda x: x.last_hidden_state.mean() -loss_fn_for_causal_lm = lambda x: x.loss -config = transformers.OPTConfig(hidden_size=128, num_hidden_layers=2, num_attention_heads=4, - dropout=0,) +loss_fn_for_lm = lambda x: x.loss +config = transformers.OPTConfig( + hidden_size=128, + num_hidden_layers=2, + num_attention_heads=4, + dropout=0, +) # register the following models # transformers.OPTModel, @@ -42,5 +66,17 @@ def data_gen_for_causal_lm(): model_fn=lambda: transformers.OPTForCausalLM(config), data_gen_fn=data_gen_for_causal_lm, output_transform_fn=output_transform_fn, - loss_fn=loss_fn_for_causal_lm, - model_attribute=ModelAttribute(has_control_flow=True)) \ No newline at end of file + loss_fn=loss_fn_for_lm, + model_attribute=ModelAttribute(has_control_flow=True)) +model_zoo.register(name='transformers_opt_for_question_answering', + model_fn=lambda: transformers.OPTForQuestionAnswering(config), + data_gen_fn=data_gen_for_question_answering, + output_transform_fn=output_transform_fn, + loss_fn=loss_fn_for_lm, + model_attribute=ModelAttribute(has_control_flow=True)) +model_zoo.register(name='transformers_opt_for_sequence_classification', + model_fn=lambda: transformers.OPTForSequenceClassification(config), + data_gen_fn=data_gen_for_sequence_classification, + output_transform_fn=output_transform_fn, + loss_fn=loss_fn_for_lm, + model_attribute=ModelAttribute(has_control_flow=True)) diff --git a/tests/test_fx/test_tracer/test_hf_model/test_hf_opt.py b/tests/test_fx/test_tracer/test_hf_model/test_hf_opt.py index ead535d06f51..c68b89e82fbe 100644 --- a/tests/test_fx/test_tracer/test_hf_model/test_hf_opt.py +++ b/tests/test_fx/test_tracer/test_hf_model/test_hf_opt.py @@ -13,7 +13,7 @@ def test_opt(): sub_registry = model_zoo.get_sub_registry('transformers_opt') for name, (model_fn, data_gen_fn, _, _, _) in sub_registry.items(): model = model_fn() - trace_model_and_compare_output(model, data_gen_fn, ignore_data=['labels']) + trace_model_and_compare_output(model, data_gen_fn, ignore_data=['labels', 'start_positions', 'end_positions']) if __name__ == '__main__': diff --git a/tests/test_shardformer/test_model/test_shard_opt.py b/tests/test_shardformer/test_model/test_shard_opt.py index 2d9b1fc67aa6..4d4c55770144 100644 --- a/tests/test_shardformer/test_model/test_shard_opt.py +++ b/tests/test_shardformer/test_model/test_shard_opt.py @@ -1,12 +1,18 @@ +import copy import os import pytest import torch -import copy import colossalai from colossalai.logging import disable_existing_loggers -from colossalai.testing import assert_hf_output_close, clear_cache_before_run, rerun_if_address_is_in_use, spawn, check_state_dict_equal +from colossalai.testing import ( + assert_hf_output_close, + check_state_dict_equal, + clear_cache_before_run, + rerun_if_address_is_in_use, + spawn, +) from tests.kit.model_zoo import model_zoo from tests.test_shardformer.test_model._utils import build_model, run_forward @@ -53,8 +59,9 @@ def check_OPTModel(rank, world_size, port): torch.cuda.empty_cache() + @pytest.mark.dist @rerun_if_address_is_in_use() @clear_cache_before_run() def test_OPTModel(): - spawn(check_OPTModel, 4) \ No newline at end of file + spawn(check_OPTModel, 4)