From 6e7bf536d54587058d16ac86de7af0769731b17c Mon Sep 17 00:00:00 2001 From: alexwong <11878166+alexwong@users.noreply.github.com> Date: Wed, 24 Feb 2021 09:48:54 +0000 Subject: [PATCH 1/3] Convert strides and pool_size to int --- python/tvm/relay/frontend/pytorch.py | 20 ++++++++++++++++++++ 1 file changed, 20 insertions(+) diff --git a/python/tvm/relay/frontend/pytorch.py b/python/tvm/relay/frontend/pytorch.py index 205b2aa779e6..1403dd78eed1 100644 --- a/python/tvm/relay/frontend/pytorch.py +++ b/python/tvm/relay/frontend/pytorch.py @@ -841,6 +841,16 @@ def maxpool_2d(self, inputs, input_types): dilation = inputs[4] ceil_mode = int(inputs[5]) + if isinstance(pool_size, list): + for i in range(len(pool_size)): + if isinstance(pool_size[i], _expr.Constant): + pool_size[i] = int(_infer_value_simulated(pool_size[i], {}).asnumpy()) + + if isinstance(strides, list): + for i in range(len(strides)): + if isinstance(strides[i], _expr.Constant): + strides[i] = int(_infer_value_simulated(strides[i], {}).asnumpy()) + if dilation != [1, 1]: msg = "MaxPool2d with dilation %s is not implemented" % (str(dilation)) raise NotImplementedError(msg) @@ -1322,6 +1332,16 @@ def avg_pool2d(self, inputs, input_types): ceil_mode = int(inputs[4]) count_include_pad = int(inputs[5]) + if isinstance(pool_size, list): + for i in range(len(pool_size)): + if isinstance(pool_size[i], _expr.Constant): + pool_size[i] = int(_infer_value_simulated(pool_size[i], {}).asnumpy()) + + if isinstance(strides, list): + for i in range(len(strides)): + if isinstance(strides[i], _expr.Constant): + strides[i] = int(_infer_value_simulated(strides[i], {}).asnumpy()) + def func(x): return _op.nn.avg_pool2d( x, From b881f70121cc416fdf2eb97072b62ead634d8f43 Mon Sep 17 00:00:00 2001 From: alexwong <11878166+alexwong@users.noreply.github.com> Date: Thu, 25 Feb 2021 12:17:52 -0800 Subject: [PATCH 2/3] Make helper function, add test --- python/tvm/relay/frontend/pytorch.py | 36 +++++++------------ tests/python/frontend/pytorch/test_forward.py | 10 +++++- 2 files changed, 21 insertions(+), 25 deletions(-) diff --git a/python/tvm/relay/frontend/pytorch.py b/python/tvm/relay/frontend/pytorch.py index 1403dd78eed1..de2d9502ff7a 100644 --- a/python/tvm/relay/frontend/pytorch.py +++ b/python/tvm/relay/frontend/pytorch.py @@ -832,25 +832,23 @@ def adaptive_avg_pool_3d(self, inputs, input_types): output_size = inputs[1] return _op.nn.adaptive_avg_pool3d(data, output_size=output_size) + @staticmethod + def convert_const_list(data): + if isinstance(data, list): + for i,_ in enumerate(data): + if isinstance(data[i], _expr.Expr): + data[i] = int(_infer_value_simulated(data[i], {}).asnumpy()) + return data + def maxpool_2d(self, inputs, input_types): data = inputs[0] - pool_size = inputs[1] - strides = inputs[2] if inputs[2] else pool_size + pool_size = self.convert_const_list(inputs[1]) + strides = self.convert_const_list(inputs[2] if inputs[2] else pool_size) padding = inputs[3] dilation = inputs[4] ceil_mode = int(inputs[5]) - if isinstance(pool_size, list): - for i in range(len(pool_size)): - if isinstance(pool_size[i], _expr.Constant): - pool_size[i] = int(_infer_value_simulated(pool_size[i], {}).asnumpy()) - - if isinstance(strides, list): - for i in range(len(strides)): - if isinstance(strides[i], _expr.Constant): - strides[i] = int(_infer_value_simulated(strides[i], {}).asnumpy()) - if dilation != [1, 1]: msg = "MaxPool2d with dilation %s is not implemented" % (str(dilation)) raise NotImplementedError(msg) @@ -1326,22 +1324,12 @@ def softplus(self, inputs, input_types): def avg_pool2d(self, inputs, input_types): data = inputs[0] - pool_size = inputs[1] - strides = inputs[2] if inputs[2] else pool_size + pool_size = self.convert_const_list(inputs[1]) + strides = self.convert_const_list(inputs[2] if inputs[2] else pool_size) padding = inputs[3] ceil_mode = int(inputs[4]) count_include_pad = int(inputs[5]) - if isinstance(pool_size, list): - for i in range(len(pool_size)): - if isinstance(pool_size[i], _expr.Constant): - pool_size[i] = int(_infer_value_simulated(pool_size[i], {}).asnumpy()) - - if isinstance(strides, list): - for i in range(len(strides)): - if isinstance(strides[i], _expr.Constant): - strides[i] = int(_infer_value_simulated(strides[i], {}).asnumpy()) - def func(x): return _op.nn.avg_pool2d( x, diff --git a/tests/python/frontend/pytorch/test_forward.py b/tests/python/frontend/pytorch/test_forward.py index aa42b0fb84e4..26f7c13d9742 100644 --- a/tests/python/frontend/pytorch/test_forward.py +++ b/tests/python/frontend/pytorch/test_forward.py @@ -736,8 +736,16 @@ def forward(self, *args): output, indices = self.pool(args[0]) return output - verify_model(MaxPool2DWithIndices().float().eval(), input_data=input_data) + class MaxPool2DWithIntStrides(Module): + def forward(self, *args): + # Makes kernel_size and strides a Relay expr to test converting back to int + x_shape = args[0].shape + kernel_size = [torch.tensor(x_shape[1]).int(), torch.tensor(x_shape[1]).int()] + strides = [torch.tensor(x_shape[0]).int(), torch.tensor(x_shape[0]).int()] + return torch.nn.functional.max_pool2d(args[0], kernel_size=[4, 4], stride=strides) + verify_model(MaxPool2DWithIndices().float().eval(), input_data=input_data) + verify_model(MaxPool2DWithIntStrides().float().eval(), input_data=input_data) @tvm.testing.uses_gpu def test_forward_maxpool1d(): From de0a205211b378e1884e9dca257d8c91697f53c7 Mon Sep 17 00:00:00 2001 From: alexwong <11878166+alexwong@users.noreply.github.com> Date: Thu, 25 Feb 2021 15:37:17 -0800 Subject: [PATCH 3/3] Fix lint --- python/tvm/relay/frontend/pytorch.py | 2 +- tests/python/frontend/pytorch/test_forward.py | 1 + 2 files changed, 2 insertions(+), 1 deletion(-) diff --git a/python/tvm/relay/frontend/pytorch.py b/python/tvm/relay/frontend/pytorch.py index de2d9502ff7a..26955d0a3d34 100644 --- a/python/tvm/relay/frontend/pytorch.py +++ b/python/tvm/relay/frontend/pytorch.py @@ -835,7 +835,7 @@ def adaptive_avg_pool_3d(self, inputs, input_types): @staticmethod def convert_const_list(data): if isinstance(data, list): - for i,_ in enumerate(data): + for i, _ in enumerate(data): if isinstance(data[i], _expr.Expr): data[i] = int(_infer_value_simulated(data[i], {}).asnumpy()) return data diff --git a/tests/python/frontend/pytorch/test_forward.py b/tests/python/frontend/pytorch/test_forward.py index 26f7c13d9742..9846e6a53a8f 100644 --- a/tests/python/frontend/pytorch/test_forward.py +++ b/tests/python/frontend/pytorch/test_forward.py @@ -747,6 +747,7 @@ def forward(self, *args): verify_model(MaxPool2DWithIndices().float().eval(), input_data=input_data) verify_model(MaxPool2DWithIntStrides().float().eval(), input_data=input_data) + @tvm.testing.uses_gpu def test_forward_maxpool1d(): torch.set_grad_enabled(False)