Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
106 changes: 105 additions & 1 deletion python/tvm/relay/op/contrib/dnnl.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,8 +33,10 @@
check the attributes of the op and decide if it should be offloaded to DNNL.
"""
import logging
from functools import reduce

import tvm.ir
from tvm.ir import Op
from tvm import relay
from tvm.relay import transform
from tvm.relay.expr import GlobalVar
Expand All @@ -44,7 +46,7 @@
from tvm.relay.analysis import analysis as _analysis
from tvm.relay import expr as _expr


from tvm.relay.expr import Call, TupleGetItem
from ... import _ffi_api
from ...dataflow_pattern import wildcard, is_op, is_constant, is_expr, rewrite, DFPatternCallback
from .register import register_pattern_table
Expand Down Expand Up @@ -166,6 +168,94 @@ def make_conv_pattern(conv_name, with_bias=True, with_eltwise=None):
return append_eltwise_ops(conv_out, with_eltwise)


def make_conv_bias_sum_relu_pattern(conv_type, has_relu=True):
"""Create patterns with sum op.

Parameters
----------
conv_type : str
Should be nn.conv1d / nn.conv2d / nn.conv3d.
has_relu : bool
Whether attach relu.
Returns
-------
out : CallPattern
Call node sequence.
"""
data1 = wildcard()
weight = wildcard()
bias = wildcard()
data2 = wildcard()
out = is_op(conv_type)(data1, weight)
out = is_op("add")(out, bias)
out = is_op("add")(out, data2)
if has_relu:
out = is_op("nn.relu")(out)
return out


def get_op_name(expr):
"""Get the operator name from an expression."""
if isinstance(expr, Op):
return expr.name
if isinstance(expr, Call):
return get_op_name(expr.op)
if isinstance(expr, TupleGetItem):
return get_op_name(expr.tuple_value)
if isinstance(expr, relay.Tuple):
return get_op_name(expr.fields[0])
return ""


def get_args(expr):
"""Get the arguments from an expression."""
if isinstance(expr, Call):
return expr.args
if isinstance(expr, TupleGetItem):
return get_args(expr.tuple_value)
if isinstance(expr, relay.Tuple):
return [arg for args in map(get_args, expr.fields) for arg in args]
return []


def get_attrs(expr):
"""Get the attributes from an expression."""
if isinstance(expr, Call):
return expr.attrs
if isinstance(expr, TupleGetItem):
return get_attrs(expr.tuple_value)
return {}


def make_predicate(checker):
"""Check whether the conv_bias_add_sum pattern is as expected."""

def predicate(expr):
if get_op_name(expr) == "nn.relu":
expr = expr.args[0]
for e, op_name in zip([expr, expr.args[0]], ["sum", "bias_add"]):
args = get_args(e)
attrs = get_attrs(e.args[0])
if not checker(attrs, args, op_name):
return False
return True

return predicate


def add_checker(attrs, args, op_name):
"""Check if add is supported by DNNL."""
if op_name == "sum":
if tuple(get_shape(args[0])) != tuple(get_shape(args[1])):
return False
if op_name == "bias_add":
channel = dict(attrs)["channels"]
const_shape = get_shape(args[1])
if channel != reduce(lambda x, y: x * y, const_shape):
return False
return True


def make_dense_pattern(with_bias=True, with_eltwise=None):
"""Create patterns related to nn.dense.

Expand Down Expand Up @@ -305,6 +395,20 @@ def pattern_table():
dnnl_patterns = list()
dnnl_patterns.append(make_qnn_conv2d_pattern())
dnnl_patterns.append(make_qnn_dense_pattern())
dnnl_patterns.append(
(
"dnnl.conv2d_bias_sum_relu",
make_conv_bias_sum_relu_pattern("nn.conv2d"),
make_predicate(add_checker),
)
)
dnnl_patterns.append(
(
"dnnl.conv2d_bias_sum",
make_conv_bias_sum_relu_pattern("nn.conv2d", False),
make_predicate(add_checker),
)
)

elt_list = ["nn.relu", "tanh", "sigmoid", "clip", "gelu", "swish", None]
for with_bias in [True, False]:
Expand Down
14 changes: 12 additions & 2 deletions src/runtime/contrib/dnnl/dnnl_json_runtime.cc
Original file line number Diff line number Diff line change
Expand Up @@ -182,8 +182,6 @@ class DNNLJSONRuntime : public JSONRuntimeBase {

if (o_scl_tr || activation[0] != "none" || sum_scl_tr || dst_zp_tr) return attr;

// parsing of name to extract attributes
auto op_name = nodes_[nid].GetOpName();
// Define RegExp.
std::regex bias_add_pat(".*_bias.*");
std::regex relu_pat(".*_relu.*");
Expand All @@ -192,9 +190,16 @@ class DNNLJSONRuntime : public JSONRuntimeBase {
std::regex clip_pat(".*_clip.*");
std::regex gelu_pat(".*_gelu.*");
std::regex swish_pat(".*_swish.*");
std::regex sum_pat(".*_sum.*");

// parsing of name to extract attributes
auto op_name = nodes_[nid].GetOpName();

// Parsing post-ops.
dnnl::post_ops ops;
if (std::regex_match(op_name, sum_pat)) {
ops.append_sum(1.f);
}
if (std::regex_match(op_name, relu_pat)) {
ops.append_eltwise(1.f, dnnl::algorithm::eltwise_relu, 0.f, 0.f);
}
Expand Down Expand Up @@ -280,6 +285,7 @@ class DNNLJSONRuntime : public JSONRuntimeBase {

void Convolution(const size_t& nid) {
auto node = nodes_[nid];
auto op_name = nodes_[nid].GetOpName();

// Setup attributes.
auto src_tr = GetInput(nid, 0);
Expand Down Expand Up @@ -361,6 +367,10 @@ class DNNLJSONRuntime : public JSONRuntimeBase {

// TODO(@apeskov): Simulation of inplace primitive. just as PoC.
auto sum_in_tr = GetInputByName(nid, "sum_idx").TreatAs(dst_layout);
if (op_name.find("_sum") != std::string::npos) {
sum_in_tr = GetInput(nid, node.GetInputs().size() - 1);
sum_in_tr = sum_in_tr.TreatAs(dst_layout);
}

Submit(dnnl::convolution_forward(conv_prim_desc),
{{DNNL_ARG_SRC, src_tr},
Expand Down
42 changes: 42 additions & 0 deletions tests/python/contrib/test_dnnl.py
Original file line number Diff line number Diff line change
Expand Up @@ -788,6 +788,48 @@ def test_conv2d_pattern(run_module, dtype="float32"):
run_and_verify_func(config, run_module=run_module, dtype=dtype)


def test_conv2d_bias_sum_relu(run_module, dtype="float32"):
x_shape = (1, 32, 8, 8)
k_shape = (16, 32, 3, 3)

def get_conv2d_bn_sum_relu(x_shape, k_shape, sum_shape, dtype="float32"):
out, dic, param_lst = get_conv2d_bias(x_shape=x_shape, k_shape=k_shape, dtype=dtype)
beta = relay.const(np.zeros(k_shape[0]).astype(dtype))
gamma = relay.const(np.ones(k_shape[0]).astype(dtype))
moving_mean = relay.const(np.zeros(k_shape[0]).astype(dtype))
moving_var = relay.const(np.ones(k_shape[0]).astype(dtype))
out, _, _ = relay.nn.batch_norm(
out,
gamma=gamma,
beta=beta,
moving_mean=moving_mean,
moving_var=moving_var,
axis=1,
center=True,
scale=True,
epsilon=1e-5,
)
sum_data = relay.var("data1", shape=sum_shape, dtype=dtype)
out = relay.add(out, sum_data)
dic["data1"] = sum_shape
param_lst += ["data1"]
return relay.nn.relu(out), dic, param_lst

conv2d_bn_sum_relu, dic, param_lst = get_conv2d_bn_sum_relu(
x_shape, k_shape, sum_shape=(1, 16, 6, 6), dtype=dtype
)
conv2d_bn_sum_relu = tvm.IRModule.from_expr(conv2d_bn_sum_relu)
config = conv2d_bn_sum_relu, dic, param_lst
run_and_verify_func(config, run_module=run_module, dtype=dtype)

conv2d_bn_sum_relu, dic, param_lst = get_conv2d_bn_sum_relu(
x_shape, k_shape, sum_shape=(1, 16, 1, 1), dtype=dtype
)
conv2d_bn_sum_relu = tvm.IRModule.from_expr(conv2d_bn_sum_relu)
config = conv2d_bn_sum_relu, dic, param_lst
run_and_verify_func(config, run_module=run_module, dtype=dtype)


def test_conv2d_transpose(run_module, dtype="float32"):
x_shape = (1, 32, 8, 8)
for k_shape, groups in [((32, 16, 3, 3), 1), ((32, 1, 3, 3), 32), ((32, 4, 3, 3), 16)]:
Expand Down
6 changes: 5 additions & 1 deletion tests/python/relay/test_pass_partition_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -919,7 +919,11 @@ def expected():

def test_dnnl_fuse():
dnnl_patterns = get_pattern_table("dnnl")
dnnl_pat_dic = dict(dnnl_patterns)
valid_pats = list()
for pattern in dnnl_patterns:
if len(pattern) == 2:
valid_pats.append(pattern)
dnnl_pat_dic = dict(valid_pats)
(
conv2d_bias_relu_pat,
conv2d_bias_sigmoid_pat,
Expand Down