Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
83 changes: 74 additions & 9 deletions python/tvm/relay/frontend/paddlepaddle.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
from .. import function as _function
from .. import ty as _ty
from .. import op as _op
from .. import qnn as _qnn
from .common import (
autopad,
fold_constant,
Expand Down Expand Up @@ -314,9 +315,9 @@ def convert_conv2d(g, op, block):
strides = op.attr("strides")

kernel = g.get_node(op.input("Filter")[0])
kernel_layout = "OIHW"
input_x = g.get_node(op.input("Input")[0])
data_layout = op.attr("data_format")
kernel_layout = "OIHW" if data_layout == "NCHW" else "HWIO"
out_channels, _, k_h, k_w = infer_shape(kernel)
if padding_algorithm == "VALID":
paddings = [0, 0]
Expand All @@ -336,9 +337,15 @@ def convert_conv2d(g, op, block):
msg = f'Value {padding_algorithm} in attribute "padding" of operator Conv is not "valid."'
raise tvm.error.OpAttributeInvalid(msg)

if data_layout == "NHWC":
kernel_layout = "HWIO"
# PaddlePaddle wieght layout is "OIHW", tvm need "HWIO" when op data_format is "NHWC".
is_quantized = op.has_attr("quantization_type")
# PaddlePaddle wieght layout is "OIHW", tvm need "HWIO" when op data_format is "NHWC".
# There are two situations when converting the data format of weights:
# 1 Conv_2d is not a quantified OP, its weight information is the weights themselves.
# We directly convert the weight information when processing conv_2d.
# 2 Conv_2d is a quantified OP, and its weight information is the output of
# the quantize_linear operator. Therefore, the weight information needs to be
# transformed when processing the quantize_linear operator.
if (not is_quantized) and (data_layout == "NHWC"):
kernel_data = g.get_params(op.input("Filter")[0])
kernel_data = kernel_data.asnumpy()
kernel_data = kernel_data.transpose((2, 3, 1, 0))
Expand Down Expand Up @@ -1626,7 +1633,7 @@ def convert_pool3d(g, op, block):
raise tvm.error.OpAttributeInvalid(msg.format(padding_algorithm))

# handle with special case
# while kernel size less than input size
# while kernel size more than input size
# shrink kernel size to input size
if (
not isinstance(in_h, _op.Expr)
Expand Down Expand Up @@ -1812,6 +1819,59 @@ def convert_roi_align(g, op, block):
g.add_node(op.output("Out")[0], out)


def convert_dequantize_linear(g, op, block):
"""Operator converter for dequantize_linear."""

data_node_name = op.input("X")[0]
data_node = g.get_node(data_node_name)

# paddle_scale = tvm_scale * 127
paddle_quantize_scale = g.get_params(op.input("Scale")[0]).asnumpy()
tvm_quantize_scale = paddle_quantize_scale / 127.0

tvm_quantize_zp = g.get_params(op.input("ZeroPoint")[0]).asnumpy()

tvm_quantize_axis = op.attr("quant_axis")
if tvm_quantize_axis == -1:
tvm_quantize_axis = 0

if len(infer_shape(data_node)) < 2:
tvm_quantize_axis = 0

out = _qnn.op.dequantize(
data=data_node,
input_scale=_op.const(tvm_quantize_scale, "float32"),
input_zero_point=_op.const(tvm_quantize_zp, "int32"),
axis=tvm_quantize_axis,
)
g.add_node(op.output("Y")[0], out)


def convert_quantize_linear(g, op, block):
"""Operator converter for dequantize_linear."""

data_node_name = op.input("X")[0]
data_node = g.get_node(data_node_name)

# paddle_scale = tvm_scale * 127
paddle_quantize_scale = g.get_params(op.input("Scale")[0]).asnumpy()
tvm_quantize_scale = paddle_quantize_scale / 127.0

tvm_quantize_zp = g.get_params(op.input("ZeroPoint")[0]).asnumpy()
tvm_quantize_axis = op.attr("quant_axis")

if tvm_quantize_axis == -1:
tvm_quantize_axis = 0

out = _qnn.op.quantize(
data=data_node,
output_scale=_op.const(tvm_quantize_scale, "float32"),
output_zero_point=_op.const(tvm_quantize_zp, "int32"),
axis=tvm_quantize_axis,
)
g.add_node(op.output("Y")[0], out)


def convert_rnn(g, op, block):
"""Operator converter for rnn."""

Expand Down Expand Up @@ -2386,11 +2446,11 @@ def convert_slice(g, op, block):
def convert_softmax(g, op, block):
"""Operator converter for softmax."""

x = g.get_node(op.input("X")[0])
axis = op.attr("axis")
input_shape = block.var(op.input("X")[0]).shape
if axis < 0:
axis = len(input_shape) + axis
x = g.get_node(op.input("X")[0])
m = _op.max(x, axis, keepdims=True)
e = _op.exp(x - m)
out = e / _op.sum(e, axis, keepdims=True)
Expand Down Expand Up @@ -2905,6 +2965,9 @@ def convert_where_index(g, op, block):
"unstack": convert_unstack,
"where": convert_where,
"where_index": convert_where_index,
# Quantized
"dequantize_linear": convert_dequantize_linear,
"quantize_linear": convert_quantize_linear,
}


Expand Down Expand Up @@ -2938,7 +3001,7 @@ def get_params(self, name=None):

if name is None:
return self.params
assert name in self.params
assert name in self.params, f"The name({name}) is not in params"
return self.params[name]

def extract_parameters(self, program, scope=None):
Expand All @@ -2947,9 +3010,12 @@ def extract_parameters(self, program, scope=None):
self.params = {}
variables = program.global_block().vars
for name in variables:
var = program.global_block().var(name)
if name.endswith("feed") or name.endswith("fetch"):
continue
# This judgment will cause the PaddleInference model
# exported by PaddleSlim to skip some operators
# that need to be read in NHWC format.
var = program.global_block().var(name)
if not var.persistable:
continue
if isinstance(scope, dict):
Expand Down Expand Up @@ -3018,7 +3084,6 @@ def from_program(self, program, shape_dict, scope):
for op in block.ops:
if op.type == "fetch":
output_names.append(op.input("X")[0])

outputs = [self.nodes[name] for name in output_names]
outputs = outputs[0] if len(outputs) == 1 else _expr.Tuple(outputs)

Expand Down