From 8ca8df6e4f818b8334d55943ce985834b0ae719c Mon Sep 17 00:00:00 2001 From: Masahiro Hiramori Date: Sun, 1 Sep 2024 17:11:12 +0900 Subject: [PATCH 1/9] add test for functional conv1d --- tests/python/relax/test_frontend_from_fx.py | 13 +++++++++++++ 1 file changed, 13 insertions(+) diff --git a/tests/python/relax/test_frontend_from_fx.py b/tests/python/relax/test_frontend_from_fx.py index c6c4f2597260..19e2211bfd57 100644 --- a/tests/python/relax/test_frontend_from_fx.py +++ b/tests/python/relax/test_frontend_from_fx.py @@ -48,6 +48,15 @@ def __init__(self): def forward(self, input): return self.conv(input) + class Conv1D1Func(Module): + def __init__(self): + super().__init__() + self.weight = torch.randn(size=[6, 3, 7]) + self.bias = torch.randn(size=[6]) + + def forward(self, input): + return torch.nn.functional.conv1d(input, self.weight, self.bias) + @tvm.script.ir_module class expected1: @R.function @@ -113,6 +122,10 @@ def main( binding = {"w1": model.conv.weight.detach().numpy(), "w2": model.conv.bias.detach().numpy()} verify_model(model, input_info, binding, expected1) + model = Conv1D1Func() + binding = {"w1": model.weight.detach().numpy(), "w2": model.bias.detach().numpy()} + verify_model(model, input_info, binding, expected1) + model = Conv1D2() binding = {"w1": model.conv.weight.detach().numpy()} verify_model(model, input_info, binding, expected2) From aa31762df4017b7fa3f5d222f24ad5a3735f038e Mon Sep 17 00:00:00 2001 From: Masahiro Hiramori Date: Sun, 1 Sep 2024 17:12:38 +0900 Subject: [PATCH 2/9] add support for functional conv1d --- .../tvm/relax/frontend/torch/fx_translator.py | 66 +++++++++++++++---- 1 file changed, 53 insertions(+), 13 deletions(-) diff --git a/python/tvm/relax/frontend/torch/fx_translator.py b/python/tvm/relax/frontend/torch/fx_translator.py index 676f63b5c359..613a09ff0ca3 100644 --- a/python/tvm/relax/frontend/torch/fx_translator.py +++ b/python/tvm/relax/frontend/torch/fx_translator.py @@ -740,34 +740,54 @@ def _linear_functional(self, node: fx.node.Node) -> relax.Var: bias = args[2] if len(args) > 2 else None return self.block_builder.emit(relax.op.linear(x, weight, bias, "float32")) - def _conv1d(self, node: fx.node.Node) -> relax.Var: - x = self.env[node.args[0]] - module = self.named_modules[node.target] - weight = self.params[module.weight] - + def _conv1d_impl( + self, + x: relax.Expr, + weight: relax.Expr, + bias: Optional[relax.Expr], + strides: Optional[Tuple], + padding: Optional[Tuple], + dilation: Optional[Tuple], + groups: Optional[Tuple], + ) -> relax.Var: conv1d = self.block_builder.emit( relax.op.nn.conv1d( x, weight, - strides=module.stride, - padding=module.padding, - dilation=module.dilation, - groups=module.groups, + strides=strides, + padding=padding, + dilation=dilation, + groups=groups, data_layout="NCW", kernel_layout="OIW", out_dtype="float32", ) ) - if module.bias is None: + if bias is None: return conv1d - - bias = self.params[module.bias] assert len(self.shape_of(bias)) == 1 bias = relax.op.reshape(bias, (1, -1, 1)) - return self.block_builder.emit(relax.op.add(conv1d, bias)) + def _conv1d(self, node: fx.node.Node) -> relax.Var: + x = self.env[node.args[0]] + module = self.named_modules[node.target] + weight = self.params[module.weight] + bias = None + if module.bias is not None: + bias = self.params[module.bias] + + return self._conv1d_impl( + x, + weight, + bias=bias, + strides=module.stride, + padding=module.padding, + dilation=module.dilation, + groups=module.groups, + ) + def _conv3d(self, node: fx.node.Node) -> relax.Var: x = self.env[node.args[0]] module = self.named_modules[node.target] @@ -826,6 +846,25 @@ def _conv2d_impl( bias = relax.op.reshape(bias, (1, -1, 1, 1)) return self.block_builder.emit(relax.op.add(conv2d, bias)) + def _conv1d_functional(self, node: fx.node.Node) -> relax.Var: + args = self.retrieve_args(node) + x = args[0] + weight = args[1] + bias = args[2] if len(args) > 2 else None + stride = args[3] if len(args) > 3 else 1 + padding = args[4] if len(args) > 4 else 0 + dilation = args[5] if len(args) > 5 else 1 + groups = args[6] if len(args) > 6 else 1 + return self._conv1d_impl( + x, + weight, + bias=bias, + strides=stride, + padding=padding, + dilation=dilation, + groups=groups, + ) + def _conv1d_transpose(self, node: fx.node.Node) -> relax.Var: x = self.env[node.args[0]] module = self.named_modules[node.target] @@ -1482,6 +1521,7 @@ def create_convert_map(self): "type": self._type, "astype": self._type, "matmul": self._matmul, + "conv1d": self._conv1d_functional, "conv2d": self._conv2d_functional, "linear": self._linear_functional, "addmm": self._addmm, From ba14bd1be12d24cb334522a04f2f845138114d35 Mon Sep 17 00:00:00 2001 From: Masahiro Hiramori Date: Sun, 1 Sep 2024 17:14:03 +0900 Subject: [PATCH 3/9] cleanup conv1d --- .../tvm/relax/frontend/torch/fx_translator.py | 126 +++++++++--------- 1 file changed, 63 insertions(+), 63 deletions(-) diff --git a/python/tvm/relax/frontend/torch/fx_translator.py b/python/tvm/relax/frontend/torch/fx_translator.py index 613a09ff0ca3..de8aa9528b07 100644 --- a/python/tvm/relax/frontend/torch/fx_translator.py +++ b/python/tvm/relax/frontend/torch/fx_translator.py @@ -788,33 +788,52 @@ def _conv1d(self, node: fx.node.Node) -> relax.Var: groups=module.groups, ) - def _conv3d(self, node: fx.node.Node) -> relax.Var: + def _conv1d_functional(self, node: fx.node.Node) -> relax.Var: + args = self.retrieve_args(node) + x = args[0] + weight = args[1] + bias = args[2] if len(args) > 2 else None + stride = args[3] if len(args) > 3 else 1 + padding = args[4] if len(args) > 4 else 0 + dilation = args[5] if len(args) > 5 else 1 + groups = args[6] if len(args) > 6 else 1 + return self._conv1d_impl( + x, + weight, + bias=bias, + strides=stride, + padding=padding, + dilation=dilation, + groups=groups, + ) + + def _conv1d_transpose(self, node: fx.node.Node) -> relax.Var: x = self.env[node.args[0]] module = self.named_modules[node.target] weight = self.params[module.weight] - conv3d = self.block_builder.emit( - relax.op.nn.conv3d( + conv1d_transpose = self.block_builder.emit( + relax.op.nn.conv1d_transpose( x, weight, strides=module.stride, padding=module.padding, dilation=module.dilation, groups=module.groups, - data_layout="NCDHW", - kernel_layout="OIDHW", + data_layout="NCW", + kernel_layout="OIW", out_dtype="float32", ) ) if module.bias is None: - return conv3d + return conv1d_transpose bias = self.params[module.bias] assert len(self.shape_of(bias)) == 1 - bias = relax.op.reshape(bias, (1, -1, 1, 1, 1)) + bias = relax.op.reshape(bias, (1, -1, 1)) - return self.block_builder.emit(relax.op.add(conv3d, bias)) + return self.block_builder.emit(relax.op.add(conv1d_transpose, bias)) def _conv2d_impl( self, @@ -846,7 +865,25 @@ def _conv2d_impl( bias = relax.op.reshape(bias, (1, -1, 1, 1)) return self.block_builder.emit(relax.op.add(conv2d, bias)) - def _conv1d_functional(self, node: fx.node.Node) -> relax.Var: + def _conv2d(self, node: fx.node.Node) -> relax.Var: + x = self.env[node.args[0]] + module = self.named_modules[node.target] + weight = self.params[module.weight] + bias = None + if module.bias is not None: + bias = self.params[module.bias] + + return self._conv2d_impl( + x, + weight, + bias=bias, + strides=module.stride, + padding=module.padding, + dilation=module.dilation, + groups=module.groups, + ) + + def _conv2d_functional(self, node: fx.node.Node) -> relax.Var: args = self.retrieve_args(node) x = args[0] weight = args[1] @@ -855,7 +892,7 @@ def _conv1d_functional(self, node: fx.node.Node) -> relax.Var: padding = args[4] if len(args) > 4 else 0 dilation = args[5] if len(args) > 5 else 1 groups = args[6] if len(args) > 6 else 1 - return self._conv1d_impl( + return self._conv2d_impl( x, weight, bias=bias, @@ -865,98 +902,61 @@ def _conv1d_functional(self, node: fx.node.Node) -> relax.Var: groups=groups, ) - def _conv1d_transpose(self, node: fx.node.Node) -> relax.Var: + def _conv2d_transpose(self, node: fx.node.Node) -> relax.Var: x = self.env[node.args[0]] module = self.named_modules[node.target] weight = self.params[module.weight] - conv1d_transpose = self.block_builder.emit( - relax.op.nn.conv1d_transpose( + conv2d_transpose = self.block_builder.emit( + relax.op.nn.conv2d_transpose( x, weight, strides=module.stride, padding=module.padding, dilation=module.dilation, groups=module.groups, - data_layout="NCW", - kernel_layout="OIW", + data_layout="NCHW", + kernel_layout="OIHW", out_dtype="float32", ) ) if module.bias is None: - return conv1d_transpose + return conv2d_transpose bias = self.params[module.bias] assert len(self.shape_of(bias)) == 1 - bias = relax.op.reshape(bias, (1, -1, 1)) + bias = relax.op.reshape(bias, (1, -1, 1, 1)) - return self.block_builder.emit(relax.op.add(conv1d_transpose, bias)) + return self.block_builder.emit(relax.op.add(conv2d_transpose, bias)) - def _conv2d_transpose(self, node: fx.node.Node) -> relax.Var: + def _conv3d(self, node: fx.node.Node) -> relax.Var: x = self.env[node.args[0]] module = self.named_modules[node.target] weight = self.params[module.weight] - conv2d_transpose = self.block_builder.emit( - relax.op.nn.conv2d_transpose( + conv3d = self.block_builder.emit( + relax.op.nn.conv3d( x, weight, strides=module.stride, padding=module.padding, dilation=module.dilation, groups=module.groups, - data_layout="NCHW", - kernel_layout="OIHW", + data_layout="NCDHW", + kernel_layout="OIDHW", out_dtype="float32", ) ) if module.bias is None: - return conv2d_transpose + return conv3d bias = self.params[module.bias] assert len(self.shape_of(bias)) == 1 - bias = relax.op.reshape(bias, (1, -1, 1, 1)) - - return self.block_builder.emit(relax.op.add(conv2d_transpose, bias)) - - def _conv2d(self, node: fx.node.Node) -> relax.Var: - x = self.env[node.args[0]] - module = self.named_modules[node.target] - weight = self.params[module.weight] - bias = None - if module.bias is not None: - bias = self.params[module.bias] - - return self._conv2d_impl( - x, - weight, - bias=bias, - strides=module.stride, - padding=module.padding, - dilation=module.dilation, - groups=module.groups, - ) + bias = relax.op.reshape(bias, (1, -1, 1, 1, 1)) - def _conv2d_functional(self, node: fx.node.Node) -> relax.Var: - args = self.retrieve_args(node) - x = args[0] - weight = args[1] - bias = args[2] if len(args) > 2 else None - stride = args[3] if len(args) > 3 else 1 - padding = args[4] if len(args) > 4 else 0 - dilation = args[5] if len(args) > 5 else 1 - groups = args[6] if len(args) > 6 else 1 - return self._conv2d_impl( - x, - weight, - bias=bias, - strides=stride, - padding=padding, - dilation=dilation, - groups=groups, - ) + return self.block_builder.emit(relax.op.add(conv3d, bias)) def _max_pool2d(self, node: fx.node.Node) -> relax.Var: x = self.env[node.args[0]] From 968e41adce4235a4255a3faa3c713b87a176a443 Mon Sep 17 00:00:00 2001 From: Masahiro Hiramori Date: Sun, 1 Sep 2024 17:24:37 +0900 Subject: [PATCH 4/9] add test for functional conv_transpose1d --- tests/python/relax/test_frontend_from_fx.py | 13 +++++++++++++ 1 file changed, 13 insertions(+) diff --git a/tests/python/relax/test_frontend_from_fx.py b/tests/python/relax/test_frontend_from_fx.py index 19e2211bfd57..4c449c126ff6 100644 --- a/tests/python/relax/test_frontend_from_fx.py +++ b/tests/python/relax/test_frontend_from_fx.py @@ -140,6 +140,15 @@ def __init__(self): def forward(self, input): return self.conv(input) + class ConvTranspose1d1Func(Module): + def __init__(self): + super().__init__() + self.weight = torch.randn(size=[6, 6, 3]) + self.bias = torch.randn(size=[6]) + + def forward(self, input): + return torch.nn.functional.conv_transpose1d(input, self.weight, self.bias) + @tvm.script.ir_module class expected1: @R.function @@ -205,6 +214,10 @@ def main( binding = {"w1": model.conv.weight.detach().numpy(), "w2": model.conv.bias.detach().numpy()} verify_model(model, input_info, binding, expected1) + model = ConvTranspose1d1Func() + binding = {"w1": model.weight.detach().numpy(), "w2": model.bias.detach().numpy()} + verify_model(model, input_info, binding, expected1) + model = ConvTranspose1d2() binding = {"w1": model.conv.weight.detach().numpy()} verify_model(model, input_info, binding, expected2) From 2200fed1acead591755f4ec5927d817f90bf6919 Mon Sep 17 00:00:00 2001 From: Masahiro Hiramori Date: Sun, 1 Sep 2024 17:24:48 +0900 Subject: [PATCH 5/9] add support for functional conv_transpose1d --- .../tvm/relax/frontend/torch/fx_translator.py | 65 +++++++++++++++---- 1 file changed, 53 insertions(+), 12 deletions(-) diff --git a/python/tvm/relax/frontend/torch/fx_translator.py b/python/tvm/relax/frontend/torch/fx_translator.py index de8aa9528b07..bbf22e27c8b9 100644 --- a/python/tvm/relax/frontend/torch/fx_translator.py +++ b/python/tvm/relax/frontend/torch/fx_translator.py @@ -807,34 +807,74 @@ def _conv1d_functional(self, node: fx.node.Node) -> relax.Var: groups=groups, ) - def _conv1d_transpose(self, node: fx.node.Node) -> relax.Var: - x = self.env[node.args[0]] - module = self.named_modules[node.target] - weight = self.params[module.weight] - + def _conv1d_transpose_impl( + self, + x: relax.Expr, + weight: relax.Expr, + bias: Optional[relax.Expr], + strides: Optional[Tuple], + padding: Optional[Tuple], + dilation: Optional[Tuple], + groups: Optional[Tuple], + ) -> relax.Var: conv1d_transpose = self.block_builder.emit( relax.op.nn.conv1d_transpose( x, weight, - strides=module.stride, - padding=module.padding, - dilation=module.dilation, - groups=module.groups, + strides=strides, + padding=padding, + dilation=dilation, + groups=groups, data_layout="NCW", kernel_layout="OIW", out_dtype="float32", ) ) - if module.bias is None: + if bias is None: return conv1d_transpose - bias = self.params[module.bias] assert len(self.shape_of(bias)) == 1 bias = relax.op.reshape(bias, (1, -1, 1)) - return self.block_builder.emit(relax.op.add(conv1d_transpose, bias)) + def _conv1d_transpose(self, node: fx.node.Node) -> relax.Var: + x = self.env[node.args[0]] + module = self.named_modules[node.target] + weight = self.params[module.weight] + bias = None + if module.bias is not None: + bias = self.params[module.bias] + + return self._conv1d_transpose_impl( + x, + weight, + bias=bias, + strides=module.stride, + padding=module.padding, + dilation=module.dilation, + groups=module.groups, + ) + + def _conv1d_transpose_functional(self, node: fx.node.Node) -> relax.Var: + args = self.retrieve_args(node) + x = args[0] + weight = args[1] + bias = args[2] if len(args) > 2 else None + stride = args[3] if len(args) > 3 else 1 + padding = args[4] if len(args) > 4 else 0 + dilation = args[5] if len(args) > 5 else 1 + groups = args[6] if len(args) > 6 else 1 + return self._conv1d_transpose_impl( + x, + weight, + bias=bias, + strides=stride, + padding=padding, + dilation=dilation, + groups=groups, + ) + def _conv2d_impl( self, x: relax.Expr, @@ -1522,6 +1562,7 @@ def create_convert_map(self): "astype": self._type, "matmul": self._matmul, "conv1d": self._conv1d_functional, + "conv_transpose1d": self._conv1d_transpose_functional, "conv2d": self._conv2d_functional, "linear": self._linear_functional, "addmm": self._addmm, From a5c18afb9a70bc8a65fb76e011d8b82bd30c3ff5 Mon Sep 17 00:00:00 2001 From: Masahiro Hiramori Date: Sun, 1 Sep 2024 17:35:01 +0900 Subject: [PATCH 6/9] add test for functional conv_transpose2d --- tests/python/relax/test_frontend_from_fx.py | 13 +++++++++++++ 1 file changed, 13 insertions(+) diff --git a/tests/python/relax/test_frontend_from_fx.py b/tests/python/relax/test_frontend_from_fx.py index 4c449c126ff6..577ecc61ca29 100644 --- a/tests/python/relax/test_frontend_from_fx.py +++ b/tests/python/relax/test_frontend_from_fx.py @@ -324,6 +324,15 @@ def __init__(self): def forward(self, input): return self.conv(input) + class ConvTranspose2d1Func(Module): + def __init__(self): + super().__init__() + self.weight = torch.randn(size=[3, 3, 7, 7]) + self.bias = torch.randn(size=[3]) + + def forward(self, input): + return torch.nn.functional.conv_transpose2d(input, self.weight, self.bias) + @tvm.script.ir_module class expected1: @R.function @@ -389,6 +398,10 @@ def main( binding = {"w1": model.conv.weight.detach().numpy(), "w2": model.conv.bias.detach().numpy()} verify_model(model, input_info, binding, expected1) + model = ConvTranspose2d1Func() + binding = {"w1": model.weight.detach().numpy(), "w2": model.bias.detach().numpy()} + verify_model(model, input_info, binding, expected1) + model = ConvTranspose2d2() binding = {"w1": model.conv.weight.detach().numpy()} verify_model(model, input_info, binding, expected2) From 52c9957b559a6caf80b285a3af614fe1517ae1be Mon Sep 17 00:00:00 2001 From: Masahiro Hiramori Date: Sun, 1 Sep 2024 17:35:32 +0900 Subject: [PATCH 7/9] add support for functional conv_transpose2d --- .../tvm/relax/frontend/torch/fx_translator.py | 65 +++++++++++++++---- 1 file changed, 53 insertions(+), 12 deletions(-) diff --git a/python/tvm/relax/frontend/torch/fx_translator.py b/python/tvm/relax/frontend/torch/fx_translator.py index bbf22e27c8b9..e9dd225d7c67 100644 --- a/python/tvm/relax/frontend/torch/fx_translator.py +++ b/python/tvm/relax/frontend/torch/fx_translator.py @@ -942,34 +942,74 @@ def _conv2d_functional(self, node: fx.node.Node) -> relax.Var: groups=groups, ) - def _conv2d_transpose(self, node: fx.node.Node) -> relax.Var: - x = self.env[node.args[0]] - module = self.named_modules[node.target] - weight = self.params[module.weight] - + def _conv2d_transpose_impl( + self, + x: relax.Expr, + weight: relax.Expr, + bias: Optional[relax.Expr], + strides: Optional[Tuple], + padding: Optional[Tuple], + dilation: Optional[Tuple], + groups: Optional[Tuple], + ) -> relax.Var: conv2d_transpose = self.block_builder.emit( relax.op.nn.conv2d_transpose( x, weight, - strides=module.stride, - padding=module.padding, - dilation=module.dilation, - groups=module.groups, + strides=strides, + padding=padding, + dilation=dilation, + groups=groups, data_layout="NCHW", kernel_layout="OIHW", out_dtype="float32", ) ) - if module.bias is None: + if bias is None: return conv2d_transpose - bias = self.params[module.bias] assert len(self.shape_of(bias)) == 1 bias = relax.op.reshape(bias, (1, -1, 1, 1)) - return self.block_builder.emit(relax.op.add(conv2d_transpose, bias)) + def _conv2d_transpose(self, node: fx.node.Node) -> relax.Var: + x = self.env[node.args[0]] + module = self.named_modules[node.target] + weight = self.params[module.weight] + bias = None + if module.bias is not None: + bias = self.params[module.bias] + + return self._conv2d_transpose_impl( + x, + weight, + bias=bias, + strides=module.stride, + padding=module.padding, + dilation=module.dilation, + groups=module.groups, + ) + + def _conv2d_transpose_functional(self, node: fx.node.Node) -> relax.Var: + args = self.retrieve_args(node) + x = args[0] + weight = args[1] + bias = args[2] if len(args) > 2 else None + stride = args[3] if len(args) > 3 else 1 + padding = args[4] if len(args) > 4 else 0 + dilation = args[5] if len(args) > 5 else 1 + groups = args[6] if len(args) > 6 else 1 + return self._conv2d_transpose_impl( + x, + weight, + bias=bias, + strides=stride, + padding=padding, + dilation=dilation, + groups=groups, + ) + def _conv3d(self, node: fx.node.Node) -> relax.Var: x = self.env[node.args[0]] module = self.named_modules[node.target] @@ -1564,6 +1604,7 @@ def create_convert_map(self): "conv1d": self._conv1d_functional, "conv_transpose1d": self._conv1d_transpose_functional, "conv2d": self._conv2d_functional, + "conv_transpose2d": self._conv2d_transpose_functional, "linear": self._linear_functional, "addmm": self._addmm, "baddbmm": self._baddbmm, From aae526dc50acdc28c81bb1ed1e72dcedef82b618 Mon Sep 17 00:00:00 2001 From: Masahiro Hiramori Date: Sun, 1 Sep 2024 17:40:46 +0900 Subject: [PATCH 8/9] add test for functional conv3d --- tests/python/relax/test_frontend_from_fx.py | 13 +++++++++++++ 1 file changed, 13 insertions(+) diff --git a/tests/python/relax/test_frontend_from_fx.py b/tests/python/relax/test_frontend_from_fx.py index 577ecc61ca29..e191775a63b2 100644 --- a/tests/python/relax/test_frontend_from_fx.py +++ b/tests/python/relax/test_frontend_from_fx.py @@ -416,6 +416,15 @@ def __init__(self): def forward(self, input): return self.conv(input) + class Conv3D1Func(Module): + def __init__(self): + super().__init__() + self.weight = torch.randn(size=[6, 3, 7, 7, 7]) + self.bias = torch.randn(size=[6]) + + def forward(self, input): + return torch.nn.functional.conv3d(input, self.weight, self.bias) + @tvm.script.ir_module class expected1: @R.function @@ -481,6 +490,10 @@ def main( binding = {"w1": model.conv.weight.detach().numpy(), "w2": model.conv.bias.detach().numpy()} verify_model(model, input_info, binding, expected1) + model = Conv3D1Func() + binding = {"w1": model.weight.detach().numpy(), "w2": model.bias.detach().numpy()} + verify_model(model, input_info, binding, expected1) + model = Conv3D2() binding = {"w1": model.conv.weight.detach().numpy()} verify_model(model, input_info, binding, expected2) From cb5f27103debffffac036dd8feac7ad6e7f7d049 Mon Sep 17 00:00:00 2001 From: Masahiro Hiramori Date: Sun, 1 Sep 2024 17:40:55 +0900 Subject: [PATCH 9/9] add support for functional conv3d --- .../tvm/relax/frontend/torch/fx_translator.py | 66 +++++++++++++++---- 1 file changed, 53 insertions(+), 13 deletions(-) diff --git a/python/tvm/relax/frontend/torch/fx_translator.py b/python/tvm/relax/frontend/torch/fx_translator.py index e9dd225d7c67..245bb4cffb57 100644 --- a/python/tvm/relax/frontend/torch/fx_translator.py +++ b/python/tvm/relax/frontend/torch/fx_translator.py @@ -1010,34 +1010,73 @@ def _conv2d_transpose_functional(self, node: fx.node.Node) -> relax.Var: groups=groups, ) - def _conv3d(self, node: fx.node.Node) -> relax.Var: - x = self.env[node.args[0]] - module = self.named_modules[node.target] - weight = self.params[module.weight] - + def _conv3d_impl( + self, + x: relax.Expr, + weight: relax.Expr, + bias: Optional[relax.Expr], + strides: Optional[Tuple], + padding: Optional[Tuple], + dilation: Optional[Tuple], + groups: Optional[Tuple], + ): conv3d = self.block_builder.emit( relax.op.nn.conv3d( x, weight, - strides=module.stride, - padding=module.padding, - dilation=module.dilation, - groups=module.groups, + strides=strides, + padding=padding, + dilation=dilation, + groups=groups, data_layout="NCDHW", kernel_layout="OIDHW", out_dtype="float32", ) ) - if module.bias is None: + if bias is None: return conv3d - - bias = self.params[module.bias] assert len(self.shape_of(bias)) == 1 bias = relax.op.reshape(bias, (1, -1, 1, 1, 1)) - return self.block_builder.emit(relax.op.add(conv3d, bias)) + def _conv3d(self, node: fx.node.Node) -> relax.Var: + x = self.env[node.args[0]] + module = self.named_modules[node.target] + weight = self.params[module.weight] + bias = None + if module.bias is not None: + bias = self.params[module.bias] + + return self._conv3d_impl( + x, + weight, + bias=bias, + strides=module.stride, + padding=module.padding, + dilation=module.dilation, + groups=module.groups, + ) + + def _conv3d_functional(self, node: fx.node.Node) -> relax.Var: + args = self.retrieve_args(node) + x = args[0] + weight = args[1] + bias = args[2] if len(args) > 2 else None + stride = args[3] if len(args) > 3 else 1 + padding = args[4] if len(args) > 4 else 0 + dilation = args[5] if len(args) > 5 else 1 + groups = args[6] if len(args) > 6 else 1 + return self._conv3d_impl( + x, + weight, + bias=bias, + strides=stride, + padding=padding, + dilation=dilation, + groups=groups, + ) + def _max_pool2d(self, node: fx.node.Node) -> relax.Var: x = self.env[node.args[0]] if node.target in self.named_modules: @@ -1605,6 +1644,7 @@ def create_convert_map(self): "conv_transpose1d": self._conv1d_transpose_functional, "conv2d": self._conv2d_functional, "conv_transpose2d": self._conv2d_transpose_functional, + "conv3d": self._conv3d_functional, "linear": self._linear_functional, "addmm": self._addmm, "baddbmm": self._baddbmm,