diff --git a/python/tvm/relax/frontend/torch/fx_translator.py b/python/tvm/relax/frontend/torch/fx_translator.py index 676f63b5c359..245bb4cffb57 100644 --- a/python/tvm/relax/frontend/torch/fx_translator.py +++ b/python/tvm/relax/frontend/torch/fx_translator.py @@ -740,61 +740,140 @@ def _linear_functional(self, node: fx.node.Node) -> relax.Var: bias = args[2] if len(args) > 2 else None return self.block_builder.emit(relax.op.linear(x, weight, bias, "float32")) - def _conv1d(self, node: fx.node.Node) -> relax.Var: - x = self.env[node.args[0]] - module = self.named_modules[node.target] - weight = self.params[module.weight] - + def _conv1d_impl( + self, + x: relax.Expr, + weight: relax.Expr, + bias: Optional[relax.Expr], + strides: Optional[Tuple], + padding: Optional[Tuple], + dilation: Optional[Tuple], + groups: Optional[Tuple], + ) -> relax.Var: conv1d = self.block_builder.emit( relax.op.nn.conv1d( x, weight, - strides=module.stride, - padding=module.padding, - dilation=module.dilation, - groups=module.groups, + strides=strides, + padding=padding, + dilation=dilation, + groups=groups, data_layout="NCW", kernel_layout="OIW", out_dtype="float32", ) ) - if module.bias is None: + if bias is None: return conv1d - - bias = self.params[module.bias] assert len(self.shape_of(bias)) == 1 bias = relax.op.reshape(bias, (1, -1, 1)) - return self.block_builder.emit(relax.op.add(conv1d, bias)) - def _conv3d(self, node: fx.node.Node) -> relax.Var: + def _conv1d(self, node: fx.node.Node) -> relax.Var: x = self.env[node.args[0]] module = self.named_modules[node.target] weight = self.params[module.weight] + bias = None + if module.bias is not None: + bias = self.params[module.bias] - conv3d = self.block_builder.emit( - relax.op.nn.conv3d( + return self._conv1d_impl( + x, + weight, + bias=bias, + strides=module.stride, + padding=module.padding, + dilation=module.dilation, + groups=module.groups, + ) + + def _conv1d_functional(self, node: fx.node.Node) -> relax.Var: + args = self.retrieve_args(node) + x = args[0] + weight = args[1] + bias = args[2] if len(args) > 2 else None + stride = args[3] if len(args) > 3 else 1 + padding = args[4] if len(args) > 4 else 0 + dilation = args[5] if len(args) > 5 else 1 + groups = args[6] if len(args) > 6 else 1 + return self._conv1d_impl( + x, + weight, + bias=bias, + strides=stride, + padding=padding, + dilation=dilation, + groups=groups, + ) + + def _conv1d_transpose_impl( + self, + x: relax.Expr, + weight: relax.Expr, + bias: Optional[relax.Expr], + strides: Optional[Tuple], + padding: Optional[Tuple], + dilation: Optional[Tuple], + groups: Optional[Tuple], + ) -> relax.Var: + conv1d_transpose = self.block_builder.emit( + relax.op.nn.conv1d_transpose( x, weight, - strides=module.stride, - padding=module.padding, - dilation=module.dilation, - groups=module.groups, - data_layout="NCDHW", - kernel_layout="OIDHW", + strides=strides, + padding=padding, + dilation=dilation, + groups=groups, + data_layout="NCW", + kernel_layout="OIW", out_dtype="float32", ) ) - if module.bias is None: - return conv3d + if bias is None: + return conv1d_transpose - bias = self.params[module.bias] assert len(self.shape_of(bias)) == 1 - bias = relax.op.reshape(bias, (1, -1, 1, 1, 1)) + bias = relax.op.reshape(bias, (1, -1, 1)) + return self.block_builder.emit(relax.op.add(conv1d_transpose, bias)) - return self.block_builder.emit(relax.op.add(conv3d, bias)) + def _conv1d_transpose(self, node: fx.node.Node) -> relax.Var: + x = self.env[node.args[0]] + module = self.named_modules[node.target] + weight = self.params[module.weight] + bias = None + if module.bias is not None: + bias = self.params[module.bias] + + return self._conv1d_transpose_impl( + x, + weight, + bias=bias, + strides=module.stride, + padding=module.padding, + dilation=module.dilation, + groups=module.groups, + ) + + def _conv1d_transpose_functional(self, node: fx.node.Node) -> relax.Var: + args = self.retrieve_args(node) + x = args[0] + weight = args[1] + bias = args[2] if len(args) > 2 else None + stride = args[3] if len(args) > 3 else 1 + padding = args[4] if len(args) > 4 else 0 + dilation = args[5] if len(args) > 5 else 1 + groups = args[6] if len(args) > 6 else 1 + return self._conv1d_transpose_impl( + x, + weight, + bias=bias, + strides=stride, + padding=padding, + dilation=dilation, + groups=groups, + ) def _conv2d_impl( self, @@ -826,63 +905,142 @@ def _conv2d_impl( bias = relax.op.reshape(bias, (1, -1, 1, 1)) return self.block_builder.emit(relax.op.add(conv2d, bias)) - def _conv1d_transpose(self, node: fx.node.Node) -> relax.Var: + def _conv2d(self, node: fx.node.Node) -> relax.Var: x = self.env[node.args[0]] module = self.named_modules[node.target] weight = self.params[module.weight] + bias = None + if module.bias is not None: + bias = self.params[module.bias] - conv1d_transpose = self.block_builder.emit( - relax.op.nn.conv1d_transpose( + return self._conv2d_impl( + x, + weight, + bias=bias, + strides=module.stride, + padding=module.padding, + dilation=module.dilation, + groups=module.groups, + ) + + def _conv2d_functional(self, node: fx.node.Node) -> relax.Var: + args = self.retrieve_args(node) + x = args[0] + weight = args[1] + bias = args[2] if len(args) > 2 else None + stride = args[3] if len(args) > 3 else 1 + padding = args[4] if len(args) > 4 else 0 + dilation = args[5] if len(args) > 5 else 1 + groups = args[6] if len(args) > 6 else 1 + return self._conv2d_impl( + x, + weight, + bias=bias, + strides=stride, + padding=padding, + dilation=dilation, + groups=groups, + ) + + def _conv2d_transpose_impl( + self, + x: relax.Expr, + weight: relax.Expr, + bias: Optional[relax.Expr], + strides: Optional[Tuple], + padding: Optional[Tuple], + dilation: Optional[Tuple], + groups: Optional[Tuple], + ) -> relax.Var: + conv2d_transpose = self.block_builder.emit( + relax.op.nn.conv2d_transpose( x, weight, - strides=module.stride, - padding=module.padding, - dilation=module.dilation, - groups=module.groups, - data_layout="NCW", - kernel_layout="OIW", + strides=strides, + padding=padding, + dilation=dilation, + groups=groups, + data_layout="NCHW", + kernel_layout="OIHW", out_dtype="float32", ) ) - if module.bias is None: - return conv1d_transpose + if bias is None: + return conv2d_transpose - bias = self.params[module.bias] assert len(self.shape_of(bias)) == 1 - bias = relax.op.reshape(bias, (1, -1, 1)) - - return self.block_builder.emit(relax.op.add(conv1d_transpose, bias)) + bias = relax.op.reshape(bias, (1, -1, 1, 1)) + return self.block_builder.emit(relax.op.add(conv2d_transpose, bias)) def _conv2d_transpose(self, node: fx.node.Node) -> relax.Var: x = self.env[node.args[0]] module = self.named_modules[node.target] weight = self.params[module.weight] + bias = None + if module.bias is not None: + bias = self.params[module.bias] - conv2d_transpose = self.block_builder.emit( - relax.op.nn.conv2d_transpose( + return self._conv2d_transpose_impl( + x, + weight, + bias=bias, + strides=module.stride, + padding=module.padding, + dilation=module.dilation, + groups=module.groups, + ) + + def _conv2d_transpose_functional(self, node: fx.node.Node) -> relax.Var: + args = self.retrieve_args(node) + x = args[0] + weight = args[1] + bias = args[2] if len(args) > 2 else None + stride = args[3] if len(args) > 3 else 1 + padding = args[4] if len(args) > 4 else 0 + dilation = args[5] if len(args) > 5 else 1 + groups = args[6] if len(args) > 6 else 1 + return self._conv2d_transpose_impl( + x, + weight, + bias=bias, + strides=stride, + padding=padding, + dilation=dilation, + groups=groups, + ) + + def _conv3d_impl( + self, + x: relax.Expr, + weight: relax.Expr, + bias: Optional[relax.Expr], + strides: Optional[Tuple], + padding: Optional[Tuple], + dilation: Optional[Tuple], + groups: Optional[Tuple], + ): + conv3d = self.block_builder.emit( + relax.op.nn.conv3d( x, weight, - strides=module.stride, - padding=module.padding, - dilation=module.dilation, - groups=module.groups, - data_layout="NCHW", - kernel_layout="OIHW", + strides=strides, + padding=padding, + dilation=dilation, + groups=groups, + data_layout="NCDHW", + kernel_layout="OIDHW", out_dtype="float32", ) ) - if module.bias is None: - return conv2d_transpose - - bias = self.params[module.bias] + if bias is None: + return conv3d assert len(self.shape_of(bias)) == 1 - bias = relax.op.reshape(bias, (1, -1, 1, 1)) - - return self.block_builder.emit(relax.op.add(conv2d_transpose, bias)) + bias = relax.op.reshape(bias, (1, -1, 1, 1, 1)) + return self.block_builder.emit(relax.op.add(conv3d, bias)) - def _conv2d(self, node: fx.node.Node) -> relax.Var: + def _conv3d(self, node: fx.node.Node) -> relax.Var: x = self.env[node.args[0]] module = self.named_modules[node.target] weight = self.params[module.weight] @@ -890,7 +1048,7 @@ def _conv2d(self, node: fx.node.Node) -> relax.Var: if module.bias is not None: bias = self.params[module.bias] - return self._conv2d_impl( + return self._conv3d_impl( x, weight, bias=bias, @@ -900,7 +1058,7 @@ def _conv2d(self, node: fx.node.Node) -> relax.Var: groups=module.groups, ) - def _conv2d_functional(self, node: fx.node.Node) -> relax.Var: + def _conv3d_functional(self, node: fx.node.Node) -> relax.Var: args = self.retrieve_args(node) x = args[0] weight = args[1] @@ -909,7 +1067,7 @@ def _conv2d_functional(self, node: fx.node.Node) -> relax.Var: padding = args[4] if len(args) > 4 else 0 dilation = args[5] if len(args) > 5 else 1 groups = args[6] if len(args) > 6 else 1 - return self._conv2d_impl( + return self._conv3d_impl( x, weight, bias=bias, @@ -1482,7 +1640,11 @@ def create_convert_map(self): "type": self._type, "astype": self._type, "matmul": self._matmul, + "conv1d": self._conv1d_functional, + "conv_transpose1d": self._conv1d_transpose_functional, "conv2d": self._conv2d_functional, + "conv_transpose2d": self._conv2d_transpose_functional, + "conv3d": self._conv3d_functional, "linear": self._linear_functional, "addmm": self._addmm, "baddbmm": self._baddbmm, diff --git a/tests/python/relax/test_frontend_from_fx.py b/tests/python/relax/test_frontend_from_fx.py index c6c4f2597260..e191775a63b2 100644 --- a/tests/python/relax/test_frontend_from_fx.py +++ b/tests/python/relax/test_frontend_from_fx.py @@ -48,6 +48,15 @@ def __init__(self): def forward(self, input): return self.conv(input) + class Conv1D1Func(Module): + def __init__(self): + super().__init__() + self.weight = torch.randn(size=[6, 3, 7]) + self.bias = torch.randn(size=[6]) + + def forward(self, input): + return torch.nn.functional.conv1d(input, self.weight, self.bias) + @tvm.script.ir_module class expected1: @R.function @@ -113,6 +122,10 @@ def main( binding = {"w1": model.conv.weight.detach().numpy(), "w2": model.conv.bias.detach().numpy()} verify_model(model, input_info, binding, expected1) + model = Conv1D1Func() + binding = {"w1": model.weight.detach().numpy(), "w2": model.bias.detach().numpy()} + verify_model(model, input_info, binding, expected1) + model = Conv1D2() binding = {"w1": model.conv.weight.detach().numpy()} verify_model(model, input_info, binding, expected2) @@ -127,6 +140,15 @@ def __init__(self): def forward(self, input): return self.conv(input) + class ConvTranspose1d1Func(Module): + def __init__(self): + super().__init__() + self.weight = torch.randn(size=[6, 6, 3]) + self.bias = torch.randn(size=[6]) + + def forward(self, input): + return torch.nn.functional.conv_transpose1d(input, self.weight, self.bias) + @tvm.script.ir_module class expected1: @R.function @@ -192,6 +214,10 @@ def main( binding = {"w1": model.conv.weight.detach().numpy(), "w2": model.conv.bias.detach().numpy()} verify_model(model, input_info, binding, expected1) + model = ConvTranspose1d1Func() + binding = {"w1": model.weight.detach().numpy(), "w2": model.bias.detach().numpy()} + verify_model(model, input_info, binding, expected1) + model = ConvTranspose1d2() binding = {"w1": model.conv.weight.detach().numpy()} verify_model(model, input_info, binding, expected2) @@ -298,6 +324,15 @@ def __init__(self): def forward(self, input): return self.conv(input) + class ConvTranspose2d1Func(Module): + def __init__(self): + super().__init__() + self.weight = torch.randn(size=[3, 3, 7, 7]) + self.bias = torch.randn(size=[3]) + + def forward(self, input): + return torch.nn.functional.conv_transpose2d(input, self.weight, self.bias) + @tvm.script.ir_module class expected1: @R.function @@ -363,6 +398,10 @@ def main( binding = {"w1": model.conv.weight.detach().numpy(), "w2": model.conv.bias.detach().numpy()} verify_model(model, input_info, binding, expected1) + model = ConvTranspose2d1Func() + binding = {"w1": model.weight.detach().numpy(), "w2": model.bias.detach().numpy()} + verify_model(model, input_info, binding, expected1) + model = ConvTranspose2d2() binding = {"w1": model.conv.weight.detach().numpy()} verify_model(model, input_info, binding, expected2) @@ -377,6 +416,15 @@ def __init__(self): def forward(self, input): return self.conv(input) + class Conv3D1Func(Module): + def __init__(self): + super().__init__() + self.weight = torch.randn(size=[6, 3, 7, 7, 7]) + self.bias = torch.randn(size=[6]) + + def forward(self, input): + return torch.nn.functional.conv3d(input, self.weight, self.bias) + @tvm.script.ir_module class expected1: @R.function @@ -442,6 +490,10 @@ def main( binding = {"w1": model.conv.weight.detach().numpy(), "w2": model.conv.bias.detach().numpy()} verify_model(model, input_info, binding, expected1) + model = Conv3D1Func() + binding = {"w1": model.weight.detach().numpy(), "w2": model.bias.detach().numpy()} + verify_model(model, input_info, binding, expected1) + model = Conv3D2() binding = {"w1": model.conv.weight.detach().numpy()} verify_model(model, input_info, binding, expected2)