Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions python/tvm/relay/quantize/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,3 +20,4 @@

from .quantize import *
from ._annotate import register_annotate_function
from .kl_divergence import kl_divergence_scale
10 changes: 7 additions & 3 deletions python/tvm/relay/quantize/_annotate.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,9 @@ def simulated_quantize_compute(attrs, inputs, out_type, target):

data, scale, clip_min, clip_max = inputs

if attrs.kind == QAnnotateKind.IDENTITY:
return [topi.identity(data)]

# simulate rounding error
scaled_data = topi.divide(data, scale)
clipped_data = topi.maximum(topi.minimum(scaled_data, clip_max), clip_min)
Expand All @@ -52,7 +55,7 @@ def simulated_quantize_compute(attrs, inputs, out_type, target):
_reg.register_schedule("relay.op.annotation.simulated_quantize",
_reg.schedule_injective)
_reg.register_pattern("relay.op.annotation.simulated_quantize",
_reg.OpPattern.OPAQUE)
_reg.OpPattern.ELEMWISE)


@register_relay_node
Expand Down Expand Up @@ -251,7 +254,7 @@ def add_rewrite(ref_call, new_args, ctx):

if lhs_kind is None and rhs_kind is not None:
# quantize lhs to INPUT field if it is normal expression
assert rhs_kind == QAnnotateKind.INPUT
assert rhs_kind in [QAnnotateKind.INPUT, QAnnotateKind.ACTIVATION]
lhs_expr = attach_simulated_quantize(lhs_expr, QAnnotateKind.INPUT)
expr = _forward_op(ref_call, [lhs_expr, rhs_expr])
return QAnnotateExpr(expr, QAnnotateKind.INPUT)
Expand All @@ -275,7 +278,8 @@ def add_rewrite(ref_call, new_args, ctx):
rhs_expr = attach_simulated_quantize(rhs_expr, QAnnotateKind.INPUT)
expr = _forward_op(ref_call, [lhs_expr, rhs_expr])
return QAnnotateExpr(expr, QAnnotateKind.ACTIVATION)
if lhs_kind == QAnnotateKind.ACTIVATION and rhs_kind == QAnnotateKind.INPUT:
if (lhs_kind == QAnnotateKind.ACTIVATION and rhs_kind == QAnnotateKind.INPUT) or \
(lhs_kind == QAnnotateKind.INPUT and rhs_kind == QAnnotateKind.ACTIVATION):
expr = _forward_op(ref_call, [lhs_expr, rhs_expr])
return QAnnotateExpr(expr, QAnnotateKind.ACTIVATION)
raise ValueError()
Expand Down
124 changes: 124 additions & 0 deletions python/tvm/relay/quantize/kl_divergence.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,124 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""Find optimal scale for quantization by minimizing KL-divergence"""

try:
from scipy import stats
except ImportError:
stats = None

import numpy as np


def _smooth_distribution(p, eps=0.0001):
"""Given a discrete distribution (may have not been normalized to 1),
smooth it by replacing zeros with eps multiplied by a scaling factor and taking the
corresponding amount off the non-zero values.
Ref: http://hanj.cs.illinois.edu/cs412/bk3/KL-divergence.pdf
"""
is_zeros = (p == 0).astype(np.float32)
is_nonzeros = (p != 0).astype(np.float32)
n_zeros = is_zeros.sum()
n_nonzeros = p.size - n_zeros
if not n_nonzeros:
raise ValueError('The discrete probability distribution is malformed. All entries are 0.')
eps1 = eps * float(n_zeros) / float(n_nonzeros)
assert eps1 < 1.0, 'n_zeros=%d, n_nonzeros=%d, eps1=%f' % (n_zeros, n_nonzeros, eps1)
hist = p.astype(np.float32)
hist += eps * is_zeros + (-eps1) * is_nonzeros
assert (hist <= 0).sum() == 0
return hist


# pylint: disable=invalid-name
def kl_divergence_scale(arr, quantized_dtype='int8', num_bins=8001, num_quantized_bins=255):
"""Given a tensor, find the optimal threshold for quantizing it.
The reference distribution is `q`, and the candidate distribution is `p`.
`q` is a truncated version of the original distribution.

Ref:
http://on-demand.gputechconf.com/gtc/2017/presentation/s7310-8-bit-inference-with-tensorrt.pdf
"""
assert isinstance(arr, np.ndarray)

min_val = np.min(arr)
max_val = np.max(arr)
th = max(abs(min_val), abs(max_val))

if min_val >= 0 and quantized_dtype in ['uint8']:
# We need to move negative bins to positive bins to fit uint8 range.
num_quantized_bins = num_quantized_bins * 2 + 1

hist, hist_edges = np.histogram(arr, bins=num_bins, range=(-th, th))
zero_bin_idx = num_bins // 2
num_half_quantized_bins = num_quantized_bins // 2

thresholds = np.zeros(num_bins // 2 + 1 - num_quantized_bins // 2)
divergence = np.zeros_like(thresholds)
quantized_bins = np.zeros(num_quantized_bins, dtype=np.int32)
# i means the number of bins on half axis excluding the zero bin.
for i in range(num_quantized_bins // 2,
num_bins // 2 + 1):
p_bin_idx_start = zero_bin_idx - i
p_bin_idx_stop = zero_bin_idx + i + 1
thresholds[i - num_half_quantized_bins] = hist_edges[p_bin_idx_stop]
sliced_nd_hist = hist[p_bin_idx_start:p_bin_idx_stop]

# generate reference distribution p
p = sliced_nd_hist.copy()
assert p.size % 2 == 1
assert p.size >= num_quantized_bins
# put left outlier count in p[0]
left_outlier_count = np.sum(hist[0:p_bin_idx_start])
p[0] += left_outlier_count
# put right outlier count in p[-1]
right_outlier_count = np.sum(hist[p_bin_idx_stop:])
p[-1] += right_outlier_count
# is_nonzeros[k] indicates whether hist[k] is nonzero
is_nonzeros = (p != 0).astype(np.int32)

# calculate how many bins should be merged to generate quantized distribution q
num_merged_bins = sliced_nd_hist.size // num_quantized_bins
# merge hist into num_quantized_bins bins
for j in range(num_quantized_bins):
start = j * num_merged_bins
stop = start + num_merged_bins
quantized_bins[j] = sliced_nd_hist[start:stop].sum()
quantized_bins[-1] += sliced_nd_hist[num_quantized_bins * num_merged_bins:].sum()
# expand quantized_bins into p.size bins
q = np.zeros(sliced_nd_hist.size, dtype=np.float32)
for j in range(num_quantized_bins):
start = j * num_merged_bins
if j == num_quantized_bins - 1:
stop = len(is_nonzeros)
else:
stop = start + num_merged_bins
norm = is_nonzeros[start:stop].sum()
if norm != 0:
q[start:stop] = float(quantized_bins[j]) / float(norm)
q[p == 0] = 0
p = _smooth_distribution(p)
# There is a chance that q is an invalid probability distribution.
try:
q = _smooth_distribution(q)
except ValueError:
divergence[i - num_half_quantized_bins] = float("inf")
divergence[i - num_half_quantized_bins] = stats.entropy(p, q)

min_divergence_idx = np.argmin(divergence)
opt_th = thresholds[min_divergence_idx]
return opt_th
57 changes: 53 additions & 4 deletions python/tvm/relay/quantize/quantize.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
class QAnnotateKind(object):
"""Denote the kind of annotation field, corresponding
to different nbit configure."""
IDENTITY = 0
INPUT = 1
WEIGHT = 2
ACTIVATION = 3
Expand All @@ -43,6 +44,7 @@ def kind2str(kind):
QAnnotateKind.INPUT: "input",
QAnnotateKind.WEIGHT: "weight",
QAnnotateKind.ACTIVATION: "activation",
QAnnotateKind.IDENTITY: "identity"
}
assert kind in str_map
return str_map[kind]
Expand Down Expand Up @@ -195,7 +197,26 @@ def annotate_context():
return AnnotateContext.Current


def calibrate(graph, mod=None, ctx=None):
def collect_stats(graph):
"""Given an annotated graph, create a profile graph to collect profile data from the
calibration dataset. This pass collects simulated_quantize op input into a tuple.
Simulated_quantize ops are rewritten to identity mode. The tuple is the output of the profile
graph.

Parameters
----------
graph: Function
The simulation graph after annotation.

Returns
-------
ret: Function
The profile graph which outputs a tuple of profile data.
"""
return _quantize.CollectStats(graph)


def calibrate(graph, mod=None, ctx=None, weight_scales='power2', scales=None):
"""The calibrate procedure will try to calculate the content of
dom_scale, nbit, clip_min, clip_max for every `simulated_quantize`
operator.
Expand All @@ -211,6 +232,16 @@ def calibrate(graph, mod=None, ctx=None):
ctx: tvm.relay.PassContext
The pass context used for calibration.

weight_scales: 'power2' or 'max'.
The way to calculate scales for weights (annotated with QAnnotateKind.WEIGHT).
power2: Find the maximum of the absolute value of the tensor, and then round up to power
of two.
max: Find the maximum of the absolute value of the tensor.

scales: List[float]
Pre-calculated scales for input and activations. Length and the order of elements of the
scales list should match the output tuple of the profile graph created by collect_stats.

Returns
-------
ret: Function
Expand All @@ -221,24 +252,42 @@ def power2_scale(arr):
val = np.amax(np.abs(arr.asnumpy()))
return 2**np.math.ceil(np.math.log(val, 2)) if val > 0 else 1.0

def max_scale(arr):
"""calculate weight scale with maximum absolute value"""
val = np.amax(np.abs(arr.asnumpy()))
return val

scale_idx = 0

cfg = current_qconfig()
const_params = {}
quantize_op = _op.get("relay.op.annotation.simulated_quantize")

def visit_func(expr):
"""Internal visit function"""
nonlocal scale_idx
if isinstance(expr, _expr.Call) and expr.op == quantize_op:
_, ndom_scale, nclip_min, nclip_max = expr.args
attrs = expr.attrs
kind = attrs.kind
nbit = cfg.get_nbit_by_kind(kind)

valid_bit = nbit - attrs.sign

if kind == QAnnotateKind.WEIGHT:
if kind in [QAnnotateKind.WEIGHT]:
if all([isinstance(arg, _expr.Constant)
for arg in [ndom_scale, nclip_min, nclip_max]]):
return
var = expr.args[0]
assert isinstance(var, _expr.Constant)
scale = power2_scale(var.data)
if weight_scales == 'max':
scale = max_scale(var.data)
elif weight_scales == 'power2':
scale = power2_scale(var.data)
else:
raise ValueError('{} not supported'.format(weight_scales))
elif scales is not None:
scale = scales[scale_idx]
scale_idx += 1
else:
scale = cfg.global_scale

Expand Down
99 changes: 99 additions & 0 deletions src/relay/pass/quantize/calibrate.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,99 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/

/*!
* Copyright (c) 2019 by Contributors
*
* \file calibrate.cc
*
* \brief Create profile graph and calibrate on dataset
*/
#include <tvm/relay/analysis.h>
#include <tvm/relay/expr_functor.h>
#include "./quantize.h"


namespace tvm {
namespace relay {
namespace quantize {

class StatsCollector : private ExprMutator {
public:
Expr Collect(const Expr& expr) {
auto new_e = this->Mutate(expr);
const FunctionNode* func = new_e.as<FunctionNode>();
CHECK(func) << "Input shoule be Function";
Expr new_body = TupleNode::make(std::move(profile_data_));
return FunctionNode::make(FreeVars(new_body), new_body, NullValue<Type>(), func->type_params,
func->attrs);
}

private:
Array<Expr> profile_data_;

Expr VisitExpr_(const CallNode* call) {
static const Op& simulated_quantize = Op::Get("relay.op.annotation.simulated_quantize");
Expr new_e = ExprMutator::VisitExpr_(call);
const CallNode* new_call = new_e.as<CallNode>();
CHECK(new_call);
if (new_call->op.same_as(simulated_quantize)) {
auto attrs = new_call->attrs.as<SimulatedQuantizeAttrs>();
// rewrite the annotation
auto new_attrs = make_node<SimulatedQuantizeAttrs>();
const Expr& quantize_input = new_call->args[0]; // expression being quantized
auto placeholder = MakeConstantScalar(Float(32), 0.); // unused argument
Array<Expr> new_args{quantize_input, placeholder, placeholder, placeholder};
new_attrs->kind = QAnnotateKind::kQIdentity;
new_attrs->sign = attrs->sign;
new_attrs->rounding = attrs->rounding;
Expr identity_quantize = CallNode::make(new_call->op, new_args, Attrs{new_attrs}, {});

// add non-const expressions to profile data
if (attrs->kind != QAnnotateKind::kQWeight) {
CHECK(!quantize_input.as<ConstantNode>());
profile_data_.push_back(identity_quantize);
}
return identity_quantize;
} else {
return new_e;
}
}
};

/*
* \brief Given an annotated graph, create a profile graph to collect profile data from the
* calibration dataset.
*
* This pass collects simulated_quantize op into a tuple. Simulated_quantize ops are rewritten to
* identity mode. The tuple is the output of the profile graph. Both input and output of this pass
* are relay::Function.
*
* \param expr The simulation graph after annotation.
* \return The profile graph.
*/
Expr CollectStats(const Expr& expr) {
return StatsCollector().Collect(expr);
}

TVM_REGISTER_API("relay._quantize.CollectStats")
.set_body_typed(CollectStats);

} // namespace quantize
} // namespace relay
} // namespace tvm
Loading