From ad0ac3c1522ea5567e7e5023c4424da5d1985049 Mon Sep 17 00:00:00 2001 From: FrankLeeeee Date: Fri, 15 Jul 2022 18:11:10 +0800 Subject: [PATCH] [fx] fixed unit tests for torch 1.12 --- colossalai/fx/tracer/_tracer_utils.py | 1 - requirements/requirements-test.txt | 1 + tests/test_fx/test_pipeline/test_hf_model/test_albert.py | 1 - tests/test_fx/test_pipeline/test_hf_model/test_bert.py | 1 - tests/test_fx/test_pipeline/test_hf_model/test_gpt.py | 1 - tests/test_fx/test_pipeline/test_hf_model/test_opt.py | 1 - tests/test_fx/test_pipeline/test_hf_model/test_t5.py | 1 - tests/test_fx/test_pipeline/test_timm_model/test_timm.py | 2 -- .../test_fx/test_pipeline/test_torchvision/test_torchvision.py | 1 - tests/test_fx/test_tracer/test_hf_model/test_hf_albert.py | 1 - tests/test_fx/test_tracer/test_hf_model/test_hf_bert.py | 1 - tests/test_fx/test_tracer/test_hf_model/test_hf_gpt.py | 1 - tests/test_fx/test_tracer/test_hf_model/test_hf_opt.py | 1 - tests/test_fx/test_tracer/test_hf_model/test_hf_t5.py | 1 - tests/test_fx/test_tracer/test_patched_module.py | 2 +- tests/test_fx/test_tracer/test_timm_model/test_timm_model.py | 2 -- .../test_torchvision_model/test_torchvision_model.py | 1 - 17 files changed, 2 insertions(+), 18 deletions(-) diff --git a/colossalai/fx/tracer/_tracer_utils.py b/colossalai/fx/tracer/_tracer_utils.py index 528c4a8e9223..c1d21e67ede0 100644 --- a/colossalai/fx/tracer/_tracer_utils.py +++ b/colossalai/fx/tracer/_tracer_utils.py @@ -22,7 +22,6 @@ def _convert(val): if isinstance(val, MetaDeviceAttribute): return 'meta' elif isinstance(val, ColoProxy): - assert val.meta_data is not None return val.meta_data return val diff --git a/requirements/requirements-test.txt b/requirements/requirements-test.txt index 03101d69f8e0..221c82ef7b72 100644 --- a/requirements/requirements-test.txt +++ b/requirements/requirements-test.txt @@ -1,4 +1,5 @@ pytest torchvision transformers +timm titans diff --git a/tests/test_fx/test_pipeline/test_hf_model/test_albert.py b/tests/test_fx/test_pipeline/test_hf_model/test_albert.py index 0bdc9a1aa694..08d20c894fe4 100644 --- a/tests/test_fx/test_pipeline/test_hf_model/test_albert.py +++ b/tests/test_fx/test_pipeline/test_hf_model/test_albert.py @@ -7,7 +7,6 @@ SEQ_LENGHT = 16 -@pytest.mark.skip("error with pytorch 1.10") def test_single_sentence_albert(): MODEL_LIST = [ transformers.AlbertModel, diff --git a/tests/test_fx/test_pipeline/test_hf_model/test_bert.py b/tests/test_fx/test_pipeline/test_hf_model/test_bert.py index c7af6e4d003c..a3699b6607df 100644 --- a/tests/test_fx/test_pipeline/test_hf_model/test_bert.py +++ b/tests/test_fx/test_pipeline/test_hf_model/test_bert.py @@ -7,7 +7,6 @@ SEQ_LENGHT = 16 -@pytest.mark.skip("error with pytorch 1.10") def test_single_sentence_bert(): MODEL_LIST = [ transformers.BertModel, diff --git a/tests/test_fx/test_pipeline/test_hf_model/test_gpt.py b/tests/test_fx/test_pipeline/test_hf_model/test_gpt.py index 6b982dda4f02..b973ac85444d 100644 --- a/tests/test_fx/test_pipeline/test_hf_model/test_gpt.py +++ b/tests/test_fx/test_pipeline/test_hf_model/test_gpt.py @@ -9,7 +9,6 @@ NUM_CHUNKS = 1 -@pytest.mark.skip("error with pytorch 1.10") def test_gpt(): MODEL_LIST = [ transformers.GPT2Model, diff --git a/tests/test_fx/test_pipeline/test_hf_model/test_opt.py b/tests/test_fx/test_pipeline/test_hf_model/test_opt.py index 00c16d201eaf..a55ea54feb59 100644 --- a/tests/test_fx/test_pipeline/test_hf_model/test_opt.py +++ b/tests/test_fx/test_pipeline/test_hf_model/test_opt.py @@ -7,7 +7,6 @@ SEQ_LENGHT = 16 -@pytest.mark.skip("error with pytorch 1.10") def test_opt(): MODEL_LIST = [ transformers.OPTModel, diff --git a/tests/test_fx/test_pipeline/test_hf_model/test_t5.py b/tests/test_fx/test_pipeline/test_hf_model/test_t5.py index f24dd705cfe6..d20d188425fa 100644 --- a/tests/test_fx/test_pipeline/test_hf_model/test_t5.py +++ b/tests/test_fx/test_pipeline/test_hf_model/test_t5.py @@ -7,7 +7,6 @@ SEQ_LENGHT = 16 -@pytest.mark.skip("error with pytorch 1.10") def test_t5(): MODEL_LIST = [ transformers.T5Model, diff --git a/tests/test_fx/test_pipeline/test_timm_model/test_timm.py b/tests/test_fx/test_pipeline/test_timm_model/test_timm.py index bf11cb30a062..da3843a27550 100644 --- a/tests/test_fx/test_pipeline/test_timm_model/test_timm.py +++ b/tests/test_fx/test_pipeline/test_timm_model/test_timm.py @@ -7,7 +7,6 @@ from timm_utils import split_model_and_compare_output -@pytest.mark.skip('skip as timm is required') def test_timm_models_without_control_flow(): MODEL_LIST = [ @@ -28,7 +27,6 @@ def test_timm_models_without_control_flow(): split_model_and_compare_output(model, data) -@pytest.mark.skip('skip as timm is required') def test_timm_models_with_control_flow(): torch.backends.cudnn.deterministic = True diff --git a/tests/test_fx/test_pipeline/test_torchvision/test_torchvision.py b/tests/test_fx/test_pipeline/test_torchvision/test_torchvision.py index e52889e3be6c..c031210630dc 100644 --- a/tests/test_fx/test_pipeline/test_torchvision/test_torchvision.py +++ b/tests/test_fx/test_pipeline/test_torchvision/test_torchvision.py @@ -19,7 +19,6 @@ torch.backends.cudnn.deterministic = True -@pytest.mark.skip('skip as torchvision is required') def test_torchvision_models(): MODEL_LIST = [ tm.vgg11, tm.resnet18, tm.densenet121, tm.mobilenet_v3_small, tm.resnext50_32x4d, tm.wide_resnet50_2, diff --git a/tests/test_fx/test_tracer/test_hf_model/test_hf_albert.py b/tests/test_fx/test_tracer/test_hf_model/test_hf_albert.py index 2b01eabd3c6d..cf809e13a79f 100644 --- a/tests/test_fx/test_tracer/test_hf_model/test_hf_albert.py +++ b/tests/test_fx/test_tracer/test_hf_model/test_hf_albert.py @@ -34,7 +34,6 @@ def data_gen(): trace_model_and_compare_output(model, data_gen) -@pytest.mark.skip("error with pytorch 1.10") def test_multi_sentence_albert(): config = transformers.AlbertConfig(hidden_size=128, num_hidden_layers=2, diff --git a/tests/test_fx/test_tracer/test_hf_model/test_hf_bert.py b/tests/test_fx/test_tracer/test_hf_model/test_hf_bert.py index e60e4aa7c292..63ad4badc04d 100644 --- a/tests/test_fx/test_tracer/test_hf_model/test_hf_bert.py +++ b/tests/test_fx/test_tracer/test_hf_model/test_hf_bert.py @@ -31,7 +31,6 @@ def data_gen(): trace_model_and_compare_output(model, data_gen) -@pytest.mark.skip("error with pytorch 1.10") def test_multi_sentence_bert(): config = transformers.BertConfig(hidden_size=128, num_hidden_layers=2, num_attention_heads=4, intermediate_size=256) tokenizer = transformers.BertTokenizer.from_pretrained("bert-base-uncased") diff --git a/tests/test_fx/test_tracer/test_hf_model/test_hf_gpt.py b/tests/test_fx/test_tracer/test_hf_model/test_hf_gpt.py index 9c8971a753f7..1c20e9bfd99b 100644 --- a/tests/test_fx/test_tracer/test_hf_model/test_hf_gpt.py +++ b/tests/test_fx/test_tracer/test_hf_model/test_hf_gpt.py @@ -7,7 +7,6 @@ SEQ_LENGHT = 16 -@pytest.mark.skip("error with pytorch 1.10") def test_gpt(): MODEL_LIST = [ transformers.GPT2Model, diff --git a/tests/test_fx/test_tracer/test_hf_model/test_hf_opt.py b/tests/test_fx/test_tracer/test_hf_model/test_hf_opt.py index 0075d1f2badf..5ac051887eb6 100644 --- a/tests/test_fx/test_tracer/test_hf_model/test_hf_opt.py +++ b/tests/test_fx/test_tracer/test_hf_model/test_hf_opt.py @@ -7,7 +7,6 @@ SEQ_LENGHT = 16 -@pytest.mark.skip("error with pytorch 1.10") def test_opt(): MODEL_LIST = [ transformers.OPTModel, diff --git a/tests/test_fx/test_tracer/test_hf_model/test_hf_t5.py b/tests/test_fx/test_tracer/test_hf_model/test_hf_t5.py index 4e2614056d51..645951de978f 100644 --- a/tests/test_fx/test_tracer/test_hf_model/test_hf_t5.py +++ b/tests/test_fx/test_tracer/test_hf_model/test_hf_t5.py @@ -7,7 +7,6 @@ SEQ_LENGHT = 16 -@pytest.mark.skip("error with pytorch 1.10") def test_t5(): MODEL_LIST = [ transformers.T5Model, diff --git a/tests/test_fx/test_tracer/test_patched_module.py b/tests/test_fx/test_tracer/test_patched_module.py index d7ceba1a5890..9b4f7c516c28 100644 --- a/tests/test_fx/test_tracer/test_patched_module.py +++ b/tests/test_fx/test_tracer/test_patched_module.py @@ -40,7 +40,7 @@ def test_embedding(): _assert_output_shape(data, ln, patched_module.torch_nn_normalize, False, data.shape) # test group norm - gn = torch.nn.GroupNorm(4, num_channels=2) + gn = torch.nn.GroupNorm(4, num_channels=8) _assert_output_shape(data, gn, patched_module.torch_nn_normalize, False, data.shape) # test batch norm 1d diff --git a/tests/test_fx/test_tracer/test_timm_model/test_timm_model.py b/tests/test_fx/test_tracer/test_timm_model/test_timm_model.py index 7df0b2e6c9cb..5e2c40cace32 100644 --- a/tests/test_fx/test_tracer/test_timm_model/test_timm_model.py +++ b/tests/test_fx/test_tracer/test_timm_model/test_timm_model.py @@ -36,7 +36,6 @@ def trace_and_compare(model_cls, tracer, data, meta_args=None): fx_out, non_fx_out), f'{model.__class__.__name__} has inconsistent outputs, {fx_out} vs {non_fx_out}' -@pytest.mark.skip('skip as timm is required') def test_timm_models_without_control_flow(): torch.backends.cudnn.deterministic = True @@ -58,7 +57,6 @@ def test_timm_models_without_control_flow(): trace_and_compare(model_cls, tracer, data) -@pytest.mark.skip('skip as timm is required') def test_timm_models_with_control_flow(): torch.backends.cudnn.deterministic = True diff --git a/tests/test_fx/test_tracer/test_torchvision_model/test_torchvision_model.py b/tests/test_fx/test_tracer/test_torchvision_model/test_torchvision_model.py index 11c3d7ea5eca..7360bd885d8d 100644 --- a/tests/test_fx/test_tracer/test_torchvision_model/test_torchvision_model.py +++ b/tests/test_fx/test_tracer/test_torchvision_model/test_torchvision_model.py @@ -8,7 +8,6 @@ from torch.fx import GraphModule -@pytest.mark.skip('skip as torchvision is required') def test_torchvision_models(): MODEL_LIST = [ tm.vgg11, tm.resnet18, tm.densenet121, tm.mobilenet_v3_small, tm.resnext50_32x4d, tm.wide_resnet50_2,