Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
39 changes: 39 additions & 0 deletions python/tvm/relay/build_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,8 @@
from . import ty as _ty
from . import expr as _expr
from .module import Module as _Module
#TODO(eqy)
#from . import transform as _transform
from .backend import interpreter as _interpreter
from .backend.vm import VMExecutor

Expand Down Expand Up @@ -113,6 +115,43 @@ def build(self, func, target=None, target_host=None, params=None):

return graph_json, mod, params

#TODO(eqy)
# def _setup_build_config(self, params):
# cfg = _transform.PassContext.current()
#
# # Set opt_level.
# self.set_opt_level(cfg.opt_level)
#
# # Set fallback device if it is available.
# if cfg.fallback_device:
# self.set_fallback_device(cfg.fallback_device)
#
# # Add required passes.
# if cfg.required_pass:
# passes = set()
# if isinstance(cfg.required_pass, (list, tuple, set)):
# passes = set(cfg.required_pass)
# else:
# raise TypeError("add_pass must be list, tuple, or set, but " +
# "got {}".format(type(cfg.required_pass)))
# for pass_name in passes:
# self.add_pass(pass_name)
#
# # Add disabled passes.
# if cfg.disabled_pass:
# passes = set()
# if isinstance(cfg.disabled_pass, (list, tuple, set)):
# passes = set(cfg.disabled_pass)
# else:
# raise TypeError("disable_pass must be list, tuple, or set, " +
# "but got {}".format(type(cfg.disabled_pass)))
# for pass_name in passes:
# print("DISABLING", pass_name)
# self.disable_pass(pass_name)
#
# if params:
# self._set_params(params)

def _set_params(self, params):
inputs = {}
for name, param in params.items():
Expand Down
155 changes: 131 additions & 24 deletions python/tvm/relay/quantize/_annotate.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
import topi
from . import _quantize
from .quantize import QAnnotateKind, current_qconfig
from .quantize import annotate_context
from .quantize import annotate_context, _get_scale_counter, _set_scale_counter, _get_layout
from .. import expr as _expr
from .. import op as _op
from ..op import op as _reg
Expand All @@ -36,11 +36,62 @@ def simulated_quantize_compute(attrs, inputs, out_type, target):
assert len(inputs) == 4
assert attrs.sign
assert attrs.rounding == "round"

data, scale, clip_min, clip_max = inputs

if (data.op == _op.get("concatenate")):
raise Exception

if 'O' in attrs.layout and 'I' in attrs.layout:
channel_dim = attrs.layout.index('I')
else:
channel_dim = attrs.layout.index('C')

shape = []
for i in range(0, len(data.shape)):
shape.append(topi.util.get_const_int(data.shape[i]))
channels = 0

if 'broadcastable' in attrs.op_hint and len(data.shape) < len(attrs.layout):
#assert len(data.shape) == len(attrs.layout) - 1, "Unsupported broadcast"
if len(data.shape) == len(attrs.layout) - 1:
for d in range(0, len(data.shape)):
if shape[d] != 1:
channel_dim = d
channels = shape[d]
break
else:
channel_dim = 0
channels = 1

if 'dense' in attrs.op_hint:
channel_dim = 1
channels = 1

if attrs.passthrough:
# if original value should be passed through
assert attrs.kind != QAnnotateKind.WEIGHT
rdata = topi.identity(data)
return [rdata]

# simulate rounding error
scaled_data = topi.divide(data, scale)
if attrs.granularity == 'channel':
assert len(data.shape) >= 4 or 'broadcastable' in attrs.op_hint or 'dense' in attrs.op_hint,\
"Unsupported Layout / Shape Broadcast"
# TODO consider memory/performance tradeoffs of not using a big scale
# tensor (does this impact fusion with other ops?)

# TODO blocked layouts
if channels == 0:
channels = shape[channel_dim]
scale_chunk_shape = [1]*len(shape)
scale_chunk_shape[channel_dim] = channels
scale_chunk = topi.reshape(scale, scale_chunk_shape)
scale_tensor = topi.broadcast_to(scale_chunk, shape)
scaled_data = topi.divide(data, scale_tensor)
scale = scale_tensor
else:
scaled_data = topi.divide(data, scale)
clipped_data = topi.maximum(topi.minimum(scaled_data, clip_max), clip_min)
round_data = topi.round(clipped_data)

Expand All @@ -49,6 +100,7 @@ def simulated_quantize_compute(attrs, inputs, out_type, target):
return [rdata]



_reg.register_schedule("relay.op.annotation.simulated_quantize",
_reg.schedule_injective)
_reg.register_pattern("relay.op.annotation.simulated_quantize",
Expand Down Expand Up @@ -116,7 +168,8 @@ def frewrite_with_guard(ref_call, new_args, ctx):
return _register(frewrite) if frewrite is not None else _register


def attach_simulated_quantize(data, kind, sign=True, rounding="round"):
@register_func("relay.quantize.attach_simulated_quantize")
def attach_simulated_quantize(data, kind, layout=None, op_hint="", sign=True, rounding="round"):
"""Attach a simulated quantize operation after input data expr.

Parameters
Expand All @@ -127,25 +180,34 @@ def attach_simulated_quantize(data, kind, sign=True, rounding="round"):
kind: QAnnotateKind
the kind of annotation field.
"""

quantize_op = _op.get("relay.op.annotation.simulated_quantize")
if isinstance(data, _expr.Call) and data.op == quantize_op:
if data.attrs.kind == kind and data.attrs.sign == sign and data.attrs.rounding == rounding:
return data

#TODO(eqy): check
actx = annotate_context()
counter = _get_scale_counter()
key = tuple([data, kind, sign, rounding])
if key in actx.qnode_map:
return actx.qnode_map[key]

dom_scale = _expr.var("dom_scale")
granularity = current_qconfig().granularity
dom_scale = _expr.var("dom_scale" + str(counter))
clip_min = _expr.var("clip_min")
clip_max = _expr.var("clip_max")
passthrough = 0
passthrough_bound = current_qconfig().passthrough_bound
if kind != QAnnotateKind.WEIGHT:
passthrough = counter > passthrough_bound
_set_scale_counter(counter + 1)

qnode = _quantize.simulated_quantize(
data, dom_scale, clip_min, clip_max, kind, sign, rounding)
data, dom_scale, clip_min, clip_max, kind, sign, rounding, passthrough, granularity, layout, op_hint, -1)
actx.qnode_map[key] = qnode
return qnode

register_func("relay.quantize.attach_simulated_quantize", attach_simulated_quantize)
return qnode


@register_annotate_function("nn.contrib_conv2d_NCHWc")
Expand All @@ -168,17 +230,39 @@ def conv2d_rewrite(ref_call, new_args, ctx):
return None
actx.count_conv2d()

op_hint = ""
len_hint = -1

data_layout = ref_call.attrs.data_layout
kernel_layout = ref_call.attrs.kernel_layout

_set_conv_counter(cnt + 1)

lhs_expr, lhs_kind = _get_expr_kind(new_args[0])
rhs_expr, rhs_kind = _get_expr_kind(new_args[1])

if lhs_kind is None or lhs_kind == QAnnotateKind.ACTIVATION:
lhs_expr = attach_simulated_quantize(lhs_expr, QAnnotateKind.INPUT)
if lhs_kind is None or lhs_kind != QAnnotateKind.INPUT:
assert _get_layout(ref_call) is not None
lhs_expr = attach_simulated_quantize(lhs_expr, QAnnotateKind.INPUT, _get_layout(ref_call))

assert rhs_kind is None
rhs_expr = attach_simulated_quantize(rhs_expr, QAnnotateKind.WEIGHT)
assert kernel_layout is not None
if 'I' in kernel_layout:
assert 'C' in data_layout
kernel_channels = rhs_expr.data.shape[kernel_layout.index('I')]
if kernel_channels <= 1:
data_channels = ref_call.args[0].checked_type.shape[data_layout.index('C')]
if int(data_channels) > int(kernel_channels):
op_hint = 'depthwise_sep_single'
len_hint = int(data_channels)

if lhs_kind is None or lhs_kind != QAnnotateKind.INPUT:
assert data_layout is not None
lhs_expr = attach_simulated_quantize(lhs_expr, QAnnotateKind.INPUT, data_layout)

rhs_expr = attach_simulated_quantize(rhs_expr, QAnnotateKind.WEIGHT, kernel_layout, op_hint, len_hint)

expr = _forward_op(ref_call, [lhs_expr, rhs_expr])

return QAnnotateExpr(expr, QAnnotateKind.ACTIVATION)


Expand All @@ -202,11 +286,16 @@ def dense_rewrite(ref_call, new_args, ctx):
lhs_expr, lhs_kind = _get_expr_kind(new_args[0])
rhs_expr, rhs_kind = _get_expr_kind(new_args[1])

if lhs_kind is None or lhs_kind == QAnnotateKind.ACTIVATION:
lhs_expr = attach_simulated_quantize(lhs_expr, QAnnotateKind.INPUT)
op_hint = "dense"
len_hint = 1
assert _get_layout(ref_call) is not None
if lhs_kind is None or lhs_kind != QAnnotateKind.INPUT:
lhs_expr = attach_simulated_quantize(lhs_expr, QAnnotateKind.INPUT,
_get_layout(ref_call), op_hint, len_hint)

assert rhs_kind is None
rhs_expr = attach_simulated_quantize(rhs_expr, QAnnotateKind.WEIGHT)
rhs_expr = attach_simulated_quantize(rhs_expr, QAnnotateKind.WEIGHT,
_get_layout(ref_call), op_hint, len_hint)

expr = _forward_op(ref_call, [lhs_expr, rhs_expr])

Expand All @@ -226,11 +315,14 @@ def multiply_rewrite(ref_call, new_args, ctx):
return None

if lhs_kind in [QAnnotateKind.ACTIVATION, QAnnotateKind.INPUT] and rhs_kind is None:
assert _get_layout(ref_call) is not None
# quantize lhs to INPUT field
if lhs_kind == QAnnotateKind.ACTIVATION:
lhs_expr = attach_simulated_quantize(lhs_expr, QAnnotateKind.INPUT)
lhs_expr = attach_simulated_quantize(lhs_expr, QAnnotateKind.INPUT,
_get_layout(ref_call))
# quantize rhs to WEIGHT field
rhs_expr = attach_simulated_quantize(rhs_expr, QAnnotateKind.WEIGHT)
rhs_expr = attach_simulated_quantize(rhs_expr, QAnnotateKind.WEIGHT,
_get_layout(ref_call), 'broadcastable_mul')

expr = _forward_op(ref_call, [lhs_expr, rhs_expr])
return QAnnotateExpr(expr, QAnnotateKind.ACTIVATION)

Expand All @@ -249,23 +341,36 @@ def add_rewrite(ref_call, new_args, ctx):
if lhs_kind is None and rhs_kind is None:
return None

assert _get_layout(ref_call) is not None
if lhs_kind is None and rhs_kind is not None:
# quantize lhs to INPUT field if it is normal expression
assert rhs_kind == QAnnotateKind.INPUT
lhs_expr = attach_simulated_quantize(lhs_expr, QAnnotateKind.INPUT)
lhs_expr = attach_simulated_quantize(lhs_expr, QAnnotateKind.INPUT, _get_layout(ref_call))
expr = _forward_op(ref_call, [lhs_expr, rhs_expr])
return QAnnotateExpr(expr, QAnnotateKind.INPUT)

#TODO(eqy): check
# if lhs_kind is None and rhs_kind is not None:
# # quantize lhs to INPUT field if it is normal expression
# lhs_expr = attach_simulated_quantize(lhs_expr, QAnnotateKind.INPUT,
# _get_layout(ref_call))
if lhs_kind is not None and rhs_kind is None:
if isinstance(rhs_expr, _expr.Constant):
# quantize rhs to WEIGHT field if it is Constant
rhs_expr = attach_simulated_quantize(rhs_expr, QAnnotateKind.WEIGHT)
rhs_expr = attach_simulated_quantize(rhs_expr, QAnnotateKind.WEIGHT,
_get_layout(ref_call), 'broadcastable_add')
else:
# quantize rhs to INPUT field if it is not Constant
rhs_expr = attach_simulated_quantize(rhs_expr, QAnnotateKind.INPUT)
rhs_expr = attach_simulated_quantize(rhs_expr, QAnnotateKind.INPUT, _get_layout(ref_call))
expr = _forward_op(ref_call, [lhs_expr, rhs_expr])
return QAnnotateExpr(expr, QAnnotateKind.ACTIVATION)

#TODO(eqy): check
# if lhs_kind == QAnnotateKind.ACTIVATION and rhs_kind == QAnnotateKind.ACTIVATION:
# # quantize rhs to INPUT field if both lhs and rhs are ACTIVATION
# rhs_expr = attach_simulated_quantize(rhs_expr, QAnnotateKind.INPUT,
# _get_layout(ref_call))
#
if lhs_kind is not None and rhs_kind is not None:
if lhs_kind == QAnnotateKind.INPUT and rhs_kind == QAnnotateKind.INPUT:
expr = _forward_op(ref_call, [lhs_expr, rhs_expr])
Expand Down Expand Up @@ -325,9 +430,10 @@ def pool2d_rewrite(ref_call, new_args, ctx):

if x_kind is None:
return None
if x_kind == QAnnotateKind.ACTIVATION:
expr = attach_simulated_quantize(expr, QAnnotateKind.INPUT)

if x_kind == QAnnotateKind.ACTIVATION:
expr = attach_simulated_quantize(expr, QAnnotateKind.INPUT,
_get_layout(ref_call))
expr = _forward_op(ref_call, [expr])
return QAnnotateExpr(expr, QAnnotateKind.INPUT)

Expand Down Expand Up @@ -368,7 +474,8 @@ def concatenate_rewrite(ref_call, new_args, ctx):
return None
for i, k in enumerate(kind_list):
if k is None:
expr_list[i] = attach_simulated_quantize(expr_list[i], QAnnotateKind.ACTIVATION)
expr_list[i] = attach_simulated_quantize(expr_list[i], QAnnotateKind.ACTIVATION, _get_layout(ref_call))

expr = _forward_op(ref_call, [_expr.Tuple(expr_list)])
return QAnnotateExpr(expr, QAnnotateKind.ACTIVATION)

Expand Down
Loading