From 2326147379773668c0e552dd869f97420e8560e5 Mon Sep 17 00:00:00 2001 From: Matthew Date: Mon, 3 May 2021 18:23:04 -0600 Subject: [PATCH 1/6] Convert a fake quantized or QAT graph into qnn ops --- python/tvm/relay/frontend/onnx.py | 6 + python/tvm/relay/op/__init__.py | 1 + python/tvm/relay/op/op.py | 20 ++ python/tvm/relay/transform/__init__.py | 1 + .../transform/quantize_fake_quantization.py | 177 +++++++++++ python/tvm/relay/transform/transform.py | 27 ++ .../transforms/quantize_fake_quantization.cc | 296 ++++++++++++++++++ .../test_pass_quantize_fake_quantization.py | 280 +++++++++++++++++ 8 files changed, 808 insertions(+) create mode 100644 python/tvm/relay/transform/quantize_fake_quantization.py create mode 100644 src/relay/transforms/quantize_fake_quantization.cc create mode 100644 tests/python/relay/test_pass_quantize_fake_quantization.py diff --git a/python/tvm/relay/frontend/onnx.py b/python/tvm/relay/frontend/onnx.py index 13505fd0f738..c8855b2ea2be 100644 --- a/python/tvm/relay/frontend/onnx.py +++ b/python/tvm/relay/frontend/onnx.py @@ -2465,6 +2465,12 @@ def _impl_v1(cls, inputs, attr, params): @classmethod def _impl_v11(cls, inputs, attr, params): + if len(inputs) == 3 and isinstance(inputs[2], _expr.Constant): + attr["max"] = inputs[2].data.asnumpy().item() + inputs = inputs[0:2] + if len(inputs) >= 2 and isinstance(inputs[1], _expr.Constant): + attr["min"] = inputs[1].data.asnumpy().item() + inputs = inputs[0:1] if "min" in attr and "max" in attr: return Clip.convert_attributes(inputs, attr, params) diff --git a/python/tvm/relay/op/__init__.py b/python/tvm/relay/op/__init__.py index 1f267abedc1a..610604604691 100644 --- a/python/tvm/relay/op/__init__.py +++ b/python/tvm/relay/op/__init__.py @@ -28,6 +28,7 @@ OpStrategy, debug, register_external_compiler, + register_quantize_fake_quantization, ) from . import strategy diff --git a/python/tvm/relay/op/op.py b/python/tvm/relay/op/op.py index 33cb46d67f34..d7a4d37e4a36 100644 --- a/python/tvm/relay/op/op.py +++ b/python/tvm/relay/op/op.py @@ -436,6 +436,26 @@ def register_external_compiler(op_name, fexternal=None, level=10): return tvm.ir.register_op_attr(op_name, "FTVMExternalCompiler", fexternal, level) +def register_quantize_fake_quantization(op_name, qfq=None, level=10): + """Register quantize function for an op + + Given an op and Affine Types on it's inputs, this function should return the op + in affine space and the new type of the output + + Parameters + ---------- + op_name : str + The name of the operator + + qfq: function (expr: Expr, map: Map) -> new_expr: Expr + The function for translating the op into affine space + + level : int + The priority level + """ + return tvm.ir.register_op_attr(op_name, "FTVMQuantizeFakeQuantization", qfq, level) + + @tvm._ffi.register_func("relay.op.compiler._lower") def _lower(name, schedule, inputs, outputs): return lower(schedule, list(inputs) + list(outputs), name=name) diff --git a/python/tvm/relay/transform/__init__.py b/python/tvm/relay/transform/__init__.py index ca9996aeaaae..ffbbf359c8dc 100644 --- a/python/tvm/relay/transform/__init__.py +++ b/python/tvm/relay/transform/__init__.py @@ -19,3 +19,4 @@ # transformation passes from .transform import * from .recast import recast +from . import quantize_fake_quantization diff --git a/python/tvm/relay/transform/quantize_fake_quantization.py b/python/tvm/relay/transform/quantize_fake_quantization.py new file mode 100644 index 000000000000..8ca578264302 --- /dev/null +++ b/python/tvm/relay/transform/quantize_fake_quantization.py @@ -0,0 +1,177 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""Relay functions for rewriting fake quantized ops.""" +import tvm +from tvm import relay +from ..op import register_quantize_fake_quantization + + +def fold_constant(expr): + mod = tvm.IRModule.from_expr(expr) + mod = relay.transform.FoldConstant()(mod) + return mod["main"].body + + +@register_quantize_fake_quantization("qnn.dequantize") +def dequantize_qfq(expr, type_map): + """Remove dequantize op""" + out = expr.args[0] + t = type_map[expr] + return [out, t.scale, t.zero_point, t.dtype] + + +@register_quantize_fake_quantization("qnn.quantize") +def quantize_qfq(expr, type_map): + """Turn a quantize op into requantize or remove it""" + out = expr.args[0] + t = type_map[out] + in_scale = fold_constant(t.scale) + in_zero_point = fold_constant(t.zero_point) + if not ( + tvm.ir.structural_equal(in_scale, expr.args[1]) + and tvm.ir.structural_equal(in_zero_point, expr.args[2]) + and tvm.ir.structural_equal(t.dtype, expr.attrs.out_dtype) + ): + out = relay.qnn.op.requantize( + out, + in_scale, + in_zero_point, + expr.args[1], + expr.args[2], + out_dtype=expr.attrs.out_dtype, + ) + return [out, expr.args[1], expr.args[2], expr.attrs.out_dtype] + + +@register_quantize_fake_quantization("reshape") +def reshape_qfq(expr, type_map): + """Rewrite a reshape op""" + arg = expr.args[0] + t = type_map[arg] + out = relay.op.reshape(arg, **expr.attrs) + return [out, t.scale, t.zero_point, t.dtype] + + +@register_quantize_fake_quantization("transpose") +def transpose_qfq(expr, type_map): + """Rewrite a transpose op""" + arg = expr.args[0] + t = type_map[arg] + out = relay.op.transpose(arg, **expr.attrs) + return [out, t.scale, t.zero_point, t.dtype] + + +@register_quantize_fake_quantization("nn.max_pool2d") +def maxpool_qfq(expr, type_map): + """Rewrite a maxpool op""" + arg = expr.args[0] + t = type_map[arg] + out = relay.op.nn.max_pool2d(arg, **expr.attrs) + return [out, t.scale, t.zero_point, t.dtype] + + +@register_quantize_fake_quantization("nn.avg_pool2d") +def avgpool_qfq(expr, type_map): + """Rewrite a avgpool op""" + arg = expr.args[0] + t = type_map[arg] + arg = relay.op.cast(arg, "int32") + out = relay.op.nn.avg_pool2d(arg, **expr.attrs) + out = relay.op.cast(out, t.dtype) + return [out, t.scale, t.zero_point, t.dtype] + + +@register_quantize_fake_quantization("nn.bias_add") +def bias_add_qfq(expr, type_map): + """Rewrite a bias_add op""" + x, b = expr.args + x_t = type_map[x] + b_t = type_map[b] + in_scale = fold_constant(x_t.scale) + in_zero_point = fold_constant(x_t.zero_point) + if not tvm.ir.structural_equal(x_t, b_t): + b = relay.qnn.op.requantize( + b, + b_t.scale, + b_t.zero_point, + in_scale, + in_zero_point, + out_dtype=xt.dtype, + ) + out = relay.op.nn.bias_add(x, b, **expr.attrs) + return [out, x_t.scale, x_t.zero_point, x_t.dtype] + + +@register_quantize_fake_quantization("nn.conv2d") +def conv2d_qfq(expr, type_map): + """Rewrite a conv2d op""" + attrs = {**expr.attrs} + attrs.pop("out_dtype") + x, weight = expr.args + x_t = type_map[x] + w_t = type_map[weight] + conv_scale = fold_constant(x_t.scale * w_t.scale) + conv_zp = relay.const(0) + out = relay.qnn.op.conv2d( + x, weight, x_t.zero_point, w_t.zero_point, x_t.scale, w_t.scale, **attrs + ) + return [out, conv_scale, conv_zp, out.attrs.out_dtype] + + +@register_quantize_fake_quantization("concatenate") +def concat_qfq(expr, type_map): + """Rewrite a concat op""" + scales = [] + zps = [] + for arg in expr.args[0].fields: + t = type_map[arg] + scales.append(t.scale) + zps.append(t.zero_point) + + out_type = type_map[expr] + + out = relay.qnn.op.concatenate( + expr.args[0], + relay.Tuple(scales), + relay.Tuple(zps), + out_type.scale, + out_type.zero_point, + **expr.attrs, + ) + return [out, out_type.scale, out_type.zero_point, out_type.dtype] + + +@register_quantize_fake_quantization("clip") +def clip_qfq(expr, type_map): + """Rewrite a clip op""" + arg = expr.args[0] + t = type_map[arg] + amin = expr.attrs.a_min + amax = expr.attrs.a_max + scale = fold_constant(t.scale) + z_p = fold_constant(t.zero_point) + if isinstance(scale, relay.expr.Constant) and isinstance(z_p, relay.expr.Constant): + scale = scale.data.numpy().item() + z_p = z_p.data.numpy().item() + new_min = int(amin / scale + z_p) + new_max = int(amax / scale + z_p) + out = relay.op.clip(arg, new_min, new_max) + else: + amin = relay.op.round(relay.op.const(amin) / scale + z_p) + amax = relay.op.round(relay.op.const(amax) / scale + z_p) + out = relay.op.minimum(relay.op.maximum(arg, amin), amax) + return [out, t.scale, t.zero_point, t.dtype] diff --git a/python/tvm/relay/transform/transform.py b/python/tvm/relay/transform/transform.py index 20e8bb94c501..5659e61a90fb 100644 --- a/python/tvm/relay/transform/transform.py +++ b/python/tvm/relay/transform/transform.py @@ -1171,3 +1171,30 @@ def AnnotateSpans(): The regsistered AnnotateSpans pass. """ return _ffi_api.AnnotateSpans() + + +def QuantizeFakeQuantization(): + """ + Find regions of the graph of the form + + x w + | | + dq dq + \ / + op1 + | + op2 + | + q + + where q == qnn.quantize and dq = qnn.dequantize + and rewrite them into integer versions of op1 and op2 + + Rules for rewriting indivdual ops are in quantize_fake_quantization.py + + Returns + ------- + ret : tvm.transform.Pass + The registered SimplifyExpr pass. + """ + return _ffi_api.QuantizeFakeQuantization() diff --git a/src/relay/transforms/quantize_fake_quantization.cc b/src/relay/transforms/quantize_fake_quantization.cc new file mode 100644 index 000000000000..1c6bbb119bff --- /dev/null +++ b/src/relay/transforms/quantize_fake_quantization.cc @@ -0,0 +1,296 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file src/relay/transforms/simplify_expr.cc + * \brief A pass for simplifying the Relay expression. + */ + +#include +#include +#include + +/* Description of QuantizeFakeQuantization + * + * The purpose of this pass is to find regions of the graph that follow + * the general pattern: + * + * x w + * | | + * dq dq + * \ / + * op1 + * | + * op2 + * | + * q + * + * and convert them into subgraphs with actual integer operations on x and w + * + * The pass does this via a multi-pass approach: + * + * The main pass is a MixedModeMutator that traverses the full graph searching for + * quantize operations + * + * The second pass is an ExprVisitor that recursively searches for subgraphs leading to the + * quantize for subtraphs bounded by dequantize operations. This pass extracts the affine + * types of the inputs for later processing + * + * The third pass is an ExprMutator the recursively rewrites the subgraphs using packed funcs + * registered with the FTVMQuantizeFakeQuantization attribute. These packed funcs rewrite + * the ops based on the affine types of their inputs and then return the affine types of the + * new rewriten ops to pass that information down the stack during rewrite. + * + * After the second and third passes run, the first pass replaces the quantize with the + * rewritten subgraph and the processing continues + */ + +namespace tvm { +namespace relay { + +/*! + * \brief AffineType representation + * \sa AffineType + */ +class AffineTypeNode : public Object { + public: + /*! \brief The scale of this type */ + Expr scale; + /*! \brief The zero point of this type */ + Expr zero_point; + /*! \brief The data type of this type */ + DataType dtype; + + void VisitAttrs(tvm::AttrVisitor* v) { + v->Visit("scale", &scale); + v->Visit("zero_point", &zero_point); + v->Visit("dtype", &dtype); + } + + bool SEqualReduce(const AffineTypeNode* other, SEqualReducer equal) const { + equal->MarkGraphNode(); + return equal(scale, other->scale) && equal(zero_point, other->zero_point) && + equal(dtype, other->dtype); + } + + void SHashReduce(SHashReducer hash_reduce) const { + hash_reduce->MarkGraphNode(); + hash_reduce(scale); + hash_reduce(zero_point); + hash_reduce(dtype); + } + + static constexpr const bool _type_has_method_sequal_reduce = true; + static constexpr const bool _type_has_method_shash_reduce = true; + static constexpr const char* _type_key = "AffineTypeNode"; + TVM_DECLARE_BASE_OBJECT_INFO(AffineTypeNode, Object); +}; + +/*! + * \brief Managed reference to AffineTypes. + * \sa AffineTypeNode + */ +class AffineType : public ObjectRef { + public: + TVM_DLL AffineType(Expr scale, Expr zero_point, DataType dtype) { + ObjectPtr n = make_object(); + n->scale = std::move(scale); + n->zero_point = std::move(zero_point); + n->dtype = std::move(dtype); + data_ = std::move(n); + } + TVM_DEFINE_OBJECT_REF_METHODS(AffineType, ObjectRef, AffineTypeNode); +}; + +TVM_REGISTER_NODE_TYPE(AffineTypeNode); + +using ExprSet = std::unordered_set; +using ExprMap = std::unordered_map; +using AffineTypeMap = Map; + +using FTVMQuantizeFakeQuantization = + runtime::TypedPackedFunc(const Expr& expr, const AffineTypeMap& map)>; + +class SubgraphExtractor : public ExprVisitor { + public: + const ExprSet GetSubgraph(const Expr& expr) { + VisitExpr(expr); + ExprSet subgraph; + if (is_fake_quantized_) { + for (auto kv : this->visit_counter_) { + if (auto call_node = GetRef(kv.first).as()) { + if (call_node->op != quantize_op_) { + subgraph.insert(Downcast(GetRef(kv.first))); + } + } + } + } + return subgraph; + } + const AffineTypeMap GetAffineTypes() { return affine_types_; } + void VisitExpr(const Expr& expr) { + if (expr.as() == nullptr && expr.as() == nullptr && + expr.as() == nullptr) { + is_fake_quantized_ = false; + } else { + ExprVisitor::VisitExpr(expr); + } + } + + protected: + void VisitExpr_(const CallNode* call_node) override { + if (call_node->op == quantize_op_) { + // Only look at arg0 for quantize + VisitExpr(call_node->args[0]); + // Collect type of quantize ops + affine_types_.Set(GetRef(call_node), + AffineType(call_node->args[1], call_node->args[2], + call_node->checked_type().as()->dtype)); + } else if (call_node->op == dequantize_op_) { + // Collect type of dequantize ops + affine_types_.Set(GetRef(call_node), + AffineType(call_node->args[1], call_node->args[2], + call_node->args[0]->checked_type().as()->dtype)); + } else { + // run normally on everything else. + ExprVisitor::VisitExpr_(call_node); + } + } + + const Op quantize_op_ = Op::Get("qnn.quantize"); + const Op dequantize_op_ = Op::Get("qnn.dequantize"); + bool is_fake_quantized_ = true; + AffineTypeMap affine_types_; +}; + +class SubgraphMutator : public ExprMutator { + public: + SubgraphMutator(ExprSet subgraph, AffineTypeMap affine_types) + : subgraph_(subgraph), affine_types_(affine_types) {} + + Expr MutateSubgraph(const Expr& expr) { + if (subgraph_.size() == 0) { + return expr; + } + const CallNode* quantize_node = expr.as(); + ICHECK(quantize_node); + ICHECK(quantize_node->op == quantize_op_); + out_type_ = affine_types_[expr]; + static auto fqfq = Op::GetAttrMap("FTVMQuantizeFakeQuantization"); + for (auto node : subgraph_) { + if (!fqfq.count(Downcast(node.as()->op))) { + // Only modify the subgraph if we have translation + // rules for every op + return expr; + } + } + return Mutate(expr); + } + + protected: + Expr VisitExpr_(const CallNode* call_node) { + Expr out; + + static auto fqfq = Op::GetAttrMap("FTVMQuantizeFakeQuantization"); + Op op = Downcast(call_node->op); + if (fqfq.count(op)) { + Expr expr; + if (op == dequantize_op_) { + expr = GetRef(call_node); + } else { + expr = ExprMutator::VisitExpr_(call_node); + // Set the current op to the output type, useful if we can't deduce output parameters + // from input parameters + affine_types_.Set(expr, out_type_); + } + // Call the rewrite + Array vals = fqfq[op](expr, affine_types_); + // Save teh outputs of the rewrite + ICHECK(vals.size() == 4) + << "got the wrong number of returned arguments from FTWMQuantizeFakeQuantization for " + << AsText(op, false); + out = Downcast(vals[0]); + affine_types_.Set(out, AffineType(Downcast(vals[1]), Downcast(vals[2]), + DataType(String2DLDataType(Downcast(vals[3]))))); + } else { + ICHECK(false) << "When rewriting a fake quantized graph, found an invalid node " + << AsText(GetRef(call_node), false); + } + return out; + } + ExprSet subgraph_; + AffineTypeMap affine_types_; + AffineType out_type_; + const Op quantize_op_ = Op::Get("qnn.quantize"); + const Op dequantize_op_ = Op::Get("qnn.dequantize"); +}; + +class FakeQuantizationRewriter : public MixedModeMutator { + protected: + Expr Rewrite_(const CallNode* pre, const Expr& post) override { + if (const CallNode* call_node = post.as()) { + if (call_node->op == quantize_op_) { + SubgraphExtractor extractor; + ExprSet subgraph = extractor.GetSubgraph(GetRef(pre)); + AffineTypeMap affine_types = extractor.GetAffineTypes(); + + ExprSet post_subgraph; + AffineTypeMap post_affine_types; + + for (auto kv : affine_types) { + if (pre == kv.first.as()) { + // we havent memoized the current op yet + post_affine_types.Set(post, kv.second); + } else { + post_affine_types.Set(memo_.at(kv.first), kv.second); + } + } + for (auto expr : subgraph) { + post_subgraph.insert(memo_[expr]); + } + Expr out = SubgraphMutator(post_subgraph, post_affine_types).MutateSubgraph(post); + return out; + } + } + return post; + } + const Op quantize_op_ = Op::Get("qnn.quantize"); +}; + +Expr QuantizeFakeQuantization(const Expr& expr, const IRModule& mod) { + return FakeQuantizationRewriter().Mutate(expr); +} + +namespace transform { + +Pass QuantizeFakeQuantization() { + runtime::TypedPackedFunc pass_func = + [=](Function f, IRModule m, PassContext pc) { + return Downcast(QuantizeFakeQuantization(f, m)); + }; + return CreateFunctionPass(pass_func, 0, "QuantizeFakeQuantization", {"InferType"}); +} + +TVM_REGISTER_GLOBAL("relay._transform.QuantizeFakeQuantization") + .set_body_typed(QuantizeFakeQuantization); + +} // namespace transform + +} // namespace relay +} // namespace tvm diff --git a/tests/python/relay/test_pass_quantize_fake_quantization.py b/tests/python/relay/test_pass_quantize_fake_quantization.py new file mode 100644 index 000000000000..96daaea7bcc7 --- /dev/null +++ b/tests/python/relay/test_pass_quantize_fake_quantization.py @@ -0,0 +1,280 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=unused-wildcard-import +import numpy as np +import pytest + +import tvm +from tvm import relay +from tvm.relay.dataflow_pattern import * + + +def test_fake_quantize_conv(): + for out_dtype in ["int8", "uint8"]: + x = relay.var("x", shape=[1, 3, 224, 224], dtype="int8") + w = relay.var("w", shape=[16, 3, 5, 5], dtype="int8") + one = relay.const(1.0) + zero = relay.const(0) + + op = relay.op.nn.conv2d( + relay.qnn.op.dequantize(x, relay.const(2.0), zero), + relay.qnn.op.dequantize(w, relay.const(0.5), zero), + ) + op = relay.qnn.op.quantize(op, one, zero, out_dtype=out_dtype) + + mod = tvm.IRModule.from_expr(op) + mod = tvm.relay.transform.InferType()(mod) + + x_np = np.random.randint(-128, 127, size=[1, 3, 224, 224], dtype="int8") + w_np = np.random.randint(-128, 127, size=[16, 3, 5, 5], dtype="int8") + + mod2 = tvm.relay.transform.QuantizeFakeQuantization()(mod) + assert not tvm.ir.structural_equal(mod, mod2) + mod2 = tvm.relay.transform.FoldConstant()(mod2) + + ex = relay.create_executor("vm", mod=mod, device=tvm.cpu(), target="llvm") + result = ex.evaluate()(x_np, w_np).asnumpy() + + ex = relay.create_executor("vm", mod=mod2, device=tvm.cpu(), target="llvm") + result2 = ex.evaluate()(x_np, w_np).asnumpy() + + assert np.array_equal(result, result2) + + +def test_fake_transpose_quantize_conv(): + x = relay.var("x", shape=[1, 224, 224, 3], dtype="int8") + w = relay.var("w", shape=[16, 3, 5, 5], dtype="int8") + one = relay.const(1.0) + zero = relay.const(0) + + x = relay.qnn.op.dequantize(x, relay.const(2.0), zero) + x = relay.transpose(x, [0, 3, 1, 2]) + op = relay.op.nn.conv2d(x, relay.qnn.op.dequantize(w, relay.const(0.5), zero)) + op = relay.qnn.op.quantize(op, one, zero) + + mod = tvm.IRModule.from_expr(op) + mod = tvm.relay.transform.InferType()(mod) + + x_np = np.random.randint(-128, 127, size=[1, 224, 224, 3], dtype="int8") + w_np = np.random.randint(-128, 127, size=[16, 3, 5, 5], dtype="int8") + + mod2 = tvm.relay.transform.QuantizeFakeQuantization()(mod) + assert not tvm.ir.structural_equal(mod, mod2) + mod2 = tvm.relay.transform.FoldConstant()(mod2) + + ex = relay.create_executor("vm", mod=mod, device=tvm.cpu(), target="llvm") + result = ex.evaluate()(x_np, w_np).asnumpy() + + ex = relay.create_executor("vm", mod=mod2, device=tvm.cpu(), target="llvm") + result2 = ex.evaluate()(x_np, w_np).asnumpy() + + assert np.array_equal(result, result2) + + +def test_fake_transpose_quantize_conv_bias_add(): + x = relay.var("x", shape=[1, 224, 224, 3], dtype="int8") + w = relay.var("w", shape=[16, 3, 5, 5], dtype="int8") + bias = relay.var("bias", shape=[16], dtype="int32") + one = relay.const(1.0) + zero = relay.const(0) + + x = relay.qnn.op.dequantize(x, relay.const(2.0), zero) + x = relay.transpose(x, [0, 3, 1, 2]) + op = relay.op.nn.conv2d(x, relay.qnn.op.dequantize(w, relay.const(0.5), zero)) + op = relay.op.nn.bias_add(op, relay.qnn.op.dequantize(bias, one, zero)) + op = relay.qnn.op.quantize(op, one, zero) + + mod = tvm.IRModule.from_expr(op) + mod = tvm.relay.transform.InferType()(mod) + + x_np = np.random.randint(-128, 127, size=[1, 224, 224, 3], dtype="int8") + w_np = np.random.randint(-128, 127, size=[16, 3, 5, 5], dtype="int8") + bias_np = np.random.randint(-32768, 32767, size=[16], dtype="int32") + + mod2 = tvm.relay.transform.QuantizeFakeQuantization()(mod) + assert not tvm.ir.structural_equal(mod, mod2) + mod2 = tvm.relay.transform.FoldConstant()(mod2) + + ex = relay.create_executor("vm", mod=mod, device=tvm.cpu(), target="llvm") + result = ex.evaluate()(x_np, w_np, bias_np).asnumpy() + + ex = relay.create_executor("vm", mod=mod2, device=tvm.cpu(), target="llvm") + result2 = ex.evaluate()(x_np, w_np, bias_np).asnumpy() + + assert np.array_equal(result, result2) + + +def test_fake_quantize_maxpool(): + x = relay.var("x", shape=[1, 3, 224, 224], dtype="int8") + + zero = relay.const(0) + x = relay.qnn.op.dequantize(x, relay.const(2.0), zero) + op = relay.op.nn.max_pool2d(x, [3, 3]) + op = relay.qnn.op.quantize(op, relay.const(2.0), zero) + + mod = tvm.IRModule.from_expr(op) + mod = tvm.relay.transform.InferType()(mod) + + x_np = np.random.randint(-128, 127, size=[1, 3, 224, 224], dtype="int8") + + mod2 = tvm.relay.transform.QuantizeFakeQuantization()(mod) + assert not tvm.ir.structural_equal(mod, mod2) + mod2 = tvm.relay.transform.FoldConstant()(mod2) + + ex = relay.create_executor("vm", mod=mod, device=tvm.cpu(), target="llvm") + result = ex.evaluate()(x_np).asnumpy() + + ex = relay.create_executor("vm", mod=mod2, device=tvm.cpu(), target="llvm") + result2 = ex.evaluate()(x_np).asnumpy() + + assert np.array_equal(result, result2) + + +def test_fake_quantize_avgpool(): + x = relay.var("x", shape=[1, 3, 224, 224], dtype="int8") + + zero = relay.const(0) + x = relay.qnn.op.dequantize(x, relay.const(2.0), zero) + op = relay.op.nn.avg_pool2d(x, [3, 3]) + op = relay.qnn.op.quantize(op, relay.const(2.0), zero) + + mod = tvm.IRModule.from_expr(op) + mod = tvm.relay.transform.InferType()(mod) + + x_np = np.random.randint(-128, 127, size=[1, 3, 224, 224], dtype="int8") + + mod2 = tvm.relay.transform.QuantizeFakeQuantization()(mod) + assert not tvm.ir.structural_equal(mod, mod2) + mod2 = tvm.relay.transform.FoldConstant()(mod2) + + ex = relay.create_executor("vm", mod=mod, device=tvm.cpu(), target="llvm") + result = ex.evaluate()(x_np).asnumpy() + + ex = relay.create_executor("vm", mod=mod2, device=tvm.cpu(), target="llvm") + result2 = ex.evaluate()(x_np).asnumpy() + + assert np.all(np.abs(result - result2) <= 1) + + +def test_fake_quantize_reshape(): + x = relay.var("x", shape=[1, 3, 224, 224], dtype="int8") + + zero = relay.const(0) + x = relay.qnn.op.dequantize(x, relay.const(2.0), zero) + op = relay.op.reshape(x, [1, 3, -1]) + op = relay.qnn.op.quantize(op, relay.const(2.0), zero) + + mod = tvm.IRModule.from_expr(op) + mod = tvm.relay.transform.InferType()(mod) + + x_np = np.random.randint(-128, 127, size=[1, 3, 224, 224], dtype="int8") + + mod2 = tvm.relay.transform.QuantizeFakeQuantization()(mod) + assert not tvm.ir.structural_equal(mod, mod2) + mod2 = tvm.relay.transform.FoldConstant()(mod2) + + ex = relay.create_executor("vm", mod=mod, device=tvm.cpu(), target="llvm") + result = ex.evaluate()(x_np).asnumpy() + + ex = relay.create_executor("vm", mod=mod2, device=tvm.cpu(), target="llvm") + result2 = ex.evaluate()(x_np).asnumpy() + + assert np.array_equal(result, result2) + + +def test_fake_quantize_transpose_reshape(): + x = relay.var("x", shape=[1, 3, 224, 224], dtype="int8") + + zero = relay.const(0) + x = relay.qnn.op.dequantize(x, relay.const(2.0), zero) + op = relay.op.transpose(x, [1, 0, 2, 3]) + op = relay.op.reshape(op, [3, -1]) + op = relay.qnn.op.quantize(op, relay.const(2.0), zero) + + mod = tvm.IRModule.from_expr(op) + mod = tvm.relay.transform.InferType()(mod) + + x_np = np.random.randint(-128, 127, size=[1, 3, 224, 224], dtype="int8") + + mod2 = tvm.relay.transform.QuantizeFakeQuantization()(mod) + assert not tvm.ir.structural_equal(mod, mod2) + mod2 = tvm.relay.transform.FoldConstant()(mod2) + + ex = relay.create_executor("vm", mod=mod, device=tvm.cpu(), target="llvm") + result = ex.evaluate()(x_np).asnumpy() + + ex = relay.create_executor("vm", mod=mod2, device=tvm.cpu(), target="llvm") + result2 = ex.evaluate()(x_np).asnumpy() + + assert np.array_equal(result, result2) + + +def test_fake_quantize_concat(): + zero = relay.const(0) + inputs = [] + for i in range(4): + inputs.append( + relay.qnn.op.dequantize( + relay.var("x%d" % i, shape=[1, 4], dtype="int8"), relay.const(i + 0.5), zero + ) + ) + concat = relay.op.concatenate(inputs, axis=1) + out = relay.qnn.op.quantize(concat, relay.const(3.5), zero) + + mod = tvm.IRModule.from_expr(out) + mod = tvm.relay.transform.InferType()(mod) + + inputs_np = [] + for i in range(4): + inputs_np.append(np.random.randint(-128, 127, size=[1, 4], dtype="int8")) + + mod2 = tvm.relay.transform.QuantizeFakeQuantization()(mod) + assert not tvm.ir.structural_equal(mod, mod2) + mod2 = tvm.relay.transform.FoldConstant()(mod2) + + ex = relay.create_executor("vm", mod=mod, device=tvm.cpu(), target="llvm") + result = ex.evaluate()(*inputs_np).asnumpy() + + ex = relay.create_executor("vm", mod=mod2, device=tvm.cpu(), target="llvm") + result2 = ex.evaluate()(*inputs_np).asnumpy() + + assert np.array_equal(result, result2) + + +def test_fake_quantize_clip(): + x = relay.var("x", shape=[1, 3, 224, 224], dtype="uint8") + + x = relay.qnn.op.dequantize(x, relay.const(2.0), relay.const(114)) + op = relay.op.clip(x, 0, 6) + op = relay.qnn.op.quantize(op, relay.const(2.0), relay.const(114), out_dtype="uint8") + + mod = tvm.IRModule.from_expr(op) + mod = tvm.relay.transform.InferType()(mod) + + x_np = np.random.randint(0, 255, size=[1, 3, 224, 224], dtype="uint8") + + mod2 = tvm.relay.transform.QuantizeFakeQuantization()(mod) + assert not tvm.ir.structural_equal(mod, mod2) + mod2 = tvm.relay.transform.FoldConstant()(mod2) + + ex = relay.create_executor("vm", mod=mod, device=tvm.cpu(), target="llvm") + result = ex.evaluate()(x_np).asnumpy() + + ex = relay.create_executor("vm", mod=mod2, device=tvm.cpu(), target="llvm") + result2 = ex.evaluate()(x_np).asnumpy() + + assert np.array_equal(result, result2) From b7ee25d336d0cbd56ac513d07ee930185c22a8ab Mon Sep 17 00:00:00 2001 From: Matthew Date: Tue, 25 May 2021 10:56:50 -0600 Subject: [PATCH 2/6] fix pylint --- python/tvm/relay/transform/transform.py | 1 + 1 file changed, 1 insertion(+) diff --git a/python/tvm/relay/transform/transform.py b/python/tvm/relay/transform/transform.py index 5659e61a90fb..3647dd79cb62 100644 --- a/python/tvm/relay/transform/transform.py +++ b/python/tvm/relay/transform/transform.py @@ -1174,6 +1174,7 @@ def AnnotateSpans(): def QuantizeFakeQuantization(): + # pylint: disable=anomalous-backslash-in-string """ Find regions of the graph of the form From 02fd4dcab849b4bfa61cd93af8877320fabaf6f2 Mon Sep 17 00:00:00 2001 From: Matthew Date: Tue, 25 May 2021 14:53:31 -0600 Subject: [PATCH 3/6] fix typos --- src/relay/transforms/quantize_fake_quantization.cc | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/src/relay/transforms/quantize_fake_quantization.cc b/src/relay/transforms/quantize_fake_quantization.cc index 1c6bbb119bff..e4024f7610b5 100644 --- a/src/relay/transforms/quantize_fake_quantization.cc +++ b/src/relay/transforms/quantize_fake_quantization.cc @@ -18,8 +18,9 @@ */ /*! - * \file src/relay/transforms/simplify_expr.cc - * \brief A pass for simplifying the Relay expression. + * \file src/relay/transforms/quantize_fake_quantization.cc + * \brief A pass for taking fake quantized graphs and converting them + * to actual integer operations. */ #include @@ -52,7 +53,7 @@ * quantize for subtraphs bounded by dequantize operations. This pass extracts the affine * types of the inputs for later processing * - * The third pass is an ExprMutator the recursively rewrites the subgraphs using packed funcs + * The third pass is an ExprMutator that recursively rewrites the subgraphs using packed funcs * registered with the FTVMQuantizeFakeQuantization attribute. These packed funcs rewrite * the ops based on the affine types of their inputs and then return the affine types of the * new rewriten ops to pass that information down the stack during rewrite. From 5dc1afd50f5db96b86a6b5d64c7e90605400a09f Mon Sep 17 00:00:00 2001 From: Matthew Date: Tue, 1 Jun 2021 10:09:07 -0600 Subject: [PATCH 4/6] use an identify function for some ops --- .../transform/quantize_fake_quantization.py | 32 ++++++------------- 1 file changed, 10 insertions(+), 22 deletions(-) diff --git a/python/tvm/relay/transform/quantize_fake_quantization.py b/python/tvm/relay/transform/quantize_fake_quantization.py index 8ca578264302..d5407d353ac3 100644 --- a/python/tvm/relay/transform/quantize_fake_quantization.py +++ b/python/tvm/relay/transform/quantize_fake_quantization.py @@ -57,31 +57,19 @@ def quantize_qfq(expr, type_map): return [out, expr.args[1], expr.args[2], expr.attrs.out_dtype] -@register_quantize_fake_quantization("reshape") -def reshape_qfq(expr, type_map): - """Rewrite a reshape op""" - arg = expr.args[0] - t = type_map[arg] - out = relay.op.reshape(arg, **expr.attrs) - return [out, t.scale, t.zero_point, t.dtype] +def register_qfq_identity(op_name, op): + def identity(expr, type_map): + arg = expr.args[0] + t = type_map[arg] + out = op(arg, **expr.attrs) + return [out, t.scale, t.zero_point, t.dtype] - -@register_quantize_fake_quantization("transpose") -def transpose_qfq(expr, type_map): - """Rewrite a transpose op""" - arg = expr.args[0] - t = type_map[arg] - out = relay.op.transpose(arg, **expr.attrs) - return [out, t.scale, t.zero_point, t.dtype] + return register_quantize_fake_quantization(op_name, identity) -@register_quantize_fake_quantization("nn.max_pool2d") -def maxpool_qfq(expr, type_map): - """Rewrite a maxpool op""" - arg = expr.args[0] - t = type_map[arg] - out = relay.op.nn.max_pool2d(arg, **expr.attrs) - return [out, t.scale, t.zero_point, t.dtype] +register_qfq_identity("reshape", relay.op.reshape) +register_qfq_identity("transpose", relay.op.transpose) +register_qfq_identity("nn.max_pool2d", relay.op.nn.max_pool2d) @register_quantize_fake_quantization("nn.avg_pool2d") From 734a0d8f4ff791fdfba99e71309ae0a70ff4c390 Mon Sep 17 00:00:00 2001 From: Matthew Date: Fri, 4 Jun 2021 09:18:15 -0600 Subject: [PATCH 5/6] rename the pass from quantize_fake_quantization to fake_quantization_to_integer --- python/tvm/relay/op/__init__.py | 2 +- python/tvm/relay/op/op.py | 10 ++--- python/tvm/relay/transform/__init__.py | 2 +- ...ion.py => fake_quantization_to_integer.py} | 41 ++++++++++--------- python/tvm/relay/transform/transform.py | 6 +-- ...ion.cc => fake_quantization_to_integer.cc} | 26 ++++++------ ...test_pass_fake_quantization_to_integer.py} | 19 ++++----- 7 files changed, 54 insertions(+), 52 deletions(-) rename python/tvm/relay/transform/{quantize_fake_quantization.py => fake_quantization_to_integer.py} (82%) rename src/relay/transforms/{quantize_fake_quantization.cc => fake_quantization_to_integer.cc} (92%) rename tests/python/relay/{test_pass_quantize_fake_quantization.py => test_pass_fake_quantization_to_integer.py} (94%) diff --git a/python/tvm/relay/op/__init__.py b/python/tvm/relay/op/__init__.py index 610604604691..4c693fe64ee0 100644 --- a/python/tvm/relay/op/__init__.py +++ b/python/tvm/relay/op/__init__.py @@ -28,7 +28,7 @@ OpStrategy, debug, register_external_compiler, - register_quantize_fake_quantization, + register_fake_quantization_to_integer, ) from . import strategy diff --git a/python/tvm/relay/op/op.py b/python/tvm/relay/op/op.py index d7a4d37e4a36..9a158b2027f3 100644 --- a/python/tvm/relay/op/op.py +++ b/python/tvm/relay/op/op.py @@ -436,24 +436,24 @@ def register_external_compiler(op_name, fexternal=None, level=10): return tvm.ir.register_op_attr(op_name, "FTVMExternalCompiler", fexternal, level) -def register_quantize_fake_quantization(op_name, qfq=None, level=10): +def register_fake_quantization_to_integer(op_name, func=None, level=10): """Register quantize function for an op Given an op and Affine Types on it's inputs, this function should return the op - in affine space and the new type of the output + in affine space/integer operators and the new type of the output Parameters ---------- op_name : str The name of the operator - qfq: function (expr: Expr, map: Map) -> new_expr: Expr - The function for translating the op into affine space + func: function (expr: Expr, map: Map) -> new_expr: Expr + The function for translating the op into affine space and integer operators level : int The priority level """ - return tvm.ir.register_op_attr(op_name, "FTVMQuantizeFakeQuantization", qfq, level) + return tvm.ir.register_op_attr(op_name, "FTVMFakeQuantizationToInteger", func, level) @tvm._ffi.register_func("relay.op.compiler._lower") diff --git a/python/tvm/relay/transform/__init__.py b/python/tvm/relay/transform/__init__.py index ffbbf359c8dc..9ed40f85c3bc 100644 --- a/python/tvm/relay/transform/__init__.py +++ b/python/tvm/relay/transform/__init__.py @@ -19,4 +19,4 @@ # transformation passes from .transform import * from .recast import recast -from . import quantize_fake_quantization +from . import fake_quantization_to_integer diff --git a/python/tvm/relay/transform/quantize_fake_quantization.py b/python/tvm/relay/transform/fake_quantization_to_integer.py similarity index 82% rename from python/tvm/relay/transform/quantize_fake_quantization.py rename to python/tvm/relay/transform/fake_quantization_to_integer.py index d5407d353ac3..5f4c53772eec 100644 --- a/python/tvm/relay/transform/quantize_fake_quantization.py +++ b/python/tvm/relay/transform/fake_quantization_to_integer.py @@ -17,7 +17,7 @@ """Relay functions for rewriting fake quantized ops.""" import tvm from tvm import relay -from ..op import register_quantize_fake_quantization +from ..op import register_fake_quantization_to_integer def fold_constant(expr): @@ -26,16 +26,16 @@ def fold_constant(expr): return mod["main"].body -@register_quantize_fake_quantization("qnn.dequantize") -def dequantize_qfq(expr, type_map): +@register_fake_quantization_to_integer("qnn.dequantize") +def dequantize(expr, type_map): """Remove dequantize op""" out = expr.args[0] t = type_map[expr] return [out, t.scale, t.zero_point, t.dtype] -@register_quantize_fake_quantization("qnn.quantize") -def quantize_qfq(expr, type_map): +@register_fake_quantization_to_integer("qnn.quantize") +def quantize(expr, type_map): """Turn a quantize op into requantize or remove it""" out = expr.args[0] t = type_map[out] @@ -57,23 +57,24 @@ def quantize_qfq(expr, type_map): return [out, expr.args[1], expr.args[2], expr.attrs.out_dtype] -def register_qfq_identity(op_name, op): +def register_unary_identity(op_name, op): def identity(expr, type_map): + assert len(expr.args) == 1 arg = expr.args[0] t = type_map[arg] out = op(arg, **expr.attrs) return [out, t.scale, t.zero_point, t.dtype] - return register_quantize_fake_quantization(op_name, identity) + return register_fake_quantization_to_integer(op_name, identity) -register_qfq_identity("reshape", relay.op.reshape) -register_qfq_identity("transpose", relay.op.transpose) -register_qfq_identity("nn.max_pool2d", relay.op.nn.max_pool2d) +register_unary_identity("reshape", relay.op.reshape) +register_unary_identity("transpose", relay.op.transpose) +register_unary_identity("nn.max_pool2d", relay.op.nn.max_pool2d) -@register_quantize_fake_quantization("nn.avg_pool2d") -def avgpool_qfq(expr, type_map): +@register_fake_quantization_to_integer("nn.avg_pool2d") +def avgpool2d(expr, type_map): """Rewrite a avgpool op""" arg = expr.args[0] t = type_map[arg] @@ -83,8 +84,8 @@ def avgpool_qfq(expr, type_map): return [out, t.scale, t.zero_point, t.dtype] -@register_quantize_fake_quantization("nn.bias_add") -def bias_add_qfq(expr, type_map): +@register_fake_quantization_to_integer("nn.bias_add") +def bias_add(expr, type_map): """Rewrite a bias_add op""" x, b = expr.args x_t = type_map[x] @@ -104,8 +105,8 @@ def bias_add_qfq(expr, type_map): return [out, x_t.scale, x_t.zero_point, x_t.dtype] -@register_quantize_fake_quantization("nn.conv2d") -def conv2d_qfq(expr, type_map): +@register_fake_quantization_to_integer("nn.conv2d") +def conv2d(expr, type_map): """Rewrite a conv2d op""" attrs = {**expr.attrs} attrs.pop("out_dtype") @@ -120,8 +121,8 @@ def conv2d_qfq(expr, type_map): return [out, conv_scale, conv_zp, out.attrs.out_dtype] -@register_quantize_fake_quantization("concatenate") -def concat_qfq(expr, type_map): +@register_fake_quantization_to_integer("concatenate") +def concat(expr, type_map): """Rewrite a concat op""" scales = [] zps = [] @@ -143,8 +144,8 @@ def concat_qfq(expr, type_map): return [out, out_type.scale, out_type.zero_point, out_type.dtype] -@register_quantize_fake_quantization("clip") -def clip_qfq(expr, type_map): +@register_fake_quantization_to_integer("clip") +def clip(expr, type_map): """Rewrite a clip op""" arg = expr.args[0] t = type_map[arg] diff --git a/python/tvm/relay/transform/transform.py b/python/tvm/relay/transform/transform.py index 3647dd79cb62..20e045abab6c 100644 --- a/python/tvm/relay/transform/transform.py +++ b/python/tvm/relay/transform/transform.py @@ -1173,7 +1173,7 @@ def AnnotateSpans(): return _ffi_api.AnnotateSpans() -def QuantizeFakeQuantization(): +def FakeQuantizationToInteger(): # pylint: disable=anomalous-backslash-in-string """ Find regions of the graph of the form @@ -1191,11 +1191,11 @@ def QuantizeFakeQuantization(): where q == qnn.quantize and dq = qnn.dequantize and rewrite them into integer versions of op1 and op2 - Rules for rewriting indivdual ops are in quantize_fake_quantization.py + Rules for rewriting indivdual ops are in fake_quantization_to_integer.py Returns ------- ret : tvm.transform.Pass The registered SimplifyExpr pass. """ - return _ffi_api.QuantizeFakeQuantization() + return _ffi_api.FakeQuantizationToInteger() diff --git a/src/relay/transforms/quantize_fake_quantization.cc b/src/relay/transforms/fake_quantization_to_integer.cc similarity index 92% rename from src/relay/transforms/quantize_fake_quantization.cc rename to src/relay/transforms/fake_quantization_to_integer.cc index e4024f7610b5..329800125601 100644 --- a/src/relay/transforms/quantize_fake_quantization.cc +++ b/src/relay/transforms/fake_quantization_to_integer.cc @@ -27,7 +27,7 @@ #include #include -/* Description of QuantizeFakeQuantization +/* Description of FakeQuantizationToInteger * * The purpose of this pass is to find regions of the graph that follow * the general pattern: @@ -54,7 +54,7 @@ * types of the inputs for later processing * * The third pass is an ExprMutator that recursively rewrites the subgraphs using packed funcs - * registered with the FTVMQuantizeFakeQuantization attribute. These packed funcs rewrite + * registered with the FTVMFakeQuantizationToInteger attribute. These packed funcs rewrite * the ops based on the affine types of their inputs and then return the affine types of the * new rewriten ops to pass that information down the stack during rewrite. * @@ -125,7 +125,7 @@ using ExprSet = std::unordered_set; using ExprMap = std::unordered_map; using AffineTypeMap = Map; -using FTVMQuantizeFakeQuantization = +using FTVMFakeQuantizationToInteger = runtime::TypedPackedFunc(const Expr& expr, const AffineTypeMap& map)>; class SubgraphExtractor : public ExprVisitor { @@ -193,7 +193,8 @@ class SubgraphMutator : public ExprMutator { ICHECK(quantize_node); ICHECK(quantize_node->op == quantize_op_); out_type_ = affine_types_[expr]; - static auto fqfq = Op::GetAttrMap("FTVMQuantizeFakeQuantization"); + static auto fqfq = + Op::GetAttrMap("FTVMFakeQuantizationToInteger"); for (auto node : subgraph_) { if (!fqfq.count(Downcast(node.as()->op))) { // Only modify the subgraph if we have translation @@ -208,7 +209,8 @@ class SubgraphMutator : public ExprMutator { Expr VisitExpr_(const CallNode* call_node) { Expr out; - static auto fqfq = Op::GetAttrMap("FTVMQuantizeFakeQuantization"); + static auto fqfq = + Op::GetAttrMap("FTVMFakeQuantizationToInteger"); Op op = Downcast(call_node->op); if (fqfq.count(op)) { Expr expr; @@ -224,7 +226,7 @@ class SubgraphMutator : public ExprMutator { Array vals = fqfq[op](expr, affine_types_); // Save teh outputs of the rewrite ICHECK(vals.size() == 4) - << "got the wrong number of returned arguments from FTWMQuantizeFakeQuantization for " + << "got the wrong number of returned arguments from FTVMFakeQuantizationToInteger for " << AsText(op, false); out = Downcast(vals[0]); affine_types_.Set(out, AffineType(Downcast(vals[1]), Downcast(vals[2]), @@ -274,22 +276,22 @@ class FakeQuantizationRewriter : public MixedModeMutator { const Op quantize_op_ = Op::Get("qnn.quantize"); }; -Expr QuantizeFakeQuantization(const Expr& expr, const IRModule& mod) { +Expr FakeQuantizationToInteger(const Expr& expr, const IRModule& mod) { return FakeQuantizationRewriter().Mutate(expr); } namespace transform { -Pass QuantizeFakeQuantization() { +Pass FakeQuantizationToInteger() { runtime::TypedPackedFunc pass_func = [=](Function f, IRModule m, PassContext pc) { - return Downcast(QuantizeFakeQuantization(f, m)); + return Downcast(FakeQuantizationToInteger(f, m)); }; - return CreateFunctionPass(pass_func, 0, "QuantizeFakeQuantization", {"InferType"}); + return CreateFunctionPass(pass_func, 0, "FakeQuantizationToInteger", {"InferType"}); } -TVM_REGISTER_GLOBAL("relay._transform.QuantizeFakeQuantization") - .set_body_typed(QuantizeFakeQuantization); +TVM_REGISTER_GLOBAL("relay._transform.FakeQuantizationToInteger") + .set_body_typed(FakeQuantizationToInteger); } // namespace transform diff --git a/tests/python/relay/test_pass_quantize_fake_quantization.py b/tests/python/relay/test_pass_fake_quantization_to_integer.py similarity index 94% rename from tests/python/relay/test_pass_quantize_fake_quantization.py rename to tests/python/relay/test_pass_fake_quantization_to_integer.py index 96daaea7bcc7..3271379cf3ef 100644 --- a/tests/python/relay/test_pass_quantize_fake_quantization.py +++ b/tests/python/relay/test_pass_fake_quantization_to_integer.py @@ -20,7 +20,6 @@ import tvm from tvm import relay -from tvm.relay.dataflow_pattern import * def test_fake_quantize_conv(): @@ -42,7 +41,7 @@ def test_fake_quantize_conv(): x_np = np.random.randint(-128, 127, size=[1, 3, 224, 224], dtype="int8") w_np = np.random.randint(-128, 127, size=[16, 3, 5, 5], dtype="int8") - mod2 = tvm.relay.transform.QuantizeFakeQuantization()(mod) + mod2 = tvm.relay.transform.FakeQuantizationToInteger()(mod) assert not tvm.ir.structural_equal(mod, mod2) mod2 = tvm.relay.transform.FoldConstant()(mod2) @@ -72,7 +71,7 @@ def test_fake_transpose_quantize_conv(): x_np = np.random.randint(-128, 127, size=[1, 224, 224, 3], dtype="int8") w_np = np.random.randint(-128, 127, size=[16, 3, 5, 5], dtype="int8") - mod2 = tvm.relay.transform.QuantizeFakeQuantization()(mod) + mod2 = tvm.relay.transform.FakeQuantizationToInteger()(mod) assert not tvm.ir.structural_equal(mod, mod2) mod2 = tvm.relay.transform.FoldConstant()(mod2) @@ -105,7 +104,7 @@ def test_fake_transpose_quantize_conv_bias_add(): w_np = np.random.randint(-128, 127, size=[16, 3, 5, 5], dtype="int8") bias_np = np.random.randint(-32768, 32767, size=[16], dtype="int32") - mod2 = tvm.relay.transform.QuantizeFakeQuantization()(mod) + mod2 = tvm.relay.transform.FakeQuantizationToInteger()(mod) assert not tvm.ir.structural_equal(mod, mod2) mod2 = tvm.relay.transform.FoldConstant()(mod2) @@ -131,7 +130,7 @@ def test_fake_quantize_maxpool(): x_np = np.random.randint(-128, 127, size=[1, 3, 224, 224], dtype="int8") - mod2 = tvm.relay.transform.QuantizeFakeQuantization()(mod) + mod2 = tvm.relay.transform.FakeQuantizationToInteger()(mod) assert not tvm.ir.structural_equal(mod, mod2) mod2 = tvm.relay.transform.FoldConstant()(mod2) @@ -157,7 +156,7 @@ def test_fake_quantize_avgpool(): x_np = np.random.randint(-128, 127, size=[1, 3, 224, 224], dtype="int8") - mod2 = tvm.relay.transform.QuantizeFakeQuantization()(mod) + mod2 = tvm.relay.transform.FakeQuantizationToInteger()(mod) assert not tvm.ir.structural_equal(mod, mod2) mod2 = tvm.relay.transform.FoldConstant()(mod2) @@ -183,7 +182,7 @@ def test_fake_quantize_reshape(): x_np = np.random.randint(-128, 127, size=[1, 3, 224, 224], dtype="int8") - mod2 = tvm.relay.transform.QuantizeFakeQuantization()(mod) + mod2 = tvm.relay.transform.FakeQuantizationToInteger()(mod) assert not tvm.ir.structural_equal(mod, mod2) mod2 = tvm.relay.transform.FoldConstant()(mod2) @@ -210,7 +209,7 @@ def test_fake_quantize_transpose_reshape(): x_np = np.random.randint(-128, 127, size=[1, 3, 224, 224], dtype="int8") - mod2 = tvm.relay.transform.QuantizeFakeQuantization()(mod) + mod2 = tvm.relay.transform.FakeQuantizationToInteger()(mod) assert not tvm.ir.structural_equal(mod, mod2) mod2 = tvm.relay.transform.FoldConstant()(mod2) @@ -242,7 +241,7 @@ def test_fake_quantize_concat(): for i in range(4): inputs_np.append(np.random.randint(-128, 127, size=[1, 4], dtype="int8")) - mod2 = tvm.relay.transform.QuantizeFakeQuantization()(mod) + mod2 = tvm.relay.transform.FakeQuantizationToInteger()(mod) assert not tvm.ir.structural_equal(mod, mod2) mod2 = tvm.relay.transform.FoldConstant()(mod2) @@ -267,7 +266,7 @@ def test_fake_quantize_clip(): x_np = np.random.randint(0, 255, size=[1, 3, 224, 224], dtype="uint8") - mod2 = tvm.relay.transform.QuantizeFakeQuantization()(mod) + mod2 = tvm.relay.transform.FakeQuantizationToInteger()(mod) assert not tvm.ir.structural_equal(mod, mod2) mod2 = tvm.relay.transform.FoldConstant()(mod2) From c150e4f5a5f8e83c03030bb03303428d22a2d22c Mon Sep 17 00:00:00 2001 From: Matthew Date: Mon, 7 Jun 2021 13:10:08 -0600 Subject: [PATCH 6/6] add definition for affine --- python/tvm/relay/op/op.py | 3 ++- src/relay/transforms/fake_quantization_to_integer.cc | 3 ++- 2 files changed, 4 insertions(+), 2 deletions(-) diff --git a/python/tvm/relay/op/op.py b/python/tvm/relay/op/op.py index 9a158b2027f3..ccf011819a97 100644 --- a/python/tvm/relay/op/op.py +++ b/python/tvm/relay/op/op.py @@ -440,7 +440,8 @@ def register_fake_quantization_to_integer(op_name, func=None, level=10): """Register quantize function for an op Given an op and Affine Types on it's inputs, this function should return the op - in affine space/integer operators and the new type of the output + in affine space/integer operators and the new type of the output, where affine + denotes the transformation x_real = (x_affine - zero_point) * scale Parameters ---------- diff --git a/src/relay/transforms/fake_quantization_to_integer.cc b/src/relay/transforms/fake_quantization_to_integer.cc index 329800125601..1a3c459967bc 100644 --- a/src/relay/transforms/fake_quantization_to_integer.cc +++ b/src/relay/transforms/fake_quantization_to_integer.cc @@ -51,7 +51,8 @@ * * The second pass is an ExprVisitor that recursively searches for subgraphs leading to the * quantize for subtraphs bounded by dequantize operations. This pass extracts the affine - * types of the inputs for later processing + * types of the inputs for later processing, where affine denotes the transformation + * x_real = (x_affine - zero_point) * scale * * The third pass is an ExprMutator that recursively rewrites the subgraphs using packed funcs * registered with the FTVMFakeQuantizationToInteger attribute. These packed funcs rewrite