diff --git a/python/tvm/relay/frontend/pytorch.py b/python/tvm/relay/frontend/pytorch.py index 205b2aa779e6..26955d0a3d34 100644 --- a/python/tvm/relay/frontend/pytorch.py +++ b/python/tvm/relay/frontend/pytorch.py @@ -832,11 +832,19 @@ def adaptive_avg_pool_3d(self, inputs, input_types): output_size = inputs[1] return _op.nn.adaptive_avg_pool3d(data, output_size=output_size) + @staticmethod + def convert_const_list(data): + if isinstance(data, list): + for i, _ in enumerate(data): + if isinstance(data[i], _expr.Expr): + data[i] = int(_infer_value_simulated(data[i], {}).asnumpy()) + return data + def maxpool_2d(self, inputs, input_types): data = inputs[0] - pool_size = inputs[1] - strides = inputs[2] if inputs[2] else pool_size + pool_size = self.convert_const_list(inputs[1]) + strides = self.convert_const_list(inputs[2] if inputs[2] else pool_size) padding = inputs[3] dilation = inputs[4] ceil_mode = int(inputs[5]) @@ -1316,8 +1324,8 @@ def softplus(self, inputs, input_types): def avg_pool2d(self, inputs, input_types): data = inputs[0] - pool_size = inputs[1] - strides = inputs[2] if inputs[2] else pool_size + pool_size = self.convert_const_list(inputs[1]) + strides = self.convert_const_list(inputs[2] if inputs[2] else pool_size) padding = inputs[3] ceil_mode = int(inputs[4]) count_include_pad = int(inputs[5]) diff --git a/tests/python/frontend/pytorch/test_forward.py b/tests/python/frontend/pytorch/test_forward.py index aa42b0fb84e4..9846e6a53a8f 100644 --- a/tests/python/frontend/pytorch/test_forward.py +++ b/tests/python/frontend/pytorch/test_forward.py @@ -736,7 +736,16 @@ def forward(self, *args): output, indices = self.pool(args[0]) return output + class MaxPool2DWithIntStrides(Module): + def forward(self, *args): + # Makes kernel_size and strides a Relay expr to test converting back to int + x_shape = args[0].shape + kernel_size = [torch.tensor(x_shape[1]).int(), torch.tensor(x_shape[1]).int()] + strides = [torch.tensor(x_shape[0]).int(), torch.tensor(x_shape[0]).int()] + return torch.nn.functional.max_pool2d(args[0], kernel_size=[4, 4], stride=strides) + verify_model(MaxPool2DWithIndices().float().eval(), input_data=input_data) + verify_model(MaxPool2DWithIntStrides().float().eval(), input_data=input_data) @tvm.testing.uses_gpu