From 8c9c708fbe278a3de49d81847c4a3c9d64f693e6 Mon Sep 17 00:00:00 2001 From: masa Date: Mon, 1 Mar 2021 13:04:43 +0900 Subject: [PATCH 1/2] Fix converting torch slice op with dynamic slice length --- python/tvm/relay/frontend/pytorch.py | 8 +++++++- tests/python/frontend/pytorch/test_forward.py | 9 +++++++++ 2 files changed, 16 insertions(+), 1 deletion(-) diff --git a/python/tvm/relay/frontend/pytorch.py b/python/tvm/relay/frontend/pytorch.py index 31c78cfdea84..159c803dd69c 100644 --- a/python/tvm/relay/frontend/pytorch.py +++ b/python/tvm/relay/frontend/pytorch.py @@ -400,7 +400,13 @@ def slice(self, inputs, input_types): ) # A fast path when slicing is nop. - if target_begin == 0 and target_end >= index_size_limit and stride == 1: + if ( + is_begin_const + and is_end_const + and target_begin == 0 + and target_end >= index_size_limit + and stride == 1 + ): return data # Process begin diff --git a/tests/python/frontend/pytorch/test_forward.py b/tests/python/frontend/pytorch/test_forward.py index 826edd051544..9f035ade7a21 100644 --- a/tests/python/frontend/pytorch/test_forward.py +++ b/tests/python/frontend/pytorch/test_forward.py @@ -1497,6 +1497,10 @@ class SliceWithStride2(torch.nn.Module): def forward(self, x): return x[0::2, 0::2] + x[1::2, 1::2] + class DynamicLengthSlice(torch.nn.Module): + def forward(self, values, length): + return values[0:length] + input_data = torch.rand(input_shape).float() verify_model(Slice1(), input_data=input_data) verify_model(Slice2(), input_data=input_data) @@ -1504,6 +1508,11 @@ def forward(self, x): verify_model(SliceWithStride(), input_data=torch.randn(1, 4)) verify_model(SliceWithStride2(), input_data=torch.randn(4, 4)) + inp = torch.tensor([[1, 2, 3], [4, 5, 6], [7, 8, 9]]) + slice_len = torch.tensor(2) + targets = ["llvm", "cuda"] + verify_trace_model(DynamicLengthSlice(), [inp, slice_len], targets) + @tvm.testing.uses_gpu def test_forward_narrow(): From 1f3f00daa1047057b65a9680fd9d577e1d0988d3 Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Tue, 2 Mar 2021 17:40:25 +0900 Subject: [PATCH 2/2] use isinstance --- python/tvm/relay/frontend/pytorch.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/python/tvm/relay/frontend/pytorch.py b/python/tvm/relay/frontend/pytorch.py index 159c803dd69c..3c61749fc203 100644 --- a/python/tvm/relay/frontend/pytorch.py +++ b/python/tvm/relay/frontend/pytorch.py @@ -401,8 +401,8 @@ def slice(self, inputs, input_types): # A fast path when slicing is nop. if ( - is_begin_const - and is_end_const + isinstance(target_begin, int) + and isinstance(target_end, int) and target_begin == 0 and target_end >= index_size_limit and stride == 1