-
Notifications
You must be signed in to change notification settings - Fork 3.8k
Closed
Labels
needs-triagePRs or issues that need to be investigated by maintainers to find the right assignees to address itPRs or issues that need to be investigated by maintainers to find the right assignees to address ittype: bug
Description
Cannot import mobilenet_v3 because Hardswish and Hardsigmoid are not supported by Relax. I'll try to fix it.
TODOs
- [Relax] [PyTorch] Add support for torch.nn.Hardswish #17084
- [Relax] [PyTorch] Add support for torch.nn.Hardsigmoid #17085
Expected behavior
mobilenet_v3_small and mobilenet_v3_small can be imported with from_fx.
Actual behavior
Got the below error message when I executed the repro.
$ python compile_mobilenet_v3.py
Traceback (most recent call last):
File "/home/ubuntu/data/sandbox/tvm_/relax_/mobilenet_v3/compile_mobilenet_v3.py", line 34, in <module>
main()
File "/home/ubuntu/data/sandbox/tvm_/relax_/mobilenet_v3/compile_mobilenet_v3.py", line 21, in main
mod = from_fx(graph_model, [(inp.shape, "float32")])
File "/home/ubuntu/data/sandbox/.dep/tvm/python/tvm/relax/frontend/torch/fx_translator.py", line 1698, in from_fx
return TorchFXImporter().from_fx(
File "/home/ubuntu/data/sandbox/.dep/tvm/python/tvm/relax/frontend/torch/fx_translator.py", line 1570, in from_fx
type(module) in self.convert_map
AssertionError: Unsupported module type <class 'torch.nn.modules.activation.Hardswish'>
[20:07:07] /home/ubuntu/data/sandbox/.dep/tvm/src/relax/ir/block_builder.cc:66: Warning: BlockBuilder destroyed with remaining blocks!
Environment
OS: Ubuntu 22.04 LTS on WSL2
TVM: 0e622e1
PyTorch: 2.3.0
Torchvision: 0.18.0
Steps to reproduce
import tvm
from tvm import relax
from tvm.relax.frontend.torch import from_fx
import torch
import torchvision
def main():
model_name = "mobilenet_v3_small" # mobilenet_v3_small or mobilenet_v3_large
inp = torch.rand(8, 3, 224, 224)
weights = torchvision.models.get_model_weights(model_name).DEFAULT
model_pth = torchvision.models.get_model(model_name, weights=weights).eval()
# PyTorch
output_pth = model_pth(inp)
# TVM
graph_model = torch.fx.symbolic_trace(model_pth)
with torch.no_grad():
mod = from_fx(graph_model, [(inp.shape, "float32")])
target = tvm.target.Target("llvm", host="llvm")
mod = relax.transform.DecomposeOpsForInference()(mod)
mod = relax.transform.LegalizeOps()(mod)
ex = relax.build(mod, target)
vm = relax.VirtualMachine(ex, tvm.cpu())
output_tvm = torch.tensor(vm["main"](tvm.nd.array(inp.detach().numpy())).numpy())
torch.testing.assert_close(output_pth, output_tvm, rtol=1e-5, atol=1e-5)
if __name__ == "__main__":
main()Triage
Please refer to the list of label tags here to find the relevant tags and add them below in a bullet format (example below).
- needs-triage
cc @junrushao
Metadata
Metadata
Assignees
Labels
needs-triagePRs or issues that need to be investigated by maintainers to find the right assignees to address itPRs or issues that need to be investigated by maintainers to find the right assignees to address ittype: bug