Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
215 changes: 202 additions & 13 deletions python/tvm/relay/frontend/onnx.py
Original file line number Diff line number Diff line change
Expand Up @@ -209,6 +209,86 @@ def _dim_check(attrs):
return _dim_check, "Only 1d, 2d and 3d kernel supported."


def collect_quant_params(graph, quant_ops):
"""Collect quant params for the quant ops that between
QuantizeLinear and DequantizeLinear nodes
"""
# init the params from initializer and Constant node
params = {}
for init in graph.initializer:
params[init.name] = get_numpy(init)
for node in graph.node:
if node.op_type == "Constant":
for attr in node.attribute:
if attr.name == "value" and attr.HasField("t"):
params[node.output[0]] = get_numpy(attr.t)

# map the output to nodes, input to nodes, prepare for dfs search
output_to_nodes = {}
input_to_nodes = {}
for node in graph.node:
for k in node.output:
output_to_nodes[k] = node
for idx, k in enumerate(node.input):
if k not in input_to_nodes:
input_to_nodes[k] = []
input_to_nodes[k].append([idx, node])

quant_params = {}
# set the quant params for the input of quant ops before DequantizeLinear
def forward_dfs(node, input_idx, scale, zero_point):
if node.op_type == "DequantizeLinear":
return
if node.op_type in quant_ops:
quant_params[node.input[input_idx]] = (scale, zero_point)
return
for out in node.output:
if out in input_to_nodes:
for next_input_idx, next_node in input_to_nodes[out]:
forward_dfs(next_node, next_input_idx, scale, zero_point)

# set the quant params for the output of quant ops after QuantizeLinear
def backward_dfs(node, scale, zero_point):
if node.op_type == "QuantizeLinear":
return
if node.op_type in quant_ops:
for out in node.output:
quant_params[out] = (scale, zero_point)
return
for inp in node.input:
if inp in output_to_nodes:
previous_node = output_to_nodes[inp]
backward_dfs(previous_node, scale, zero_point)

for node in graph.node:
# pass quant params forward to the input of quant ops before DequantizeLinear
if node.op_type == "QuantizeLinear":
if node.input[1] in params and node.input[2] in params:
forward_dfs(node, 0, params[node.input[1]], params[node.input[2]])

# pass quant params backward to the output of quant ops after QuantizeLinear
if node.op_type == "DequantizeLinear":
if node.input[1] in params and node.input[2] in params:
backward_dfs(node, params[node.input[1]], params[node.input[2]])

return quant_params


def is_quantized_expr(x):
ttype = infer_type(x).checked_type
return ttype.dtype in ["int8", "uint8"]


def check_quant_params(attr, input_ids, output_ids):
for input_id in input_ids:
if attr["tvm_custom"]["input_quant_params"][input_id] is None:
return False
for output_id in output_ids:
if attr["tvm_custom"]["input_quant_params"][output_id] is None:
return False
return True


class OnnxOpConverter(object):
"""A helper class for holding onnx op converters."""

Expand Down Expand Up @@ -264,6 +344,36 @@ def _impl_v1(cls, inputs, attr, params):
# TODO(zhreshold): remove hard coded infershape
axis = int(attr.get("axis", 0))
inputs[1] = _op.expand_dims(inputs[1], axis=axis, num_newaxis=2)

if (
is_quantized_expr(inputs[0])
and is_quantized_expr(inputs[1])
and check_quant_params(attr, [0, 1], [0])
):

lhs_scale, lhs_zero_point = attr["tvm_custom"]["input_quant_params"][0]
rhs_scale, rhs_zero_point = attr["tvm_custom"]["input_quant_params"][1]
out_scale, out_zero_point = attr["tvm_custom"]["output_quant_params"][0]

args = [
inputs[0],
inputs[1],
_expr.const(lhs_scale),
_expr.const(int(lhs_zero_point)),
_expr.const(rhs_scale),
_expr.const(int(rhs_zero_point)),
_expr.const(out_scale),
_expr.const(int(out_zero_point)),
]

if op_name == "add":
return _qnn.op.add(*args)
if op_name == "multiply":
return _qnn.op.mul(*args)
if op_name == "subtract":
return _qnn.op.subtract(*args)
raise Exception("Unsupported op: {} for quantized input!".format(op_name))

return get_relay_op(op_name)(*inputs)


Expand Down Expand Up @@ -406,6 +516,7 @@ class Conv(OnnxOpConverter):

@classmethod
def _impl_v1(cls, inputs, attr, params):

# Use shape of input to determine convolution type.
data = inputs[0]
input_shape = infer_shape(data)
Expand Down Expand Up @@ -436,16 +547,51 @@ def _impl_v1(cls, inputs, attr, params):
msg = 'Value {} in attribute "auto_pad" of operator Conv is invalid.'
raise tvm.error.OpAttributeInvalid(msg.format(attr["auto_pad"]))
attr.pop("auto_pad")
out = AttrCvt(
op_name=dimension_picker("conv"),
transforms={
"kernel_shape": "kernel_size",
"dilations": ("dilation", 1),
"pads": ("padding", 0),
"group": ("groups", 1),
},
custom_check=dimension_constraint(),
)([data, inputs[1]], attr, params)

if (
ndim == 4
and is_quantized_expr(data)
and is_quantized_expr(inputs[1])
and check_quant_params(attr, [0, 1], [0])
):

input_scale, input_zero_point = attr["tvm_custom"]["input_quant_params"][0]
kernel_scale, kernel_zero_point = attr["tvm_custom"]["input_quant_params"][1]

weight_shape = infer_shape(inputs[1])
kernel_size = (weight_shape[2], weight_shape[3])
out_channels = weight_shape[0]

padding = attr["pads"] if "pads" in attr else 0
dilation = attr["dilations"] if "dilations" in attr else 1
groups = attr["group"] if "group" in attr else 1
strides = attr["strides"]

out = _qnn.op.conv2d(
data,
inputs[1],
_expr.const(int(input_zero_point)),
_expr.const(int(kernel_zero_point)),
_expr.const(input_scale),
_expr.const(kernel_scale),
kernel_size=kernel_size,
dilation=dilation,
strides=strides,
padding=padding,
groups=groups,
channels=out_channels,
)
else:
out = AttrCvt(
op_name=dimension_picker("conv"),
transforms={
"kernel_shape": "kernel_size",
"dilations": ("dilation", 1),
"pads": ("padding", 0),
"group": ("groups", 1),
},
custom_check=dimension_constraint(),
)([data, inputs[1]], attr, params)

use_bias = len(inputs) == 3
if use_bias:
Expand Down Expand Up @@ -587,9 +733,43 @@ def _impl_v1(cls, inputs, attr, params):
inputs[0] = _op.nn.batch_flatten(inputs[0])
if alpha != 1.0:
inputs[0] *= _expr.const(alpha)
out = _op.nn.dense(inputs[0], inputs[1], units=channels)
if len(inputs) == 3:
out = out + _expr.const(beta) * inputs[2]

if (
is_quantized_expr(inputs[0])
and is_quantized_expr(inputs[1])
and check_quant_params(attr, [0, 1], [0])
):

input_scale, input_zero_point = attr["tvm_custom"]["input_quant_params"][0]
kernel_scale, kernel_zero_point = attr["tvm_custom"]["input_quant_params"][1]

out = _qnn.op.dense(
inputs[0],
inputs[1],
_expr.const(int(input_zero_point)),
_expr.const(int(kernel_zero_point)),
_expr.const(input_scale),
_expr.const(kernel_scale),
units=channels,
)
if len(inputs) == 3:
output_scale, output_zero_point = attr["tvm_custom"]["output_quant_params"][0]
bias = _qnn.op.dequantize(
inputs[2], _expr.const(output_scale), _expr.const(output_zero_point), axis=0
)
bias = _expr.const(beta) * bias
bias = _qnn.op.quantize(
bias,
_expr.const(output_scale),
_expr.const(0, "int32"),
out_dtype="int32",
axis=0,
)
out = out + bias
else:
out = _op.nn.dense(inputs[0], inputs[1], units=channels)
if len(inputs) == 3:
out = out + _expr.const(beta) * inputs[2]
return out


Expand Down Expand Up @@ -2937,6 +3117,7 @@ def __init__(self, shape, dtype, freeze_params=False):
self._dtype = dtype
self.opset = None
self._freeze_params = freeze_params
self._quant_params = {}

def __enter__(self):
self._old_manager = GraphProto.current
Expand Down Expand Up @@ -3054,6 +3235,7 @@ def from_onnx(self, graph, opset, get_output_expr=False):
msg = "The following operators are not supported for frontend ONNX: "
msg += ", ".join(unsupported_ops)
raise tvm.error.OpNotImplemented(msg)
self._quant_params = collect_quant_params(graph, ["Conv", "Gemm", "Add", "Mul", "Sub"])
# construct nodes, nodes are stored as directed acyclic graph
for node in graph.node:
op_name = node.op_type
Expand All @@ -3067,9 +3249,16 @@ def from_onnx(self, graph, opset, get_output_expr=False):
inputs[i] = None
i_name = self._parse_value_proto(node)
node_output = self._fix_outputs(op_name, node.output)

attr["tvm_custom"] = {}
attr["tvm_custom"]["name"] = i_name
attr["tvm_custom"]["num_outputs"] = len(node_output)
attr["tvm_custom"]["input_quant_params"] = [
self._quant_params[x] if x in self._quant_params else None for x in node.input
]
attr["tvm_custom"]["output_quant_params"] = [
self._quant_params[x] if x in self._quant_params else None for x in node.output
]

op = self._convert_operator(op_name, inputs, attr, opset)
if not isinstance(op, _expr.TupleWrapper):
Expand Down
Loading