diff --git a/python/tvm/relax/frontend/torch/base_fx_graph_translator.py b/python/tvm/relax/frontend/torch/base_fx_graph_translator.py index f789eb8af35b..5c8d7095e511 100644 --- a/python/tvm/relax/frontend/torch/base_fx_graph_translator.py +++ b/python/tvm/relax/frontend/torch/base_fx_graph_translator.py @@ -733,6 +733,7 @@ def _conv_transpose1d_impl( padding: Optional[Tuple], dilation: Optional[Tuple], groups: Optional[Tuple], + output_padding: Optional[Tuple], ) -> relax.Var: conv1d_transpose = self.block_builder.emit( relax.op.nn.conv1d_transpose( @@ -742,8 +743,9 @@ def _conv_transpose1d_impl( padding=padding, dilation=dilation, groups=groups, + output_padding=output_padding, data_layout="NCW", - kernel_layout="OIW", + kernel_layout="IOW", out_dtype="float32", ) ) @@ -762,8 +764,9 @@ def _conv_transpose1d(self, node: fx.Node) -> relax.Var: bias = args[2] if len(args) > 2 else None stride = args[3] if len(args) > 3 else 1 padding = args[4] if len(args) > 4 else 0 - dilation = args[5] if len(args) > 5 else 1 + output_padding = args[5] if len(args) > 5 else 0 groups = args[6] if len(args) > 6 else 1 + dilation = args[7] if len(args) > 7 else 1 return self._conv_transpose1d_impl( x, weight, @@ -772,6 +775,7 @@ def _conv_transpose1d(self, node: fx.Node) -> relax.Var: padding=padding, dilation=dilation, groups=groups, + output_padding=output_padding, ) def _conv_transpose2d_impl( @@ -783,6 +787,7 @@ def _conv_transpose2d_impl( padding: Optional[Tuple], dilation: Optional[Tuple], groups: Optional[Tuple], + output_padding: Optional[Tuple], ) -> relax.Var: conv2d_transpose = self.block_builder.emit( relax.op.nn.conv2d_transpose( @@ -792,8 +797,9 @@ def _conv_transpose2d_impl( padding=padding, dilation=dilation, groups=groups, + output_padding=output_padding, data_layout="NCHW", - kernel_layout="OIHW", + kernel_layout="IOHW", out_dtype="float32", ) ) @@ -812,8 +818,9 @@ def _conv_transpose2d(self, node: fx.Node) -> relax.Var: bias = args[2] if len(args) > 2 else None stride = args[3] if len(args) > 3 else 1 padding = args[4] if len(args) > 4 else 0 - dilation = args[5] if len(args) > 5 else 1 + output_padding = args[5] if len(args) > 5 else 0 groups = args[6] if len(args) > 6 else 1 + dilation = args[7] if len(args) > 7 else 1 return self._conv_transpose2d_impl( x, weight, @@ -822,6 +829,7 @@ def _conv_transpose2d(self, node: fx.Node) -> relax.Var: padding=padding, dilation=dilation, groups=groups, + output_padding=output_padding, ) def _conv1d_impl( diff --git a/python/tvm/relax/frontend/torch/fx_translator.py b/python/tvm/relax/frontend/torch/fx_translator.py index fc12f877e012..97a2b51e496a 100644 --- a/python/tvm/relax/frontend/torch/fx_translator.py +++ b/python/tvm/relax/frontend/torch/fx_translator.py @@ -294,6 +294,7 @@ def _conv_transpose1d_module(self, node: fx.Node) -> relax.Var: padding=module.padding, dilation=module.dilation, groups=module.groups, + output_padding=module.output_padding, ) def _conv_transpose2d_module(self, node: fx.Node) -> relax.Var: @@ -310,6 +311,7 @@ def _conv_transpose2d_module(self, node: fx.Node) -> relax.Var: padding=module.padding, dilation=module.dilation, groups=module.groups, + output_padding=module.output_padding, ) def _conv1d_module(self, node: fx.Node) -> relax.Var: diff --git a/tests/python/relax/test_frontend_from_exported_program.py b/tests/python/relax/test_frontend_from_exported_program.py index f01da1336eda..80da6fcf19ad 100644 --- a/tests/python/relax/test_frontend_from_exported_program.py +++ b/tests/python/relax/test_frontend_from_exported_program.py @@ -1870,9 +1870,10 @@ def main( w1, strides=[1], padding=[0, 0], + output_padding=[0], dilation=[1], data_layout="NCW", - kernel_layout="OIW", + kernel_layout="IOW", out_layout="NCW", out_dtype="float32", ) @@ -1904,9 +1905,10 @@ def main( w1, strides=[1], padding=[0, 0], + output_padding=[0], dilation=[1], data_layout="NCW", - kernel_layout="OIW", + kernel_layout="IOW", out_layout="NCW", out_dtype="float32", ) @@ -1962,9 +1964,10 @@ def main( w1, strides=[1, 1], padding=[0, 0, 0, 0], + output_padding=[0, 0], dilation=[1, 1], data_layout="NCHW", - kernel_layout="OIHW", + kernel_layout="IOHW", out_layout="NCHW", out_dtype="float32", ) @@ -1996,9 +1999,10 @@ def main( w1, strides=[1, 1], padding=[0, 0, 0, 0], + output_padding=[0, 0], dilation=[1, 1], data_layout="NCHW", - kernel_layout="OIHW", + kernel_layout="IOHW", out_layout="NCHW", out_dtype="float32", ) diff --git a/tests/python/relax/test_frontend_from_fx.py b/tests/python/relax/test_frontend_from_fx.py index f507071b0734..7fb2bed328a8 100644 --- a/tests/python/relax/test_frontend_from_fx.py +++ b/tests/python/relax/test_frontend_from_fx.py @@ -168,9 +168,10 @@ def main( w1, strides=[1], padding=[0, 0], + output_padding=[0], dilation=[1], data_layout="NCW", - kernel_layout="OIW", + kernel_layout="IOW", out_layout="NCW", out_dtype="float32", ) @@ -202,9 +203,10 @@ def main( w1, strides=[1], padding=[0, 0], + output_padding=[0], dilation=[1], data_layout="NCW", - kernel_layout="OIW", + kernel_layout="IOW", out_layout="NCW", out_dtype="float32", ) @@ -352,9 +354,10 @@ def main( w1, strides=[1, 1], padding=[0, 0, 0, 0], + output_padding=[0, 0], dilation=[1, 1], data_layout="NCHW", - kernel_layout="OIHW", + kernel_layout="IOHW", out_layout="NCHW", out_dtype="float32", ) @@ -386,9 +389,10 @@ def main( w1, strides=[1, 1], padding=[0, 0, 0, 0], + output_padding=[0, 0], dilation=[1, 1], data_layout="NCHW", - kernel_layout="OIHW", + kernel_layout="IOHW", out_layout="NCHW", out_dtype="float32", )