Skip to content

Conversation

@masahi
Copy link
Member

@masahi masahi commented Oct 26, 2021

Closes #9368

This PR adds missing converters necessary for importing the recent state of the art model Swin-Transoformer.

In particular, aten::roll is an interesting op to implement and generally useful op to have. It can be implemented via gather but the encoding is not obvious. Here are the references:
https://pytorch.org/docs/stable/generated/torch.roll.html
https://numpy.org/doc/stable/reference/generated/numpy.roll.html

please review @comaniac @jcf94 @junrushao1994

@Kyrie-Zhao aten::rand is not needed, it is used in dropout but if you trace the model with eval mode, dropout is gone.

Now the following script works with the error 1.4081733e-07.

import numpy as np
import tvm
import torch
from tvm import relay
from swin_transformer import SwinTransformer

net = SwinTransformer().eval()

img = torch.randn(1, 3, 224, 224)

scripted_model = torch.jit.trace(net, img).eval()
input_name = "input0"
shape_list = [(input_name, img.shape)]
mod, params = relay.frontend.from_pytorch(scripted_model, shape_list)

with torch.no_grad():
    pt_result = net(img).numpy()

target = "llvm"

with tvm.transform.PassContext(opt_level=3):
    json, lib, params = relay.build(mod, target=target, params=params)

ctx = tvm.device(target, 0)
runtime = tvm.contrib.graph_executor.create(json, lib, ctx)
runtime.set_input(**params)
runtime.set_input("input0", img.numpy())
runtime.run()

tvm_result = runtime.get_output(0).asnumpy()

print(np.mean(np.abs(tvm_result - pt_result)))

Copy link
Member

@junrushao junrushao left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks awesome! Thank you @masahi!

@junrushao junrushao merged commit 0df4edc into apache:main Oct 26, 2021
ylc pushed a commit to ylc/tvm that referenced this pull request Jan 7, 2022
* add test

* first impl

* basic example working

* all test cases working

* support adaptive avg and max pool

* cleanup

* axes transpose logic fixed for roll

* pylint

* fixed roll dim indexing
ylc pushed a commit to ylc/tvm that referenced this pull request Jan 13, 2022
* add test

* first impl

* basic example working

* all test cases working

* support adaptive avg and max pool

* cleanup

* axes transpose logic fixed for roll

* pylint

* fixed roll dim indexing
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

[Bug] A list of missing op conversion (Swin Transformer)

2 participants