From b0c1985ed749f379ffa76ca3a55e248573bd129f Mon Sep 17 00:00:00 2001 From: Qingchao Shen Date: Fri, 15 Sep 2023 10:14:26 +0800 Subject: [PATCH 1/3] fix flip correct the calculate logic of flip in pytorch converter --- python/tvm/relay/frontend/pytorch.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/python/tvm/relay/frontend/pytorch.py b/python/tvm/relay/frontend/pytorch.py index 9ddd04b5b4ee..f716a03e7a36 100644 --- a/python/tvm/relay/frontend/pytorch.py +++ b/python/tvm/relay/frontend/pytorch.py @@ -2977,7 +2977,12 @@ def nll_loss(self, inputs, input_types): def flip(self, inputs, input_types): data = inputs[0] axis = inputs[1] - return _op.transform.reverse(data, axis=axis[0]) + for i, ax in enumerate(axis): + if i == 0: + out = _op.reverse(data, ax) + else: + out = _op.reverse(out, ax) + return out def bidir_rnn_cell(self, input_seqs, weights_dicts, act=_op.tanh): """ From ead36a5416d2b1d68830e24a2653e9af5d0a2d8e Mon Sep 17 00:00:00 2001 From: Qingchao Shen Date: Fri, 15 Sep 2023 10:17:05 +0800 Subject: [PATCH 2/3] Update test_forward.py --- tests/python/frontend/pytorch/test_forward.py | 11 ++++++----- 1 file changed, 6 insertions(+), 5 deletions(-) diff --git a/tests/python/frontend/pytorch/test_forward.py b/tests/python/frontend/pytorch/test_forward.py index 9ee03512e7ae..6bbb9ef5cc4a 100644 --- a/tests/python/frontend/pytorch/test_forward.py +++ b/tests/python/frontend/pytorch/test_forward.py @@ -4899,13 +4899,14 @@ def __init__(self, axis=0): self.axis = axis def forward(self, x): - return x.flip([self.axis]) + return x.flip(self.axis) input_t = torch.randn(2, 3, 4) - verify_model(Flip(axis=0), input_data=input_t) - verify_model(Flip(axis=1), input_data=input_t) - verify_model(Flip(axis=2), input_data=input_t) - verify_model(Flip(axis=-1), input_data=input_t) + verify_model(Flip(axis=[0]), input_data=input_t) + verify_model(Flip(axis=[1]), input_data=input_t) + verify_model(Flip(axis=[2]), input_data=input_t) + verify_model(Flip(axis=[-1]), input_data=input_t) + verify_model(Flip(axis=[0, 1]), input_data=input_t) def test_annotate_span(): From 0a57fc01db20c884ad9950a73a0fc02ec67b6531 Mon Sep 17 00:00:00 2001 From: Qingchao Shen Date: Fri, 15 Sep 2023 21:10:32 +0800 Subject: [PATCH 3/3] Update python/tvm/relay/frontend/pytorch.py Co-authored-by: Egor Churaev --- python/tvm/relay/frontend/pytorch.py | 8 +++----- 1 file changed, 3 insertions(+), 5 deletions(-) diff --git a/python/tvm/relay/frontend/pytorch.py b/python/tvm/relay/frontend/pytorch.py index f716a03e7a36..89dcad03e6ad 100644 --- a/python/tvm/relay/frontend/pytorch.py +++ b/python/tvm/relay/frontend/pytorch.py @@ -2977,11 +2977,9 @@ def nll_loss(self, inputs, input_types): def flip(self, inputs, input_types): data = inputs[0] axis = inputs[1] - for i, ax in enumerate(axis): - if i == 0: - out = _op.reverse(data, ax) - else: - out = _op.reverse(out, ax) + out = data + for ax in axis: + out = _op.reverse(out, ax) return out def bidir_rnn_cell(self, input_seqs, weights_dicts, act=_op.tanh):