diff --git a/python/tvm/relay/frontend/pytorch.py b/python/tvm/relay/frontend/pytorch.py index 9ddd04b5b4ee..89dcad03e6ad 100644 --- a/python/tvm/relay/frontend/pytorch.py +++ b/python/tvm/relay/frontend/pytorch.py @@ -2977,7 +2977,10 @@ def nll_loss(self, inputs, input_types): def flip(self, inputs, input_types): data = inputs[0] axis = inputs[1] - return _op.transform.reverse(data, axis=axis[0]) + out = data + for ax in axis: + out = _op.reverse(out, ax) + return out def bidir_rnn_cell(self, input_seqs, weights_dicts, act=_op.tanh): """ diff --git a/tests/python/frontend/pytorch/test_forward.py b/tests/python/frontend/pytorch/test_forward.py index 9ee03512e7ae..6bbb9ef5cc4a 100644 --- a/tests/python/frontend/pytorch/test_forward.py +++ b/tests/python/frontend/pytorch/test_forward.py @@ -4899,13 +4899,14 @@ def __init__(self, axis=0): self.axis = axis def forward(self, x): - return x.flip([self.axis]) + return x.flip(self.axis) input_t = torch.randn(2, 3, 4) - verify_model(Flip(axis=0), input_data=input_t) - verify_model(Flip(axis=1), input_data=input_t) - verify_model(Flip(axis=2), input_data=input_t) - verify_model(Flip(axis=-1), input_data=input_t) + verify_model(Flip(axis=[0]), input_data=input_t) + verify_model(Flip(axis=[1]), input_data=input_t) + verify_model(Flip(axis=[2]), input_data=input_t) + verify_model(Flip(axis=[-1]), input_data=input_t) + verify_model(Flip(axis=[0, 1]), input_data=input_t) def test_annotate_span():