From ea909162c1a02d7749cff3acb9a9a81695173e65 Mon Sep 17 00:00:00 2001 From: Luong Ducnhat Date: Mon, 6 Nov 2023 14:23:43 +0900 Subject: [PATCH] support the pytorch's maxvit model by adding the aten::swapaxes operator support. Co-authored-by: Masahiro Hiramori --- python/tvm/relay/frontend/pytorch.py | 1 + tests/python/frontend/pytorch/test_forward.py | 24 +++++++++++++++++++ 2 files changed, 25 insertions(+) diff --git a/python/tvm/relay/frontend/pytorch.py b/python/tvm/relay/frontend/pytorch.py index 81392a08ecd1..402ab592027c 100644 --- a/python/tvm/relay/frontend/pytorch.py +++ b/python/tvm/relay/frontend/pytorch.py @@ -4108,6 +4108,7 @@ def create_convert_map(self): "aten::multinomial": self.multinomial, "aten::_weight_norm": self.weight_norm, "aten::copy_": self.inplace_copy, + "aten::swapaxes": self.transpose, } def update_convert_map(self, custom_map): diff --git a/tests/python/frontend/pytorch/test_forward.py b/tests/python/frontend/pytorch/test_forward.py index abdbda8e4005..b9c1b6ce9cd1 100644 --- a/tests/python/frontend/pytorch/test_forward.py +++ b/tests/python/frontend/pytorch/test_forward.py @@ -5381,6 +5381,30 @@ def forward(self, x): verify_model(PartialDimensionInplaceCopy(), [inputs]) +@tvm.testing.uses_gpu +def test_swapaxes(): + """test_swapaxes""" + torch.set_grad_enabled(False) + input_shape = [2, 3, 10, 5] + + class Swapaxes1(Module): + def forward(self, *args): + return args[0].swapaxes(2, 3) + + class Swapaxes2(Module): + def forward(self, *args): + return args[0].swapaxes(-2, -1) + + class Swapaxes3(Module): + def forward(self, *args): + return args[0].swapaxes(1, 1) + + input_data = torch.rand(input_shape).float() + verify_model(Swapaxes1().float().eval(), input_data=input_data) + verify_model(Swapaxes2().float().eval(), input_data=input_data) + verify_model(Swapaxes3().float().eval(), input_data=input_data) + + class TestSetSpan: """test structural equal between translated / hand-crafted relay IR with span tagged."""