From c2bf9fbd868520a39ba2b2657e559eade88e1e93 Mon Sep 17 00:00:00 2001 From: Zheng-Bicheng Date: Wed, 21 Feb 2024 20:20:50 +0800 Subject: [PATCH 1/2] support conv2d when data_format is NHWC --- python/tvm/relay/frontend/paddlepaddle.py | 18 ++++++++++++++++++ 1 file changed, 18 insertions(+) diff --git a/python/tvm/relay/frontend/paddlepaddle.py b/python/tvm/relay/frontend/paddlepaddle.py index 1a3b119b383f..c64eb91e6ee4 100755 --- a/python/tvm/relay/frontend/paddlepaddle.py +++ b/python/tvm/relay/frontend/paddlepaddle.py @@ -314,6 +314,7 @@ def convert_conv2d(g, op, block): strides = op.attr("strides") kernel = g.get_node(op.input("Filter")[0]) + kernel_layout = "OIHW" input_x = g.get_node(op.input("Input")[0]) data_layout = op.attr("data_format") out_channels, _, k_h, k_w = infer_shape(kernel) @@ -335,6 +336,16 @@ def convert_conv2d(g, op, block): msg = f'Value {padding_algorithm} in attribute "padding" of operator Conv is not "valid."' raise tvm.error.OpAttributeInvalid(msg) + if data_layout == "NHWC": + kernel_layout = "HWIO" + # PaddlePaddle wieght layout is "OIHW", tvm need "HWIO" when op data_format is "NHWC" + kernel_data = g.get_params(op.input("Filter")[0]) + kernel_data = kernel_data.asnumpy() + kernel_data = kernel_data.transpose((2, 3, 1, 0)) + kernel_data = _nd.array(kernel_data) + g.modify_node(op.input("Filter")[0], kernel_data) + kernel = g.get_node(op.input("Filter")[0]) + out = _op.nn.conv2d( input_x, kernel, @@ -345,6 +356,7 @@ def convert_conv2d(g, op, block): channels=out_channels, kernel_size=[k_h, k_w], data_layout=data_layout, + kernel_layout=kernel_layout, ) g.add_node(op.output("Output")[0], out) @@ -2915,6 +2927,12 @@ def add_node(self, name, node): self.nodes[name] = fold_constant(node) + def modify_node(self, name, params): + """modify node from graph""" + + self.params[name] = params + self.nodes[name] = new_var(name, shape=params.shape, dtype=params.dtype) + def get_params(self, name=None): """Get params from graph.""" From f6223d1e6a774f44576c5e7187220a2751838ba6 Mon Sep 17 00:00:00 2001 From: Zheng-Bicheng <58363586+Zheng-Bicheng@users.noreply.github.com> Date: Wed, 21 Feb 2024 22:39:17 +0800 Subject: [PATCH 2/2] modify the annotation --- python/tvm/relay/frontend/paddlepaddle.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/tvm/relay/frontend/paddlepaddle.py b/python/tvm/relay/frontend/paddlepaddle.py index c64eb91e6ee4..bb72d30352af 100755 --- a/python/tvm/relay/frontend/paddlepaddle.py +++ b/python/tvm/relay/frontend/paddlepaddle.py @@ -338,7 +338,7 @@ def convert_conv2d(g, op, block): if data_layout == "NHWC": kernel_layout = "HWIO" - # PaddlePaddle wieght layout is "OIHW", tvm need "HWIO" when op data_format is "NHWC" + # PaddlePaddle wieght layout is "OIHW", tvm need "HWIO" when op data_format is "NHWC". kernel_data = g.get_params(op.input("Filter")[0]) kernel_data = kernel_data.asnumpy() kernel_data = kernel_data.transpose((2, 3, 1, 0))