diff --git a/python/tvm/relay/frontend/pytorch.py b/python/tvm/relay/frontend/pytorch.py index 08bf5d517c8b..35cb86e41dfe 100644 --- a/python/tvm/relay/frontend/pytorch.py +++ b/python/tvm/relay/frontend/pytorch.py @@ -399,7 +399,8 @@ def squeeze(self, inputs, input_types): axis = None else: # TODO (t-vi): why is the cast to int needed? similarly elsewhere - axis = [int(inputs[1])] + inputs = [inputs[1]] if not isinstance(inputs[1], list) else inputs[1] + axis = [int(v) for v in inputs] return _op.transform.squeeze(data, axis) diff --git a/tests/python/frontend/pytorch/test_forward.py b/tests/python/frontend/pytorch/test_forward.py index b602c14df3af..65134ca15a34 100644 --- a/tests/python/frontend/pytorch/test_forward.py +++ b/tests/python/frontend/pytorch/test_forward.py @@ -578,9 +578,15 @@ class Squeeze2(Module): def forward(self, *args): return args[0].squeeze(1) + class Squeeze3(Module): + def forward(self, *args): + return args[0].squeeze((1, 3)) + input_data = torch.rand(input_shape).float() verify_model(Squeeze1().float().eval(), input_data=input_data) verify_model(Squeeze2().float().eval(), input_data=input_data) + if package_version.parse(torch.__version__) >= package_version.parse("2.0.0"): + verify_model(Squeeze3().float().eval(), input_data=input_data) @tvm.testing.uses_gpu