Skip to content

[Bug] [Relax] Cannot import mobilenet_v3 #17068

@mshr-h

Description

@mshr-h

Cannot import mobilenet_v3 because Hardswish and Hardsigmoid are not supported by Relax. I'll try to fix it.

TODOs

Expected behavior

mobilenet_v3_small and mobilenet_v3_small can be imported with from_fx.

Actual behavior

Got the below error message when I executed the repro.

$ python compile_mobilenet_v3.py 
Traceback (most recent call last):
  File "/home/ubuntu/data/sandbox/tvm_/relax_/mobilenet_v3/compile_mobilenet_v3.py", line 34, in <module>
    main()
  File "/home/ubuntu/data/sandbox/tvm_/relax_/mobilenet_v3/compile_mobilenet_v3.py", line 21, in main
    mod = from_fx(graph_model, [(inp.shape, "float32")])
  File "/home/ubuntu/data/sandbox/.dep/tvm/python/tvm/relax/frontend/torch/fx_translator.py", line 1698, in from_fx
    return TorchFXImporter().from_fx(
  File "/home/ubuntu/data/sandbox/.dep/tvm/python/tvm/relax/frontend/torch/fx_translator.py", line 1570, in from_fx
    type(module) in self.convert_map
AssertionError: Unsupported module type <class 'torch.nn.modules.activation.Hardswish'>
[20:07:07] /home/ubuntu/data/sandbox/.dep/tvm/src/relax/ir/block_builder.cc:66: Warning: BlockBuilder destroyed with remaining blocks!

Environment

OS: Ubuntu 22.04 LTS on WSL2
TVM: 0e622e1
PyTorch: 2.3.0
Torchvision: 0.18.0

Steps to reproduce

import tvm
from tvm import relax
from tvm.relax.frontend.torch import from_fx
import torch
import torchvision


def main():
  model_name = "mobilenet_v3_small"  # mobilenet_v3_small or mobilenet_v3_large
  inp = torch.rand(8, 3, 224, 224)

  weights = torchvision.models.get_model_weights(model_name).DEFAULT
  model_pth = torchvision.models.get_model(model_name, weights=weights).eval()

  # PyTorch
  output_pth = model_pth(inp)

  # TVM
  graph_model = torch.fx.symbolic_trace(model_pth)
  with torch.no_grad():
    mod = from_fx(graph_model, [(inp.shape, "float32")])

  target = tvm.target.Target("llvm", host="llvm")
  mod = relax.transform.DecomposeOpsForInference()(mod)
  mod = relax.transform.LegalizeOps()(mod)
  ex = relax.build(mod, target)
  vm = relax.VirtualMachine(ex, tvm.cpu())
  output_tvm = torch.tensor(vm["main"](tvm.nd.array(inp.detach().numpy())).numpy())

  torch.testing.assert_close(output_pth, output_tvm, rtol=1e-5, atol=1e-5)


if __name__ == "__main__":
  main()

Triage

Please refer to the list of label tags here to find the relevant tags and add them below in a bullet format (example below).

  • needs-triage

cc @junrushao

Metadata

Metadata

Assignees

No one assigned

    Labels

    needs-triagePRs or issues that need to be investigated by maintainers to find the right assignees to address ittype: bug

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions