From d8ec4fa9284937ad351bd513e9a5e692c4dfece7 Mon Sep 17 00:00:00 2001 From: Kavin-mcw Date: Tue, 13 May 2025 13:36:27 +0530 Subject: [PATCH 1/4] Bugfix conv_transpose1d and conv_transpose2d --- .../frontend/torch/base_fx_graph_translator.py | 16 ++++++++++++---- python/tvm/relax/frontend/torch/fx_translator.py | 2 ++ 2 files changed, 14 insertions(+), 4 deletions(-) diff --git a/python/tvm/relax/frontend/torch/base_fx_graph_translator.py b/python/tvm/relax/frontend/torch/base_fx_graph_translator.py index f789eb8af35b..55019d832373 100644 --- a/python/tvm/relax/frontend/torch/base_fx_graph_translator.py +++ b/python/tvm/relax/frontend/torch/base_fx_graph_translator.py @@ -733,6 +733,7 @@ def _conv_transpose1d_impl( padding: Optional[Tuple], dilation: Optional[Tuple], groups: Optional[Tuple], + output_padding: Optional[Tuple], ) -> relax.Var: conv1d_transpose = self.block_builder.emit( relax.op.nn.conv1d_transpose( @@ -742,8 +743,9 @@ def _conv_transpose1d_impl( padding=padding, dilation=dilation, groups=groups, + output_padding = output_padding, data_layout="NCW", - kernel_layout="OIW", + kernel_layout="IOW", out_dtype="float32", ) ) @@ -762,8 +764,9 @@ def _conv_transpose1d(self, node: fx.Node) -> relax.Var: bias = args[2] if len(args) > 2 else None stride = args[3] if len(args) > 3 else 1 padding = args[4] if len(args) > 4 else 0 - dilation = args[5] if len(args) > 5 else 1 + output_padding = args[5] if len(args) > 5 else 0 groups = args[6] if len(args) > 6 else 1 + dilation = args[7] if len(args) > 7 else 1 return self._conv_transpose1d_impl( x, weight, @@ -772,6 +775,7 @@ def _conv_transpose1d(self, node: fx.Node) -> relax.Var: padding=padding, dilation=dilation, groups=groups, + output_padding = output_padding ) def _conv_transpose2d_impl( @@ -783,6 +787,7 @@ def _conv_transpose2d_impl( padding: Optional[Tuple], dilation: Optional[Tuple], groups: Optional[Tuple], + output_padding: Optional[Tuple], ) -> relax.Var: conv2d_transpose = self.block_builder.emit( relax.op.nn.conv2d_transpose( @@ -792,8 +797,9 @@ def _conv_transpose2d_impl( padding=padding, dilation=dilation, groups=groups, + output_padding = output_padding, data_layout="NCHW", - kernel_layout="OIHW", + kernel_layout="IOHW", out_dtype="float32", ) ) @@ -812,8 +818,9 @@ def _conv_transpose2d(self, node: fx.Node) -> relax.Var: bias = args[2] if len(args) > 2 else None stride = args[3] if len(args) > 3 else 1 padding = args[4] if len(args) > 4 else 0 - dilation = args[5] if len(args) > 5 else 1 + output_padding = args[5] if len(args) > 5 else 0 groups = args[6] if len(args) > 6 else 1 + dilation = args[7] if len(args) > 7 else 1 return self._conv_transpose2d_impl( x, weight, @@ -822,6 +829,7 @@ def _conv_transpose2d(self, node: fx.Node) -> relax.Var: padding=padding, dilation=dilation, groups=groups, + output_padding=output_padding, ) def _conv1d_impl( diff --git a/python/tvm/relax/frontend/torch/fx_translator.py b/python/tvm/relax/frontend/torch/fx_translator.py index fc12f877e012..8db560b0cc3e 100644 --- a/python/tvm/relax/frontend/torch/fx_translator.py +++ b/python/tvm/relax/frontend/torch/fx_translator.py @@ -294,6 +294,7 @@ def _conv_transpose1d_module(self, node: fx.Node) -> relax.Var: padding=module.padding, dilation=module.dilation, groups=module.groups, + output_padding=module.output_padding ) def _conv_transpose2d_module(self, node: fx.Node) -> relax.Var: @@ -310,6 +311,7 @@ def _conv_transpose2d_module(self, node: fx.Node) -> relax.Var: padding=module.padding, dilation=module.dilation, groups=module.groups, + output_padding=module.output_padding ) def _conv1d_module(self, node: fx.Node) -> relax.Var: From 49b9353ed61ec80b775f5c855f66da0d82ebd22a Mon Sep 17 00:00:00 2001 From: Kavin-mcw Date: Tue, 13 May 2025 14:03:09 +0530 Subject: [PATCH 2/4] fix lint issue --- python/tvm/relax/frontend/torch/base_fx_graph_translator.py | 6 +++--- python/tvm/relax/frontend/torch/fx_translator.py | 4 ++-- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/python/tvm/relax/frontend/torch/base_fx_graph_translator.py b/python/tvm/relax/frontend/torch/base_fx_graph_translator.py index 55019d832373..5c8d7095e511 100644 --- a/python/tvm/relax/frontend/torch/base_fx_graph_translator.py +++ b/python/tvm/relax/frontend/torch/base_fx_graph_translator.py @@ -743,7 +743,7 @@ def _conv_transpose1d_impl( padding=padding, dilation=dilation, groups=groups, - output_padding = output_padding, + output_padding=output_padding, data_layout="NCW", kernel_layout="IOW", out_dtype="float32", @@ -775,7 +775,7 @@ def _conv_transpose1d(self, node: fx.Node) -> relax.Var: padding=padding, dilation=dilation, groups=groups, - output_padding = output_padding + output_padding=output_padding, ) def _conv_transpose2d_impl( @@ -797,7 +797,7 @@ def _conv_transpose2d_impl( padding=padding, dilation=dilation, groups=groups, - output_padding = output_padding, + output_padding=output_padding, data_layout="NCHW", kernel_layout="IOHW", out_dtype="float32", diff --git a/python/tvm/relax/frontend/torch/fx_translator.py b/python/tvm/relax/frontend/torch/fx_translator.py index 8db560b0cc3e..97a2b51e496a 100644 --- a/python/tvm/relax/frontend/torch/fx_translator.py +++ b/python/tvm/relax/frontend/torch/fx_translator.py @@ -294,7 +294,7 @@ def _conv_transpose1d_module(self, node: fx.Node) -> relax.Var: padding=module.padding, dilation=module.dilation, groups=module.groups, - output_padding=module.output_padding + output_padding=module.output_padding, ) def _conv_transpose2d_module(self, node: fx.Node) -> relax.Var: @@ -311,7 +311,7 @@ def _conv_transpose2d_module(self, node: fx.Node) -> relax.Var: padding=module.padding, dilation=module.dilation, groups=module.groups, - output_padding=module.output_padding + output_padding=module.output_padding, ) def _conv1d_module(self, node: fx.Node) -> relax.Var: From 2126f8670eb89fc59be98bf40eaa42a0cb1e64f2 Mon Sep 17 00:00:00 2001 From: Kavin-mcw Date: Tue, 13 May 2025 16:49:50 +0530 Subject: [PATCH 3/4] Update tests to reflect changes --- .../relax/test_frontend_from_exported_program.py | 12 ++++++++---- tests/python/relax/test_frontend_from_fx.py | 12 ++++++++---- 2 files changed, 16 insertions(+), 8 deletions(-) diff --git a/tests/python/relax/test_frontend_from_exported_program.py b/tests/python/relax/test_frontend_from_exported_program.py index f01da1336eda..6708cda7c960 100644 --- a/tests/python/relax/test_frontend_from_exported_program.py +++ b/tests/python/relax/test_frontend_from_exported_program.py @@ -1870,9 +1870,10 @@ def main( w1, strides=[1], padding=[0, 0], + output_padding=[0], dilation=[1], data_layout="NCW", - kernel_layout="OIW", + kernel_layout="IOW", out_layout="NCW", out_dtype="float32", ) @@ -1904,9 +1905,10 @@ def main( w1, strides=[1], padding=[0, 0], + output_padding=[0], dilation=[1], data_layout="NCW", - kernel_layout="OIW", + kernel_layout="IOW", out_layout="NCW", out_dtype="float32", ) @@ -1962,9 +1964,10 @@ def main( w1, strides=[1, 1], padding=[0, 0, 0, 0], + output_padding=[0,0], dilation=[1, 1], data_layout="NCHW", - kernel_layout="OIHW", + kernel_layout="IOHW", out_layout="NCHW", out_dtype="float32", ) @@ -1996,9 +1999,10 @@ def main( w1, strides=[1, 1], padding=[0, 0, 0, 0], + output_padding=[0,0], dilation=[1, 1], data_layout="NCHW", - kernel_layout="OIHW", + kernel_layout="IOHW", out_layout="NCHW", out_dtype="float32", ) diff --git a/tests/python/relax/test_frontend_from_fx.py b/tests/python/relax/test_frontend_from_fx.py index f507071b0734..7cb98e2014da 100644 --- a/tests/python/relax/test_frontend_from_fx.py +++ b/tests/python/relax/test_frontend_from_fx.py @@ -168,9 +168,10 @@ def main( w1, strides=[1], padding=[0, 0], + output_padding=[0], dilation=[1], data_layout="NCW", - kernel_layout="OIW", + kernel_layout="IOW", out_layout="NCW", out_dtype="float32", ) @@ -202,9 +203,10 @@ def main( w1, strides=[1], padding=[0, 0], + output_padding=[0], dilation=[1], data_layout="NCW", - kernel_layout="OIW", + kernel_layout="IOW", out_layout="NCW", out_dtype="float32", ) @@ -352,9 +354,10 @@ def main( w1, strides=[1, 1], padding=[0, 0, 0, 0], + output_padding=[0,0], dilation=[1, 1], data_layout="NCHW", - kernel_layout="OIHW", + kernel_layout="IOHW", out_layout="NCHW", out_dtype="float32", ) @@ -386,9 +389,10 @@ def main( w1, strides=[1, 1], padding=[0, 0, 0, 0], + output_padding=[0,0], dilation=[1, 1], data_layout="NCHW", - kernel_layout="OIHW", + kernel_layout="IOHW", out_layout="NCHW", out_dtype="float32", ) From da6d04a2ad28984803a21f8586afc5f13b72812a Mon Sep 17 00:00:00 2001 From: Kavin-mcw Date: Tue, 13 May 2025 17:10:49 +0530 Subject: [PATCH 4/4] lint fix --- tests/python/relax/test_frontend_from_exported_program.py | 4 ++-- tests/python/relax/test_frontend_from_fx.py | 4 ++-- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/tests/python/relax/test_frontend_from_exported_program.py b/tests/python/relax/test_frontend_from_exported_program.py index 6708cda7c960..80da6fcf19ad 100644 --- a/tests/python/relax/test_frontend_from_exported_program.py +++ b/tests/python/relax/test_frontend_from_exported_program.py @@ -1964,7 +1964,7 @@ def main( w1, strides=[1, 1], padding=[0, 0, 0, 0], - output_padding=[0,0], + output_padding=[0, 0], dilation=[1, 1], data_layout="NCHW", kernel_layout="IOHW", @@ -1999,7 +1999,7 @@ def main( w1, strides=[1, 1], padding=[0, 0, 0, 0], - output_padding=[0,0], + output_padding=[0, 0], dilation=[1, 1], data_layout="NCHW", kernel_layout="IOHW", diff --git a/tests/python/relax/test_frontend_from_fx.py b/tests/python/relax/test_frontend_from_fx.py index 7cb98e2014da..7fb2bed328a8 100644 --- a/tests/python/relax/test_frontend_from_fx.py +++ b/tests/python/relax/test_frontend_from_fx.py @@ -354,7 +354,7 @@ def main( w1, strides=[1, 1], padding=[0, 0, 0, 0], - output_padding=[0,0], + output_padding=[0, 0], dilation=[1, 1], data_layout="NCHW", kernel_layout="IOHW", @@ -389,7 +389,7 @@ def main( w1, strides=[1, 1], padding=[0, 0, 0, 0], - output_padding=[0,0], + output_padding=[0, 0], dilation=[1, 1], data_layout="NCHW", kernel_layout="IOHW",