diff --git a/python/tvm/relay/build_module.py b/python/tvm/relay/build_module.py index 404829f74cf7..9a09af828c7c 100644 --- a/python/tvm/relay/build_module.py +++ b/python/tvm/relay/build_module.py @@ -28,6 +28,8 @@ from . import ty as _ty from . import expr as _expr from .module import Module as _Module +#TODO(eqy) +#from . import transform as _transform from .backend import interpreter as _interpreter from .backend.vm import VMExecutor @@ -113,6 +115,43 @@ def build(self, func, target=None, target_host=None, params=None): return graph_json, mod, params +#TODO(eqy) +# def _setup_build_config(self, params): +# cfg = _transform.PassContext.current() +# +# # Set opt_level. +# self.set_opt_level(cfg.opt_level) +# +# # Set fallback device if it is available. +# if cfg.fallback_device: +# self.set_fallback_device(cfg.fallback_device) +# +# # Add required passes. +# if cfg.required_pass: +# passes = set() +# if isinstance(cfg.required_pass, (list, tuple, set)): +# passes = set(cfg.required_pass) +# else: +# raise TypeError("add_pass must be list, tuple, or set, but " + +# "got {}".format(type(cfg.required_pass))) +# for pass_name in passes: +# self.add_pass(pass_name) +# +# # Add disabled passes. +# if cfg.disabled_pass: +# passes = set() +# if isinstance(cfg.disabled_pass, (list, tuple, set)): +# passes = set(cfg.disabled_pass) +# else: +# raise TypeError("disable_pass must be list, tuple, or set, " + +# "but got {}".format(type(cfg.disabled_pass))) +# for pass_name in passes: +# print("DISABLING", pass_name) +# self.disable_pass(pass_name) +# +# if params: +# self._set_params(params) + def _set_params(self, params): inputs = {} for name, param in params.items(): diff --git a/python/tvm/relay/quantize/_annotate.py b/python/tvm/relay/quantize/_annotate.py index 7b7f9c42f2f1..30300816167b 100644 --- a/python/tvm/relay/quantize/_annotate.py +++ b/python/tvm/relay/quantize/_annotate.py @@ -22,7 +22,7 @@ import topi from . import _quantize from .quantize import QAnnotateKind, current_qconfig -from .quantize import annotate_context +from .quantize import annotate_context, _get_scale_counter, _set_scale_counter, _get_layout from .. import expr as _expr from .. import op as _op from ..op import op as _reg @@ -36,11 +36,62 @@ def simulated_quantize_compute(attrs, inputs, out_type, target): assert len(inputs) == 4 assert attrs.sign assert attrs.rounding == "round" - + data, scale, clip_min, clip_max = inputs + if (data.op == _op.get("concatenate")): + raise Exception + + if 'O' in attrs.layout and 'I' in attrs.layout: + channel_dim = attrs.layout.index('I') + else: + channel_dim = attrs.layout.index('C') + + shape = [] + for i in range(0, len(data.shape)): + shape.append(topi.util.get_const_int(data.shape[i])) + channels = 0 + + if 'broadcastable' in attrs.op_hint and len(data.shape) < len(attrs.layout): + #assert len(data.shape) == len(attrs.layout) - 1, "Unsupported broadcast" + if len(data.shape) == len(attrs.layout) - 1: + for d in range(0, len(data.shape)): + if shape[d] != 1: + channel_dim = d + channels = shape[d] + break + else: + channel_dim = 0 + channels = 1 + + if 'dense' in attrs.op_hint: + channel_dim = 1 + channels = 1 + + if attrs.passthrough: + # if original value should be passed through + assert attrs.kind != QAnnotateKind.WEIGHT + rdata = topi.identity(data) + return [rdata] + # simulate rounding error - scaled_data = topi.divide(data, scale) + if attrs.granularity == 'channel': + assert len(data.shape) >= 4 or 'broadcastable' in attrs.op_hint or 'dense' in attrs.op_hint,\ + "Unsupported Layout / Shape Broadcast" + # TODO consider memory/performance tradeoffs of not using a big scale + # tensor (does this impact fusion with other ops?) + + # TODO blocked layouts + if channels == 0: + channels = shape[channel_dim] + scale_chunk_shape = [1]*len(shape) + scale_chunk_shape[channel_dim] = channels + scale_chunk = topi.reshape(scale, scale_chunk_shape) + scale_tensor = topi.broadcast_to(scale_chunk, shape) + scaled_data = topi.divide(data, scale_tensor) + scale = scale_tensor + else: + scaled_data = topi.divide(data, scale) clipped_data = topi.maximum(topi.minimum(scaled_data, clip_max), clip_min) round_data = topi.round(clipped_data) @@ -49,6 +100,7 @@ def simulated_quantize_compute(attrs, inputs, out_type, target): return [rdata] + _reg.register_schedule("relay.op.annotation.simulated_quantize", _reg.schedule_injective) _reg.register_pattern("relay.op.annotation.simulated_quantize", @@ -116,7 +168,8 @@ def frewrite_with_guard(ref_call, new_args, ctx): return _register(frewrite) if frewrite is not None else _register -def attach_simulated_quantize(data, kind, sign=True, rounding="round"): +@register_func("relay.quantize.attach_simulated_quantize") +def attach_simulated_quantize(data, kind, layout=None, op_hint="", sign=True, rounding="round"): """Attach a simulated quantize operation after input data expr. Parameters @@ -127,25 +180,34 @@ def attach_simulated_quantize(data, kind, sign=True, rounding="round"): kind: QAnnotateKind the kind of annotation field. """ + quantize_op = _op.get("relay.op.annotation.simulated_quantize") if isinstance(data, _expr.Call) and data.op == quantize_op: if data.attrs.kind == kind and data.attrs.sign == sign and data.attrs.rounding == rounding: return data + #TODO(eqy): check actx = annotate_context() + counter = _get_scale_counter() key = tuple([data, kind, sign, rounding]) if key in actx.qnode_map: return actx.qnode_map[key] - dom_scale = _expr.var("dom_scale") + granularity = current_qconfig().granularity + dom_scale = _expr.var("dom_scale" + str(counter)) clip_min = _expr.var("clip_min") clip_max = _expr.var("clip_max") + passthrough = 0 + passthrough_bound = current_qconfig().passthrough_bound + if kind != QAnnotateKind.WEIGHT: + passthrough = counter > passthrough_bound + _set_scale_counter(counter + 1) + qnode = _quantize.simulated_quantize( - data, dom_scale, clip_min, clip_max, kind, sign, rounding) + data, dom_scale, clip_min, clip_max, kind, sign, rounding, passthrough, granularity, layout, op_hint, -1) actx.qnode_map[key] = qnode - return qnode -register_func("relay.quantize.attach_simulated_quantize", attach_simulated_quantize) + return qnode @register_annotate_function("nn.contrib_conv2d_NCHWc") @@ -168,17 +230,39 @@ def conv2d_rewrite(ref_call, new_args, ctx): return None actx.count_conv2d() + op_hint = "" + len_hint = -1 + + data_layout = ref_call.attrs.data_layout + kernel_layout = ref_call.attrs.kernel_layout + + _set_conv_counter(cnt + 1) + lhs_expr, lhs_kind = _get_expr_kind(new_args[0]) rhs_expr, rhs_kind = _get_expr_kind(new_args[1]) - if lhs_kind is None or lhs_kind == QAnnotateKind.ACTIVATION: - lhs_expr = attach_simulated_quantize(lhs_expr, QAnnotateKind.INPUT) + if lhs_kind is None or lhs_kind != QAnnotateKind.INPUT: + assert _get_layout(ref_call) is not None + lhs_expr = attach_simulated_quantize(lhs_expr, QAnnotateKind.INPUT, _get_layout(ref_call)) assert rhs_kind is None - rhs_expr = attach_simulated_quantize(rhs_expr, QAnnotateKind.WEIGHT) + assert kernel_layout is not None + if 'I' in kernel_layout: + assert 'C' in data_layout + kernel_channels = rhs_expr.data.shape[kernel_layout.index('I')] + if kernel_channels <= 1: + data_channels = ref_call.args[0].checked_type.shape[data_layout.index('C')] + if int(data_channels) > int(kernel_channels): + op_hint = 'depthwise_sep_single' + len_hint = int(data_channels) + + if lhs_kind is None or lhs_kind != QAnnotateKind.INPUT: + assert data_layout is not None + lhs_expr = attach_simulated_quantize(lhs_expr, QAnnotateKind.INPUT, data_layout) + + rhs_expr = attach_simulated_quantize(rhs_expr, QAnnotateKind.WEIGHT, kernel_layout, op_hint, len_hint) expr = _forward_op(ref_call, [lhs_expr, rhs_expr]) - return QAnnotateExpr(expr, QAnnotateKind.ACTIVATION) @@ -202,11 +286,16 @@ def dense_rewrite(ref_call, new_args, ctx): lhs_expr, lhs_kind = _get_expr_kind(new_args[0]) rhs_expr, rhs_kind = _get_expr_kind(new_args[1]) - if lhs_kind is None or lhs_kind == QAnnotateKind.ACTIVATION: - lhs_expr = attach_simulated_quantize(lhs_expr, QAnnotateKind.INPUT) + op_hint = "dense" + len_hint = 1 + assert _get_layout(ref_call) is not None + if lhs_kind is None or lhs_kind != QAnnotateKind.INPUT: + lhs_expr = attach_simulated_quantize(lhs_expr, QAnnotateKind.INPUT, + _get_layout(ref_call), op_hint, len_hint) assert rhs_kind is None - rhs_expr = attach_simulated_quantize(rhs_expr, QAnnotateKind.WEIGHT) + rhs_expr = attach_simulated_quantize(rhs_expr, QAnnotateKind.WEIGHT, + _get_layout(ref_call), op_hint, len_hint) expr = _forward_op(ref_call, [lhs_expr, rhs_expr]) @@ -226,11 +315,14 @@ def multiply_rewrite(ref_call, new_args, ctx): return None if lhs_kind in [QAnnotateKind.ACTIVATION, QAnnotateKind.INPUT] and rhs_kind is None: + assert _get_layout(ref_call) is not None # quantize lhs to INPUT field - if lhs_kind == QAnnotateKind.ACTIVATION: - lhs_expr = attach_simulated_quantize(lhs_expr, QAnnotateKind.INPUT) + lhs_expr = attach_simulated_quantize(lhs_expr, QAnnotateKind.INPUT, + _get_layout(ref_call)) # quantize rhs to WEIGHT field - rhs_expr = attach_simulated_quantize(rhs_expr, QAnnotateKind.WEIGHT) + rhs_expr = attach_simulated_quantize(rhs_expr, QAnnotateKind.WEIGHT, + _get_layout(ref_call), 'broadcastable_mul') + expr = _forward_op(ref_call, [lhs_expr, rhs_expr]) return QAnnotateExpr(expr, QAnnotateKind.ACTIVATION) @@ -249,23 +341,36 @@ def add_rewrite(ref_call, new_args, ctx): if lhs_kind is None and rhs_kind is None: return None + assert _get_layout(ref_call) is not None if lhs_kind is None and rhs_kind is not None: # quantize lhs to INPUT field if it is normal expression assert rhs_kind == QAnnotateKind.INPUT - lhs_expr = attach_simulated_quantize(lhs_expr, QAnnotateKind.INPUT) + lhs_expr = attach_simulated_quantize(lhs_expr, QAnnotateKind.INPUT, _get_layout(ref_call)) expr = _forward_op(ref_call, [lhs_expr, rhs_expr]) return QAnnotateExpr(expr, QAnnotateKind.INPUT) +#TODO(eqy): check +# if lhs_kind is None and rhs_kind is not None: +# # quantize lhs to INPUT field if it is normal expression +# lhs_expr = attach_simulated_quantize(lhs_expr, QAnnotateKind.INPUT, +# _get_layout(ref_call)) if lhs_kind is not None and rhs_kind is None: if isinstance(rhs_expr, _expr.Constant): # quantize rhs to WEIGHT field if it is Constant - rhs_expr = attach_simulated_quantize(rhs_expr, QAnnotateKind.WEIGHT) + rhs_expr = attach_simulated_quantize(rhs_expr, QAnnotateKind.WEIGHT, + _get_layout(ref_call), 'broadcastable_add') else: # quantize rhs to INPUT field if it is not Constant - rhs_expr = attach_simulated_quantize(rhs_expr, QAnnotateKind.INPUT) + rhs_expr = attach_simulated_quantize(rhs_expr, QAnnotateKind.INPUT, _get_layout(ref_call)) expr = _forward_op(ref_call, [lhs_expr, rhs_expr]) return QAnnotateExpr(expr, QAnnotateKind.ACTIVATION) +#TODO(eqy): check +# if lhs_kind == QAnnotateKind.ACTIVATION and rhs_kind == QAnnotateKind.ACTIVATION: +# # quantize rhs to INPUT field if both lhs and rhs are ACTIVATION +# rhs_expr = attach_simulated_quantize(rhs_expr, QAnnotateKind.INPUT, +# _get_layout(ref_call)) +# if lhs_kind is not None and rhs_kind is not None: if lhs_kind == QAnnotateKind.INPUT and rhs_kind == QAnnotateKind.INPUT: expr = _forward_op(ref_call, [lhs_expr, rhs_expr]) @@ -325,9 +430,10 @@ def pool2d_rewrite(ref_call, new_args, ctx): if x_kind is None: return None - if x_kind == QAnnotateKind.ACTIVATION: - expr = attach_simulated_quantize(expr, QAnnotateKind.INPUT) + if x_kind == QAnnotateKind.ACTIVATION: + expr = attach_simulated_quantize(expr, QAnnotateKind.INPUT, + _get_layout(ref_call)) expr = _forward_op(ref_call, [expr]) return QAnnotateExpr(expr, QAnnotateKind.INPUT) @@ -368,7 +474,8 @@ def concatenate_rewrite(ref_call, new_args, ctx): return None for i, k in enumerate(kind_list): if k is None: - expr_list[i] = attach_simulated_quantize(expr_list[i], QAnnotateKind.ACTIVATION) + expr_list[i] = attach_simulated_quantize(expr_list[i], QAnnotateKind.ACTIVATION, _get_layout(ref_call)) + expr = _forward_op(ref_call, [_expr.Tuple(expr_list)]) return QAnnotateExpr(expr, QAnnotateKind.ACTIVATION) diff --git a/python/tvm/relay/quantize/quantize.py b/python/tvm/relay/quantize/quantize.py index beebceaf8590..5cd7ccf8ba3e 100644 --- a/python/tvm/relay/quantize/quantize.py +++ b/python/tvm/relay/quantize/quantize.py @@ -17,7 +17,10 @@ #pylint: disable=unused-argument """Automatic quantization toolkit.""" from __future__ import absolute_import +import time import numpy as np +from scipy import stats + from . import _quantize from .. import expr as _expr @@ -26,9 +29,11 @@ from .. import transform as _transform from .. import op as _op from ... import make as _make +from ..._ffi.function import register_func from ..base import NodeBase, register_relay_node + class QAnnotateKind(object): """Denote the kind of annotation field, corresponding to different nbit configure.""" @@ -75,6 +80,17 @@ class QConfig(NodeBase): "round_for_shift": True, "store_lowbit_output": True, "debug_enabled_ops": None, +#TODO(eqy) +#======= +# "skip_k_conv": 1, +# "skip_conv_layers": None, + "passthrough_bound": 1e9, +# "round_for_shift": True, +# "store_lowbit_output": True, + "debug_enabled_ops": None, +# "use_stop_fusion": True, + "granularity": "layer", +#>>>>>>> check in } # pylint: disable=no-member @@ -187,15 +203,343 @@ def count_conv2d(self): def __exit__(self, ptype, value, traceback): pass - def annotate_context(): """Get the global singleton scope""" if AnnotateContext.Current is None: AnnotateContext.Current = AnnotateContext() return AnnotateContext.Current +#TODO(eqy) +#def calibrate(graph, mod=None, ctx=None): +#======= +SCALE_COUNTER = 0 + + +def _get_scale_counter(): + """Get the global counter for scale setting.""" + return SCALE_COUNTER + + +def _set_scale_counter(n): + """Set the value of the global scale setting counter.""" + global SCALE_COUNTER + SCALE_COUNTER = n + + +LAYOUT_MAP = None + + +def _set_layout_map(layout_map): + global LAYOUT_MAP + LAYOUT_MAP = layout_map + + +def _layout_walk(expr): + conv2d_op = _op.get("nn.conv2d") + if isinstance(expr, _expr.Call): + if expr.op == conv2d_op: + return expr.attrs.data_layout if expr.attrs.out_layout == "" else expr.attrs.out_layout + else: + for arg in expr.args: + if arg in LAYOUT_MAP: + return LAYOUT_MAP[arg] + ret = _layout_walk(arg) + if ret is not None: + return ret + return None + elif isinstance(expr, _expr.Tuple): + for arg in expr.fields: + ret = _layout_walk(arg) + if ret is not None: + return ret + return None + elif isinstance(expr, _expr.TupleGetItem): + return _layout_walk(expr.tuple_value) + raise Exception + + +@register_func("relay.quantize._get_layout") +def _get_layout(expr): + try: + return LAYOUT_MAP[expr] + except KeyError: + ret = _layout_walk(expr) + if ret is not None: + return ret + raise KeyError + + +def annotate(graph, layout_map): + """Given a float32 graph, annotate will rewrite the graph + and return back a graph which simulates the error brought by + current quantization scheme. + + Parameters + --------- + graph: Function + The original graph + + Returns + ------- + ret: Function + The graph after annotation + """ + _set_conv_counter(0) # reset counter + _set_layout_map(layout_map) + return _quantize.annotate(graph) + + +def tag_layout(graph): + conv2d_op = _op.get("nn.conv2d") + dense_op = _op.get("nn.dense") + _op_layout_map = dict() + # layouts to tag later + deferred = set() + + def extract_call_layout(args): + cur_layout = None + for arg in args: + if isinstance(arg, _expr.Call): + assert arg in _op_layout_map + if cur_layout is None: + cur_layout = _op_layout_map[arg] + else: + assert cur_layout == _op_layout_map[arg] + elif isinstance(arg, _expr.Tuple): + return extract_call_layout(arg.fields) + elif isinstance(arg, _expr.TupleGetItem): + return extract_call_layout(arg.tuple_value.args) + return cur_layout + + def visit_func(expr): + """Internal visit function""" + if isinstance(expr, _expr.Call): + cur_layout = None + if expr.op == conv2d_op: + if expr.attrs.out_layout == "": + _op_layout_map[expr] = expr.attrs.data_layout + else: + _op_layout_map[expr] = expr.attrs.out_layout + cur_layout = _op_layout_map[expr] + else: + cur_layout = extract_call_layout(expr.args) + if cur_layout is None: + deferred.add(expr) + else: + _op_layout_map[expr] = cur_layout + if cur_layout is not None: + for arg in expr.args: + if arg in deferred: + _op_layout_map[arg] = cur_layout + deferred.remove(arg) + + _analysis.post_order_visit(graph, visit_func) + if len(deferred) > 0: + raise ValueError + + return _op_layout_map + + +_WEIGHT_SCALE_OPTS = [2**i for i in range(-10, 8)] + + +def slice_idx(begin, end): + import tvm.expr + assert len(begin) == len(end) + for i in range(0, len(begin)): + if not isinstance(end[i], tvm.expr.IntImm) or end[i].value - begin[i].value == 0: + continue + return begin[i].value, end[i].value + raise ValueError + + +# ALIGN weight * data scales for convolution +def match_scales(graph, const_params, mode='max'): + conv2d_op = _op.get("nn.conv2d") + quantize_op = _op.get("relay.op.annotation.simulated_quantize") + + def visit_func(expr): + if isinstance(expr, _expr.Call): + if expr.op == conv2d_op: + quant_weight = expr.args[1] + weight_data, weight_scale_var, _, _ = quant_weight.args + weight_scale = const_params[weight_scale_var].data + + if expr.args[0].op != quantize_op and\ + expr.args[1].op == quantize_op: + unified_scale = np.empty(weight_scale.asnumpy().shape, dtype='float32') + # only weight shift possible + parent = expr.args[0].args[0] + assert parent.op == quantize_op + _, data_scale_var, _, _, = parent.args + data_scale = const_params[data_scale_var].data + if data_scale.shape != weight_scale.shape: + assert expr.args[0].op == _op.get("strided_slice") + begin, end = slice_idx(expr.args[0].attrs.begin, expr.args[0].attrs.end) + product = data_scale.asnumpy()[begin:end] * weight_scale.asnumpy() + else: + product = data_scale.asnumpy() * weight_scale.asnumpy() + if mode == 'max': + unified_scale = max(product) + else: + unified_scale = min(product) + # (d * s_d) * (w * s_w) = (o * s_d * s_w) + # (d * s_d) * (w * s_w') = (o * s_u) + # s_w' = s_u/s_d + gaps = unified_scale/product + weight_scale_transform = np.empty(gaps.shape, dtype='float32') + for i in range(0, gaps.shape[0]): + shift_width = np.log2(gaps[i]) + if shift_width == 0: + weight_scale_transform[i] = 1.0 + else: + weight_scale_transform[i] = 2**shift_width + new_weight_scale = weight_scale.asnumpy()*weight_scale_transform + const_params[weight_scale_var] = _expr.const(new_weight_scale) + return + elif expr.args[0].op == quantize_op and\ + expr.args[1].op != quantize_op: + raise ValueError + elif expr.args[0].op != quantize_op and\ + expr.args[1].op != quantize_op: + raise ValueError + + quant_data = expr.args[0] + _, data_scale_var, _, _ = quant_data.args + data_scale = const_params[data_scale_var].data + assert len(data_scale.shape) == 1 + assert len(weight_scale.shape) == 1 + assert weight_scale.shape[0] == data_scale.shape[0] or\ + expr.attrs.groups != 1 # grouped + if weight_scale.shape[0] != 1: + if weight_scale.shape[0] == data_scale.shape[0]: + product = data_scale.asnumpy() * weight_scale.asnumpy() + if mode == 'max': + unified_scale = max(product) + else: + unified_scale = np.median(product) + # (d * s_d) * (w * s_w) = (o * s_d * s_w) + # (d * s_d) * (w * s_w') = (o * s_u) + # s_w' = s_u/s_d + gaps = unified_scale/product + data_scale_transform = np.empty(gaps.shape, dtype='float32') + weight_scale_transform = np.empty(gaps.shape, dtype='float32') + for i in range(0, gaps.shape[0]): + shift_width = np.log2(gaps[i]) + if shift_width == 0: + weight_scale_transform[i] = 1.0 + else: + # magic heuristic, change data scales more + # aggressively than weight scales for + # compensation + weight_scale_transform[i] = 2**(shift_width//2) + data_scale_transform = gaps/weight_scale_transform + new_data_scale = data_scale.asnumpy()*data_scale_transform + new_weight_scale = weight_scale.asnumpy()*weight_scale_transform + const_params[weight_scale_var] = _expr.const(new_weight_scale) + const_params[data_scale_var] = _expr.const(new_data_scale) + # grouped convolution + else: + # match each group's accumulation scales + chunk_size = data_scale.shape[0]//expr.attrs.groups + assert chunk_size == weight_scale.shape[0] + data_scale_np = np.array(data_scale.asnumpy(), copy=True) + weight_scale_np = np.array(weight_scale.asnumpy(), copy=True) + for group in range(0, expr.attrs.groups): + # this part is likely inefficient but easy to reason about + chunk = data_scale_np[group*chunk_size:group*chunk_size+chunk_size] + chunk_prod = chunk * weight_scale_np + unified_scale = max(chunk_prod) + gaps = unified_scale/chunk_prod + data_scale_transform = np.empty(gaps.shape, dtype='float32') + for i in range(0, gaps.shape[0]): + shift_width = np.log2(gaps[i]) + if shift_width == 0: + data_scale_transform[i] = 1.0 + else: + data_scale_transform[i] = 2**(shift_width) + data_scale_np[group*chunk_size:group*chunk_size+chunk_size] =\ + data_scale_np[group*chunk_size:group*chunk_size+chunk_size] * data_scale_transform + assert data_scale_np.shape == data_scale.asnumpy().shape + const_params[data_scale_var] = _expr.const(data_scale_np) -def calibrate(graph, mod=None, ctx=None): + _analysis.post_order_visit(graph, visit_func) + return const_params + + +def _simulate_quantize(array, scale): + # simulate rounding error + + valid_bit = 7 + valid_range = 2**valid_bit + clip_min = - (valid_range - 1) + clip_max = valid_range - 1 + + scale = scale / valid_range + assert scale > 0 + scaled_data = array/scale + clipped_data = np.clip(scaled_data, clip_min, clip_max) + + round_data = np.round(clipped_data) + return round_data*scale + +def _mse_chooser(act, granularity, layout, op_hint=None): + t1 = time.time() + assert len(act.shape) <= 4, "Unsupported layout" + # TODO block layouts + assert layout.upper() == layout, "Blocked layouts not supported" + + if granularity == 'layer' or (op_hint is None and len(act.shape) < len(layout)): + mses = list() + for config_opt in _WEIGHT_SCALE_OPTS: + q = _simulate_quantize(act, config_opt) + mse = ((act - q)**2).mean() + mses.append(mse) + t2 = time.time() + scale = _WEIGHT_SCALE_OPTS[np.argmin(mses)] + return np.array([scale], dtype='float32') + else: + if len(act.shape) >= len(layout): + if 'O' in layout and 'I' in layout: + channel_dim = layout.index('I') + else: + channel_dim = layout.index('C') + channels = act.shape[channel_dim] + elif op_hint is not None and 'dense' in op_hint: + channel_dim = 0 + channels = 1 + else: + assert 'broadcastable' in op_hint, "trying to broadcast non-broadcastable op" + if len(act.shape) == len(layout) - 1: + for i in range(0, len(act.shape)): + if act.shape[i] != 1: + channel_dim = i + channels = act.shape[i] + else: + channel_dim = 0 + channels = 1 + + scales = np.array([0.0]*channels, dtype='float32') + for i in range(0, channels): + mses = list() + for config_opt in _WEIGHT_SCALE_OPTS: + sliced_act = np.take(act, i, channel_dim) + q = _simulate_quantize(sliced_act, config_opt) + mse = ((sliced_act - q)**2).mean() + mses.append(mse) + if mse == 0.0: + # use mode as fallback + scales[i] = -1 + break + if scales[i] == 0.0: + scales[i] = _WEIGHT_SCALE_OPTS[np.argmin(mses)] + mode = stats.mode(scales[scales > 0.0])[0] + scales[scales < 0] = mode + t2 = time.time() + return scales + + +def calibrate(graph, dataset=None, profile_mode=False, scales=None): """The calibrate procedure will try to calculate the content of dom_scale, nbit, clip_min, clip_max for every `simulated_quantize` operator. @@ -216,42 +560,98 @@ def calibrate(graph, mod=None, ctx=None): ret: Function The graph after calibration """ - def power2_scale(arr): + if profile_mode: + assert scales is None, "scales should not be passed in with profile_mode" + else: + assert scales is not None, "did not receive scales" + + def power2_scale(arr, granularity, layout, op_hint): """calculate weight scale with nearest mode-2 scale""" val = np.amax(np.abs(arr.asnumpy())) - return 2**np.math.ceil(np.math.log(val, 2)) if val > 0 else 1.0 + + # TODO blocked layout + if granularity == 'channel' or granularity == 'layer': + scale = _mse_chooser(arr.asnumpy(), granularity, layout, op_hint) + return scale + if len(arr.shape) >= 4: + if 'I' in layout: + channel_dim = layout.index('I') + else: + channel_dim = layout.index('C') + channels = arr.shape[channel_dim] + scales = list() + for i in range(0, channels): + val = np.amax(np.abs(np.take(arr.asnumpy(), i, channel_dim))) + scale = 2**np.math.ceil(np.math.log(val, 2)) if val > 0 else 1.0 + scales.append(scale) + return np.array(scales, dtype='float32') + else: + scale = 2**np.math.ceil(np.math.log(val, 2)) if val > 0 else 1.0 + return np.array([scale], dtype='float32') + else: + return 2**np.math.ceil(np.math.log(val, 2)) if val > 0 else 1.0 cfg = current_qconfig() const_params = {} quantize_op = _op.get("relay.op.annotation.simulated_quantize") + profile_data = [] + scale_idx = 0 def visit_func(expr): """Internal visit function""" + nonlocal scale_idx if isinstance(expr, _expr.Call) and expr.op == quantize_op: _, ndom_scale, nclip_min, nclip_max = expr.args attrs = expr.attrs kind = attrs.kind + granularity = attrs.granularity + layout = attrs.layout + nbit = cfg.get_nbit_by_kind(kind) valid_bit = nbit - attrs.sign + valid_range = 2**valid_bit + + def _make_const(val): + return _expr.const(val, 'float32') if kind == QAnnotateKind.WEIGHT: var = expr.args[0] assert isinstance(var, _expr.Constant) - scale = power2_scale(var.data) + data = var.data + if False and 'add' in attrs.op_hint: + data_np = data.asnumpy() + zero_ind = data_np < 2**-4 + data_np[zero_ind] = np.mean(data_np) + data = _make_const(data).data + scale = power2_scale(data, granularity, layout, attrs.op_hint) + const = _make_const(scale / valid_range) + assert len(const.data.shape) == 1 + const_params[ndom_scale] = const else: - scale = cfg.global_scale - - def _make_const(val): - return _expr.const(val, 'float32') - - valid_range = 2**valid_bit - const_params[ndom_scale] = _make_const(scale / valid_range) + if profile_mode: + profile_data.append((ndom_scale.name_hint, expr.args[0], + granularity, layout)) + else: + const = _make_const(scales[scale_idx]/valid_range) + const_params[ndom_scale] = const + assert len(const.data.shape) == 1 + scale_idx += 1 const_params[nclip_min] = _make_const(- (valid_range - 1)) const_params[nclip_max] = _make_const((valid_range - 1)) _analysis.post_order_visit(graph, visit_func) - return _expr.bind(graph, const_params) + if profile_mode: + for i, val in enumerate(profile_data): + profile_data[i] = (val[0], _expr.bind(val[1], const_params), val[2], val[3]) + else: + const_params = match_scales(graph, const_params) +#TODO(eqy): +# return _expr.bind(graph, const_params) +#======= + #_ir_pass.post_order_visit(graph, visit_func) + return _expr.bind(graph, const_params), profile_data +#>>>>>>> check in def annotate(): @@ -312,6 +712,15 @@ def _bind_params(func, params): bind_dict[arg] = _expr.const(v) return _expr.bind(func, bind_dict) +def optimize(graph): + # Perform "SimplifyInference", "FoldScaleAxis", "FoldConstant", and + # "CanonicalizeOps" optimization before quantization. + opt = _transform.Sequential([_transform.SimplifyInference(), + _transform.FoldConstant(), + _transform.FoldScaleAxis(), + _transform.CanonicalizeOps(), + _transform.FoldConstant()]) + return opt(graph) def quantize(graph, params=None, dataset=None): """ The quantization procedure. Before running the three main @@ -340,13 +749,6 @@ def quantize(graph, params=None, dataset=None): graph = _bind_params(graph, params) mod = _module.Module.from_expr(graph) - # Perform "SimplifyInference", "FoldScaleAxis", "FoldConstant", and - # "CanonicalizeOps" optimization before quantization. - optimize = _transform.Sequential([_transform.SimplifyInference(), - _transform.FoldConstant(), - _transform.FoldScaleAxis(), - _transform.CanonicalizeOps(), - _transform.FoldConstant()]) calibrate_pass = _transform.function_pass(calibrate, opt_level=1, name="QuantizeCalibrate") @@ -366,3 +768,149 @@ def quantize(graph, params=None, dataset=None): mod = quantize_seq(mod) return mod["main"] + +#TODO(eqy) +# # TODO(zhiics) Move this to the pass manager. +# graph = optimize(graph, params) +# +# graph = annotate(graph) +# graph = calibrate(graph, dataset) +# graph = realize(graph) +# graph = _ir_pass.fold_constant(graph) +# return graph + +def _evaluate(val_data, batch_fn, graph, lib, params, ctx, free_vars=[], config=[], num_classes=1000, early_stopping=32, log_iter=2): + import mxnet as mx + """Evaluate function for profiling.""" + import tvm + import logging + logging.basicConfig(level=logging.INFO) + from tvm.contrib import graph_runtime + + # create runtime module + m = graph_runtime.create(graph, lib, ctx) + scales = {} + + for i in range(0, len(free_vars)): + free_var = free_vars[i] + if i >= len(config): + shape = m.get_input(i+1).shape + dummy = np.empty(shape=shape) + if len(dummy.shape) > 0: + dummy[:] = np.nan + else: + dummy = np.nan + params[str(free_var.name_hint)] = np.array(dummy) + else: + params[str(free_var.name_hint)] = np.array(config[i]/128) + + m.set_input(**params) + batch_size = 1 + oshape = (batch_size, num_classes) + out_arr = tvm.nd.empty(oshape, "float32") + # setup evaluaiton metric + acc_top1 = mx.metric.Accuracy() + acc_top5 = mx.metric.TopKAccuracy(5) + val_data.reset() + acc_top1.reset() + acc_top5.reset() + # execute + + output_collection = [None]*(m.get_num_outputs() - 1) + for i, batch in enumerate(val_data): + data, label = batch_fn(batch, [mx.cpu(0)]) + m.run(data=data[0].asnumpy()) + m.run(data=data[0].asnumpy(), **scales) + m.get_output(0, out_arr) + for o in range(0, len(output_collection)): + if output_collection[o] is None: + output_collection[o] = m.get_output(o+1).asnumpy() + else: + output_collection[o] = np.concatenate((output_collection[o], m.get_output(o+1).asnumpy())) + acc_top1.update(label, [mx.nd.array(out_arr.asnumpy())]) + acc_top5.update(label, [mx.nd.array(out_arr.asnumpy())]) + _, top1 = acc_top1.get() + _, top5 = acc_top5.get() + + if not (i + 1) % log_iter: + nsamples = (i + 1) * batch_size + print('[{0:d} samples] evaluation: acc-top1={1:f} acc-top5={2:f}'.format(nsamples, top1, top5)) + + if (i+1)*batch_size >= early_stopping: + return top1, output_collection + +#def autoquantize(graph_callback, tr_data, tr_batch_fn, granularity='layer'): +def autoquantize(graph, params, tr_data, tr_batch_fn, granularity='layer'): + + import tvm + import copy + from tvm import relay + #from tvm.relay import ir_pass + from tvm.relay.testing import run_infer_type as infer_type + #graph, params = graph_callback() + graph = _bind_params(graph, params) + graph = optimize(graph, params) + with qconfig(skip_k_conv=0, + passthrough_bound=int(-1), + nbit_input=8, + nbit_weight=8, + global_scale=8.0, + dtype_input='int8', + dtype_weight='int8', + dtype_activation='int32', + store_lowbit_output=True, + debug_enabled_ops=None, + granularity=granularity): + layout_map = tag_layout(graph) + graph = annotate(graph, layout_map) + annotated = copy.deepcopy(graph) + #TODO(eqy) graph = ir_pass.infer_type(graph) + graph = _transform.infer_type(graph) + graph, profile_data = calibrate(graph, profile_mode=True, scales=None) + + #free_vars = list(ir_pass.free_vars(graph)) + free_vars = list(analysis.free_vars(graph), graph) + graph = relay.Function(list(graph.params) + free_vars, + graph.body, graph.ret_type, + graph.type_params, graph.attrs) + additional_outputs = list() + metadata = list() + for hint, data, granularity, layout in profile_data: + additional_outputs.append(data) + metadata.append((hint, granularity, layout)) + graph = relay.Function(graph.params, + relay.expr.Tuple([graph.body]+additional_outputs)) + target = 'llvm -mcpu=core-avx2' + #target = 'cuda' + with relay.build_config(opt_level=0): + graph, lib, params = relay.build(graph, target) + ctx = tvm.nd.context(target) + + config = list() + print("calibrating...") + t1 = time.time() + top1, outputs = _evaluate(tr_data, tr_batch_fn, graph, lib, params, ctx, free_vars, early_stopping=64) + for i, output in enumerate(outputs): + config.append(_mse_chooser(output, granularity, metadata[i][-1])) + with qconfig(skip_k_conv=0, + passthrough_bound=int(1e9), + nbit_input=8, + nbit_weight=8, + global_scale=8.0, + dtype_input='int8', + dtype_weight='int8', + dtype_activation='int32', + store_lowbit_output=True, + debug_enabled_ops=None, + granularity=granularity): + graph = annotated + #TODO(eqy) + #graph = ir_pass.infer_type(graph) + graph = infer_type(graph) + graph, profile_data = calibrate(graph, profile_mode=False, scales=config) + graph = realize(graph) + t2 = time.time() + print("calibrated in approx", t2-t1, "s") + with relay.build_config(opt_level=3): + graph = optimize(graph, params) + return graph diff --git a/src/relay/pass/fold_constant.cc b/src/relay/pass/fold_constant.cc index 7b896a8d0f7f..3e457addb9c9 100644 --- a/src/relay/pass/fold_constant.cc +++ b/src/relay/pass/fold_constant.cc @@ -64,6 +64,21 @@ class ConstantChecker : private ExprVisitor { } memo_[GetRef(n)] = result; } + + + void VisitExpr_(const CallNode* call) final { + bool result = true; + result &= Check(call->op); + + for (auto arg : call->args) { + result &= Check(arg); + } + memo_[GetRef(call)] = result; + } + + void VisitExpr_(const OpNode* op) final { + memo_[GetRef(op)] = true; + } }; diff --git a/src/relay/pass/pattern_util.h b/src/relay/pass/pattern_util.h index 5c303905968e..2374488a2cac 100644 --- a/src/relay/pass/pattern_util.h +++ b/src/relay/pass/pattern_util.h @@ -361,6 +361,18 @@ inline Expr LeftShift(Expr x, Expr nbit) { } +inline Expr Max(Expr x, Expr y) { + static const Op& op = Op::Get("maximum"); + return CallNode::make(op, {x, y}, Attrs(), {}); +} + + +inline Expr Min(Expr x, Expr y) { + static const Op& op = Op::Get("minimum"); + return CallNode::make(op, {x, y}, Attrs(), {}); +} + + inline Expr ReshapeLike(Expr lhs, Expr rhs) { static const Op& op = Op::Get("reshape_like"); return CallNode::make(op, {lhs, rhs}, Attrs(), {}); @@ -372,6 +384,26 @@ inline Expr Copy(Expr data) { return CallNode::make(op, {data}, Attrs(), {}); } +inline Expr Reshape(Expr data, Array newshape) { + static const Op& op = Op::Get("reshape"); + auto attrs = make_node(); + attrs->newshape = newshape; + return CallNode::make(op, {data}, Attrs(attrs), {}); +} + +inline Expr ForwardOp(const Call& ref_call, const Array& args) { + return CallNode::make(ref_call->op, + args, ref_call->attrs, ref_call->type_args); +} + +inline Expr Ones(Array shape, + DataType dtype) { + auto attrs = make_node(); + attrs->shape = std::move(shape); + attrs->dtype = std::move(dtype); + static const Op& op = Op::Get("ones"); + return CallNode::make(op, {}, Attrs(attrs), {}); +} Expr MakeConcatenate(Expr data, int axis); diff --git a/src/relay/pass/quantize.cc b/src/relay/pass/quantize.cc index 83d9220ccf79..a75c0d0dd2e3 100644 --- a/src/relay/pass/quantize.cc +++ b/src/relay/pass/quantize.cc @@ -51,6 +51,10 @@ struct SimulatedQuantizeAttrs : public tvm::AttrsNode { int kind; bool sign; std::string rounding; + int passthrough; + std::string granularity; + std::string layout; + std::string op_hint; TVM_DECLARE_ATTRS(SimulatedQuantizeAttrs, "relay.attrs.SimulatedQuantizeAttrs") { TVM_ATTR_FIELD(kind) @@ -59,6 +63,14 @@ struct SimulatedQuantizeAttrs : public tvm::AttrsNode { .describe("whether to use signed data type."); TVM_ATTR_FIELD(rounding).set_default("round") .describe("rounding mode. Can be 'floor', 'ceil', 'round'"); + TVM_ATTR_FIELD(passthrough).set_default(false) + .describe("whether to passthrough full precision value (for data-aware calibration)"); + TVM_ATTR_FIELD(granularity).set_default("layer") + .describe("scale granularity. Can be 'global', 'layer', 'channel'"); + TVM_ATTR_FIELD(layout).set_default("unknown") + .describe("data layout (e.g., 'NCHW', 'NHWC')"); + TVM_ATTR_FIELD(op_hint).set_default("") + .describe("operator hint on how to interpret layout (e.g., 'broadcastable')"); } }; @@ -76,7 +88,40 @@ bool SimulatedQuantizeRel(const Array& types, CHECK(data != nullptr); CHECK_NE(data->shape.size(), 0) << "Input shape cannot be empty"; - reporter->Assign(types[1], TensorTypeNode::make({}, Float(32))); // dom_scale + size_t channel_dim = param->layout.find("C"); + if (channel_dim == std::string::npos) + channel_dim = param->layout.find("I"); + + // TODO(eqy): blocked layouts + CHECK(param->layout.find_first_not_of("NCOIHW") == std::string::npos)\ + << "Unsupported Layout in Simulated Quantize"; + + int channels = 1; + if (data->shape.size() >= 4) { + channels = data->shape[channel_dim].as()->value; + } else if (param->op_hint.find("broadcastable") != std::string::npos && data->shape.size() == 3) { + // TODO(eqy): robust broadcast handling, blocked layout support + size_t d = 0; + for (; d < data->shape.size(); d++) { + if (data->shape[d].as()->value != 1) { + channels = data->shape[d].as()->value; + break; + } + } + for (d = d + 1; d < data->shape.size(); d++) { + CHECK_EQ(data->shape[d].as()->value, 1) + << "Unhandled broadcastable data shape" + << data->shape; + } + } else { + channels = 1; + } + + if (param->granularity == "channel") { + reporter->Assign(types[1], TensorTypeNode::make({channels}, Float(32))); // dom_scale + } else { + reporter->Assign(types[1], TensorTypeNode::make({1}, Float(32))); // dom_scale + } reporter->Assign(types[2], TensorTypeNode::make({}, Float(32))); // clip_min reporter->Assign(types[3], TensorTypeNode::make({}, Float(32))); // clip_max reporter->Assign(types[4], types[0]); // output @@ -94,14 +139,21 @@ RELAY_REGISTER_OP("relay.op.annotation.simulated_quantize") .set_support_level(11) .add_type_rel("SimulatedQuantize", SimulatedQuantizeRel); + TVM_REGISTER_API("relay._quantize.simulated_quantize") -.set_body_typed( +.set_body_typed( [](Expr data, Expr dom_scale, Expr clip_min, Expr clip_max, - int kind, bool sign, std::string rounding) { + int kind, bool sign, std::string rounding, int passthrough, + std::string granularity, std::string layout, std::string op_hint) { auto attrs = make_node(); attrs->kind = kind; attrs->sign = sign; attrs->rounding = rounding; + attrs->passthrough = passthrough; + attrs->granularity = granularity; + attrs->layout = layout; + attrs->op_hint = op_hint; static const Op& op = Op::Get("relay.op.annotation.simulated_quantize"); return CallNode::make(op, {data, dom_scale, clip_min, clip_max}, Attrs(attrs), {}); }); @@ -114,8 +166,11 @@ Expr QAnnotateExprNode::Realize() const { const auto& cfg = QConfig::Current(); if (cfg->store_lowbit_output) { // store low bit output back for VTA + const PackedFunc* layout_f = runtime::Registry::Get("relay.quantize._get_layout"); + std::string layout = (*layout_f) (this->expr); const PackedFunc* f = runtime::Registry::Get("relay.quantize.attach_simulated_quantize"); - return (*f)(this->expr, static_cast(kQInput)); + return (*f)(this->expr, static_cast(kQInput), layout, (std::string) +this->expr.as()->op.as()->name); } else { return expr; } @@ -135,9 +190,187 @@ TVM_REGISTER_API("relay._quantize.make_annotate_expr") }); +/* +TODO(eqy) +TVM_REGISTER_API("relay._quantize.annotate") +.set_body_typed([] (const Expr& expr) { + std::function fmulti_ref = [](const Expr& e) { + if (e->derived_from()) { + const auto* n = e.as(); + CHECK(n); + const PackedFunc* f = runtime::Registry::Get("relay.quantize.attach_simulated_quantize"); + const PackedFunc* layout_f = runtime::Registry::Get("relay.quantize._get_layout"); + std::string layout = (*layout_f) (n->expr); + std::string name = n->expr.as()->op.as()->name; + Expr ret = (*f)(n->expr, static_cast(kQInput), layout, name); + return static_cast(QAnnotateExprNode::make(ret, kQInput)); + } + return e; + }; + return ForwardRewrite(expr, "FQAnnotateRewrite", nullptr, fmulti_ref); +}); +*/ + + // ============= // realize pass +Expr InferTypeOpt(const Expr& expr) { + auto mod = ModuleNode::FromExpr(expr); + mod = transform::InferType()(mod); + auto entry_func = mod->Lookup("main"); + return expr.as() == nullptr ? entry_func->body : entry_func; +} + +Expr FoldConstantOpt(const Expr& expr) { + auto mod = ModuleNode::FromExpr(expr); + mod = transform::FoldConstant()(mod); + auto entry_func = mod->Lookup("main"); + return expr.as() == nullptr ? entry_func->body : entry_func; +} + +Expr _ReshapeChannelScale(Expr dom_scale, Expr arr, size_t pos) { + auto* dom_scale_tensor = dom_scale.as(); + CHECK(dom_scale_tensor); + Array data_shape; + + if (!arr->checked_type_.defined()) { + //arr = InferType(arr, Module(nullptr)); + arr = InferTypeOpt(arr); + data_shape = arr->checked_type().as()->shape; + } else { + data_shape = arr->checked_type().as()->shape; + } + Array dom_scale_shape = dom_scale_tensor->tensor_type()->shape; + Array broadcast_shape; + + CHECK_LE(dom_scale_shape.size(), 1); + if (dom_scale_shape[0].as()->value == 1) { + // leverage implicit broadcasting + return dom_scale; + } + + int channels = -1; + if (dom_scale_shape.size() == 1) { + channels = dom_scale_shape[0].as()->value; + } + for (size_t i = 0; i < data_shape.size(); i++) { + int dim = data_shape[i].as()->value; + if (i == pos) { + CHECK(dim == channels || dim == 1); + broadcast_shape.push_back(channels); + } else { + broadcast_shape.push_back(1); + } + } + return Reshape(dom_scale, broadcast_shape); +} + +inline bool _IsTensor(Expr dom_scale) { + auto* dom_scale_tensor = dom_scale.as(); + CHECK(dom_scale_tensor); + Array dom_scale_shape = dom_scale_tensor->tensor_type()->shape; + if (dom_scale_shape.size() >= 1) { + CHECK_EQ(dom_scale_shape.size(), 1); + return true; + } + return false; +} + +int _FindChannelPos(Expr arr, const std::string &layout) { + Array data_shape; + if (!arr->checked_type_.defined()) { + //arr = InferType(arr, Module(nullptr)); + arr = InferTypeOpt(arr); + data_shape = arr->checked_type().as()->shape; + } else { + data_shape = arr->checked_type().as()->shape; + } + // TODO(eqy): robust handling of this case + if (data_shape.size() < layout.size()) { + return 0; + } + + int pos = layout.find("C"); + if (pos < 0) { + pos = layout.find("I"); + } + return pos; +} + +inline bool _ConstantEq(Expr s1, Expr s2) { + auto* s1_tensor = s1.as(); + auto* s2_tensor = s2.as(); + CHECK(s1_tensor); + CHECK(s2_tensor); + Array s1_tensor_shape = s1_tensor->tensor_type()->shape; + Array s2_tensor_shape = s2_tensor->tensor_type()->shape; + CHECK(s1_tensor_shape.size() == s2_tensor_shape.size()); + // non-vector constants not suported + CHECK_EQ(s1_tensor_shape.size(), 1); + if (s1_tensor_shape[0].as()->value != s2_tensor_shape[0].as()->value) { + size_t dim; + float val; + const ConstantNode* tensor; + if (s1_tensor_shape[0].as()->value == 1) { + dim = s2_tensor_shape[0].as()->value; + tensor = s2_tensor; + val = static_cast(s1_tensor->data->data)[0]; + } else if (s2_tensor_shape[0].as()->value == 1) { + dim = s1_tensor_shape[0].as()->value; + tensor = s2_tensor; + val = static_cast(s1_tensor->data->data)[0]; + } else { + return false; + } + for (size_t i = 0; i < dim; i++) { + if (val !=\ + static_cast(tensor->data->data)[i]) + return false; + } + return true; + } + size_t dim = s1_tensor_shape[0].as()->value; + for (size_t i = 0; i < dim; i++) { + if (static_cast(s1_tensor->data->data)[i] !=\ + static_cast(s2_tensor->data->data)[i]) + return false; + } + return true; +} + +// Eagerly produce a new ConstantNode by applying an elementwise operation to an +// existing ConstantNode with a custom function +inline Expr _FloatLambda(Expr data, float (*func)(float)) { + CHECK(_IsTensor(data)); + auto* data_tensor = data.as(); + CHECK(data_tensor); + Array data_shape = data_tensor->tensor_type()->shape; + std::vector new_data_shape; + CHECK_EQ(data_shape.size(), 1); + + size_t dim = data_shape[0].as()->value; + new_data_shape.push_back(dim); + + DLContext ctx; + ctx.device_type = kDLCPU; + ctx.device_id = 0; + + DLDataType dtype; + dtype.code = kDLFloat; + dtype.bits = 32; + dtype.lanes = 1; + + runtime::NDArray new_data_array = runtime::NDArray::Empty(new_data_shape, dtype, ctx); + + for (size_t i = 0; i < dim; i++) { + reinterpret_cast(new_data_array->data)[i] =\ + (*func)(reinterpret_cast(data_tensor->data->data)[i]); + } + + return ConstantNode::make(new_data_array); +} + Expr QRealizeIntExprNode::Realize() const { const auto& cfg = QConfig::Current(); Expr data = this->data; @@ -146,25 +379,24 @@ Expr QRealizeIntExprNode::Realize() const { } // dequantize data = Cast(data, Float(32)); - data = Multiply(data, this->dom_scale); + int pos = _FindChannelPos(data, this->data_layout); + CHECK_GE(pos, 0); + Expr broadcastable_dom_scale = _ReshapeChannelScale(this->dom_scale, data, pos); + data = Multiply(data, broadcastable_dom_scale); return data; } -QRealizeIntExpr QRealizeIntExprNode::make(Expr data, Expr dom_scale, DataType dtype) { +QRealizeIntExpr QRealizeIntExprNode::make(Expr data, Expr dom_scale, + DataType dtype, std::string data_layout) { NodePtr n = make_node(); n->data = std::move(data); n->dom_scale = std::move(dom_scale); n->dtype = std::move(dtype); + n->data_layout = std::move(data_layout); return QRealizeIntExpr(n); } -inline Expr ForwardOp(const Call& ref_call, const Array& args) { - return CallNode::make(ref_call->op, - args, ref_call->attrs, ref_call->type_args); -} - - /* calculate `data * s1 / s2`, use shift if possible */ inline Expr MulAndDiv(Expr data, float s1, float s2) { // here we assume the dtype of data is dtype activation @@ -186,6 +418,46 @@ inline Expr MulAndDiv(Expr data, float s1, float s2) { } } +inline Expr MulAndDiv(Expr data, Expr s1, Expr s2, Expr ref_data, const std::string& layout) { + // here we assume the dtype of data is dtype activation + CHECK(_IsTensor(s1)); + CHECK(_IsTensor(s2)); + if (_ConstantEq(s1, s2)) return data; + // should be constant + Expr factor = Divide(s1, s2); + factor = FoldConstantOpt(factor); + // should be constant + Expr shift_factor = _FloatLambda(factor, &std::log2f); + auto* shift_factor_tensor = shift_factor.as(); + CHECK(shift_factor_tensor); + Array shift_factor_tensor_shape = shift_factor_tensor->tensor_type()->shape; + int64_t channels = shift_factor_tensor_shape[0].as()->value; + DLContext ctx; + ctx.device_type = kDLCPU; + ctx.device_id = 0; + DLDataType dtype; + dtype.code = kDLInt; + dtype.bits = 32; + dtype.lanes = 1; + runtime::NDArray shift_array = runtime::NDArray::Empty({channels}, dtype, ctx); + for (int64_t dim = 0; dim < channels; dim++) { + float cur_shift_factor = static_cast(shift_factor_tensor->data->data)[dim]; + // currently only support power of two scaling + CHECK(static_cast(cur_shift_factor) == cur_shift_factor); + reinterpret_cast(shift_array->data)[dim] =\ + static_cast(cur_shift_factor); + } + int pos = _FindChannelPos(ref_data, layout); + CHECK_GE(pos, 0); + Expr broadcastable_shift = _ReshapeChannelScale(ConstantNode::make(shift_array), ref_data, pos); + return LeftShift(data, broadcastable_shift); +} + +float _RoundBias(float shift_nbit) { + float round_bias = std::pow(2.0, shift_nbit - 1); + return round_bias; +} + Expr QuantizeRealize(const Call& ref_call, const Array& new_args, const NodeRef& ctx) { @@ -198,62 +470,146 @@ Expr QuantizeRealize(const Call& ref_call, Expr clip_min = new_args[2]; Expr clip_max = new_args[3]; - float dom_scale_imm = GetScalarFromConstant(dom_scale); float clip_min_imm = GetScalarFromConstant(clip_min); float clip_max_imm = GetScalarFromConstant(clip_max); + auto* dom_scale_tensor = dom_scale.as(); + CHECK(dom_scale_tensor); + Array dom_scale_shape = dom_scale_tensor->tensor_type()->shape; + + std::string layout = param->layout; + // x * idom_scale = y * odom_scale // => y = x * idom_scale / odom_scale if (const auto* n = new_args[0].as()) { // int32->int8 Expr data = n->data; - float idom_scale_imm = GetScalarFromConstant(n->dom_scale); - float odom_scale_imm = GetScalarFromConstant(dom_scale); - if (idom_scale_imm == odom_scale_imm) { - // same domain scale, only clip - data = Clip(data, clip_min_imm, clip_max_imm); - return QRealizeIntExprNode::make(data, dom_scale, n->dtype); - } - - float shift_nbit = std::log2(odom_scale_imm / idom_scale_imm); - CHECK_GT(shift_nbit, 0); - if (static_cast(shift_nbit) == shift_nbit) { - // use right shift + auto* idom_scale_tensor = n->dom_scale.as(); + CHECK(idom_scale_tensor); + Array idom_scale_shape = idom_scale_tensor->tensor_type()->shape; +/* TODO(eqy) + if (dom_scale_shape.size() >= 1) { + CHECK(dom_scale_shape.size() == 1); + CHECK(idom_scale_shape.size() == 1); + size_t dom_scale_channels = dom_scale_shape[0].as()->value; + size_t idom_scale_channels = idom_scale_shape[0].as()->value; + CHECK(dom_scale_channels == idom_scale_channels || dom_scale_channels == 1 + || idom_scale_channels == 1); + Expr factor = Divide(dom_scale, n->dom_scale); + factor = FoldConstantOpt(factor); + auto* factor_tensor = factor.as(); + CHECK(factor_tensor != nullptr); + Expr shift_factor = _FloatLambda(factor, &std::log2f); + auto* shift_factor_tensor = shift_factor.as(); + CHECK(shift_factor_tensor != nullptr); + size_t dim = shift_factor_tensor->data->shape[0]; + for (size_t i = 0; i < dim; i++) { + float val = static_cast(shift_factor_tensor->data->data)[i]; + CHECK(static_cast(val) == val); + } if (cfg->round_for_shift) { - float round_bias = std::pow(2.0, shift_nbit - 1); - data = Add(data, MakeConstantScalar(cfg->dtype_activation, static_cast(round_bias))); + Expr round_bias = _FloatLambda(shift_factor, _RoundBias); + round_bias = FoldConstantOpt(round_bias); + int pos = _FindChannelPos(ref_call->args[0], layout); + CHECK(pos >= 0); + round_bias = _ReshapeChannelScale(round_bias, ref_call->args[0], pos); + round_bias = FoldConstantOpt(round_bias); + CHECK(round_bias.as() != nullptr); + round_bias = FoldConstantOpt(round_bias); + round_bias = Cast(round_bias, n->dtype); + // TODO: why can we not use cfg->dtype_activation? + data = Add(data, round_bias); } - data = RightShift(data, MakeConstantScalar(cfg->dtype_activation, - static_cast(shift_nbit))); - data = Clip(data, clip_min_imm, clip_max_imm); - return QRealizeIntExprNode::make(data, dom_scale, n->dtype); - } else { - // float computation - data = Cast(data, Float(32)); - Expr scaled_data = Multiply(data, Divide(n->dom_scale, dom_scale)); - Expr round_data = Clip(Round(scaled_data), clip_min_imm, clip_max_imm); - return QRealizeIntExprNode::make(round_data, dom_scale, Float(32)); + */ + CHECK_EQ(dom_scale_shape.size(), 1); // remove support for floating point scalar case + + CHECK_EQ(dom_scale_shape.size(), 1); + CHECK_EQ(idom_scale_shape.size(), 1); + size_t dom_scale_channels = dom_scale_shape[0].as()->value; + size_t idom_scale_channels = idom_scale_shape[0].as()->value; + CHECK(dom_scale_channels == idom_scale_channels || dom_scale_channels == 1 + || idom_scale_channels == 1); + Expr factor = Divide(dom_scale, n->dom_scale); + factor = FoldConstantOpt(factor); + auto* factor_tensor = factor.as(); + CHECK(factor_tensor != nullptr); + Expr shift_factor = _FloatLambda(factor, &std::log2f); + auto* shift_factor_tensor = shift_factor.as(); + CHECK(shift_factor_tensor != nullptr); + size_t dim = shift_factor_tensor->data->shape[0]; + for (size_t i = 0; i < dim; i++) { + float val = static_cast(shift_factor_tensor->data->data)[i]; + CHECK(static_cast(val) == val); + } + if (cfg->round_for_shift) { + Expr round_bias = _FloatLambda(shift_factor, _RoundBias); + round_bias = FoldConstantOpt(round_bias); + int pos = _FindChannelPos(ref_call->args[0], layout); + CHECK_GE(pos, 0); + round_bias = _ReshapeChannelScale(round_bias, ref_call->args[0], pos); + round_bias = FoldConstantOpt(round_bias); + CHECK(round_bias.as() != nullptr); + round_bias = FoldConstantOpt(round_bias); + round_bias = Cast(round_bias, n->dtype); + // TODO(eqy): why can we not use cfg->dtype_activation? + data = Add(data, round_bias); } + int pos = _FindChannelPos(ref_call->args[0], layout); + CHECK_GE(pos, 0); + shift_factor = _ReshapeChannelScale(shift_factor, data, pos); + shift_factor = Cast(shift_factor, n->dtype); + data = RightShift(data, shift_factor); + data = Clip(data, clip_min_imm, clip_max_imm); + Expr res = QRealizeIntExprNode::make(data, dom_scale, n->dtype, layout); + return res; } // quantize from real CHECK(!new_args[0]->derived_from()); Expr data = new_args[0]; - Expr scaled_data = Multiply(data, MakeConstantScalar(Float(32), 1 / dom_scale_imm)); + Expr scaled_data; + CHECK(dom_scale_shape.size() >= 1); + CHECK(dom_scale_shape.size() == 1); + int pos = _FindChannelPos(ref_call->args[0], layout); + CHECK(pos >= 0); + Expr broadcastable_dom_scale = _ReshapeChannelScale(dom_scale, new_args[0], pos); + scaled_data = Multiply(data, Divide(MakeConstantScalar(Float(32), 1), + broadcastable_dom_scale)); Expr round_data = Clip(Round(scaled_data), clip_min_imm, clip_max_imm); - return QRealizeIntExprNode::make(round_data, dom_scale, Float(32)); + return QRealizeIntExprNode::make(round_data, dom_scale, Float(32), layout); } -Expr FoldConstantOpt(const Expr& expr) { - auto mod = ModuleNode::FromExpr(expr); - mod = transform::FoldConstant()(mod); - auto entry_func = mod->Lookup("main"); - return expr.as() == nullptr ? entry_func->body : entry_func; -} + RELAY_REGISTER_OP("relay.op.annotation.simulated_quantize") .set_attr("FQRealizeRewrite", QuantizeRealize); +bool _IsStridedSlice(Expr arg) { + auto ref_arg = arg.as(); + if (ref_arg && ref_arg->op == Op::Get("strided_slice")) { + return true; + } + return false; +} + +void _GetStridedIdx(Expr arg, std::vector &idx) { + auto ref_arg = arg.as(); + auto param = ref_arg->attrs.as(); + for (size_t i = 0; i < param->begin.size(); i++) { + auto intimm1 = param->begin[i].as(); + auto intimm2 = param->end[i].as(); + if (!intimm1 || !intimm2) { + continue; + } + if (intimm2->value - intimm1->value == 0) { + continue; + } + CHECK(intimm1->value >= 0); + CHECK(intimm1->value >= 0); + idx.push_back(intimm1->value); + idx.push_back(intimm2->value); + } +} Expr Conv2dRealize(const Call& ref_call, const Array& new_args, @@ -272,7 +628,8 @@ Expr Conv2dRealize(const Call& ref_call, if (lhs->dtype != cfg->dtype_input) { ldata = Cast(ldata, cfg->dtype_input); } - Expr rdata = Cast(rhs->data, cfg->dtype_weight); + Expr rdata = rhs->data; + rdata = Cast(rdata, cfg->dtype_weight); const auto ref_attrs = ref_call->attrs.as(); auto attrs = make_node(); @@ -280,17 +637,77 @@ Expr Conv2dRealize(const Call& ref_call, DataType out_dtype = cfg->dtype_activation; attrs->out_dtype = out_dtype; + Expr input_scale = lhs->dom_scale; + Expr weight_scale = rhs->dom_scale; + + Array data_shape = input_scale.as()->tensor_type()->shape; + Array weight_shape = weight_scale.as()->tensor_type()->shape; + CHECK(data_shape.size() == 1); + CHECK(weight_shape.size() == 1); + size_t data_dim = data_shape[0].as()->value; + size_t weight_dim = weight_shape[0].as()->value; + + Expr dom_scale; + /* Special handling for strided_slice is needed because it changes the number + * of channel dimensions and the number of per-channel scales. We may consider + * changing the srided_slice rewrite to something other than identity to avoid + * this issue.*/ + if (data_dim == weight_dim) { + // TODO)eqy): special handling for only layer wise scale (when both scales are size 1), we can skip this + // calculation and only do the old style: + auto* data_scale_tensor = input_scale.as(); + auto* weight_scale_tensor = weight_scale.as(); + + // CURRENT scheme relies product and weight scales to be matched after + // multiplying + float max_output_scale =\ + reinterpret_cast(data_scale_tensor->data->data)[0]*\ + reinterpret_cast(weight_scale_tensor->data->data)[0]; + + for (size_t i = 0; i < weight_dim; i++) { + float cur_output_scale =\ + reinterpret_cast(data_scale_tensor->data->data)[i]*\ + reinterpret_cast(weight_scale_tensor->data->data)[i]; + CHECK(cur_output_scale == max_output_scale); + } + dom_scale = Multiply(Ones({1}, Float(32)), MakeConstantScalar(Float(32), max_output_scale)); + dom_scale = FoldConstantOpt(dom_scale); + + CHECK(dom_scale.as()); + CHECK(dom_scale.as()->tensor_type()->shape.size() == 1); + } else if (data_dim == weight_dim * attrs->groups && weight_dim != 1) { + Array weight_scale_tuple; + for (int i = 0; i < attrs->groups; i++) { + weight_scale_tuple.push_back(weight_scale); + } + dom_scale = Multiply(input_scale, MakeConcatenate(TupleNode::make(weight_scale_tuple), 0)); + dom_scale = FoldConstant(dom_scale); + CHECK(dom_scale.as()); + } else { + // depthwise + CHECK(weight_dim == 1); + + // unmatched scales are fine for depthwise convolution + dom_scale = Multiply(input_scale, weight_scale); + dom_scale = FoldConstantOpt(dom_scale); + CHECK(dom_scale.as()); + CHECK((size_t) dom_scale.as()->tensor_type()->shape[0].as()->value == data_dim); + } Expr ret = CallNode::make(ref_call->op, {ldata, rdata}, Attrs(attrs), ref_call->type_args); +/* + TODO(eqy) Expr mul = Multiply(lhs->dom_scale, rhs->dom_scale); Expr dom_scale = FoldConstantOpt(mul); return QRealizeIntExprNode::make(ret, dom_scale, out_dtype); +*/ + + return QRealizeIntExprNode::make(ret, dom_scale, out_dtype, attrs->data_layout); } RELAY_REGISTER_OP("nn.conv2d") .set_attr("FQRealizeRewrite", Conv2dRealize); - Expr DenseRealize(const Call& ref_call, const Array& new_args, const NodeRef& ctx) { @@ -301,30 +718,39 @@ Expr DenseRealize(const Call& ref_call, } const auto* lhs = new_args[0].as(); const auto* rhs = new_args[1].as(); + CHECK(lhs); + CHECK(rhs); Expr ldata = lhs->data; if (lhs->dtype != cfg->dtype_input) { ldata = Cast(ldata, cfg->dtype_input); } Expr rdata = Cast(rhs->data, cfg->dtype_weight); - const auto ref_attrs = ref_call->attrs.as(); + CHECK(ref_attrs); auto attrs = make_node(); *attrs = *ref_attrs; DataType out_dtype = cfg->dtype_activation; attrs->out_dtype = out_dtype; - Expr ret = CallNode::make(ref_call->op, {ldata, rdata}, Attrs(attrs), ref_call->type_args); +/* Expr mul = Multiply(lhs->dom_scale, rhs->dom_scale); Expr dom_scale = FoldConstantOpt(mul); return QRealizeIntExprNode::make(ret, dom_scale, out_dtype); +*/ + Expr dom_scale = FoldConstantOpt(Multiply(lhs->dom_scale, rhs->dom_scale)); + CHECK(dom_scale.as()); + //CHECK(ref_call->args[0].as()); + const PackedFunc* layout_f = runtime::Registry::Get("relay.quantize._get_layout"); + std::string layout = (*layout_f) (ref_call); + + return QRealizeIntExprNode::make(ret, dom_scale, out_dtype, layout); } RELAY_REGISTER_OP("nn.dense") .set_attr("FQRealizeRewrite", DenseRealize); - Expr MulRealize(const Call& ref_call, const Array& new_args, const NodeRef& ctx) { @@ -336,23 +762,24 @@ Expr MulRealize(const Call& ref_call, const auto* rhs = new_args[1].as(); Expr ldata = lhs->data; Expr rdata = rhs->data; - DataType dtype = cfg->dtype_activation; - if (lhs->dtype == Float(32)) { - ldata = Cast(ldata, dtype); - } else { - CHECK_EQ(lhs->dtype, dtype); - } - if (rhs->dtype == Float(32)) { - rdata = Cast(rdata, dtype); - } else { - CHECK_EQ(rhs->dtype, dtype); - } + ldata = Cast(ldata, dtype); + rdata = Cast(rdata, dtype); Expr ret = ForwardOp(ref_call, {ldata, rdata}); + /* + TODO(eqy): check Expr mul = Multiply(lhs->dom_scale, rhs->dom_scale); Expr dom_scale = FoldConstantOpt(mul); return QRealizeIntExprNode::make(ret, dom_scale, dtype); + */ + Expr dom_scale = FoldConstantOpt(Multiply(lhs->dom_scale, rhs->dom_scale)); + CHECK(dom_scale.as()); + CHECK(dom_scale.as()->tensor_type()->shape.size() == 1); + const PackedFunc* layout_f = runtime::Registry::Get("relay.quantize._get_layout"); + std::string layout = (*layout_f) (ref_call); + + return QRealizeIntExprNode::make(ret, dom_scale, dtype, layout); } CHECK(!new_args[0]->derived_from() && !new_args[1]->derived_from()); return Expr(nullptr); @@ -362,25 +789,101 @@ RELAY_REGISTER_OP("multiply") .set_attr("FQRealizeRewrite", MulRealize); -float ChooseDomScale(const std::vector& nptrs) { +Expr ChooseDomScale(const std::vector& nptrs, + bool max=false) { + DLContext ctx; + ctx.device_type = kDLCPU; + ctx.device_id = 0; + + DLDataType dtype; + dtype.code = kDLFloat; + dtype.bits = 32; + dtype.lanes = 1; if (nptrs.size() == 2) { // x = a * s1, y = b * s2 // x + y = (a * s1 / s2 + b) * s2, if s1 > s2 // = (a + b * s2 / s1) * s1, if s2 > s1 - float s1 = GetScalarFromConstant(nptrs[0]->dom_scale); - float s2 = GetScalarFromConstant(nptrs[1]->dom_scale); - return s1 > s2 ? s2 : s1; + Expr s1 = nptrs[0]->dom_scale; + Expr s2 = nptrs[1]->dom_scale; + auto* s1_tensor = s1.as(); + auto* s2_tensor = s2.as(); + CHECK(s1_tensor); + CHECK(s2_tensor); + Array s1_shape = s1_tensor->tensor_type()->shape; + Array s2_shape = s2_tensor->tensor_type()->shape; + // tensor dom scales + CHECK(s1_shape.size() >= 1); + CHECK(s1_shape.size() == 1); + CHECK(s2_shape.size() == 1); + // broadcasting + if (s1_shape[0].as()->value != s2_shape[0].as()->value) { + CHECK(s1_shape[0].as()->value == 1 || s2_shape[0].as()->value == 1); + const ConstantNode* single; + const ConstantNode* broadcast_to; + if (s1_shape[0].as()->value == 1) { + single = s1_tensor; + broadcast_to = s2_tensor; + } + else { + single = s2_tensor; + broadcast_to = s1_tensor; + } + float cur_s1 = reinterpret_cast(single->data->data)[0]; + int64_t dim = broadcast_to->tensor_type()->shape[0].as()->value; + + runtime::NDArray s = runtime::NDArray::Empty({dim}, dtype, ctx); + for (int64_t i = 0; i < dim; i++) { + float cur_s2 = reinterpret_cast(broadcast_to->data->data)[i]; + float cur_s = cur_s1 > cur_s2 ? cur_s2 : cur_s1; + reinterpret_cast(s->data)[i] = cur_s; + } + return ConstantNode::make(s); + } else { + int64_t dim = s1_shape[0].as()->value; + runtime::NDArray s = runtime::NDArray::Empty({dim}, dtype, ctx); + for (int64_t i = 0; i < dim; i++) { + float cur_s1 = reinterpret_cast(s1_tensor->data->data)[i]; + float cur_s2 = reinterpret_cast(s2_tensor->data->data)[i]; + reinterpret_cast(s->data)[i] = cur_s1 > cur_s2 ? cur_s2 : cur_s1; + } + return ConstantNode::make(s); + } + } else if (max) { + Expr scale; + std::vector scales; + for (size_t i = 0; i < nptrs.size(); i++) { + Expr s = nptrs[i]->dom_scale; + auto* s_tensor = s.as(); + CHECK(s_tensor); + Array s_shape = s_tensor->tensor_type()->shape; + CHECK_EQ(s_shape[0].as()->value, 1); + scales.push_back(static_cast(s_tensor->data->data)[0]); + if (!i) { + scale = s; + } else { + scale = Min(s, scale); + } + } + return FoldConstantOpt(scale); } else { + LOG(INFO) << "WARNING, using global scale"; const QConfig& cfg = QConfig::Current(); float scale = cfg->global_scale; - return scale / std::pow(2.0, cfg->nbit_activation - 1); + scale = scale / std::pow(2.0, cfg->nbit_activation - 1); + runtime::NDArray s = runtime::NDArray::Empty({1}, dtype, ctx); + reinterpret_cast(s->data)[0] = scale; + Expr scale_constant = ConstantNode::make(s); + return scale_constant; } } - /* \brief Unify the dom scale of arguments */ -Array UnifyDTypeScale(const Array& ref_args, const Array& args, - DataType* dtype_ptr, Expr* scale_ptr) { +Array UnifyDTypeScale(const Array& ref_args, + const Array& args, + DataType* dtype_ptr, + Expr* scale_ptr, + const std::string& layout, + bool min=false) { static const Op& simulated_quantize = Op::Get("relay.op.annotation.simulated_quantize"); const QConfig& cfg = QConfig::Current(); @@ -414,11 +917,11 @@ Array UnifyDTypeScale(const Array& ref_args, const Array& args } // unify the dom_scale - float s = ChooseDomScale(nptrs); - Expr dom_scale = MakeConstantScalar(Float(32), s); + // s should be a constant, created by ChooseDomScale + Expr dom_scale = ChooseDomScale(nptrs, min); for (size_t i = 0; i < ret.size(); ++i) { - float cur_s = GetScalarFromConstant(nptrs[i]->dom_scale); - ret.Set(i, MulAndDiv(ret[i], cur_s, s)); + Expr cur_s = nptrs[i]->dom_scale; + ret.Set(i, MulAndDiv(ret[i], cur_s, dom_scale, ref_args[i], layout)); } *dtype_ptr = dtype; @@ -426,6 +929,7 @@ Array UnifyDTypeScale(const Array& ref_args, const Array& args return ret; } + Expr AddRealize(const Call& ref_call, const Array& new_args, const NodeRef& ctx) { @@ -433,15 +937,18 @@ Expr AddRealize(const Call& ref_call, if (new_args[0].as() && new_args[1].as()) { DataType dtype; Expr dom_scale; - Array ret_args = UnifyDTypeScale(ref_call->args, new_args, &dtype, &dom_scale); + const PackedFunc* layout_f = runtime::Registry::Get("relay.quantize._get_layout"); + std::string layout = (*layout_f) (ref_call); + Array ret_args = UnifyDTypeScale(ref_call->args, new_args, &dtype, &dom_scale, layout); Expr ret = ForwardOp(ref_call, ret_args); - return QRealizeIntExprNode::make(ret, dom_scale, dtype); + return QRealizeIntExprNode::make(ret, dom_scale, dtype, layout); } CHECK(!new_args[0]->derived_from() && !new_args[1]->derived_from()); return Expr(nullptr); } + RELAY_REGISTER_OP("add") .set_attr("FQRealizeRewrite", AddRealize); @@ -458,7 +965,7 @@ Expr ClipRealize(const Call& ref_call, Expr ret = CallNode::make(ref_call->op, {n->data}, Attrs(attrs), ref_call->type_args); - return QRealizeIntExprNode::make(ret, n->dom_scale, n->dtype); + return QRealizeIntExprNode::make(ret, n->dom_scale, n->dtype, n->data_layout); } CHECK(!new_args[0]->derived_from()); return Expr(nullptr); @@ -468,25 +975,94 @@ RELAY_REGISTER_OP("clip") .set_attr("FQRealizeRewrite", ClipRealize); +/* \brief Unify the dom scale of arguments */ +Array ConcatenateDTypeScale(const Array& ref_args, + const Array& args, + DataType* dtype_ptr, + Expr* scale_ptr, + const std::string& layout) { + static const Op& simulated_quantize = Op::Get("relay.op.annotation.simulated_quantize"); + const QConfig& cfg = QConfig::Current(); + + std::vector nptrs; + Array ret; + for (auto arg : args) { + const auto* nptr = arg.as(); + CHECK(nptr); + nptrs.push_back(nptr); + ret.push_back(nptr->data); + } + // unify the data type + CHECK_EQ(ref_args.size(), args.size()); + DataType dtype = cfg->dtype_activation; + for (size_t i = 0; i < ret.size(); ++i) { + auto ref_arg = ref_args[i].as(); + if (nptrs[i]->dtype != dtype) { + ret.Set(i, Cast(ret[i], dtype)); + } else if (ref_arg && ref_arg->op.same_as(simulated_quantize) && + ref_arg->attrs.as()->kind == kQInput) { + auto new_arg = Cast(ret[i], cfg->dtype_input); + //TODO(eqy): if (cfg->use_stop_fusion) { + // new_arg = StopFusion(new_arg); + //} + ret.Set(i, Cast(new_arg, dtype)); + } + } + // unify the dom_scale + // s should be a constant, created by ChooseDomScale + Array dom_scales; + for (size_t i = 0; i < ret.size(); ++i) { + Expr data = ref_args[i]; + if (!data->checked_type_.defined()) { + //data = InferType(data, Module(nullptr)); + data = InferTypeOpt(data); + } + int pos = _FindChannelPos(data, layout); + int dom_scale_dim = nptrs[i]->dom_scale.as()->tensor_type()->shape[0].as()->value; + int channels = data->checked_type().as()->shape[pos].as()->value; + if (channels != dom_scale_dim) { + CHECK(dom_scale_dim == 1); + dom_scales.push_back(FoldConstantOpt(Multiply(Ones({channels}, Float(32)), nptrs[i]->dom_scale))); + } else { + dom_scales.push_back(nptrs[i]->dom_scale); + } + } + Expr dom_scale = MakeConcatenate(TupleNode::make(dom_scales), 0); + dom_scale = FoldConstantOpt(dom_scale); + *dtype_ptr = dtype; + *scale_ptr = dom_scale; + return ret; +} + + Expr ConcatenateRealize(const Call& ref_call, const Array& new_args, const NodeRef& ctx) { CHECK_EQ(new_args.size(), 1); CHECK_EQ(ref_call->args.size(), 1); - const auto* tuple = new_args[0].as(); const auto* ref_tuple = ref_call->args[0].as(); - CHECK(tuple); CHECK(ref_tuple); + CHECK(tuple); const Array& arr = tuple->fields; const Array& ref_arr = ref_tuple->fields; if (arr[0].as()) { DataType dtype; Expr dom_scale; - Array ret_args = UnifyDTypeScale(ref_arr, arr, &dtype, &dom_scale); + // CHECK that it is is a per-channel concatenate + // TODO(eqy): consider adding granularity as a field instead of relying on + // brittle heuristic + if (arr[0].as()->dom_scale.as()->tensor_type()->shape[0].as()->value > 1) { + Array ret_args = ConcatenateDTypeScale(ref_arr, arr, &dtype, &dom_scale, arr[0].as()->data_layout); + Expr ret = ForwardOp(ref_call, {TupleNode::make(ret_args)}); + return QRealizeIntExprNode::make(ret, dom_scale, dtype, arr[0].as()->data_layout); + + } else { + Array ret_args = UnifyDTypeScale(ref_arr, arr, &dtype, &dom_scale, arr[0].as()->data_layout, true); Expr ret = ForwardOp(ref_call, {TupleNode::make(ret_args)}); - return QRealizeIntExprNode::make(ret, dom_scale, dtype); + return QRealizeIntExprNode::make(ret, dom_scale, dtype, arr[0].as()->data_layout); + } } else { for (auto arg : new_args) { CHECK(!arg->derived_from()); @@ -495,6 +1071,7 @@ Expr ConcatenateRealize(const Call& ref_call, } } + RELAY_REGISTER_OP("concatenate") .set_attr("FQRealizeRewrite", ConcatenateRealize); @@ -505,13 +1082,25 @@ Expr IdentityRealize(const Call& ref_call, const NodeRef& ctx) { CHECK_EQ(new_args.size(), 1); if (const auto* n = new_args[0].as()) { - Expr ret = ForwardOp(ref_call, {n->data}); - return QRealizeIntExprNode::make(ret, n->dom_scale, n->dtype); + int scale_dim = n->dom_scale.as()->tensor_type()->shape[0].as()->value; + // TODO(eqy):use more reliable check for per-layer scale + if (ref_call->op == Op::Get("strided_slice") && scale_dim > 1) { + std::vector idx; + _GetStridedIdx(ref_call, idx); + Expr sliced_scale = MakeStridedSlice(n->dom_scale, {(int) idx[0]}, {(int) idx[1]}, {1}); + sliced_scale = FoldConstantOpt(sliced_scale); + Expr ret = ForwardOp(ref_call, {n->data}); + return QRealizeIntExprNode::make(ret, sliced_scale, n->dtype, n->data_layout); + } else { + Expr ret = ForwardOp(ref_call, {n->data}); + return QRealizeIntExprNode::make(ret, n->dom_scale, n->dtype, n->data_layout); + } } CHECK(!new_args[0]->derived_from()); return Expr(nullptr); } + RELAY_REGISTER_OP("nn.relu") .set_attr("FQRealizeRewrite", IdentityRealize); @@ -530,7 +1119,7 @@ Expr CastDtypeInputRealize(const Call& ref_call, if (const auto* n = new_args[0].as()) { Expr data = Cast(n->data, cfg->dtype_input); Expr ret = ForwardOp(ref_call, {data}); - return QRealizeIntExprNode::make(ret, n->dom_scale, cfg->dtype_input); + return QRealizeIntExprNode::make(ret, n->dom_scale, cfg->dtype_input, n->data_layout); } CHECK(!new_args[0]->derived_from()); return Expr(nullptr); @@ -551,7 +1140,7 @@ Expr AvgPoolRealize(const Call& ref_call, data = Cast(n->data, cfg->dtype_activation); } Expr ret = ForwardOp(ref_call, {data}); - return QRealizeIntExprNode::make(ret, n->dom_scale, cfg->dtype_activation); + return QRealizeIntExprNode::make(ret, n->dom_scale, cfg->dtype_activation, n->data_layout); } CHECK(!new_args[0]->derived_from()); return Expr(nullptr); @@ -567,7 +1156,7 @@ Expr ForceCastRealize(const Call& ref_call, CHECK_EQ(new_args.size(), 1); if (const auto* n = new_args[0].as()) { Expr ret = Cast(n->data, cfg->dtype_input); - return QRealizeIntExprNode::make(ret, n->dom_scale, cfg->dtype_input); + return QRealizeIntExprNode::make(ret, n->dom_scale, cfg->dtype_input, n->data_layout); } CHECK(!new_args[0]->derived_from()); return Expr(nullptr); @@ -638,6 +1227,10 @@ TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable) p->stream << "round_for_shift==" << op->round_for_shift << ", "; p->stream << "store_lowbit_output==" << op->store_lowbit_output << ", "; p->stream << "debug_enabled_ops==" << op->debug_enabled_ops; + p->stream << "passthrough_bound=" << op->passthrough_bound << ", "; + //TODO(eqy): p->stream << "use_stop_fusion==" << op->use_stop_fusion << ", "; + p->stream << "granularity="<< op->granularity; + p->stream << ")"; p->stream << ")"; }); diff --git a/src/relay/pass/quantize.h b/src/relay/pass/quantize.h index 262d420acf97..20ed556115e0 100644 --- a/src/relay/pass/quantize.h +++ b/src/relay/pass/quantize.h @@ -108,6 +108,7 @@ class QRealizeExprNode : public TempExprNode { public: /*! \brief The original expression */ Expr data; + std::string data_layout; static constexpr const char* _type_key = "relay.quantize.QRealizeExpr"; TVM_DECLARE_BASE_NODE_INFO(QRealizeExprNode, TempExprNode); }; @@ -125,11 +126,12 @@ class QRealizeIntExprNode : public QRealizeExprNode { v->Visit("data", &data); v->Visit("dom_scale", &dom_scale); v->Visit("dtype", &dtype); + v->Visit("data_layout", &data_layout); } Expr Realize() const final; - TVM_DLL static QRealizeIntExpr make(Expr data, Expr dom_scale, DataType dtype); + TVM_DLL static QRealizeIntExpr make(Expr data, Expr dom_scale, DataType dtype, std::string data_layout); static constexpr const char * _type_key = "relay.quantize.QRealizeIntExpr"; TVM_DECLARE_NODE_TYPE_INFO(QRealizeIntExprNode, QRealizeExprNode); @@ -152,10 +154,14 @@ class QConfigNode : public Node { DataType dtype_weight = Int(8); DataType dtype_activation = Int(32); double global_scale = 8.0; + //TODO(eqy): int skip_k_conv = 1; + int passthrough_bound = 1e9; Array skip_conv_layers = Array(NodePtr(nullptr)); bool round_for_shift = true; bool store_lowbit_output = true; Array debug_enabled_ops = Array(NodePtr(nullptr)); + //TODO(eqy): bool use_stop_fusion = true; + std::string granularity = "layer"; void VisitAttrs(AttrVisitor* v) final { v->Visit("nbit_input", &nbit_input); @@ -165,10 +171,14 @@ class QConfigNode : public Node { v->Visit("dtype_weight", &dtype_weight); v->Visit("dtype_activation", &dtype_activation); v->Visit("global_scale", &global_scale); + v->Visit("passthrough_bound", &passthrough_bound); + //TODO(eqy): v->Visit("skip_k_conv", &skip_k_conv); v->Visit("skip_conv_layers", &skip_conv_layers); v->Visit("round_for_shift", &round_for_shift); v->Visit("store_lowbit_output", &store_lowbit_output); v->Visit("debug_enabled_ops", &debug_enabled_ops); + //TODO(eqy): v->Visit("use_stop_fusion", &use_stop_fusion); + v->Visit("granularity", &granularity); } static constexpr const char* _type_key = "relay.quantize.QConfig";