Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion colossalai/fx/tracer/_tracer_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,6 @@ def _convert(val):
if isinstance(val, MetaDeviceAttribute):
return 'meta'
elif isinstance(val, ColoProxy):
assert val.meta_data is not None
return val.meta_data
return val

Expand Down
1 change: 1 addition & 0 deletions requirements/requirements-test.txt
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
pytest
torchvision
transformers
timm
titans
1 change: 0 additions & 1 deletion tests/test_fx/test_pipeline/test_hf_model/test_albert.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@
SEQ_LENGHT = 16


@pytest.mark.skip("error with pytorch 1.10")
def test_single_sentence_albert():
MODEL_LIST = [
transformers.AlbertModel,
Expand Down
1 change: 0 additions & 1 deletion tests/test_fx/test_pipeline/test_hf_model/test_bert.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@
SEQ_LENGHT = 16


@pytest.mark.skip("error with pytorch 1.10")
def test_single_sentence_bert():
MODEL_LIST = [
transformers.BertModel,
Expand Down
1 change: 0 additions & 1 deletion tests/test_fx/test_pipeline/test_hf_model/test_gpt.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@
NUM_CHUNKS = 1


@pytest.mark.skip("error with pytorch 1.10")
def test_gpt():
MODEL_LIST = [
transformers.GPT2Model,
Expand Down
1 change: 0 additions & 1 deletion tests/test_fx/test_pipeline/test_hf_model/test_opt.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@
SEQ_LENGHT = 16


@pytest.mark.skip("error with pytorch 1.10")
def test_opt():
MODEL_LIST = [
transformers.OPTModel,
Expand Down
1 change: 0 additions & 1 deletion tests/test_fx/test_pipeline/test_hf_model/test_t5.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@
SEQ_LENGHT = 16


@pytest.mark.skip("error with pytorch 1.10")
def test_t5():
MODEL_LIST = [
transformers.T5Model,
Expand Down
2 changes: 0 additions & 2 deletions tests/test_fx/test_pipeline/test_timm_model/test_timm.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@
from timm_utils import split_model_and_compare_output


@pytest.mark.skip('skip as timm is required')
def test_timm_models_without_control_flow():

MODEL_LIST = [
Expand All @@ -28,7 +27,6 @@ def test_timm_models_without_control_flow():
split_model_and_compare_output(model, data)


@pytest.mark.skip('skip as timm is required')
def test_timm_models_with_control_flow():
torch.backends.cudnn.deterministic = True

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@
torch.backends.cudnn.deterministic = True


@pytest.mark.skip('skip as torchvision is required')
def test_torchvision_models():
MODEL_LIST = [
tm.vgg11, tm.resnet18, tm.densenet121, tm.mobilenet_v3_small, tm.resnext50_32x4d, tm.wide_resnet50_2,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,6 @@ def data_gen():
trace_model_and_compare_output(model, data_gen)


@pytest.mark.skip("error with pytorch 1.10")
def test_multi_sentence_albert():
config = transformers.AlbertConfig(hidden_size=128,
num_hidden_layers=2,
Expand Down
1 change: 0 additions & 1 deletion tests/test_fx/test_tracer/test_hf_model/test_hf_bert.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,6 @@ def data_gen():
trace_model_and_compare_output(model, data_gen)


@pytest.mark.skip("error with pytorch 1.10")
def test_multi_sentence_bert():
config = transformers.BertConfig(hidden_size=128, num_hidden_layers=2, num_attention_heads=4, intermediate_size=256)
tokenizer = transformers.BertTokenizer.from_pretrained("bert-base-uncased")
Expand Down
1 change: 0 additions & 1 deletion tests/test_fx/test_tracer/test_hf_model/test_hf_gpt.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@
SEQ_LENGHT = 16


@pytest.mark.skip("error with pytorch 1.10")
def test_gpt():
MODEL_LIST = [
transformers.GPT2Model,
Expand Down
1 change: 0 additions & 1 deletion tests/test_fx/test_tracer/test_hf_model/test_hf_opt.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@
SEQ_LENGHT = 16


@pytest.mark.skip("error with pytorch 1.10")
def test_opt():
MODEL_LIST = [
transformers.OPTModel,
Expand Down
1 change: 0 additions & 1 deletion tests/test_fx/test_tracer/test_hf_model/test_hf_t5.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@
SEQ_LENGHT = 16


@pytest.mark.skip("error with pytorch 1.10")
def test_t5():
MODEL_LIST = [
transformers.T5Model,
Expand Down
2 changes: 1 addition & 1 deletion tests/test_fx/test_tracer/test_patched_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ def test_embedding():
_assert_output_shape(data, ln, patched_module.torch_nn_normalize, False, data.shape)

# test group norm
gn = torch.nn.GroupNorm(4, num_channels=2)
gn = torch.nn.GroupNorm(4, num_channels=8)
_assert_output_shape(data, gn, patched_module.torch_nn_normalize, False, data.shape)

# test batch norm 1d
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,6 @@ def trace_and_compare(model_cls, tracer, data, meta_args=None):
fx_out, non_fx_out), f'{model.__class__.__name__} has inconsistent outputs, {fx_out} vs {non_fx_out}'


@pytest.mark.skip('skip as timm is required')
def test_timm_models_without_control_flow():
torch.backends.cudnn.deterministic = True

Expand All @@ -58,7 +57,6 @@ def test_timm_models_without_control_flow():
trace_and_compare(model_cls, tracer, data)


@pytest.mark.skip('skip as timm is required')
def test_timm_models_with_control_flow():
torch.backends.cudnn.deterministic = True

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@
from torch.fx import GraphModule


@pytest.mark.skip('skip as torchvision is required')
def test_torchvision_models():
MODEL_LIST = [
tm.vgg11, tm.resnet18, tm.densenet121, tm.mobilenet_v3_small, tm.resnext50_32x4d, tm.wide_resnet50_2,
Expand Down