diff --git a/tests/python/driver/tvmc/conftest.py b/tests/python/driver/tvmc/conftest.py index 835b2583c725..ca4ab2247bd9 100644 --- a/tests/python/driver/tvmc/conftest.py +++ b/tests/python/driver/tvmc/conftest.py @@ -126,6 +126,23 @@ def pytorch_resnet18(tmpdir_factory): return model_file_name +@pytest.fixture(scope="session") +def pytorch_mobilenetv2_quantized(tmpdir_factory): + try: + import torch + import torchvision.models as models + except ImportError: + # Not all environments provide Pytorch, so skip if that's the case. + return "" + model = models.quantization.mobilenet_v2(quantize=True) + model_file_name = "{}/{}".format(tmpdir_factory.mktemp("data"), "mobilenet_v2_quantized.pth") + # Trace model into torchscript. + traced_cpu = torch.jit.trace(model, torch.randn(1, 3, 224, 224)) + torch.jit.save(traced_cpu, model_file_name) + + return model_file_name + + @pytest.fixture(scope="session") def onnx_resnet50(): base_url = "https://github.com/onnx/models/raw/master/vision/classification/resnet/model" diff --git a/tests/python/driver/tvmc/test_frontends.py b/tests/python/driver/tvmc/test_frontends.py index a2f659b42eff..e742a1e5e4f7 100644 --- a/tests/python/driver/tvmc/test_frontends.py +++ b/tests/python/driver/tvmc/test_frontends.py @@ -217,6 +217,21 @@ def test_load_model__pth(pytorch_resnet18): assert "layer1.0.conv1.weight" in tvmc_model.params.keys() +def test_load_quantized_model__pth(pytorch_mobilenetv2_quantized): + # some CI environments wont offer torch, so skip in case it is not present + pytest.importorskip("torch") + pytest.importorskip("torchvision") + + tvmc_model = tvmc.load(pytorch_mobilenetv2_quantized, shape_dict={"input": [1, 3, 224, 224]}) + assert type(tvmc_model) is TVMCModel + assert type(tvmc_model.mod) is IRModule + assert type(tvmc_model.params) is dict + + # checking weights remain quantized and are not float32 + for p in tvmc_model.params.values(): + assert p.dtype in ["int8", "uint8", "int32"] # int32 for bias + + def test_load_model___wrong_language__to_pytorch(tflite_mobilenet_v1_1_quant): # some CI environments wont offer pytorch, so skip in case it is not present pytest.importorskip("torch")