From 9140d33c2bf8f7ee3a4476714d01b189d459b8b6 Mon Sep 17 00:00:00 2001 From: Wuwei Lin Date: Fri, 12 Jul 2019 14:00:10 +0000 Subject: [PATCH 01/12] [Relay][Quantization] Support floating-point scale --- src/relay/pass/quantize.cc | 20 ++++++++------------ 1 file changed, 8 insertions(+), 12 deletions(-) diff --git a/src/relay/pass/quantize.cc b/src/relay/pass/quantize.cc index 83d9220ccf79..92df8d2a0dd5 100644 --- a/src/relay/pass/quantize.cc +++ b/src/relay/pass/quantize.cc @@ -166,7 +166,7 @@ inline Expr ForwardOp(const Call& ref_call, const Array& args) { /* calculate `data * s1 / s2`, use shift if possible */ -inline Expr MulAndDiv(Expr data, float s1, float s2) { +inline Expr MulAndDiv(Expr data, float s1, float s2, DataType dtype) { // here we assume the dtype of data is dtype activation const QConfig& cfg = QConfig::Current(); if (s1 == s2) return data; @@ -175,14 +175,14 @@ inline Expr MulAndDiv(Expr data, float s1, float s2) { float shift_factor = std::log2(factor); CHECK_GT(shift_factor, 0); if (static_cast(shift_factor) == shift_factor) { - return LeftShift(data, MakeConstantScalar(cfg->dtype_activation, + return LeftShift(data, MakeConstantScalar(dtype, static_cast(shift_factor))); } else if (static_cast(factor) == factor) { - return Multiply(data, MakeConstantScalar(cfg->dtype_activation, factor)); + return Multiply(data, MakeConstantScalar(dtype, factor)); } else { - LOG(FATAL) << "fall back to float computation"; data = Cast(data, Float(32)); - return Multiply(data, MakeConstantScalar(Float(32), factor)); + data = Multiply(data, MakeConstantScalar(Float(32), factor)); + return Cast(Round(data), dtype); } } @@ -338,15 +338,11 @@ Expr MulRealize(const Call& ref_call, Expr rdata = rhs->data; DataType dtype = cfg->dtype_activation; - if (lhs->dtype == Float(32)) { + if (lhs->dtype != dtype) { ldata = Cast(ldata, dtype); - } else { - CHECK_EQ(lhs->dtype, dtype); } - if (rhs->dtype == Float(32)) { + if (rhs->dtype != dtype) { rdata = Cast(rdata, dtype); - } else { - CHECK_EQ(rhs->dtype, dtype); } Expr ret = ForwardOp(ref_call, {ldata, rdata}); @@ -418,7 +414,7 @@ Array UnifyDTypeScale(const Array& ref_args, const Array& args Expr dom_scale = MakeConstantScalar(Float(32), s); for (size_t i = 0; i < ret.size(); ++i) { float cur_s = GetScalarFromConstant(nptrs[i]->dom_scale); - ret.Set(i, MulAndDiv(ret[i], cur_s, s)); + ret.Set(i, MulAndDiv(ret[i], cur_s, s, dtype)); } *dtype_ptr = dtype; From c9f46c4443d26d06fd175146d195171b01c1dfc6 Mon Sep 17 00:00:00 2001 From: Wuwei Lin Date: Fri, 12 Jul 2019 12:09:24 +0000 Subject: [PATCH 02/12] [Relay][Quantization] KL-divergence calibration on dataset --- python/tvm/relay/quantize/__init__.py | 1 + python/tvm/relay/quantize/_annotate.py | 2 +- python/tvm/relay/quantize/kl_divergence.py | 123 +++++++++++++++++++++ python/tvm/relay/quantize/quantize.py | 37 ++++++- src/relay/pass/quantize.cc | 42 +++++++ src/relay/pass/quantize.h | 5 + 6 files changed, 205 insertions(+), 5 deletions(-) create mode 100644 python/tvm/relay/quantize/kl_divergence.py diff --git a/python/tvm/relay/quantize/__init__.py b/python/tvm/relay/quantize/__init__.py index 45bb62e66853..a9e7b40b039e 100644 --- a/python/tvm/relay/quantize/__init__.py +++ b/python/tvm/relay/quantize/__init__.py @@ -20,3 +20,4 @@ from .quantize import * from ._annotate import register_annotate_function +from .kl_divergence import kl_divergence_scale diff --git a/python/tvm/relay/quantize/_annotate.py b/python/tvm/relay/quantize/_annotate.py index 7b7f9c42f2f1..a365ec9fad09 100644 --- a/python/tvm/relay/quantize/_annotate.py +++ b/python/tvm/relay/quantize/_annotate.py @@ -230,7 +230,7 @@ def multiply_rewrite(ref_call, new_args, ctx): if lhs_kind == QAnnotateKind.ACTIVATION: lhs_expr = attach_simulated_quantize(lhs_expr, QAnnotateKind.INPUT) # quantize rhs to WEIGHT field - rhs_expr = attach_simulated_quantize(rhs_expr, QAnnotateKind.WEIGHT) + rhs_expr = attach_simulated_quantize(rhs_expr, QAnnotateKind.BIAS) expr = _forward_op(ref_call, [lhs_expr, rhs_expr]) return QAnnotateExpr(expr, QAnnotateKind.ACTIVATION) diff --git a/python/tvm/relay/quantize/kl_divergence.py b/python/tvm/relay/quantize/kl_divergence.py new file mode 100644 index 000000000000..b4081a8bb7cc --- /dev/null +++ b/python/tvm/relay/quantize/kl_divergence.py @@ -0,0 +1,123 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +try: + from scipy import stats +except ImportError: + stats = None + +import numpy as np + + +def _smooth_distribution(p, eps=0.0001): + """Given a discrete distribution (may have not been normalized to 1), + smooth it by replacing zeros with eps multiplied by a scaling factor and taking the + corresponding amount off the non-zero values. + Ref: http://hanj.cs.illinois.edu/cs412/bk3/KL-divergence.pdf + """ + is_zeros = (p == 0).astype(np.float32) + is_nonzeros = (p != 0).astype(np.float32) + n_zeros = is_zeros.sum() + n_nonzeros = p.size - n_zeros + if not n_nonzeros: + raise ValueError('The discrete probability distribution is malformed. All entries are 0.') + eps1 = eps * float(n_zeros) / float(n_nonzeros) + assert eps1 < 1.0, 'n_zeros=%d, n_nonzeros=%d, eps1=%f' % (n_zeros, n_nonzeros, eps1) + hist = p.astype(np.float32) + hist += eps * is_zeros + (-eps1) * is_nonzeros + assert (hist <= 0).sum() == 0 + return hist + + +# pylint: disable=line-too-long +def kl_divergence_scale(arr, quantized_dtype='int8', num_bins=8001, num_quantized_bins=255): + """Given a dataset, find the optimal threshold for quantizing it. + The reference distribution is `q`, and the candidate distribution is `p`. + `q` is a truncated version of the original distribution. + + Ref: http://on-demand.gputechconf.com/gtc/2017/presentation/s7310-8-bit-inference-with-tensorrt.pdf + """ + assert isinstance(arr, np.ndarray) + + min_val = np.min(arr) + max_val = np.max(arr) + th = max(abs(min_val), abs(max_val)) + + if min_val >= 0 and quantized_dtype in ['uint8']: + # We need to move negative bins to positive bins to fit uint8 range. + num_quantized_bins = num_quantized_bins * 2 + 1 + + hist, hist_edges = np.histogram(arr, bins=num_bins, range=(-th, th)) + zero_bin_idx = num_bins // 2 + num_half_quantized_bins = num_quantized_bins // 2 + + thresholds = np.zeros(num_bins // 2 + 1 - num_quantized_bins // 2) + divergence = np.zeros_like(thresholds) + quantized_bins = np.zeros(num_quantized_bins, dtype=np.int32) + # i means the number of bins on half axis excluding the zero bin. + for i in range(num_quantized_bins // 2, + num_bins // 2 + 1): + p_bin_idx_start = zero_bin_idx - i + p_bin_idx_stop = zero_bin_idx + i + 1 + thresholds[i - num_half_quantized_bins] = hist_edges[p_bin_idx_stop] + sliced_nd_hist = hist[p_bin_idx_start:p_bin_idx_stop] + + # generate reference distribution p + p = sliced_nd_hist.copy() + assert p.size % 2 == 1 + assert p.size >= num_quantized_bins + # put left outlier count in p[0] + left_outlier_count = np.sum(hist[0:p_bin_idx_start]) + p[0] += left_outlier_count + # put right outlier count in p[-1] + right_outlier_count = np.sum(hist[p_bin_idx_stop:]) + p[-1] += right_outlier_count + # is_nonzeros[k] indicates whether hist[k] is nonzero + is_nonzeros = (p != 0).astype(np.int32) + + # calculate how many bins should be merged to generate quantized distribution q + num_merged_bins = sliced_nd_hist.size // num_quantized_bins + # merge hist into num_quantized_bins bins + for j in range(num_quantized_bins): + start = j * num_merged_bins + stop = start + num_merged_bins + quantized_bins[j] = sliced_nd_hist[start:stop].sum() + quantized_bins[-1] += sliced_nd_hist[num_quantized_bins * num_merged_bins:].sum() + # expand quantized_bins into p.size bins + q = np.zeros(sliced_nd_hist.size, dtype=np.float32) + for j in range(num_quantized_bins): + start = j * num_merged_bins + if j == num_quantized_bins - 1: + stop = len(is_nonzeros) + else: + stop = start + num_merged_bins + norm = is_nonzeros[start:stop].sum() + if norm != 0: + q[start:stop] = float(quantized_bins[j]) / float(norm) + q[p == 0] = 0 + p = _smooth_distribution(p) + # There is a chance that q is an invalid probability distribution. + try: + q = _smooth_distribution(q) + except ValueError: + divergence[i - num_half_quantized_bins] = float("inf") + divergence[i - num_half_quantized_bins] = stats.entropy(p, q) + + min_divergence_idx = np.argmin(divergence) + opt_th = thresholds[min_divergence_idx] + return opt_th +# pylint: enable=line-too-long diff --git a/python/tvm/relay/quantize/quantize.py b/python/tvm/relay/quantize/quantize.py index beebceaf8590..81ab2f9dc356 100644 --- a/python/tvm/relay/quantize/quantize.py +++ b/python/tvm/relay/quantize/quantize.py @@ -35,6 +35,7 @@ class QAnnotateKind(object): INPUT = 1 WEIGHT = 2 ACTIVATION = 3 + BIAS = 4 def kind2str(kind): @@ -43,6 +44,7 @@ def kind2str(kind): QAnnotateKind.INPUT: "input", QAnnotateKind.WEIGHT: "weight", QAnnotateKind.ACTIVATION: "activation", + QAnnotateKind.BIAS: "bias", } assert kind in str_map return str_map[kind] @@ -67,9 +69,11 @@ class QConfig(NodeBase): "nbit_input": 8, "nbit_weight": 8, "nbit_activation": 32, + "nbit_bias": 32, "dtype_input": "int8", "dtype_weight": "int8", "dtype_activation": "int32", + "dtype_bias": "int32", "global_scale": 8.0, "skip_conv_layers": [0], "round_for_shift": True, @@ -195,7 +199,11 @@ def annotate_context(): return AnnotateContext.Current -def calibrate(graph, mod=None, ctx=None): +def collect_stats(graph): + return _quantize.CollectStats(graph) + + +def calibrate(graph, mod=None, ctx=None, scales=None): """The calibrate procedure will try to calculate the content of dom_scale, nbit, clip_min, clip_max for every `simulated_quantize` operator. @@ -221,12 +229,22 @@ def power2_scale(arr): val = np.amax(np.abs(arr.asnumpy())) return 2**np.math.ceil(np.math.log(val, 2)) if val > 0 else 1.0 + def max_scale(arr): + val = np.amax(np.abs(arr.asnumpy())) + return val + + scale_idx = 0 + + #fcalib_weight = power2_scale + fcalib_weight = max_scale + cfg = current_qconfig() const_params = {} quantize_op = _op.get("relay.op.annotation.simulated_quantize") def visit_func(expr): """Internal visit function""" + nonlocal scale_idx if isinstance(expr, _expr.Call) and expr.op == quantize_op: _, ndom_scale, nclip_min, nclip_max = expr.args attrs = expr.attrs @@ -234,11 +252,18 @@ def visit_func(expr): nbit = cfg.get_nbit_by_kind(kind) valid_bit = nbit - attrs.sign - - if kind == QAnnotateKind.WEIGHT: + if kind in [QAnnotateKind.WEIGHT, QAnnotateKind.BIAS]: + if all([isinstance(arg, _expr.Constant) for arg in [ndom_scale, nclip_min, nclip_max]]): + return var = expr.args[0] assert isinstance(var, _expr.Constant) - scale = power2_scale(var.data) + scale = fcalib_weight(var.data) + print('weight scale: {}'.format(scale)) + elif scales is not None: + scale = scales[scale_idx] + scale_idx += 1 + print('{} / {} ...'.format(scale_idx, len(scales))) + print('act scale: {}'.format(scale)) else: scale = cfg.global_scale @@ -246,6 +271,10 @@ def _make_const(val): return _expr.const(val, 'float32') valid_range = 2**valid_bit + if kind == QAnnotateKind.BIAS: + # bias hack + valid_range = 2**15 + const_params[ndom_scale] = _make_const(scale / valid_range) const_params[nclip_min] = _make_const(- (valid_range - 1)) const_params[nclip_max] = _make_const((valid_range - 1)) diff --git a/src/relay/pass/quantize.cc b/src/relay/pass/quantize.cc index 92df8d2a0dd5..600f3e680391 100644 --- a/src/relay/pass/quantize.cc +++ b/src/relay/pass/quantize.cc @@ -629,6 +629,7 @@ TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable) p->stream << "nbit_input=" << op->nbit_input << ", "; p->stream << "nbit_weight=" << op->nbit_weight << ", "; p->stream << "nbit_activation=" << op->nbit_activation << ", "; + p->stream << "nbit_bias=" << op->nbit_bias << ", "; p->stream << "global_scale=" << op->global_scale << ", "; p->stream << "skip_conv_layers==" << op->skip_conv_layers << ", "; p->stream << "round_for_shift==" << op->round_for_shift << ", "; @@ -734,7 +735,48 @@ TVM_REGISTER_API("relay._quantize.temp_expr_realize") return n->Realize(); }); +// ============= +// calibration + +class StatsCollector : private ExprMutator { + public: + Expr Collect(const Expr& expr) { + auto new_e = this->Mutate(expr); + const FunctionNode* func = new_e.as(); + CHECK(func); + Expr new_body = TupleNode::make(std::move(profile_data_)); + return FunctionNode::make(FreeVars(new_body), new_body, NullValue(), func->type_params, + func->attrs); + } + + private: + Array profile_data_; + + Expr VisitExpr_(const CallNode* call) { + static const Op& simulated_quantize = Op::Get("relay.op.annotation.simulated_quantize"); + Expr new_e = ExprMutator::VisitExpr_(call); + const CallNode* new_call = new_e.as(); + CHECK(new_call); + if (new_call->op.same_as(simulated_quantize)) { + auto attrs = new_call->attrs.as(); + if (attrs->kind != QAnnotateKind::kQWeight && attrs->kind != QAnnotateKind::kQBias) { + CHECK(!new_call->args[0].as()); + profile_data_.push_back(new_call->args[0]); + } + return new_call->args[0]; + } else { + return new_e; + } + } +}; + +Expr CollectStats(const Expr& expr) { + return StatsCollector().Collect(expr); +} +TVM_REGISTER_API("relay._quantize.CollectStats") +.set_body_typed(CollectStats); + } // namespace quantize } // namespace relay } // namespace tvm diff --git a/src/relay/pass/quantize.h b/src/relay/pass/quantize.h index 262d420acf97..2d0d5c968d5c 100644 --- a/src/relay/pass/quantize.h +++ b/src/relay/pass/quantize.h @@ -40,6 +40,7 @@ enum QAnnotateKind : int { kQInput = 1, kQWeight = 2, kQActivation = 3, + kQBias = 4, }; /*! @@ -148,9 +149,11 @@ class QConfigNode : public Node { int nbit_input = 8; int nbit_weight = 8; int nbit_activation = 32; + int nbit_bias = 32; DataType dtype_input = Int(8); DataType dtype_weight = Int(8); DataType dtype_activation = Int(32); + DataType dtype_bias = Int(32); double global_scale = 8.0; Array skip_conv_layers = Array(NodePtr(nullptr)); bool round_for_shift = true; @@ -161,9 +164,11 @@ class QConfigNode : public Node { v->Visit("nbit_input", &nbit_input); v->Visit("nbit_weight", &nbit_weight); v->Visit("nbit_activation", &nbit_activation); + v->Visit("nbit_bias", &nbit_bias); v->Visit("dtype_input", &dtype_input); v->Visit("dtype_weight", &dtype_weight); v->Visit("dtype_activation", &dtype_activation); + v->Visit("dtype_bias", &dtype_bias); v->Visit("global_scale", &global_scale); v->Visit("skip_conv_layers", &skip_conv_layers); v->Visit("round_for_shift", &round_for_shift); From 0b9c5f74874e2df14bb6c695b6c28e2e3d6ca79a Mon Sep 17 00:00:00 2001 From: Wuwei Lin Date: Tue, 16 Jul 2019 08:27:18 +0000 Subject: [PATCH 03/12] Fix unhandled LeftShift case in QuantizeRealize --- src/relay/pass/quantize.cc | 61 +++++++++++++++++++++----------------- 1 file changed, 34 insertions(+), 27 deletions(-) diff --git a/src/relay/pass/quantize.cc b/src/relay/pass/quantize.cc index 600f3e680391..4934aa5b0243 100644 --- a/src/relay/pass/quantize.cc +++ b/src/relay/pass/quantize.cc @@ -216,15 +216,21 @@ Expr QuantizeRealize(const Call& ref_call, } float shift_nbit = std::log2(odom_scale_imm / idom_scale_imm); - CHECK_GT(shift_nbit, 0); + CHECK_NE(shift_nbit, 0); if (static_cast(shift_nbit) == shift_nbit) { - // use right shift - if (cfg->round_for_shift) { - float round_bias = std::pow(2.0, shift_nbit - 1); - data = Add(data, MakeConstantScalar(cfg->dtype_activation, static_cast(round_bias))); + if (shift_nbit > 0) { + // use right shift + if (cfg->round_for_shift) { + float round_bias = std::pow(2.0, shift_nbit - 1); + data = Add(data, MakeConstantScalar(cfg->dtype_activation, + static_cast(round_bias))); + } + data = RightShift(data, MakeConstantScalar(cfg->dtype_activation, + static_cast(shift_nbit))); + } else { + data = LeftShift(data, MakeConstantScalar(cfg->dtype_activation, + static_cast(shift_nbit))); } - data = RightShift(data, MakeConstantScalar(cfg->dtype_activation, - static_cast(shift_nbit))); data = Clip(data, clip_min_imm, clip_max_imm); return QRealizeIntExprNode::make(data, dom_scale, n->dtype); } else { @@ -741,12 +747,12 @@ TVM_REGISTER_API("relay._quantize.temp_expr_realize") class StatsCollector : private ExprMutator { public: Expr Collect(const Expr& expr) { - auto new_e = this->Mutate(expr); - const FunctionNode* func = new_e.as(); - CHECK(func); - Expr new_body = TupleNode::make(std::move(profile_data_)); - return FunctionNode::make(FreeVars(new_body), new_body, NullValue(), func->type_params, - func->attrs); + auto new_e = this->Mutate(expr); + const FunctionNode* func = new_e.as(); + CHECK(func); + Expr new_body = TupleNode::make(std::move(profile_data_)); + return FunctionNode::make(FreeVars(new_body), new_body, NullValue(), func->type_params, + func->attrs); } private: @@ -754,22 +760,23 @@ class StatsCollector : private ExprMutator { Expr VisitExpr_(const CallNode* call) { static const Op& simulated_quantize = Op::Get("relay.op.annotation.simulated_quantize"); - Expr new_e = ExprMutator::VisitExpr_(call); - const CallNode* new_call = new_e.as(); - CHECK(new_call); - if (new_call->op.same_as(simulated_quantize)) { - auto attrs = new_call->attrs.as(); - if (attrs->kind != QAnnotateKind::kQWeight && attrs->kind != QAnnotateKind::kQBias) { - CHECK(!new_call->args[0].as()); - profile_data_.push_back(new_call->args[0]); - } - return new_call->args[0]; - } else { - return new_e; - } + Expr new_e = ExprMutator::VisitExpr_(call); + const CallNode* new_call = new_e.as(); + CHECK(new_call); + if (new_call->op.same_as(simulated_quantize)) { + auto attrs = new_call->attrs.as(); + if (attrs->kind != QAnnotateKind::kQWeight && attrs->kind != QAnnotateKind::kQBias) { + CHECK(!new_call->args[0].as()); + const Expr& quantize_input = new_call->args[0]; // expression being quantized + profile_data_.push_back(quantize_input); + } + return new_call->args[0]; + } else { + return new_e; + } } }; - + Expr CollectStats(const Expr& expr) { return StatsCollector().Collect(expr); } From 20ec8554d8a5768ec678576787f5f80cc9598729 Mon Sep 17 00:00:00 2001 From: Wuwei Lin Date: Tue, 16 Jul 2019 08:27:49 +0000 Subject: [PATCH 04/12] Fix lint --- python/tvm/relay/quantize/kl_divergence.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/python/tvm/relay/quantize/kl_divergence.py b/python/tvm/relay/quantize/kl_divergence.py index b4081a8bb7cc..3444ccda5b30 100644 --- a/python/tvm/relay/quantize/kl_divergence.py +++ b/python/tvm/relay/quantize/kl_divergence.py @@ -14,6 +14,7 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. +"""Find optimal scale for quantization by minimizing KL-divergence""" try: from scipy import stats @@ -43,7 +44,7 @@ def _smooth_distribution(p, eps=0.0001): return hist -# pylint: disable=line-too-long +# pylint: disable=line-too-long,invalid-name def kl_divergence_scale(arr, quantized_dtype='int8', num_bins=8001, num_quantized_bins=255): """Given a dataset, find the optimal threshold for quantizing it. The reference distribution is `q`, and the candidate distribution is `p`. From 0e555185b483c12f35b731f7b54dfcc38ada9c74 Mon Sep 17 00:00:00 2001 From: Wuwei Lin Date: Tue, 16 Jul 2019 14:20:42 +0000 Subject: [PATCH 05/12] drop QBias --- python/tvm/relay/quantize/_annotate.py | 7 ++++--- python/tvm/relay/quantize/quantize.py | 29 ++++++++++---------------- src/relay/pass/quantize.cc | 3 +-- src/relay/pass/quantize.h | 5 ----- 4 files changed, 16 insertions(+), 28 deletions(-) diff --git a/python/tvm/relay/quantize/_annotate.py b/python/tvm/relay/quantize/_annotate.py index a365ec9fad09..9d81f58d3279 100644 --- a/python/tvm/relay/quantize/_annotate.py +++ b/python/tvm/relay/quantize/_annotate.py @@ -230,7 +230,7 @@ def multiply_rewrite(ref_call, new_args, ctx): if lhs_kind == QAnnotateKind.ACTIVATION: lhs_expr = attach_simulated_quantize(lhs_expr, QAnnotateKind.INPUT) # quantize rhs to WEIGHT field - rhs_expr = attach_simulated_quantize(rhs_expr, QAnnotateKind.BIAS) + rhs_expr = attach_simulated_quantize(rhs_expr, QAnnotateKind.WEIGHT) expr = _forward_op(ref_call, [lhs_expr, rhs_expr]) return QAnnotateExpr(expr, QAnnotateKind.ACTIVATION) @@ -251,7 +251,7 @@ def add_rewrite(ref_call, new_args, ctx): if lhs_kind is None and rhs_kind is not None: # quantize lhs to INPUT field if it is normal expression - assert rhs_kind == QAnnotateKind.INPUT + assert rhs_kind in [QAnnotateKind.INPUT, QAnnotateKind.ACTIVATION] lhs_expr = attach_simulated_quantize(lhs_expr, QAnnotateKind.INPUT) expr = _forward_op(ref_call, [lhs_expr, rhs_expr]) return QAnnotateExpr(expr, QAnnotateKind.INPUT) @@ -275,7 +275,8 @@ def add_rewrite(ref_call, new_args, ctx): rhs_expr = attach_simulated_quantize(rhs_expr, QAnnotateKind.INPUT) expr = _forward_op(ref_call, [lhs_expr, rhs_expr]) return QAnnotateExpr(expr, QAnnotateKind.ACTIVATION) - if lhs_kind == QAnnotateKind.ACTIVATION and rhs_kind == QAnnotateKind.INPUT: + if (lhs_kind == QAnnotateKind.ACTIVATION and rhs_kind == QAnnotateKind.INPUT) or \ + (lhs_kind == QAnnotateKind.INPUT and rhs_kind == QAnnotateKind.ACTIVATION): expr = _forward_op(ref_call, [lhs_expr, rhs_expr]) return QAnnotateExpr(expr, QAnnotateKind.ACTIVATION) raise ValueError() diff --git a/python/tvm/relay/quantize/quantize.py b/python/tvm/relay/quantize/quantize.py index 81ab2f9dc356..562e97a79768 100644 --- a/python/tvm/relay/quantize/quantize.py +++ b/python/tvm/relay/quantize/quantize.py @@ -35,7 +35,6 @@ class QAnnotateKind(object): INPUT = 1 WEIGHT = 2 ACTIVATION = 3 - BIAS = 4 def kind2str(kind): @@ -44,7 +43,6 @@ def kind2str(kind): QAnnotateKind.INPUT: "input", QAnnotateKind.WEIGHT: "weight", QAnnotateKind.ACTIVATION: "activation", - QAnnotateKind.BIAS: "bias", } assert kind in str_map return str_map[kind] @@ -69,11 +67,9 @@ class QConfig(NodeBase): "nbit_input": 8, "nbit_weight": 8, "nbit_activation": 32, - "nbit_bias": 32, "dtype_input": "int8", "dtype_weight": "int8", "dtype_activation": "int32", - "dtype_bias": "int32", "global_scale": 8.0, "skip_conv_layers": [0], "round_for_shift": True, @@ -203,7 +199,7 @@ def collect_stats(graph): return _quantize.CollectStats(graph) -def calibrate(graph, mod=None, ctx=None, scales=None): +def calibrate(graph, mod=None, ctx=None, weight_scales='power2', scales=None): """The calibrate procedure will try to calculate the content of dom_scale, nbit, clip_min, clip_max for every `simulated_quantize` operator. @@ -230,14 +226,12 @@ def power2_scale(arr): return 2**np.math.ceil(np.math.log(val, 2)) if val > 0 else 1.0 def max_scale(arr): + """calculate weight scale with maximum absolute value""" val = np.amax(np.abs(arr.asnumpy())) return val scale_idx = 0 - #fcalib_weight = power2_scale - fcalib_weight = max_scale - cfg = current_qconfig() const_params = {} quantize_op = _op.get("relay.op.annotation.simulated_quantize") @@ -252,18 +246,21 @@ def visit_func(expr): nbit = cfg.get_nbit_by_kind(kind) valid_bit = nbit - attrs.sign - if kind in [QAnnotateKind.WEIGHT, QAnnotateKind.BIAS]: - if all([isinstance(arg, _expr.Constant) for arg in [ndom_scale, nclip_min, nclip_max]]): + if kind in [QAnnotateKind.WEIGHT]: + if all([isinstance(arg, _expr.Constant) + for arg in [ndom_scale, nclip_min, nclip_max]]): return var = expr.args[0] assert isinstance(var, _expr.Constant) - scale = fcalib_weight(var.data) - print('weight scale: {}'.format(scale)) + if weight_scales == 'max': + scale = max_scale(var.data) + elif weight_scales == 'power2': + scale = power2_scale(var.data) + else: + raise ValueError('{} not supported'.format(weight_scales)) elif scales is not None: scale = scales[scale_idx] scale_idx += 1 - print('{} / {} ...'.format(scale_idx, len(scales))) - print('act scale: {}'.format(scale)) else: scale = cfg.global_scale @@ -271,10 +268,6 @@ def _make_const(val): return _expr.const(val, 'float32') valid_range = 2**valid_bit - if kind == QAnnotateKind.BIAS: - # bias hack - valid_range = 2**15 - const_params[ndom_scale] = _make_const(scale / valid_range) const_params[nclip_min] = _make_const(- (valid_range - 1)) const_params[nclip_max] = _make_const((valid_range - 1)) diff --git a/src/relay/pass/quantize.cc b/src/relay/pass/quantize.cc index 4934aa5b0243..efe79c1b2b2d 100644 --- a/src/relay/pass/quantize.cc +++ b/src/relay/pass/quantize.cc @@ -635,7 +635,6 @@ TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable) p->stream << "nbit_input=" << op->nbit_input << ", "; p->stream << "nbit_weight=" << op->nbit_weight << ", "; p->stream << "nbit_activation=" << op->nbit_activation << ", "; - p->stream << "nbit_bias=" << op->nbit_bias << ", "; p->stream << "global_scale=" << op->global_scale << ", "; p->stream << "skip_conv_layers==" << op->skip_conv_layers << ", "; p->stream << "round_for_shift==" << op->round_for_shift << ", "; @@ -765,7 +764,7 @@ class StatsCollector : private ExprMutator { CHECK(new_call); if (new_call->op.same_as(simulated_quantize)) { auto attrs = new_call->attrs.as(); - if (attrs->kind != QAnnotateKind::kQWeight && attrs->kind != QAnnotateKind::kQBias) { + if (attrs->kind != QAnnotateKind::kQWeight) { CHECK(!new_call->args[0].as()); const Expr& quantize_input = new_call->args[0]; // expression being quantized profile_data_.push_back(quantize_input); diff --git a/src/relay/pass/quantize.h b/src/relay/pass/quantize.h index 2d0d5c968d5c..262d420acf97 100644 --- a/src/relay/pass/quantize.h +++ b/src/relay/pass/quantize.h @@ -40,7 +40,6 @@ enum QAnnotateKind : int { kQInput = 1, kQWeight = 2, kQActivation = 3, - kQBias = 4, }; /*! @@ -149,11 +148,9 @@ class QConfigNode : public Node { int nbit_input = 8; int nbit_weight = 8; int nbit_activation = 32; - int nbit_bias = 32; DataType dtype_input = Int(8); DataType dtype_weight = Int(8); DataType dtype_activation = Int(32); - DataType dtype_bias = Int(32); double global_scale = 8.0; Array skip_conv_layers = Array(NodePtr(nullptr)); bool round_for_shift = true; @@ -164,11 +161,9 @@ class QConfigNode : public Node { v->Visit("nbit_input", &nbit_input); v->Visit("nbit_weight", &nbit_weight); v->Visit("nbit_activation", &nbit_activation); - v->Visit("nbit_bias", &nbit_bias); v->Visit("dtype_input", &dtype_input); v->Visit("dtype_weight", &dtype_weight); v->Visit("dtype_activation", &dtype_activation); - v->Visit("dtype_bias", &dtype_bias); v->Visit("global_scale", &global_scale); v->Visit("skip_conv_layers", &skip_conv_layers); v->Visit("round_for_shift", &round_for_shift); From 577387d808750063744d54770ee42d40c1949d41 Mon Sep 17 00:00:00 2001 From: Wuwei Lin Date: Tue, 16 Jul 2019 15:30:13 +0000 Subject: [PATCH 06/12] fix lint --- src/relay/pass/quantize.cc | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/src/relay/pass/quantize.cc b/src/relay/pass/quantize.cc index efe79c1b2b2d..81044a9addad 100644 --- a/src/relay/pass/quantize.cc +++ b/src/relay/pass/quantize.cc @@ -168,7 +168,6 @@ inline Expr ForwardOp(const Call& ref_call, const Array& args) { /* calculate `data * s1 / s2`, use shift if possible */ inline Expr MulAndDiv(Expr data, float s1, float s2, DataType dtype) { // here we assume the dtype of data is dtype activation - const QConfig& cfg = QConfig::Current(); if (s1 == s2) return data; float factor = s1 / s2; @@ -766,7 +765,7 @@ class StatsCollector : private ExprMutator { auto attrs = new_call->attrs.as(); if (attrs->kind != QAnnotateKind::kQWeight) { CHECK(!new_call->args[0].as()); - const Expr& quantize_input = new_call->args[0]; // expression being quantized + const Expr& quantize_input = new_call->args[0]; // expression being quantized profile_data_.push_back(quantize_input); } return new_call->args[0]; @@ -782,7 +781,7 @@ Expr CollectStats(const Expr& expr) { TVM_REGISTER_API("relay._quantize.CollectStats") .set_body_typed(CollectStats); - + } // namespace quantize } // namespace relay } // namespace tvm From 7ba8f30c71caa274353781d2bb6da71f0f45ae7f Mon Sep 17 00:00:00 2001 From: Wuwei Lin Date: Mon, 22 Jul 2019 06:40:13 +0000 Subject: [PATCH 07/12] address comments --- python/tvm/relay/quantize/quantize.py | 24 ++++++ src/relay/pass/quantize/calibration.cc | 89 +++++++++++++++++++++++ src/relay/pass/{ => quantize}/quantize.cc | 62 +--------------- src/relay/pass/{ => quantize}/quantize.h | 24 +++++- 4 files changed, 135 insertions(+), 64 deletions(-) create mode 100644 src/relay/pass/quantize/calibration.cc rename src/relay/pass/{ => quantize}/quantize.cc (92%) rename src/relay/pass/{ => quantize}/quantize.h (90%) diff --git a/python/tvm/relay/quantize/quantize.py b/python/tvm/relay/quantize/quantize.py index 562e97a79768..ac806f2ab5f0 100644 --- a/python/tvm/relay/quantize/quantize.py +++ b/python/tvm/relay/quantize/quantize.py @@ -196,6 +196,20 @@ def annotate_context(): def collect_stats(graph): + """Given an annotated graph, create a profile graph to collect profile data from the + calibration dataset. This pass finds simulated_quantize op and collects its input into a tuple. + The tuple is the output of the profile graph. + + Parameters + ---------- + graph: Function + The simulation graph after annotation. + + Returns + ------- + ret: Function + The profile graph which outputs a tuple of profile data. + """ return _quantize.CollectStats(graph) @@ -215,6 +229,16 @@ def calibrate(graph, mod=None, ctx=None, weight_scales='power2', scales=None): ctx: tvm.relay.PassContext The pass context used for calibration. + weight_scales: 'power2' or 'max'. + The way to calculate scales for weights (annotated with QAnnotateKind.WEIGHT). + power2: Find the maximum of the absolute value of the tensor, and then round up to power + of two. + max: Find the maximum of the absolute value of the tensor. + + scales: List[float] + Pre-calculated scales for input and activations. Length and the order of elements of the + scales list should match the output tuple of the profile graph created by collect_stats. + Returns ------- ret: Function diff --git a/src/relay/pass/quantize/calibration.cc b/src/relay/pass/quantize/calibration.cc new file mode 100644 index 000000000000..e3d0f5a3078f --- /dev/null +++ b/src/relay/pass/quantize/calibration.cc @@ -0,0 +1,89 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * Copyright (c) 2019 by Contributors + * + * \file calibration.cc + * + * \brief Create profile graph and calibrate on dataset + */ +#include +#include +#include "./quantize.h" + + +namespace tvm { +namespace relay { +namespace quantize { + +class StatsCollector : private ExprMutator { + public: + Expr Collect(const Expr& expr) { + auto new_e = this->Mutate(expr); + const FunctionNode* func = new_e.as(); + CHECK(func) << "Input shoule be Function"; + Expr new_body = TupleNode::make(std::move(profile_data_)); + return FunctionNode::make(FreeVars(new_body), new_body, NullValue(), func->type_params, + func->attrs); + } + + private: + Array profile_data_; + + Expr VisitExpr_(const CallNode* call) { + static const Op& simulated_quantize = Op::Get("relay.op.annotation.simulated_quantize"); + Expr new_e = ExprMutator::VisitExpr_(call); + const CallNode* new_call = new_e.as(); + CHECK(new_call); + if (new_call->op.same_as(simulated_quantize)) { + auto attrs = new_call->attrs.as(); + const Expr& quantize_input = new_call->args[0]; // expression being quantized + if (attrs->kind != QAnnotateKind::kQWeight) { + CHECK(!quantize_input.as()); + profile_data_.push_back(quantize_input); + } + return quantize_input; + } else { + return new_e; + } + } +}; + +/* + * \brief Given an annotated graph, create a profile graph to collect profile data from the + * calibration dataset. + * + * This pass finds simulated_quantize op and collects its input into a tuple. The tuple is the + * output of the profile graph. Both input and output of this pass + * are relay::Function. + * + * \param expr The simulation graph after annotation. + * \return The profile graph. + */ +Expr CollectStats(const Expr& expr) { + return StatsCollector().Collect(expr); +} + +TVM_REGISTER_API("relay._quantize.CollectStats") +.set_body_typed(CollectStats); + +} // namespace quantize +} // namespace relay +} // namespace tvm diff --git a/src/relay/pass/quantize.cc b/src/relay/pass/quantize/quantize.cc similarity index 92% rename from src/relay/pass/quantize.cc rename to src/relay/pass/quantize/quantize.cc index 81044a9addad..6cffc2053e5c 100644 --- a/src/relay/pass/quantize.cc +++ b/src/relay/pass/quantize/quantize.cc @@ -36,8 +36,8 @@ #include #include #include -#include "pattern_util.h" -#include "quantize.h" +#include "../pattern_util.h" +#include "./quantize.h" namespace tvm { @@ -46,22 +46,6 @@ namespace quantize { using namespace relay::transform; -/*! \brief Attribute for simulated quantize operator */ -struct SimulatedQuantizeAttrs : public tvm::AttrsNode { - int kind; - bool sign; - std::string rounding; - - TVM_DECLARE_ATTRS(SimulatedQuantizeAttrs, "relay.attrs.SimulatedQuantizeAttrs") { - TVM_ATTR_FIELD(kind) - .describe("kind of field, hint for nbit/dtype configuration."); - TVM_ATTR_FIELD(sign).set_default(true) - .describe("whether to use signed data type."); - TVM_ATTR_FIELD(rounding).set_default("round") - .describe("rounding mode. Can be 'floor', 'ceil', 'round'"); - } -}; - TVM_REGISTER_NODE_TYPE(SimulatedQuantizeAttrs); bool SimulatedQuantizeRel(const Array& types, @@ -739,48 +723,6 @@ TVM_REGISTER_API("relay._quantize.temp_expr_realize") return n->Realize(); }); -// ============= -// calibration - -class StatsCollector : private ExprMutator { - public: - Expr Collect(const Expr& expr) { - auto new_e = this->Mutate(expr); - const FunctionNode* func = new_e.as(); - CHECK(func); - Expr new_body = TupleNode::make(std::move(profile_data_)); - return FunctionNode::make(FreeVars(new_body), new_body, NullValue(), func->type_params, - func->attrs); - } - - private: - Array profile_data_; - - Expr VisitExpr_(const CallNode* call) { - static const Op& simulated_quantize = Op::Get("relay.op.annotation.simulated_quantize"); - Expr new_e = ExprMutator::VisitExpr_(call); - const CallNode* new_call = new_e.as(); - CHECK(new_call); - if (new_call->op.same_as(simulated_quantize)) { - auto attrs = new_call->attrs.as(); - if (attrs->kind != QAnnotateKind::kQWeight) { - CHECK(!new_call->args[0].as()); - const Expr& quantize_input = new_call->args[0]; // expression being quantized - profile_data_.push_back(quantize_input); - } - return new_call->args[0]; - } else { - return new_e; - } - } -}; - -Expr CollectStats(const Expr& expr) { - return StatsCollector().Collect(expr); -} - -TVM_REGISTER_API("relay._quantize.CollectStats") -.set_body_typed(CollectStats); } // namespace quantize } // namespace relay diff --git a/src/relay/pass/quantize.h b/src/relay/pass/quantize/quantize.h similarity index 90% rename from src/relay/pass/quantize.h rename to src/relay/pass/quantize/quantize.h index 262d420acf97..e3c88b5c3565 100644 --- a/src/relay/pass/quantize.h +++ b/src/relay/pass/quantize/quantize.h @@ -23,13 +23,13 @@ * \file tvm/relay/pass/quantize.h * \brief Header of definitions for quantization */ -#ifndef TVM_RELAY_PASS_QUANTIZE_H_ -#define TVM_RELAY_PASS_QUANTIZE_H_ +#ifndef TVM_RELAY_PASS_QUANTIZE_QUANTIZE_H_ +#define TVM_RELAY_PASS_QUANTIZE_QUANTIZE_H_ #include #include #include -#include "pattern_util.h" +#include "../pattern_util.h" namespace tvm { namespace relay { @@ -42,6 +42,22 @@ enum QAnnotateKind : int { kQActivation = 3, }; +/*! \brief Attribute for simulated quantize operator */ +struct SimulatedQuantizeAttrs : public tvm::AttrsNode { + int kind; + bool sign; + std::string rounding; + + TVM_DECLARE_ATTRS(SimulatedQuantizeAttrs, "relay.attrs.SimulatedQuantizeAttrs") { + TVM_ATTR_FIELD(kind) + .describe("kind of field, hint for nbit/dtype configuration."); + TVM_ATTR_FIELD(sign).set_default(true) + .describe("whether to use signed data type."); + TVM_ATTR_FIELD(rounding).set_default("round") + .describe("rounding mode. Can be 'floor', 'ceil', 'round'"); + } +}; + /*! * \brief TempExpr used during annotate forward rewrite. */ @@ -242,4 +258,4 @@ TVM_DLL QConfig qconfig(); } // namespace quantize } // namespace relay } // namespace tvm -#endif // TVM_RELAY_PASS_QUANTIZE_H_ +#endif // TVM_RELAY_PASS_QUANTIZE_QUANTIZE_H_ From a7557a405b6248a8d09596293f34c6bb8a8b63ce Mon Sep 17 00:00:00 2001 From: Wuwei Lin Date: Tue, 23 Jul 2019 04:57:51 +0000 Subject: [PATCH 08/12] address comments --- python/tvm/relay/quantize/kl_divergence.py | 6 +++--- src/relay/pass/quantize/{calibration.cc => calibrate.cc} | 2 +- 2 files changed, 4 insertions(+), 4 deletions(-) rename src/relay/pass/quantize/{calibration.cc => calibrate.cc} (99%) diff --git a/python/tvm/relay/quantize/kl_divergence.py b/python/tvm/relay/quantize/kl_divergence.py index 3444ccda5b30..21bce4c2eb3c 100644 --- a/python/tvm/relay/quantize/kl_divergence.py +++ b/python/tvm/relay/quantize/kl_divergence.py @@ -44,13 +44,14 @@ def _smooth_distribution(p, eps=0.0001): return hist -# pylint: disable=line-too-long,invalid-name +# pylint: disable=invalid-name def kl_divergence_scale(arr, quantized_dtype='int8', num_bins=8001, num_quantized_bins=255): """Given a dataset, find the optimal threshold for quantizing it. The reference distribution is `q`, and the candidate distribution is `p`. `q` is a truncated version of the original distribution. - Ref: http://on-demand.gputechconf.com/gtc/2017/presentation/s7310-8-bit-inference-with-tensorrt.pdf + Ref: + http://on-demand.gputechconf.com/gtc/2017/presentation/s7310-8-bit-inference-with-tensorrt.pdf """ assert isinstance(arr, np.ndarray) @@ -121,4 +122,3 @@ def kl_divergence_scale(arr, quantized_dtype='int8', num_bins=8001, num_quantize min_divergence_idx = np.argmin(divergence) opt_th = thresholds[min_divergence_idx] return opt_th -# pylint: enable=line-too-long diff --git a/src/relay/pass/quantize/calibration.cc b/src/relay/pass/quantize/calibrate.cc similarity index 99% rename from src/relay/pass/quantize/calibration.cc rename to src/relay/pass/quantize/calibrate.cc index e3d0f5a3078f..fadc3f2c67ea 100644 --- a/src/relay/pass/quantize/calibration.cc +++ b/src/relay/pass/quantize/calibrate.cc @@ -20,7 +20,7 @@ /*! * Copyright (c) 2019 by Contributors * - * \file calibration.cc + * \file calibrate.cc * * \brief Create profile graph and calibrate on dataset */ From 8e9be266b92dfae0ed19575711a3df667faf4579 Mon Sep 17 00:00:00 2001 From: Wuwei Lin Date: Fri, 26 Jul 2019 19:00:38 +0800 Subject: [PATCH 09/12] Update comments --- python/tvm/relay/quantize/kl_divergence.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/tvm/relay/quantize/kl_divergence.py b/python/tvm/relay/quantize/kl_divergence.py index 21bce4c2eb3c..bce45dca6f1c 100644 --- a/python/tvm/relay/quantize/kl_divergence.py +++ b/python/tvm/relay/quantize/kl_divergence.py @@ -46,7 +46,7 @@ def _smooth_distribution(p, eps=0.0001): # pylint: disable=invalid-name def kl_divergence_scale(arr, quantized_dtype='int8', num_bins=8001, num_quantized_bins=255): - """Given a dataset, find the optimal threshold for quantizing it. + """Given a tensor, find the optimal threshold for quantizing it. The reference distribution is `q`, and the candidate distribution is `p`. `q` is a truncated version of the original distribution. From 3e09ee9d40ef807e4f879c712871647dbdfccba3 Mon Sep 17 00:00:00 2001 From: Wuwei Lin Date: Sat, 27 Jul 2019 07:16:16 +0000 Subject: [PATCH 10/12] address comments --- python/tvm/relay/quantize/_annotate.py | 5 ++++- python/tvm/relay/quantize/quantize.py | 7 +++++-- src/relay/pass/quantize/calibrate.cc | 18 ++++++++++++++---- src/relay/pass/quantize/quantize.h | 1 + 4 files changed, 24 insertions(+), 7 deletions(-) diff --git a/python/tvm/relay/quantize/_annotate.py b/python/tvm/relay/quantize/_annotate.py index 9d81f58d3279..e03eaab507ad 100644 --- a/python/tvm/relay/quantize/_annotate.py +++ b/python/tvm/relay/quantize/_annotate.py @@ -39,6 +39,9 @@ def simulated_quantize_compute(attrs, inputs, out_type, target): data, scale, clip_min, clip_max = inputs + if attrs.kind == QAnnotateKind.IDENTITY: + return [topi.identity(data)] + # simulate rounding error scaled_data = topi.divide(data, scale) clipped_data = topi.maximum(topi.minimum(scaled_data, clip_max), clip_min) @@ -52,7 +55,7 @@ def simulated_quantize_compute(attrs, inputs, out_type, target): _reg.register_schedule("relay.op.annotation.simulated_quantize", _reg.schedule_injective) _reg.register_pattern("relay.op.annotation.simulated_quantize", - _reg.OpPattern.OPAQUE) + _reg.OpPattern.ELEMWISE) @register_relay_node diff --git a/python/tvm/relay/quantize/quantize.py b/python/tvm/relay/quantize/quantize.py index ac806f2ab5f0..9584b2528c98 100644 --- a/python/tvm/relay/quantize/quantize.py +++ b/python/tvm/relay/quantize/quantize.py @@ -35,6 +35,7 @@ class QAnnotateKind(object): INPUT = 1 WEIGHT = 2 ACTIVATION = 3 + IDENTITY = 4 def kind2str(kind): @@ -43,6 +44,7 @@ def kind2str(kind): QAnnotateKind.INPUT: "input", QAnnotateKind.WEIGHT: "weight", QAnnotateKind.ACTIVATION: "activation", + QAnnotateKind.IDENTITY: "identity" } assert kind in str_map return str_map[kind] @@ -197,8 +199,9 @@ def annotate_context(): def collect_stats(graph): """Given an annotated graph, create a profile graph to collect profile data from the - calibration dataset. This pass finds simulated_quantize op and collects its input into a tuple. - The tuple is the output of the profile graph. + calibration dataset. This pass collects simulated_quantize op input into a tuple. + Simulated_quantize ops are rewritten to identity mode. The tuple is the output of the profile + graph. Parameters ---------- diff --git a/src/relay/pass/quantize/calibrate.cc b/src/relay/pass/quantize/calibrate.cc index fadc3f2c67ea..7282b56deb21 100644 --- a/src/relay/pass/quantize/calibrate.cc +++ b/src/relay/pass/quantize/calibrate.cc @@ -54,12 +54,22 @@ class StatsCollector : private ExprMutator { CHECK(new_call); if (new_call->op.same_as(simulated_quantize)) { auto attrs = new_call->attrs.as(); + // rewrite the annotation + auto new_attrs = make_node(); const Expr& quantize_input = new_call->args[0]; // expression being quantized + auto placeholder = MakeConstantScalar(Float(32), 0.); // unused argument for simulated_quantize + Array new_args{quantize_input, placeholder, placeholder, placeholder}; + new_attrs->kind = QAnnotateKind::kQIdentity; + new_attrs->sign = attrs->sign; + new_attrs->rounding = attrs->rounding; + Expr identity_quantize = CallNode::make(new_call->op, new_args, Attrs{new_attrs}, {}); + + // add non-const expressions to profile data if (attrs->kind != QAnnotateKind::kQWeight) { CHECK(!quantize_input.as()); - profile_data_.push_back(quantize_input); + profile_data_.push_back(identity_quantize); } - return quantize_input; + return identity_quantize; } else { return new_e; } @@ -70,8 +80,8 @@ class StatsCollector : private ExprMutator { * \brief Given an annotated graph, create a profile graph to collect profile data from the * calibration dataset. * - * This pass finds simulated_quantize op and collects its input into a tuple. The tuple is the - * output of the profile graph. Both input and output of this pass + * This pass collects simulated_quantize op into a tuple. Simulated_quantize ops are rewritten to + * identity mode. The tuple is the output of the profile graph. Both input and output of this pass * are relay::Function. * * \param expr The simulation graph after annotation. diff --git a/src/relay/pass/quantize/quantize.h b/src/relay/pass/quantize/quantize.h index e3c88b5c3565..d57e8a875580 100644 --- a/src/relay/pass/quantize/quantize.h +++ b/src/relay/pass/quantize/quantize.h @@ -40,6 +40,7 @@ enum QAnnotateKind : int { kQInput = 1, kQWeight = 2, kQActivation = 3, + kQIdentity = 4 }; /*! \brief Attribute for simulated quantize operator */ From 0dd38e7d99b9ff12b88eb0178947aea0ff89672f Mon Sep 17 00:00:00 2001 From: Wuwei Lin Date: Sat, 27 Jul 2019 19:19:03 +0800 Subject: [PATCH 11/12] lint --- src/relay/pass/quantize/calibrate.cc | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/relay/pass/quantize/calibrate.cc b/src/relay/pass/quantize/calibrate.cc index 7282b56deb21..30b47ba69a6e 100644 --- a/src/relay/pass/quantize/calibrate.cc +++ b/src/relay/pass/quantize/calibrate.cc @@ -57,7 +57,7 @@ class StatsCollector : private ExprMutator { // rewrite the annotation auto new_attrs = make_node(); const Expr& quantize_input = new_call->args[0]; // expression being quantized - auto placeholder = MakeConstantScalar(Float(32), 0.); // unused argument for simulated_quantize + auto placeholder = MakeConstantScalar(Float(32), 0.); // unused argument Array new_args{quantize_input, placeholder, placeholder, placeholder}; new_attrs->kind = QAnnotateKind::kQIdentity; new_attrs->sign = attrs->sign; From 493d14b9559d44526326f1279d4c427e147fa417 Mon Sep 17 00:00:00 2001 From: Wuwei Lin Date: Sun, 28 Jul 2019 13:41:47 +0800 Subject: [PATCH 12/12] kQIdentity = 0 --- python/tvm/relay/quantize/quantize.py | 2 +- src/relay/pass/quantize/quantize.h | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/python/tvm/relay/quantize/quantize.py b/python/tvm/relay/quantize/quantize.py index 9584b2528c98..07d4d9d25e01 100644 --- a/python/tvm/relay/quantize/quantize.py +++ b/python/tvm/relay/quantize/quantize.py @@ -32,10 +32,10 @@ class QAnnotateKind(object): """Denote the kind of annotation field, corresponding to different nbit configure.""" + IDENTITY = 0 INPUT = 1 WEIGHT = 2 ACTIVATION = 3 - IDENTITY = 4 def kind2str(kind): diff --git a/src/relay/pass/quantize/quantize.h b/src/relay/pass/quantize/quantize.h index d57e8a875580..4965a706b4b4 100644 --- a/src/relay/pass/quantize/quantize.h +++ b/src/relay/pass/quantize/quantize.h @@ -37,10 +37,10 @@ namespace quantize { /*! \brief Kind of annotate field */ enum QAnnotateKind : int { + kQIdentity = 0, kQInput = 1, kQWeight = 2, - kQActivation = 3, - kQIdentity = 4 + kQActivation = 3 }; /*! \brief Attribute for simulated quantize operator */