diff --git a/python/tvm/relay/frontend/pytorch.py b/python/tvm/relay/frontend/pytorch.py index 8ae1e862ffd5..cb9ea6a043f4 100644 --- a/python/tvm/relay/frontend/pytorch.py +++ b/python/tvm/relay/frontend/pytorch.py @@ -1353,47 +1353,54 @@ def softplus(self, inputs, input_types): beta = _expr.const(float(inputs[1]), dtype=dtype) return _op.log(_op.exp(inputs[0] * beta) + _expr.const(1.0, dtype=dtype)) / beta - def avg_pool2d(self, inputs, input_types): - data = inputs[0] - - pool_size = self.convert_const_list(inputs[1]) - strides = self.convert_const_list(inputs[2] if inputs[2] else pool_size) - padding = inputs[3] - ceil_mode = int(inputs[4]) - count_include_pad = int(inputs[5]) - - def func(x): - return _op.nn.avg_pool2d( - x, - pool_size=pool_size, - strides=strides, - padding=padding, - ceil_mode=ceil_mode, - count_include_pad=count_include_pad, - ) + def make_avg_pool(self, dim): + def avg_pool(inputs, input_types): + data = inputs[0] - if self.is_quantized_tensor(data): - return qnn_torch.apply_with_upcast(data, func) + pool_size = self.convert_const_list(inputs[1]) + strides = self.convert_const_list(inputs[2] if inputs[2] else pool_size) + padding = inputs[3] + ceil_mode = int(inputs[4]) + count_include_pad = int(inputs[5]) - return func(data) + def func(x): + if dim == 1: + return _op.nn.avg_pool1d( + x, + pool_size=pool_size, + strides=strides, + padding=padding, + ceil_mode=ceil_mode, + count_include_pad=count_include_pad, + ) + elif dim == 2: + return _op.nn.avg_pool2d( + x, + pool_size=pool_size, + strides=strides, + padding=padding, + ceil_mode=ceil_mode, + count_include_pad=count_include_pad, + ) + elif dim == 3: + return _op.nn.avg_pool3d( + x, + pool_size=pool_size, + strides=strides, + padding=padding, + ceil_mode=ceil_mode, + count_include_pad=count_include_pad, + ) + else: + msg = "Average Pooling dimension should be between 1 and 3" + raise RuntimeError(msg) - def avg_pool3d(self, inputs, input_types): - data = inputs[0] + if self.is_quantized_tensor(data): + return qnn_torch.apply_with_upcast(data, func) - pool_size = inputs[1] - strides = inputs[2] if inputs[2] else pool_size - padding = inputs[3] - ceil_mode = int(inputs[4]) - count_include_pad = int(inputs[5]) + return func(data) - return _op.nn.avg_pool3d( - data, - pool_size=pool_size, - strides=strides, - padding=padding, - ceil_mode=ceil_mode, - count_include_pad=count_include_pad, - ) + return avg_pool def linear(self, inputs, input_types): # https://pytorch.org/docs/stable/nn.functional.html#linear @@ -2350,8 +2357,9 @@ def create_convert_map(self): "aten::log_softmax": self.log_softmax, "aten::sigmoid": self.sigmoid, "aten::softplus": self.softplus, - "aten::avg_pool2d": self.avg_pool2d, - "aten::avg_pool3d": self.avg_pool3d, + "aten::avg_pool1d": self.make_avg_pool(1), + "aten::avg_pool2d": self.make_avg_pool(2), + "aten::avg_pool3d": self.make_avg_pool(3), "aten::linear": self.linear, "aten::dropout": self.dropout, "aten::dropout_": self.dropout, diff --git a/tests/python/frontend/pytorch/test_forward.py b/tests/python/frontend/pytorch/test_forward.py index d0edfd9c8036..572aa472c540 100644 --- a/tests/python/frontend/pytorch/test_forward.py +++ b/tests/python/frontend/pytorch/test_forward.py @@ -809,7 +809,24 @@ def forward(self, *args): @tvm.testing.uses_gpu -def test_forward_avgpool(): +def test_forward_avgpool1d(): + torch.set_grad_enabled(False) + input_shape = [1, 3, 10] + + class AvgPool1D2(Module): + def forward(self, *args): + return torch.nn.functional.avg_pool1d(args[0], kernel_size=[10]) + + input_data = torch.rand(input_shape).float() + verify_model(torch.nn.AvgPool1d(kernel_size=[10]).eval(), input_data=input_data) + verify_model(AvgPool1D2().float().eval(), input_data=input_data) + verify_model( + torch.nn.AvgPool1d(kernel_size=[5], stride=2, padding=2).eval(), input_data=input_data + ) + + +@tvm.testing.uses_gpu +def test_forward_avgpool2d(): torch.set_grad_enabled(False) input_shape = [1, 3, 10, 10] @@ -820,6 +837,9 @@ def forward(self, *args): input_data = torch.rand(input_shape).float() verify_model(torch.nn.AvgPool2d(kernel_size=[10, 10]).eval(), input_data=input_data) verify_model(AvgPool2D2().float().eval(), input_data=input_data) + verify_model( + torch.nn.AvgPool2d(kernel_size=5, stride=2, padding=2).eval(), input_data=input_data + ) @tvm.testing.uses_gpu @@ -834,6 +854,9 @@ def forward(self, *args): input_data = torch.rand(input_shape).float() verify_model(torch.nn.AvgPool3d(kernel_size=[10, 10, 10]).eval(), input_data=input_data) verify_model(AvgPool3D1().float().eval(), input_data=input_data) + verify_model( + torch.nn.AvgPool3d(kernel_size=5, stride=2, padding=2).eval(), input_data=input_data + ) @tvm.testing.uses_gpu @@ -3838,7 +3861,8 @@ def test_fn(is_sorted, return_inverse, return_counts): test_forward_logsoftmax() test_forward_sigmoid() test_forward_dense() - test_forward_avgpool() + test_forward_avgpool1d() + test_forward_avgpool2d() test_forward_avgpool3d() test_forward_dropout() test_forward_slice()