From 0f65f7fffc3fc801cecc2f0fd5426a97ac10c6af Mon Sep 17 00:00:00 2001 From: Zhang Hao Date: Mon, 3 Aug 2020 10:15:23 +0800 Subject: [PATCH 1/2] quant support for alu-only op --- python/tvm/relay/quantize/_annotate.py | 1 + python/tvm/relay/quantize/_partition.py | 16 +++++++++++++- src/relay/quantize/realize.cc | 28 ++++++++++++++++++------- 3 files changed, 37 insertions(+), 8 deletions(-) diff --git a/python/tvm/relay/quantize/_annotate.py b/python/tvm/relay/quantize/_annotate.py index 329ba64aae00..b187387a56c2 100644 --- a/python/tvm/relay/quantize/_annotate.py +++ b/python/tvm/relay/quantize/_annotate.py @@ -284,6 +284,7 @@ def identity_rewrite(ref_call, new_args, ctx): return QAnnotateExpr(ret_expr, x_kind) +register_annotate_function("reshape", identity_rewrite) register_annotate_function("clip", identity_rewrite) register_annotate_function("nn.relu", identity_rewrite) register_annotate_function("strided_slice", identity_rewrite) diff --git a/python/tvm/relay/quantize/_partition.py b/python/tvm/relay/quantize/_partition.py index 6892e8612a94..a80ea6c9ad1e 100644 --- a/python/tvm/relay/quantize/_partition.py +++ b/python/tvm/relay/quantize/_partition.py @@ -82,7 +82,7 @@ def add_partition_generic(ref_call, new_args, ctx): # ... lhs = new_args[0].realize() rhs = new_args[1].realize() - return _forward_op(ref_call, [lhs, rhs]) + return QPartitionExpr(_forward_op(ref_call, [lhs, rhs])) if not lhs_cond and rhs_cond: # - introduced by residual connection in ResNet # ... @@ -130,6 +130,7 @@ def mul_partition_generic(ref_call, new_args, ctx): if lhs_cond: # introduced by bn: multiply(out, scale) + lhs = new_args[0].realize() return QPartitionExpr(_forward_op(ref_call, [lhs, rhs])) if not lhs_cond and not rhs_cond: @@ -155,3 +156,16 @@ def add_partition_function(ref_call, new_args, ctx): def multiply_partition_function(ref_call, new_args, ctx): """Rewrite function for ewise multiply for partition""" return mul_partition_generic(ref_call, new_args, ctx) + + +# add cast after the relu op to make it run on vta +@register_partition_function("nn.global_avg_pool2d") +def global_avg_pool2d_partition_function(ref_call, new_args, ctx): + cond, expr = partition_expr_check(new_args[0]) + if cond: + expr = new_args[0].realize() + return _forward_op(ref_call, [expr]) + else: + expr = QPartitionExpr(new_args[0]).realize() + return _forward_op(ref_call, [expr]) + return None diff --git a/src/relay/quantize/realize.cc b/src/relay/quantize/realize.cc index c96a1b063e98..52b31b18a5ff 100644 --- a/src/relay/quantize/realize.cc +++ b/src/relay/quantize/realize.cc @@ -309,7 +309,8 @@ float ChooseDomScale(const std::vector& nptrs) { /* \brief Unify the dom scale of arguments */ Array UnifyDTypeScale(const Array& ref_args, const Array& args, - DataType* dtype_ptr, Expr* scale_ptr) { + DataType* dtype_ptr, Expr* scale_ptr, + DataType dtype = DataType::Void()) { static const Op& simulated_quantize = Op::Get("relay.op.annotation.simulated_quantize"); const QConfig& cfg = QConfig::Current(); @@ -324,13 +325,15 @@ Array UnifyDTypeScale(const Array& ref_args, const Array& args // unify the data type ICHECK_EQ(ref_args.size(), args.size()); - DataType dtype; - if (ret.size() == 2 && nptrs[1]->dtype == cfg->dtype_input) { - dtype = cfg->dtype_input; - } else { - dtype = cfg->dtype_activation; + if (dtype.is_void()) { + if (ret.size() == 2 && nptrs[1]->dtype == cfg->dtype_input) { + dtype = cfg->dtype_input; + } else { + dtype = cfg->dtype_activation; + } } + for (size_t i = 0; i < ret.size(); ++i) { auto ref_arg = ref_args[i].as(); if (nptrs[i]->dtype != dtype) { @@ -361,7 +364,16 @@ Expr AddRealize(const Call& ref_call, const Array& new_args, const ObjectR if (new_args[0].as() && new_args[1].as()) { DataType dtype; Expr dom_scale; - Array ret_args = UnifyDTypeScale(ref_call->args, new_args, &dtype, &dom_scale); + // execute the operation with activation data type. + const QConfig& cfg = QConfig::Current(); + Array ret_args = UnifyDTypeScale(ref_call->args, new_args, + &dtype, &dom_scale, cfg->dtype_activation); + for (size_t i = 0; i < ret_args.size(); ++i) { + // do not fuse float32 arg + if (new_args[i].as()->dtype == DataType::Float(32)) { + ret_args.Set(i, StopFusion(ret_args[i])); + } + } Expr ret = ForwardOp(ref_call, ret_args); return QRealizeIntExpr(ret, dom_scale, dtype); } @@ -430,6 +442,8 @@ Expr IdentityRealize(const Call& ref_call, const Array& new_args, const Ob RELAY_REGISTER_OP("nn.relu").set_attr("FQRealizeRewrite", IdentityRealize); +RELAY_REGISTER_OP("reshape").set_attr("FQRealizeRewrite", IdentityRealize); + RELAY_REGISTER_OP("strided_slice").set_attr("FQRealizeRewrite", IdentityRealize); RELAY_REGISTER_OP("nn.batch_flatten") From 099aca622c6d1e606d82529d1e74c86ec819bc10 Mon Sep 17 00:00:00 2001 From: Zhang Hao Date: Thu, 6 Aug 2020 10:26:33 +0800 Subject: [PATCH 2/2] fix lint --- python/tvm/relay/quantize/_partition.py | 5 ++--- src/relay/quantize/realize.cc | 4 ++-- 2 files changed, 4 insertions(+), 5 deletions(-) diff --git a/python/tvm/relay/quantize/_partition.py b/python/tvm/relay/quantize/_partition.py index a80ea6c9ad1e..563d28366874 100644 --- a/python/tvm/relay/quantize/_partition.py +++ b/python/tvm/relay/quantize/_partition.py @@ -164,8 +164,7 @@ def global_avg_pool2d_partition_function(ref_call, new_args, ctx): cond, expr = partition_expr_check(new_args[0]) if cond: expr = new_args[0].realize() - return _forward_op(ref_call, [expr]) else: expr = QPartitionExpr(new_args[0]).realize() - return _forward_op(ref_call, [expr]) - return None + + return _forward_op(ref_call, [expr]) diff --git a/src/relay/quantize/realize.cc b/src/relay/quantize/realize.cc index 52b31b18a5ff..8db72a3f2b32 100644 --- a/src/relay/quantize/realize.cc +++ b/src/relay/quantize/realize.cc @@ -366,8 +366,8 @@ Expr AddRealize(const Call& ref_call, const Array& new_args, const ObjectR Expr dom_scale; // execute the operation with activation data type. const QConfig& cfg = QConfig::Current(); - Array ret_args = UnifyDTypeScale(ref_call->args, new_args, - &dtype, &dom_scale, cfg->dtype_activation); + Array ret_args = + UnifyDTypeScale(ref_call->args, new_args, &dtype, &dom_scale, cfg->dtype_activation); for (size_t i = 0; i < ret_args.size(); ++i) { // do not fuse float32 arg if (new_args[i].as()->dtype == DataType::Float(32)) {