From ee32d23008dadce0f8d4fde857222f043d7f897e Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Mon, 25 Jan 2021 11:29:45 +0900 Subject: [PATCH 1/9] add conversion for detr --- python/tvm/relay/frontend/pytorch.py | 31 ++++++++++++++++--- tests/python/frontend/pytorch/test_forward.py | 28 +++++++++++++++++ 2 files changed, 54 insertions(+), 5 deletions(-) diff --git a/python/tvm/relay/frontend/pytorch.py b/python/tvm/relay/frontend/pytorch.py index 991e3a8a0032..fcf0c83d3143 100644 --- a/python/tvm/relay/frontend/pytorch.py +++ b/python/tvm/relay/frontend/pytorch.py @@ -399,10 +399,7 @@ def slice(self, inputs, input_types): begin = [0] * ndim dim = int(inputs[1]) stride = int(inputs[4]) - if isinstance(inputs[2], _expr.Call): - begin[dim], _ = try_infer_value(inputs[2], lambda ret: np.asscalar(ret.astype(np.int))) - else: - begin[dim] = int(inputs[2]) + begin[dim], _ = try_infer_value(inputs[2], lambda ret: np.asscalar(ret.astype(np.int))) # Process begin if not isinstance(begin[dim], int): @@ -551,7 +548,13 @@ def reciprocal(self, inputs, input_types): def repeat(self, inputs, input_types): data = inputs[0] - reps = inputs[1] + reps = [] + for r in inputs[1]: + if isinstance(r, int): + reps.append(r) + else: + reps.append(int(_infer_value(r, {}).asnumpy())) + return _op.transform.tile(data, reps=reps) def repeat_interleave(self, inputs, input_types): @@ -2070,6 +2073,22 @@ def scatter_add(self, inputs, input_types): src = inputs[3] return _op.scatter_add(data, index, src, axis=axis) + def cumsum(self, inputs, input_types): + data = inputs[0] + dim = inputs[1] + dtype = inputs[2] + + if inputs[2] is not None: + dtype = _convert_dtype_value(inputs[2]) + + return _op.cumsum(data, axis=dim, dtype=dtype) + + def masked_fill(self, inputs, input_types): + mask = inputs[1] + value = _op.cast(_wrap_const(inputs[2]), input_types[0]) + + return _op.where(mask, value, inputs[0]) + def is_floating_point(self, inputs, input_types): assert len(inputs) == 1 @@ -2278,6 +2297,8 @@ def create_convert_map(self): "aten::__not__": self.logical_not, "aten::hardswish_": self.hard_swish, "aten::hardswish": self.hard_swish, + "aten::cumsum": self.cumsum, + "aten::masked_fill": self.masked_fill, } def update_convert_map(self, custom_map): diff --git a/tests/python/frontend/pytorch/test_forward.py b/tests/python/frontend/pytorch/test_forward.py index 7cdd450448ca..3627df3e029b 100644 --- a/tests/python/frontend/pytorch/test_forward.py +++ b/tests/python/frontend/pytorch/test_forward.py @@ -3452,6 +3452,32 @@ def test_hard_swish(): verify_model(torch.nn.Hardswish(inplace=True).eval(), input_data=input) +def test_cumsum(): + def test_fn(dim, dtype=None): + return lambda x: torch.cumsum(x, dim=dim, dtype=dtype) + + inp = torch.randint(0, 100, (10000,), dtype=torch.int32) + verify_model(test_fn(0), [inp]) + verify_model(test_fn(0), [inp.to(torch.int64)]) + verify_model(test_fn(0, dtype=torch.int64), [inp.to(torch.int64)]) + + inp = torch.randn((100, 100), dtype=torch.float32) + verify_model(test_fn(dim=0, dtype=torch.float64), [inp]) + verify_model(test_fn(dim=1), [inp]) + + inp = torch.randn((100, 100), dtype=torch.float32) > 0.5 + verify_model(test_fn(dim=0, dtype=torch.int32), [inp]) + + +def test_masked_fill(): + def test_fn(x, mask): + return torch.masked_fill(x, mask, 0.0) + + inp = torch.randn(100, 100) + verify_model(test_fn, [inp, inp > 0.5]) + verify_model(test_fn, [inp.to(torch.float64), inp > 0.5]) + + if __name__ == "__main__": # some structural tests test_forward_traced_function() @@ -3580,6 +3606,8 @@ def test_hard_swish(): test_forward_scatter() test_numel() test_bincount() + test_cumsum() + test_masked_fill() # Model tests test_resnet18() From 4e35b33fd3f19d067750a1c2f533291c611bc5f6 Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Wed, 27 Jan 2021 08:51:48 +0900 Subject: [PATCH 2/9] remove explicit broadcast_to before batched matmul --- python/tvm/relay/frontend/pytorch.py | 6 ------ 1 file changed, 6 deletions(-) diff --git a/python/tvm/relay/frontend/pytorch.py b/python/tvm/relay/frontend/pytorch.py index fcf0c83d3143..2132ab00c0ae 100644 --- a/python/tvm/relay/frontend/pytorch.py +++ b/python/tvm/relay/frontend/pytorch.py @@ -1523,12 +1523,6 @@ def matmul(self, inputs, input_types): # Convert a and b into 3 dimensional tensors. a = _op.reshape(inputs_0, [-1, a_shape[-2], a_shape[-1]]) b = _op.reshape(inputs_1, [-1, b_shape[-2], b_shape[-1]]) - # Broadcast b to match batch size of a - new_b_shape = list(self.infer_shape_with_prelude(b)) - new_a_shape = self.infer_shape_with_prelude(a) - if new_a_shape[0] > new_b_shape[0]: - new_b_shape[0] = new_a_shape[0] - b = _op.broadcast_to(b, new_b_shape) # Transpose matrix dimensions of b. b = _op.transpose(b, [0, 2, 1]) # Perform a batch matmul. From 2ef69535d3f94e57e2b190ffd2addd2df8888730 Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Wed, 27 Jan 2021 09:38:40 +0900 Subject: [PATCH 3/9] use take with wrap mode --- python/tvm/relay/frontend/pytorch.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/python/tvm/relay/frontend/pytorch.py b/python/tvm/relay/frontend/pytorch.py index 2132ab00c0ae..007a1f7d45ca 100644 --- a/python/tvm/relay/frontend/pytorch.py +++ b/python/tvm/relay/frontend/pytorch.py @@ -515,13 +515,13 @@ def select(self, inputs, input_types): data = inputs[0] dim = int(inputs[1]) index = _wrap_const(inputs[2]) - return _op.transform.take(data, index, axis=dim) + return _op.transform.take(data, index, axis=dim, mode="wrap") def take(self, inputs, input_types): data = inputs[0] indices = _op.cast(inputs[1], "int32") - return _op.transform.take(data, indices=indices) + return _op.transform.take(data, indices=indices, mode="wrap") def topk(self, inputs, input_types): data = inputs[0] From 92d2a8a4f8ac688e09620017b6f9f496c3e95ef2 Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Wed, 27 Jan 2021 18:30:18 +0900 Subject: [PATCH 4/9] add test for transformer and negative indices --- tests/python/frontend/pytorch/test_forward.py | 350 +++++++++--------- 1 file changed, 182 insertions(+), 168 deletions(-) diff --git a/tests/python/frontend/pytorch/test_forward.py b/tests/python/frontend/pytorch/test_forward.py index 3627df3e029b..cc84bdaac342 100644 --- a/tests/python/frontend/pytorch/test_forward.py +++ b/tests/python/frontend/pytorch/test_forward.py @@ -1147,7 +1147,7 @@ def forward(self, *args): @tvm.testing.uses_gpu def test_forward_select(): torch.set_grad_enabled(False) - input_shape = [1, 3, 10, 10] + input_shape = [5, 3, 10, 10] class Select1(Module): def forward(self, *args): @@ -1167,6 +1167,9 @@ def forward(self, index): input_data = torch.rand(input_shape).float() verify_model(Select1().float().eval(), input_data=input_data) + # test negative indexing + verify_model(lambda x: x[-1], input_data=input_data) + x = torch.randn(3, 4) indices = torch.tensor([0, 2]) verify_model(IndexedSelect(x, 0).eval(), input_data=indices) @@ -2653,6 +2656,8 @@ def forward(self, *args): verify_model(Take1().float().eval(), input_data=input_data) indices = torch.tensor([[0, 0], [1, 0]]) verify_model(Take2().float().eval(), input_data=[input_data, indices]) + indices = torch.tensor([0, -1]) + verify_model(Take2().float().eval(), input_data=[input_data, indices]) @tvm.testing.uses_gpu @@ -3478,172 +3483,181 @@ def test_fn(x, mask): verify_model(test_fn, [inp.to(torch.float64), inp > 0.5]) +def test_transformer(): + model = torch.nn.Transformer(d_model=256, nhead=8, num_encoder_layers=6, num_decoder_layers=6) + model = model.eval() + src = torch.rand((10, 32, 256)) + tgt = torch.rand((20, 32, 256)) + verify_model(model.eval(), input_data=[src, tgt]) + + if __name__ == "__main__": - # some structural tests - test_forward_traced_function() - test_forward_dtypes() - test_weight_names() - test_duplicate_weight_use() - - # Single operator tests - test_forward_pixel_shuffle() - test_forward_add() - test_forward_subtract() - test_forward_multiply() - test_forward_matmul() - test_forward_rsub() - test_forward_onehot() - test_forward_embedding() - test_forward_reshape() - test_forward_reciprocal() - test_forward_repeat() - test_forward_repeat_interleave() - test_forward_squeeze() - test_forward_unsqueeze() - test_forward_concatenate() - test_forward_reduce_sum() - test_forward_reduce_prod() - test_forward_argmin() - test_forward_argmax() - test_forward_norm() - test_forward_frobenius_norm() - test_forward_std() - test_forward_variance() - test_forward_relu() - test_forward_prelu() - test_forward_leakyrelu() - test_forward_elu() - test_forward_celu() - test_forward_gelu() - test_forward_selu() - test_forward_log_sigmoid() - test_forward_adaptiveavgpool() - test_forward_maxpool2d() - test_forward_maxpool1d() - test_forward_maxpool3d() - test_forward_hardtanh() - test_forward_conv() - test_forward_conv_transpose() - test_forward_threshold() - test_forward_contiguous() - test_forward_batchnorm() - test_forward_instancenorm() - test_forward_layernorm() - test_forward_groupnorm() - test_forward_transpose() - test_forward_size() - test_forward_view() + # # some structural tests + # test_forward_traced_function() + # test_forward_dtypes() + # test_weight_names() + # test_duplicate_weight_use() + + # # Single operator tests + # test_forward_pixel_shuffle() + # test_forward_add() + # test_forward_subtract() + # test_forward_multiply() + # test_forward_matmul() + # test_forward_rsub() + # test_forward_onehot() + # test_forward_embedding() + # test_forward_reshape() + # test_forward_reciprocal() + # test_forward_repeat() + # test_forward_repeat_interleave() + # test_forward_squeeze() + # test_forward_unsqueeze() + # test_forward_concatenate() + # test_forward_reduce_sum() + # test_forward_reduce_prod() + # test_forward_argmin() + # test_forward_argmax() + # test_forward_norm() + # test_forward_frobenius_norm() + # test_forward_std() + # test_forward_variance() + # test_forward_relu() + # test_forward_prelu() + # test_forward_leakyrelu() + # test_forward_elu() + # test_forward_celu() + # test_forward_gelu() + # test_forward_selu() + # test_forward_log_sigmoid() + # test_forward_adaptiveavgpool() + # test_forward_maxpool2d() + # test_forward_maxpool1d() + # test_forward_maxpool3d() + # test_forward_hardtanh() + # test_forward_conv() + # test_forward_conv_transpose() + # test_forward_threshold() + # test_forward_contiguous() + # test_forward_batchnorm() + # test_forward_instancenorm() + # test_forward_layernorm() + # test_forward_groupnorm() + # test_forward_transpose() + # test_forward_size() + # test_forward_view() test_forward_select() - test_forward_take() - test_forward_topk() - test_forward_where() - test_forward_addcdiv() - test_forward_addcmul() - test_forward_true_divide() - test_forward_is_floating_point() - test_forward_clone() - test_forward_softplus() - test_forward_softsign() - test_forward_logsoftmax() - test_forward_sigmoid() - test_forward_dense() - test_forward_avgpool() - test_forward_avgpool3d() - test_forward_dropout() - test_forward_slice() - test_forward_mean() - test_forward_expand() - test_forward_pow() - test_forward_unary() - test_forward_clamp() - test_forward_clamp_() - test_forward_logical_not() - test_forward_bitwise_not() - test_forward_bitwise_xor() - test_forward_logical_xor() - test_forward_isfinite() - test_forward_isnan() - test_forward_isinf() - test_forward_ones() - test_forward_ones_like() - test_forward_zeros() - test_forward_zeros_like() - test_forward_full() - test_forward_full_like() - test_forward_linspace() - test_forward_arange() - test_forward_mesh_grid() - test_forward_chunk() - test_forward_split() - test_forward_gather() - test_upsample() - test_forward_upsample3d() - test_forward_nms() - test_forward_roi_align() - test_to() - test_flatten() - test_type_as() - test_forward_functional_pad() - test_forward_zero_pad2d() - test_forward_constant_pad1d() - test_forward_constant_pad2d() - test_forward_constant_pad3d() - test_forward_reflection_pad1d() - test_forward_reflection_pad2d() - test_forward_replication_pad1d() - test_forward_replication_pad2d() - test_forward_replication_pad3d() - test_adaptive_pool3d() - test_conv3d() - test_conv3d_transpose() - test_forward_index() - test_min_max() - test_logsumexp() - test_stack() - test_stack_dynamic() - test_forward_unbind() - test_forward_nonzero() - test_forward_scatter() - test_numel() - test_bincount() - test_cumsum() - test_masked_fill() - - # Model tests - test_resnet18() - test_squeezenet1_0() - test_squeezenet1_1() - test_densenet121() - # disable inception test for now, since loading it takes ~5min on torchvision-0.5 due to scipy bug - # See https://discuss.pytorch.org/t/torchvisions-inception-v3-takes-much-longer-to-load-than-other-models/68756 - # test_inception_v3() - test_googlenet() - test_mnasnet0_5() - test_mobilenet_v2() - - test_custom_conversion_map() - - test_segmentaton_models() - test_3d_models() - - # Quantization test - from qnn_test import test_quantized_imagenet, test_quantized_modules - - test_quantized_modules() - test_quantized_imagenet() - - # Test simple conditionals and loop - test_control_flow() - test_simple_rnn() - - # More complex recurrent models - from test_lstm import test_custom_lstm - - test_custom_lstm() - - # Test bert model - test_forward_pretrained_bert_base_uncased() - - # Test convert torch script(jit) with specific inputs' types - test_convert_torch_script_with_input_types() - test_hard_swish() + # test_forward_take() + # test_forward_topk() + # test_forward_where() + # test_forward_addcdiv() + # test_forward_addcmul() + # test_forward_true_divide() + # test_forward_is_floating_point() + # test_forward_clone() + # test_forward_softplus() + # test_forward_softsign() + # test_forward_logsoftmax() + # test_forward_sigmoid() + # test_forward_dense() + # test_forward_avgpool() + # test_forward_avgpool3d() + # test_forward_dropout() + # test_forward_slice() + # test_forward_mean() + # test_forward_expand() + # test_forward_pow() + # test_forward_unary() + # test_forward_clamp() + # test_forward_clamp_() + # test_forward_logical_not() + # test_forward_bitwise_not() + # test_forward_bitwise_xor() + # test_forward_logical_xor() + # test_forward_isfinite() + # test_forward_isnan() + # test_forward_isinf() + # test_forward_ones() + # test_forward_ones_like() + # test_forward_zeros() + # test_forward_zeros_like() + # test_forward_full() + # test_forward_full_like() + # test_forward_linspace() + # test_forward_arange() + # test_forward_mesh_grid() + # test_forward_chunk() + # test_forward_split() + # test_forward_gather() + # test_upsample() + # test_forward_upsample3d() + # test_forward_nms() + # test_forward_roi_align() + # test_to() + # test_flatten() + # test_type_as() + # test_forward_functional_pad() + # test_forward_zero_pad2d() + # test_forward_constant_pad1d() + # test_forward_constant_pad2d() + # test_forward_constant_pad3d() + # test_forward_reflection_pad1d() + # test_forward_reflection_pad2d() + # test_forward_replication_pad1d() + # test_forward_replication_pad2d() + # test_forward_replication_pad3d() + # test_adaptive_pool3d() + # test_conv3d() + # test_conv3d_transpose() + # test_forward_index() + # test_min_max() + # test_logsumexp() + # test_stack() + # test_stack_dynamic() + # test_forward_unbind() + # test_forward_nonzero() + # test_forward_scatter() + # test_numel() + # test_bincount() + # test_cumsum() + # test_masked_fill() + # test_transformer() + + # # Model tests + # test_resnet18() + # test_squeezenet1_0() + # test_squeezenet1_1() + # test_densenet121() + # # disable inception test for now, since loading it takes ~5min on torchvision-0.5 due to scipy bug + # # See https://discuss.pytorch.org/t/torchvisions-inception-v3-takes-much-longer-to-load-than-other-models/68756 + # # test_inception_v3() + # test_googlenet() + # test_mnasnet0_5() + # test_mobilenet_v2() + + # test_custom_conversion_map() + + # test_segmentaton_models() + # test_3d_models() + + # # Quantization test + # from qnn_test import test_quantized_imagenet, test_quantized_modules + + # test_quantized_modules() + # test_quantized_imagenet() + + # # Test simple conditionals and loop + # test_control_flow() + # test_simple_rnn() + + # # More complex recurrent models + # from test_lstm import test_custom_lstm + + # test_custom_lstm() + + # # Test bert model + # test_forward_pretrained_bert_base_uncased() + + # # Test convert torch script(jit) with specific inputs' types + # test_convert_torch_script_with_input_types() + # test_hard_swish() From fd91f7c178ac4eec2ae8fd2676b8b05742c60b55 Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Wed, 27 Jan 2021 18:55:17 +0900 Subject: [PATCH 5/9] add sort and argsort --- python/tvm/relay/frontend/pytorch.py | 21 +++++++++++- tests/python/frontend/pytorch/test_forward.py | 34 ++++++++++++++++++- 2 files changed, 53 insertions(+), 2 deletions(-) diff --git a/python/tvm/relay/frontend/pytorch.py b/python/tvm/relay/frontend/pytorch.py index 007a1f7d45ca..9c288ccac486 100644 --- a/python/tvm/relay/frontend/pytorch.py +++ b/python/tvm/relay/frontend/pytorch.py @@ -2080,9 +2080,26 @@ def cumsum(self, inputs, input_types): def masked_fill(self, inputs, input_types): mask = inputs[1] value = _op.cast(_wrap_const(inputs[2]), input_types[0]) - return _op.where(mask, value, inputs[0]) + def sort(self, inputs, input_types): + data = inputs[0] + dim = inputs[1] + is_descending = inputs[2] + # pytorch sort returns both sorted indices and values + indices = _op.argsort(data, dim, not is_descending) + shape = self.infer_shape(data) + if len(shape) == 1: + return _op.take(data, indices, dim), indices + # TOOD(masahi): Is there a better way? + return _op.sort(data, dim, not is_descending), indices + + def argsort(self, inputs, input_types): + data = inputs[0] + dim = inputs[1] + is_descending = inputs[2] + return _op.argsort(data, dim, not is_descending) + def is_floating_point(self, inputs, input_types): assert len(inputs) == 1 @@ -2293,6 +2310,8 @@ def create_convert_map(self): "aten::hardswish": self.hard_swish, "aten::cumsum": self.cumsum, "aten::masked_fill": self.masked_fill, + "aten::argsort": self.argsort, + "aten::sort": self.sort, } def update_convert_map(self, custom_map): diff --git a/tests/python/frontend/pytorch/test_forward.py b/tests/python/frontend/pytorch/test_forward.py index cc84bdaac342..63ffed49778b 100644 --- a/tests/python/frontend/pytorch/test_forward.py +++ b/tests/python/frontend/pytorch/test_forward.py @@ -3491,6 +3491,36 @@ def test_transformer(): verify_model(model.eval(), input_data=[src, tgt]) +def test_argsort(): + def test_fn(dim, descending): + return lambda x: torch.argsort(x, dim=dim, descending=descending) + + inp = torch.randn(100) + verify_model(test_fn(0, True), [inp]) + verify_model(test_fn(0, False), [inp]) + + inp = torch.randn(100, 100) + verify_model(test_fn(0, True), [inp]) + verify_model(test_fn(0, False), [inp]) + verify_model(test_fn(1, True), [inp]) + verify_model(test_fn(1, False), [inp]) + + +def test_sort(): + def test_fn(dim, descending): + return lambda x: torch.sort(x, dim=dim, descending=descending) + + inp = torch.randn(100) + verify_model(test_fn(0, True), [inp]) + verify_model(test_fn(0, False), [inp]) + + inp = torch.randn(100, 100) + verify_model(test_fn(0, True), [inp]) + verify_model(test_fn(0, False), [inp]) + verify_model(test_fn(1, True), [inp]) + verify_model(test_fn(1, False), [inp]) + + if __name__ == "__main__": # # some structural tests # test_forward_traced_function() @@ -3546,7 +3576,7 @@ def test_transformer(): # test_forward_transpose() # test_forward_size() # test_forward_view() - test_forward_select() + # test_forward_select() # test_forward_take() # test_forward_topk() # test_forward_where() @@ -3622,6 +3652,8 @@ def test_transformer(): # test_cumsum() # test_masked_fill() # test_transformer() + test_sort() + test_argsort() # # Model tests # test_resnet18() From f13417adc74c2d42c9b462d105e46ed1defb1336 Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Wed, 27 Jan 2021 19:03:14 +0900 Subject: [PATCH 6/9] add logical_and --- python/tvm/relay/frontend/pytorch.py | 1 + tests/python/frontend/pytorch/test_forward.py | 18 ++++++++++++++++-- 2 files changed, 17 insertions(+), 2 deletions(-) diff --git a/python/tvm/relay/frontend/pytorch.py b/python/tvm/relay/frontend/pytorch.py index 9c288ccac486..14a6e28a183d 100644 --- a/python/tvm/relay/frontend/pytorch.py +++ b/python/tvm/relay/frontend/pytorch.py @@ -2293,6 +2293,7 @@ def create_convert_map(self): "torchvision::roi_align": self.roi_align, "aten::unbind": self.unbind, "aten::__and__": self.logical_and, + "aten::logical_and": self.logical_and, "aten::_shape_as_tensor": self.shape_as_tensor, "aten::nonzero": self.nonzero, "aten::nonzero_numpy": self.nonzero_numpy, diff --git a/tests/python/frontend/pytorch/test_forward.py b/tests/python/frontend/pytorch/test_forward.py index 63ffed49778b..a262514aa7e5 100644 --- a/tests/python/frontend/pytorch/test_forward.py +++ b/tests/python/frontend/pytorch/test_forward.py @@ -3521,6 +3521,19 @@ def test_fn(dim, descending): verify_model(test_fn(1, False), [inp]) +def test_logical_and(): + def test_fn(x, y): + return torch.logical_and(x, y) + + a = torch.tensor([0, 1, 10, 0], dtype=torch.int8) + b = torch.tensor([4, 0, 1, 0], dtype=torch.int8) + verify_model(test_fn, [a, b]) + + a = torch.tensor([True, False, True]) + b = torch.tensor([True, False, False]) + verify_model(test_fn, [a, b]) + + if __name__ == "__main__": # # some structural tests # test_forward_traced_function() @@ -3652,8 +3665,9 @@ def test_fn(dim, descending): # test_cumsum() # test_masked_fill() # test_transformer() - test_sort() - test_argsort() + # test_sort() + # test_argsort() + test_logical_and() # # Model tests # test_resnet18() From 949eece15ef3e3d99c3759ce7ab80ff54fd2ac78 Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Wed, 27 Jan 2021 19:41:41 +0900 Subject: [PATCH 7/9] support masked_select --- python/tvm/relay/frontend/pytorch.py | 6 + tests/python/frontend/pytorch/test_forward.py | 333 +++++++++--------- 2 files changed, 178 insertions(+), 161 deletions(-) diff --git a/python/tvm/relay/frontend/pytorch.py b/python/tvm/relay/frontend/pytorch.py index 14a6e28a183d..7c033459217a 100644 --- a/python/tvm/relay/frontend/pytorch.py +++ b/python/tvm/relay/frontend/pytorch.py @@ -2082,6 +2082,11 @@ def masked_fill(self, inputs, input_types): value = _op.cast(_wrap_const(inputs[2]), input_types[0]) return _op.where(mask, value, inputs[0]) + def masked_select(self, inputs, input_types): + mask = inputs[1] + indices = self.nonzero([mask], input_types, is_numpy_style=True) + return _op.adv_index([inputs[0]] + [indices[i] for i in range(indices.size)]) + def sort(self, inputs, input_types): data = inputs[0] dim = inputs[1] @@ -2311,6 +2316,7 @@ def create_convert_map(self): "aten::hardswish": self.hard_swish, "aten::cumsum": self.cumsum, "aten::masked_fill": self.masked_fill, + "aten::masked_select": self.masked_select, "aten::argsort": self.argsort, "aten::sort": self.sort, } diff --git a/tests/python/frontend/pytorch/test_forward.py b/tests/python/frontend/pytorch/test_forward.py index a262514aa7e5..9f27ea4750af 100644 --- a/tests/python/frontend/pytorch/test_forward.py +++ b/tests/python/frontend/pytorch/test_forward.py @@ -3534,176 +3534,187 @@ def test_fn(x, y): verify_model(test_fn, [a, b]) +def test_masked_select(): + def test_fn(x, mask): + return torch.masked_select(x, mask) + + for shape in [(10,), (3, 4), (16, 32, 64)]: + x = torch.randn(*shape) + mask = x.ge(0.5) + verify_trace_model(test_fn, [x, mask], ["llvm"]) + + if __name__ == "__main__": - # # some structural tests - # test_forward_traced_function() - # test_forward_dtypes() - # test_weight_names() - # test_duplicate_weight_use() - - # # Single operator tests - # test_forward_pixel_shuffle() - # test_forward_add() - # test_forward_subtract() - # test_forward_multiply() - # test_forward_matmul() - # test_forward_rsub() - # test_forward_onehot() - # test_forward_embedding() - # test_forward_reshape() - # test_forward_reciprocal() - # test_forward_repeat() - # test_forward_repeat_interleave() - # test_forward_squeeze() - # test_forward_unsqueeze() - # test_forward_concatenate() - # test_forward_reduce_sum() - # test_forward_reduce_prod() - # test_forward_argmin() - # test_forward_argmax() - # test_forward_norm() - # test_forward_frobenius_norm() - # test_forward_std() - # test_forward_variance() - # test_forward_relu() - # test_forward_prelu() - # test_forward_leakyrelu() - # test_forward_elu() - # test_forward_celu() - # test_forward_gelu() - # test_forward_selu() - # test_forward_log_sigmoid() - # test_forward_adaptiveavgpool() - # test_forward_maxpool2d() - # test_forward_maxpool1d() - # test_forward_maxpool3d() - # test_forward_hardtanh() - # test_forward_conv() - # test_forward_conv_transpose() - # test_forward_threshold() - # test_forward_contiguous() - # test_forward_batchnorm() - # test_forward_instancenorm() - # test_forward_layernorm() - # test_forward_groupnorm() - # test_forward_transpose() - # test_forward_size() - # test_forward_view() - # test_forward_select() - # test_forward_take() - # test_forward_topk() - # test_forward_where() - # test_forward_addcdiv() - # test_forward_addcmul() - # test_forward_true_divide() - # test_forward_is_floating_point() - # test_forward_clone() - # test_forward_softplus() - # test_forward_softsign() - # test_forward_logsoftmax() - # test_forward_sigmoid() - # test_forward_dense() - # test_forward_avgpool() - # test_forward_avgpool3d() - # test_forward_dropout() - # test_forward_slice() - # test_forward_mean() - # test_forward_expand() - # test_forward_pow() - # test_forward_unary() - # test_forward_clamp() - # test_forward_clamp_() - # test_forward_logical_not() - # test_forward_bitwise_not() - # test_forward_bitwise_xor() - # test_forward_logical_xor() - # test_forward_isfinite() - # test_forward_isnan() - # test_forward_isinf() - # test_forward_ones() - # test_forward_ones_like() - # test_forward_zeros() - # test_forward_zeros_like() - # test_forward_full() - # test_forward_full_like() - # test_forward_linspace() - # test_forward_arange() - # test_forward_mesh_grid() - # test_forward_chunk() - # test_forward_split() - # test_forward_gather() - # test_upsample() - # test_forward_upsample3d() - # test_forward_nms() - # test_forward_roi_align() - # test_to() - # test_flatten() - # test_type_as() - # test_forward_functional_pad() - # test_forward_zero_pad2d() - # test_forward_constant_pad1d() - # test_forward_constant_pad2d() - # test_forward_constant_pad3d() - # test_forward_reflection_pad1d() - # test_forward_reflection_pad2d() - # test_forward_replication_pad1d() - # test_forward_replication_pad2d() - # test_forward_replication_pad3d() - # test_adaptive_pool3d() - # test_conv3d() - # test_conv3d_transpose() - # test_forward_index() - # test_min_max() - # test_logsumexp() - # test_stack() - # test_stack_dynamic() - # test_forward_unbind() - # test_forward_nonzero() - # test_forward_scatter() - # test_numel() - # test_bincount() - # test_cumsum() - # test_masked_fill() - # test_transformer() - # test_sort() - # test_argsort() + # some structural tests + test_forward_traced_function() + test_forward_dtypes() + test_weight_names() + test_duplicate_weight_use() + + # Single operator tests + test_forward_pixel_shuffle() + test_forward_add() + test_forward_subtract() + test_forward_multiply() + test_forward_matmul() + test_forward_rsub() + test_forward_onehot() + test_forward_embedding() + test_forward_reshape() + test_forward_reciprocal() + test_forward_repeat() + test_forward_repeat_interleave() + test_forward_squeeze() + test_forward_unsqueeze() + test_forward_concatenate() + test_forward_reduce_sum() + test_forward_reduce_prod() + test_forward_argmin() + test_forward_argmax() + test_forward_norm() + test_forward_frobenius_norm() + test_forward_std() + test_forward_variance() + test_forward_relu() + test_forward_prelu() + test_forward_leakyrelu() + test_forward_elu() + test_forward_celu() + test_forward_gelu() + test_forward_selu() + test_forward_log_sigmoid() + test_forward_adaptiveavgpool() + test_forward_maxpool2d() + test_forward_maxpool1d() + test_forward_maxpool3d() + test_forward_hardtanh() + test_forward_conv() + test_forward_conv_transpose() + test_forward_threshold() + test_forward_contiguous() + test_forward_batchnorm() + test_forward_instancenorm() + test_forward_layernorm() + test_forward_groupnorm() + test_forward_transpose() + test_forward_size() + test_forward_view() + test_forward_select() + test_forward_take() + test_forward_topk() + test_forward_where() + test_forward_addcdiv() + test_forward_addcmul() + test_forward_true_divide() + test_forward_is_floating_point() + test_forward_clone() + test_forward_softplus() + test_forward_softsign() + test_forward_logsoftmax() + test_forward_sigmoid() + test_forward_dense() + test_forward_avgpool() + test_forward_avgpool3d() + test_forward_dropout() + test_forward_slice() + test_forward_mean() + test_forward_expand() + test_forward_pow() + test_forward_unary() + test_forward_clamp() + test_forward_clamp_() + test_forward_logical_not() + test_forward_bitwise_not() + test_forward_bitwise_xor() + test_forward_logical_xor() + test_forward_isfinite() + test_forward_isnan() + test_forward_isinf() + test_forward_ones() + test_forward_ones_like() + test_forward_zeros() + test_forward_zeros_like() + test_forward_full() + test_forward_full_like() + test_forward_linspace() + test_forward_arange() + test_forward_mesh_grid() + test_forward_chunk() + test_forward_split() + test_forward_gather() + test_upsample() + test_forward_upsample3d() + test_forward_nms() + test_forward_roi_align() + test_to() + test_flatten() + test_type_as() + test_forward_functional_pad() + test_forward_zero_pad2d() + test_forward_constant_pad1d() + test_forward_constant_pad2d() + test_forward_constant_pad3d() + test_forward_reflection_pad1d() + test_forward_reflection_pad2d() + test_forward_replication_pad1d() + test_forward_replication_pad2d() + test_forward_replication_pad3d() + test_adaptive_pool3d() + test_conv3d() + test_conv3d_transpose() + test_forward_index() + test_min_max() + test_logsumexp() + test_stack() + test_stack_dynamic() + test_forward_unbind() + test_forward_nonzero() + test_forward_scatter() + test_numel() + test_bincount() + test_cumsum() + test_masked_fill() + test_transformer() + test_sort() + test_argsort() test_logical_and() + test_masked_select() - # # Model tests - # test_resnet18() - # test_squeezenet1_0() - # test_squeezenet1_1() - # test_densenet121() - # # disable inception test for now, since loading it takes ~5min on torchvision-0.5 due to scipy bug - # # See https://discuss.pytorch.org/t/torchvisions-inception-v3-takes-much-longer-to-load-than-other-models/68756 - # # test_inception_v3() - # test_googlenet() - # test_mnasnet0_5() - # test_mobilenet_v2() + # Model tests + test_resnet18() + test_squeezenet1_0() + test_squeezenet1_1() + test_densenet121() + # disable inception test for now, since loading it takes ~5min on torchvision-0.5 due to scipy bug + # See https://discuss.pytorch.org/t/torchvisions-inception-v3-takes-much-longer-to-load-than-other-models/68756 + # test_inception_v3() + test_googlenet() + test_mnasnet0_5() + test_mobilenet_v2() - # test_custom_conversion_map() + test_custom_conversion_map() - # test_segmentaton_models() - # test_3d_models() + test_segmentaton_models() + test_3d_models() - # # Quantization test - # from qnn_test import test_quantized_imagenet, test_quantized_modules + # Quantization test + from qnn_test import test_quantized_imagenet, test_quantized_modules - # test_quantized_modules() - # test_quantized_imagenet() + test_quantized_modules() + test_quantized_imagenet() - # # Test simple conditionals and loop - # test_control_flow() - # test_simple_rnn() + # Test simple conditionals and loop + test_control_flow() + test_simple_rnn() - # # More complex recurrent models - # from test_lstm import test_custom_lstm + # More complex recurrent models + from test_lstm import test_custom_lstm - # test_custom_lstm() + test_custom_lstm() - # # Test bert model - # test_forward_pretrained_bert_base_uncased() + # Test bert model + test_forward_pretrained_bert_base_uncased() - # # Test convert torch script(jit) with specific inputs' types - # test_convert_torch_script_with_input_types() - # test_hard_swish() + # Test convert torch script(jit) with specific inputs' types + test_convert_torch_script_with_input_types() + test_hard_swish() From ad1903e386bab9929063e656ed1dda24fb59f924 Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Wed, 27 Jan 2021 20:02:15 +0900 Subject: [PATCH 8/9] add gpu targets to masked_select test --- tests/python/frontend/pytorch/test_forward.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/python/frontend/pytorch/test_forward.py b/tests/python/frontend/pytorch/test_forward.py index 9f27ea4750af..6d9b559c6ba1 100644 --- a/tests/python/frontend/pytorch/test_forward.py +++ b/tests/python/frontend/pytorch/test_forward.py @@ -3541,7 +3541,7 @@ def test_fn(x, mask): for shape in [(10,), (3, 4), (16, 32, 64)]: x = torch.randn(*shape) mask = x.ge(0.5) - verify_trace_model(test_fn, [x, mask], ["llvm"]) + verify_trace_model(test_fn, [x, mask], ["llvm", "cuda", "nvptx"]) if __name__ == "__main__": From 6b4107de0f8212aed9056abf5efc75f62f326efb Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Wed, 27 Jan 2021 22:03:01 +0900 Subject: [PATCH 9/9] improve sort conversion --- python/tvm/relay/frontend/pytorch.py | 6 +----- 1 file changed, 1 insertion(+), 5 deletions(-) diff --git a/python/tvm/relay/frontend/pytorch.py b/python/tvm/relay/frontend/pytorch.py index 7c033459217a..68e68fdbeed2 100644 --- a/python/tvm/relay/frontend/pytorch.py +++ b/python/tvm/relay/frontend/pytorch.py @@ -2093,11 +2093,7 @@ def sort(self, inputs, input_types): is_descending = inputs[2] # pytorch sort returns both sorted indices and values indices = _op.argsort(data, dim, not is_descending) - shape = self.infer_shape(data) - if len(shape) == 1: - return _op.take(data, indices, dim), indices - # TOOD(masahi): Is there a better way? - return _op.sort(data, dim, not is_descending), indices + return _op.gather(data, dim, indices), indices def argsort(self, inputs, input_types): data = inputs[0]