diff --git a/python/tvm/relay/quantize/_annotate.py b/python/tvm/relay/quantize/_annotate.py index 329ba64aae00..b187387a56c2 100644 --- a/python/tvm/relay/quantize/_annotate.py +++ b/python/tvm/relay/quantize/_annotate.py @@ -284,6 +284,7 @@ def identity_rewrite(ref_call, new_args, ctx): return QAnnotateExpr(ret_expr, x_kind) +register_annotate_function("reshape", identity_rewrite) register_annotate_function("clip", identity_rewrite) register_annotate_function("nn.relu", identity_rewrite) register_annotate_function("strided_slice", identity_rewrite) diff --git a/python/tvm/relay/quantize/_partition.py b/python/tvm/relay/quantize/_partition.py index 6892e8612a94..563d28366874 100644 --- a/python/tvm/relay/quantize/_partition.py +++ b/python/tvm/relay/quantize/_partition.py @@ -82,7 +82,7 @@ def add_partition_generic(ref_call, new_args, ctx): # ... lhs = new_args[0].realize() rhs = new_args[1].realize() - return _forward_op(ref_call, [lhs, rhs]) + return QPartitionExpr(_forward_op(ref_call, [lhs, rhs])) if not lhs_cond and rhs_cond: # - introduced by residual connection in ResNet # ... @@ -130,6 +130,7 @@ def mul_partition_generic(ref_call, new_args, ctx): if lhs_cond: # introduced by bn: multiply(out, scale) + lhs = new_args[0].realize() return QPartitionExpr(_forward_op(ref_call, [lhs, rhs])) if not lhs_cond and not rhs_cond: @@ -155,3 +156,15 @@ def add_partition_function(ref_call, new_args, ctx): def multiply_partition_function(ref_call, new_args, ctx): """Rewrite function for ewise multiply for partition""" return mul_partition_generic(ref_call, new_args, ctx) + + +# add cast after the relu op to make it run on vta +@register_partition_function("nn.global_avg_pool2d") +def global_avg_pool2d_partition_function(ref_call, new_args, ctx): + cond, expr = partition_expr_check(new_args[0]) + if cond: + expr = new_args[0].realize() + else: + expr = QPartitionExpr(new_args[0]).realize() + + return _forward_op(ref_call, [expr]) diff --git a/src/relay/quantize/realize.cc b/src/relay/quantize/realize.cc index c96a1b063e98..8db72a3f2b32 100644 --- a/src/relay/quantize/realize.cc +++ b/src/relay/quantize/realize.cc @@ -309,7 +309,8 @@ float ChooseDomScale(const std::vector& nptrs) { /* \brief Unify the dom scale of arguments */ Array UnifyDTypeScale(const Array& ref_args, const Array& args, - DataType* dtype_ptr, Expr* scale_ptr) { + DataType* dtype_ptr, Expr* scale_ptr, + DataType dtype = DataType::Void()) { static const Op& simulated_quantize = Op::Get("relay.op.annotation.simulated_quantize"); const QConfig& cfg = QConfig::Current(); @@ -324,13 +325,15 @@ Array UnifyDTypeScale(const Array& ref_args, const Array& args // unify the data type ICHECK_EQ(ref_args.size(), args.size()); - DataType dtype; - if (ret.size() == 2 && nptrs[1]->dtype == cfg->dtype_input) { - dtype = cfg->dtype_input; - } else { - dtype = cfg->dtype_activation; + if (dtype.is_void()) { + if (ret.size() == 2 && nptrs[1]->dtype == cfg->dtype_input) { + dtype = cfg->dtype_input; + } else { + dtype = cfg->dtype_activation; + } } + for (size_t i = 0; i < ret.size(); ++i) { auto ref_arg = ref_args[i].as(); if (nptrs[i]->dtype != dtype) { @@ -361,7 +364,16 @@ Expr AddRealize(const Call& ref_call, const Array& new_args, const ObjectR if (new_args[0].as() && new_args[1].as()) { DataType dtype; Expr dom_scale; - Array ret_args = UnifyDTypeScale(ref_call->args, new_args, &dtype, &dom_scale); + // execute the operation with activation data type. + const QConfig& cfg = QConfig::Current(); + Array ret_args = + UnifyDTypeScale(ref_call->args, new_args, &dtype, &dom_scale, cfg->dtype_activation); + for (size_t i = 0; i < ret_args.size(); ++i) { + // do not fuse float32 arg + if (new_args[i].as()->dtype == DataType::Float(32)) { + ret_args.Set(i, StopFusion(ret_args[i])); + } + } Expr ret = ForwardOp(ref_call, ret_args); return QRealizeIntExpr(ret, dom_scale, dtype); } @@ -430,6 +442,8 @@ Expr IdentityRealize(const Call& ref_call, const Array& new_args, const Ob RELAY_REGISTER_OP("nn.relu").set_attr("FQRealizeRewrite", IdentityRealize); +RELAY_REGISTER_OP("reshape").set_attr("FQRealizeRewrite", IdentityRealize); + RELAY_REGISTER_OP("strided_slice").set_attr("FQRealizeRewrite", IdentityRealize); RELAY_REGISTER_OP("nn.batch_flatten")