Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
279 changes: 275 additions & 4 deletions python/tvm/relay/op/contrib/dnnl.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,13 +39,14 @@
from tvm.relay import transform
from tvm.relay.expr import GlobalVar
from tvm.relay.expr_functor import ExprMutator, ExprVisitor
from tvm.relay.expr import const

from tvm.relay.analysis import analysis as _analysis
from tvm.relay import expr as _expr


from ... import _ffi_api
from ...dataflow_pattern import wildcard, is_op, is_expr, rewrite, DFPatternCallback
from ...dataflow_pattern import wildcard, is_op, is_constant, is_expr, rewrite, DFPatternCallback
from .register import register_pattern_table


Expand All @@ -56,8 +57,8 @@ def _register_external_op_helper(op_name, supported=True):
"""The helper function to indicate that a given operator can be supported
by DNNL.

Paramters
---------
Parameters
----------
op_name : Str
The name of operator that will be registered.

Expand All @@ -69,6 +70,10 @@ def _register_external_op_helper(op_name, supported=True):

@tvm.ir.register_op_attr(op_name, "target.dnnl")
def _func_wrapper(expr):
args = expr.args
if any([x.checked_type.dtype == "int64" for x in args]):
logger.info("DNNL does not support int64.")
return False
return supported

return _func_wrapper
Expand All @@ -90,6 +95,7 @@ def _func_wrapper(expr):
_register_external_op_helper("exp")
_register_external_op_helper("log")
_register_external_op_helper("sqrt")
_register_external_op_helper("round")
_register_external_op_helper("nn.relu")
_register_external_op_helper("nn.leaky_relu")
_register_external_op_helper("tanh")
Expand Down Expand Up @@ -199,6 +205,70 @@ def make_dnnl_pattern(op_name, with_bias, with_eltwise):
return dnnl_pattern


def make_qnn_conv2d_pattern():
"""Make qnn.conv2d based pattern supported by DNNL

Returns
-------
pattern : Tuple(pattern_name, CallPattern)
Created pattern name, along with its CallPattern.
"""
data = wildcard()
weight = is_constant()
bias = is_constant()
o_scl = is_constant()
dst_zp = is_constant()
act_scl = is_constant()
sum_scl = is_constant()
sum_src = wildcard()

zero_zp = is_expr(const(0, dtype="int32"))

pat = is_op("qnn.conv2d")(data, weight, zero_zp, zero_zp, is_constant(), is_constant())
pat = is_op("cast")(pat)
pat = is_op("add")(pat, bias) | pat # optional bias
pat = is_op("multiply")(pat, o_scl)
pat = is_op("clip")(pat) # TBD, not only clip
pat = is_op("multiply")(pat, act_scl) | pat # optional multiply. Ex: act_scl == 1
pat = is_op("add")(pat, sum_scl * is_op("cast")(sum_src)) | pat # optional sum
pat = is_op("add")(pat, dst_zp) | pat # optional dst_zp, can be dst_zp == 0
pat = is_op("cast")(pat)

return "dnnl.qnn.conv2d", pat


def make_qnn_dense_pattern():
"""Make qnn.dense based pattern supported by DNNL

Returns
-------
pattern : Tuple(pattern_name, CallPattern)
Created pattern name, along with its CallPattern.
"""
data = wildcard()
weight = is_constant()
bias = is_constant()
o_scl = is_constant()
dst_zp = is_constant()
act_scl = is_constant()
sum_scl = is_constant()
sum_src = wildcard()

zero_zp = is_expr(const(0, dtype="int32"))

pat = is_op("qnn.dense")(data, weight, zero_zp, zero_zp, is_constant(), is_constant())
pat = is_op("cast")(pat)
pat = is_op("add")(pat, bias) | pat # optional bias
pat = is_op("multiply")(pat, o_scl)
pat = is_op("clip")(pat) # TBD, not only clip
pat = is_op("multiply")(pat, act_scl) | pat # optional multiply. ex act_scl == 1
pat = is_op("add")(pat, sum_scl * is_op("cast")(sum_src)) | pat # optional sum
pat = is_op("add")(pat, dst_zp) | pat # optional dst_zp, can be dst_zp == 0
pat = is_op("cast")(pat)

return "dnnl.qnn.dense", pat


@register_pattern_table("dnnl")
def pattern_table():
"""Create dnnl patterns.
Expand All @@ -208,8 +278,11 @@ def pattern_table():
dnnl_patterns : List[dnnl_pattern]
Created patterns.
"""
dnnl_patterns = list()
dnnl_patterns.append(make_qnn_conv2d_pattern())
dnnl_patterns.append(make_qnn_dense_pattern())

elt_list = ["nn.relu", "tanh", "sigmoid", "gelu", None]
dnnl_patterns = []
for with_bias in [True, False]:
for elt in elt_list:
if not with_bias and not elt:
Expand Down Expand Up @@ -707,3 +780,201 @@ def rewrite_dense_bias_gelu_reshape_last(mod):
[DenseReshapeBiasGeluRewrite(), DenseReshapeBiasGeluRewrite(has_gelu=False)], mod["main"]
)
return mod


class LegalizeQnnOpForDnnl(DFPatternCallback):
"""Legalize QNN based patterns to match DNNL

original pattern:
OP = qnn.dense | qnn.conv2d
%1 = OP<int>(SRC, WGH) - OP<int>(src_zp, WGH) // qnn.conv2d
%2 = %1 + orig_bias // bias
%2 = (%1 - rq_in_zp) * rq_in_scl / rq_out_scl + rq_out_zp // qnn.requantize
%3 = act(%2) // activation == clip
%4 = ((%3 - sum_lh_zp) * sum_lh_scl + (SRC2 - sum_rh_zp) * sum_rh_scl) // qnn.add
/ sum_out_scl + sum_out_zp

transform to DNNL compatible:
%1 = OP<int>(SRC, WGH)
%2 = cast(%1, dtype="float")
%2 = (%1 + bias) * o_scl
%3 = act(%2) * act_scl
%4 = %3 + SRC2 * sum_scl
%5 = %4 + dst_zp
%6 = cast(%5, dtype="float")

where:
o_scl = rq_in_scl / rq_out_scl
act_scl = sum_lhs_scl / sum_out_scl
sum_scl = sum_rhs_scl / sum_out_scl
bias = orig_bias - OP(src_zp, WGH) - rq_in_zp + rq_out_zp * rq_out_scl / rq_in_scl
dst_zp = sum_out_zp - sum_lhs_zp * sum_lhs_scl / sum_out_scl -
sum_rhs_zp * sum_rhs_scl / sum_out_scl
"""

def __init__(self):
super(LegalizeQnnOpForDnnl, self).__init__()
self.src = wildcard()
self.wgh = wildcard()
self.bias = wildcard()
self.sum_src = wildcard()

self.src_scl = is_constant()
self.src_zp = is_constant()
self.wgh_scl = is_constant()
self.wgh_zp = is_expr(const(0))

self.rq_in_scl = is_constant()
self.rq_in_zp = is_constant()
self.rq_out_scl = is_constant()
self.rq_out_zp = is_constant()

self.sum_lhs_scl = is_constant()
self.sum_lhs_zp = is_constant()
self.sum_rhs_scl = is_constant()
self.sum_rhs_zp = is_constant()
self.sum_out_scl = is_constant()
self.sum_out_zp = is_constant()

self.root = (is_op("qnn.conv2d") | is_op("qnn.dense"))(
self.src, self.wgh, self.src_zp, self.wgh_zp, self.src_scl, self.wgh_scl
)
pat = is_op("add")(self.root, self.bias) | self.root # optional bias
pat = is_op("qnn.requantize")(
pat, self.rq_in_scl, self.rq_in_zp, self.rq_out_scl, self.rq_out_zp
)
pat = is_op("clip")(pat)
cast = is_op("cast")(pat)
pat = is_op("qnn.add")(
cast,
self.sum_src,
self.sum_lhs_scl,
self.sum_lhs_zp,
self.sum_rhs_scl,
self.sum_rhs_zp,
self.sum_out_scl,
self.sum_out_zp,
)
pat = is_op("clip")(pat)
self.pattern = pat | cast

def callback(self, pre, post, node_map):
root = node_map[self.root][0]
src = node_map[self.src][0]
wgh = node_map[self.wgh][0]
bias = node_map.get(self.bias, default=[relay.const(0, dtype="int32")])[0]
src_zp = node_map[self.src_zp][0]
rq_in_scl = node_map[self.rq_in_scl][0]
rq_in_zp = node_map[self.rq_in_zp][0]
rq_out_scl = node_map[self.rq_out_scl][0]
rq_out_zp = node_map[self.rq_out_zp][0]

final_dtype = node_map[self.pattern][0].checked_type.dtype

if root.op == relay.op.get("qnn.conv2d"):
dst_layout = root.attrs.out_layout
dst_layout = root.attrs.data_layout if dst_layout == "" else dst_layout
wgh_layout = root.attrs.kernel_layout
else:
# qnn.dense has no layout attributes. Assume that is plain
dst_layout = "NC"
wgh_layout = "OI"

# TODO(@apeskov): dst_layout may ne blocked
bias_rank = len(dst_layout) - dst_layout.index("C")

sum_src = node_map[self.sum_src][0] if self.sum_src in node_map else None
# Default values if qnn.sum is not present
sum_lhs_scl = node_map[self.sum_lhs_scl][0] if sum_src else relay.const(1, dtype="float32")
sum_lhs_zp = node_map[self.sum_lhs_zp][0] if sum_src else relay.const(0, dtype="int32")
sum_rhs_scl = node_map[self.sum_rhs_scl][0] if sum_src else relay.const(0, dtype="float32")
sum_rhs_zp = node_map[self.sum_rhs_zp][0] if sum_src else relay.const(0, dtype="int32")
sum_out_scl = node_map[self.sum_out_scl][0] if sum_src else relay.const(1, dtype="float32")
sum_out_zp = node_map[self.sum_out_zp][0] if sum_src else relay.const(0, dtype="int32")

def cast_fp(op):
return relay.op.cast(op, dtype="float32")

# recalculate some factors
o_scl = rq_in_scl / rq_out_scl
act_scl = sum_lhs_scl / sum_out_scl
sum_scl = sum_rhs_scl / sum_out_scl
dst_zp = (
cast_fp(sum_out_zp)
- cast_fp(sum_lhs_zp) * sum_lhs_scl / sum_out_scl
- cast_fp(sum_rhs_zp) * sum_rhs_scl / sum_out_scl
)
bias = self.squeeze_bias(bias, dst_layout)
bias = (
cast_fp(bias)
- cast_fp(self.fake_op(src_zp, wgh, wgh_layout))
- cast_fp(rq_in_zp)
+ cast_fp(rq_out_zp) * rq_out_scl / rq_in_scl
)
bias = self.broadcast_to_rank(bias, bias_rank)

zero_zp = relay.const(0, dtype="int32")
one_scl = relay.const(1.0, dtype="float32")

# construct new graph with proper post op ordering
gr = tvm.relay.Call(
root.op,
[src, wgh, zero_zp, zero_zp, one_scl, one_scl],
root.attrs,
root.type_args,
root.span,
)
gr = relay.op.cast(gr, dtype="float32")
gr = gr + bias
gr = gr * o_scl
gr = relay.op.clip(gr, 0, 255) * act_scl
gr = gr + sum_scl * cast_fp(sum_src) if sum_src else gr
gr = gr + dst_zp
gr = relay.op.cast(gr, dtype=final_dtype)
return gr

@staticmethod
def fake_op(zp, wgh, layout):
"""Fake operator implementation for zp broadcast input"""
# Conv: reduce kernel {OC, IC, KH, KW} -> {OC} in case of group that is still correct
# Dense: reduce kernel {OC, IC} -> {OC}
wgh_int = relay.op.cast(wgh, dtype="int32")
reduced_kernel = relay.op.sum(
wgh_int, axis=[layout.index("O")], keepdims=False, exclude=True
)
return zp * reduced_kernel

@staticmethod
def squeeze_bias(bias, layout):
shape = transform.InferTypeLocal(bias).concrete_shape
c_position = layout.index("C") - len(layout) + len(shape)
squeeze_idxs = [i for i in range(len(shape)) if i != c_position]
return relay.op.squeeze(bias, squeeze_idxs)

@staticmethod
def broadcast_to_rank(op, rank):
"""Scalar or 1D tensor are supported"""
shape = transform.InferTypeLocal(op).concrete_shape
if len(shape) == 0:
return op
if len(shape) == 1:
return relay.op.expand_dims(op, 1, rank - 1)
raise ValueError("Unexpected bias rank to broadcast. Only 0 and 1 are supported.")


def legalize_qnn_for_dnnl(mod):
"""Transform qnn primitives to DNNL compatible form. Eliminate source zero point and apply
strict sequence of post ops."""
mod["main"] = rewrite(LegalizeQnnOpForDnnl(), mod["main"])

seq = tvm.transform.Sequential(
[
transform.InferType(),
# transform.SimplifyInference(), # TODO: this pass decompose nn.layer_norm
# transform.FoldScaleAxis(), # TODO: fail inside TVM in case of grouped convolutions.
transform.FoldConstant(),
]
)
with tvm.transform.PassContext(opt_level=3):
mod = seq(mod)
return mod
Loading