Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
44 changes: 43 additions & 1 deletion python/tvm/relay/frontend/onnx.py
Original file line number Diff line number Diff line change
Expand Up @@ -2973,6 +2973,8 @@ def _impl_v13(cls, inputs, attr, params):
data, scale, zp = inputs
out_dtype = infer_type(zp).checked_type.dtype
axis = attr.get("axis", 1)
if len(infer_shape(data)) < 2:
axis = 0
return _qnn.op.quantize(data, scale, _op.cast(zp, "int32"), axis, out_dtype)


Expand Down Expand Up @@ -3033,10 +3035,11 @@ def get_scalar(x, dtype="float32"):
weight = inputs[3]
w_scale = get_scalar(inputs[4])
w_zero_point = get_scalar(inputs[5], "int32")
y_scale = get_scalar(inputs[6])
y_scale = fold_constant(get_scalar(inputs[6]))
y_zero_point = get_scalar(inputs[7], "int32")

input_shape = infer_shape(data)

ndim = len(input_shape)
kernel_type = infer_type(weight)
kernel_shapes = [get_const_tuple(kernel_type.checked_type.shape)]
Expand Down Expand Up @@ -3116,6 +3119,44 @@ def get_scalar(x, dtype="float32"):
return out


class QLinearAdd(OnnxOpConverter):
"""Operator converter for QLinearAdd from Microsoft onnxruntime contrib opset."""

@classmethod
def _impl_v10(cls, inputs, attr, params):
def get_scalar(x, dtype="float32"):
if isinstance(x, _expr.Var) and x.name_hint in params:
return _op.const(params[x.name_hint].numpy(), dtype)
rank = len(infer_shape(x))
assert rank <= 1, "QLinearConv scale and zero_point input must be scalars"
if rank == 1:
x = _op.squeeze(x, [0])
return _op.cast(x, dtype)

a = inputs[0]
a_scale = get_scalar(inputs[1])
a_zero_point = get_scalar(inputs[2], "int32")
b = inputs[3]
b_scale = get_scalar(inputs[4])
b_zero_point = get_scalar(inputs[5], "int32")
c_scale = get_scalar(inputs[6])
c_zero_point = get_scalar(inputs[7], "int32")

dtype = infer_type(a).checked_type.dtype

## Onnxruntime doesn't actually do this op in integer, they dequantize to fp32
## and then requantize afer
## https://github.com/microsoft/onnxruntime/blob/master/onnxruntime/core/mlas/lib/qladd.cpp
a = _qnn.op.dequantize(
inputs[0], a_scale, a_zero_point
) # , c_scale, c_zero_point, out_dtype = dtype)
b = _qnn.op.dequantize(
inputs[3], b_scale, b_zero_point
) # , c_scale, c_zero_point, out_dtype = dtype)
out = _op.add(a, b)
return _qnn.op.quantize(out, c_scale, c_zero_point, out_dtype=dtype)


class BitShift(OnnxOpConverter):
"""Operator converter for NonZero"""

Expand Down Expand Up @@ -3343,6 +3384,7 @@ def _get_convert_map(opset):
"DynamicQuantizeLinear": DynamicQuantizeLinear.get_converter(opset),
"ReverseSequence": ReverseSequence.get_converter(opset),
"QLinearConv": QLinearConv.get_converter(opset),
"QLinearAdd": QLinearAdd.get_converter(opset),
}


Expand Down
62 changes: 61 additions & 1 deletion tests/python/frontend/onnx/test_forward.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
import os
import re

import numpy as np
Expand Down Expand Up @@ -120,7 +121,7 @@ def get_tvm_output(
def get_onnxruntime_output(model, inputs):
import onnxruntime.backend

rep = onnxruntime.backend.prepare(model, "CPU")
rep = onnxruntime.backend.prepare(model.SerializeToString(), "CPU")
if isinstance(inputs, list) and len(inputs) == 1:
inp = inputs[0]
else:
Expand Down Expand Up @@ -149,6 +150,7 @@ def verify_with_ort_with_inputs(
):
if opset is not None:
model.opset_import[0].version = opset

ort_out = get_onnxruntime_output(model, inputs)

if targets is None:
Expand Down Expand Up @@ -4755,6 +4757,64 @@ def repeat(N, D):
)


def verify_qlinearadd(a_shape, b_shape, c_shape):

a_array = np.random.random(a_shape).astype("float32")
b_array = np.random.random(b_shape).astype("float32")

input_nodes = [
helper.make_tensor_value_info("a", TensorProto.FLOAT, list(a_shape)),
helper.make_tensor_value_info("b", TensorProto.FLOAT, list(b_shape)),
]
input_names = [
"a",
"b",
]
input_values = [a_array, b_array]

node = helper.make_node("QLinearAdd", inputs=input_names, outputs=["C"])

node = helper.make_node("Add", ["a", "b"], ["C"])
graph = helper.make_graph(
[node],
"qlinearadd_test",
inputs=input_nodes,
outputs=[helper.make_tensor_value_info("C", TensorProto.FLOAT, list(c_shape))],
)
model = helper.make_model(graph, producer_name="qlinearconv_test")
from onnxruntime.quantization import quantize_static, CalibrationDataReader, QuantType

class RandomDataReader(CalibrationDataReader):
def __init__(self, n=10):
self.data = iter(
[
{
"a": np.random.random(a_shape).astype("float32"),
"b": np.random.random(b_shape).astype("float32"),
}
for _ in range(n)
]
)

def get_next(self):
return next(self.data, None)

d = tvm.contrib.utils.tempdir()
model_fp32 = os.path.join(d.temp_dir, "model.onnx")
onnx.save_model(model, model_fp32)
model_quant = os.path.join(d.temp_dir, "model.quant.onnx")
quantized_model = quantize_static(model_fp32, model_quant, RandomDataReader())
# opt_level=1 will cause error with qnn lowering
model = onnx.load(model_quant)
verify_with_ort_with_inputs(model, input_values, opt_level=2)


def test_qlinearadd():
verify_qlinearadd([4, 2], [4, 2], [4, 2])
verify_qlinearadd([4, 2], [2], [4, 2])
verify_qlinearadd([5, 1, 7], [2, 7], [5, 2, 7])


if __name__ == "__main__":
test_flatten()
test_reshape()
Expand Down