From 05b1bdf9b381f715a3358ca3c1497d87c504da99 Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Mon, 7 Nov 2022 19:35:18 +0900 Subject: [PATCH 1/2] [Torch] Fix advanced indexing with boolean mask --- python/tvm/relay/frontend/pytorch.py | 11 +++++++++-- tests/python/frontend/pytorch/test_forward.py | 9 +++++++++ 2 files changed, 18 insertions(+), 2 deletions(-) diff --git a/python/tvm/relay/frontend/pytorch.py b/python/tvm/relay/frontend/pytorch.py index 1b86b120dfcc..870fc1dc03c5 100644 --- a/python/tvm/relay/frontend/pytorch.py +++ b/python/tvm/relay/frontend/pytorch.py @@ -2323,8 +2323,15 @@ def one_hot(self, inputs, input_types): def index(self, inputs, input_types): data = inputs[0] - indices = inputs[1] - return _op.adv_index([data] + indices) + indices_list = [] + + for indices in inputs[1]: + if self.infer_type(indices).dtype == "bool": + indices_list.append(_op.squeeze(_op.transform.argwhere(indices), axis=[1])) + else: + indices_list.append(indices) + + return _op.adv_index([data] + indices_list) def meshgrid(self, inputs, input_types): data = inputs[0] diff --git a/tests/python/frontend/pytorch/test_forward.py b/tests/python/frontend/pytorch/test_forward.py index 8045635127bb..a6396968da1c 100755 --- a/tests/python/frontend/pytorch/test_forward.py +++ b/tests/python/frontend/pytorch/test_forward.py @@ -4034,6 +4034,15 @@ def forward(self, x): input_data = torch.rand(input_shape).float() verify_model(Index1().eval(), input_data=input_data) + def test_fn_bool_mask(): + return lambda data, mask: data[0, mask] + + data = torch.tensor([[1, 2, 3], [4, 5, 6]]) + + mask = torch.tensor([True, True, False]) + + verify_trace_model(test_fn_bool_mask(), [data, mask], ["llvm", "cuda"]) + def test_logsumexp(): """test_logsumexp""" From e4b416760eb46e788b9644fc549d99162f259a31 Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Mon, 7 Nov 2022 19:42:29 +0900 Subject: [PATCH 2/2] add comment --- python/tvm/relay/frontend/pytorch.py | 4 ++++ tests/python/frontend/pytorch/test_forward.py | 1 - 2 files changed, 4 insertions(+), 1 deletion(-) diff --git a/python/tvm/relay/frontend/pytorch.py b/python/tvm/relay/frontend/pytorch.py index 870fc1dc03c5..30f14b490b1b 100644 --- a/python/tvm/relay/frontend/pytorch.py +++ b/python/tvm/relay/frontend/pytorch.py @@ -2327,6 +2327,10 @@ def index(self, inputs, input_types): for indices in inputs[1]: if self.infer_type(indices).dtype == "bool": + # adv_index does not support a mask as the index tensor (it will treat 0/1 as + # an index rather than a flag). + # So we use argwhere to turn the mask into indices, which will also take care + # of the dynamism in the indexing by mask. indices_list.append(_op.squeeze(_op.transform.argwhere(indices), axis=[1])) else: indices_list.append(indices) diff --git a/tests/python/frontend/pytorch/test_forward.py b/tests/python/frontend/pytorch/test_forward.py index a6396968da1c..36bb5bede475 100755 --- a/tests/python/frontend/pytorch/test_forward.py +++ b/tests/python/frontend/pytorch/test_forward.py @@ -4038,7 +4038,6 @@ def test_fn_bool_mask(): return lambda data, mask: data[0, mask] data = torch.tensor([[1, 2, 3], [4, 5, 6]]) - mask = torch.tensor([True, True, False]) verify_trace_model(test_fn_bool_mask(), [data, mask], ["llvm", "cuda"])