diff --git a/python/tvm/relay/op/contrib/dnnl.py b/python/tvm/relay/op/contrib/dnnl.py index 228619e0ef35..129b3de31ae4 100644 --- a/python/tvm/relay/op/contrib/dnnl.py +++ b/python/tvm/relay/op/contrib/dnnl.py @@ -33,8 +33,10 @@ check the attributes of the op and decide if it should be offloaded to DNNL. """ import logging +from functools import reduce import tvm.ir +from tvm.ir import Op from tvm import relay from tvm.relay import transform from tvm.relay.expr import GlobalVar @@ -44,7 +46,7 @@ from tvm.relay.analysis import analysis as _analysis from tvm.relay import expr as _expr - +from tvm.relay.expr import Call, TupleGetItem from ... import _ffi_api from ...dataflow_pattern import wildcard, is_op, is_constant, is_expr, rewrite, DFPatternCallback from .register import register_pattern_table @@ -166,6 +168,94 @@ def make_conv_pattern(conv_name, with_bias=True, with_eltwise=None): return append_eltwise_ops(conv_out, with_eltwise) +def make_conv_bias_sum_relu_pattern(conv_type, has_relu=True): + """Create patterns with sum op. + + Parameters + ---------- + conv_type : str + Should be nn.conv1d / nn.conv2d / nn.conv3d. + has_relu : bool + Whether attach relu. + Returns + ------- + out : CallPattern + Call node sequence. + """ + data1 = wildcard() + weight = wildcard() + bias = wildcard() + data2 = wildcard() + out = is_op(conv_type)(data1, weight) + out = is_op("add")(out, bias) + out = is_op("add")(out, data2) + if has_relu: + out = is_op("nn.relu")(out) + return out + + +def get_op_name(expr): + """Get the operator name from an expression.""" + if isinstance(expr, Op): + return expr.name + if isinstance(expr, Call): + return get_op_name(expr.op) + if isinstance(expr, TupleGetItem): + return get_op_name(expr.tuple_value) + if isinstance(expr, relay.Tuple): + return get_op_name(expr.fields[0]) + return "" + + +def get_args(expr): + """Get the arguments from an expression.""" + if isinstance(expr, Call): + return expr.args + if isinstance(expr, TupleGetItem): + return get_args(expr.tuple_value) + if isinstance(expr, relay.Tuple): + return [arg for args in map(get_args, expr.fields) for arg in args] + return [] + + +def get_attrs(expr): + """Get the attributes from an expression.""" + if isinstance(expr, Call): + return expr.attrs + if isinstance(expr, TupleGetItem): + return get_attrs(expr.tuple_value) + return {} + + +def make_predicate(checker): + """Check whether the conv_bias_add_sum pattern is as expected.""" + + def predicate(expr): + if get_op_name(expr) == "nn.relu": + expr = expr.args[0] + for e, op_name in zip([expr, expr.args[0]], ["sum", "bias_add"]): + args = get_args(e) + attrs = get_attrs(e.args[0]) + if not checker(attrs, args, op_name): + return False + return True + + return predicate + + +def add_checker(attrs, args, op_name): + """Check if add is supported by DNNL.""" + if op_name == "sum": + if tuple(get_shape(args[0])) != tuple(get_shape(args[1])): + return False + if op_name == "bias_add": + channel = dict(attrs)["channels"] + const_shape = get_shape(args[1]) + if channel != reduce(lambda x, y: x * y, const_shape): + return False + return True + + def make_dense_pattern(with_bias=True, with_eltwise=None): """Create patterns related to nn.dense. @@ -305,6 +395,20 @@ def pattern_table(): dnnl_patterns = list() dnnl_patterns.append(make_qnn_conv2d_pattern()) dnnl_patterns.append(make_qnn_dense_pattern()) + dnnl_patterns.append( + ( + "dnnl.conv2d_bias_sum_relu", + make_conv_bias_sum_relu_pattern("nn.conv2d"), + make_predicate(add_checker), + ) + ) + dnnl_patterns.append( + ( + "dnnl.conv2d_bias_sum", + make_conv_bias_sum_relu_pattern("nn.conv2d", False), + make_predicate(add_checker), + ) + ) elt_list = ["nn.relu", "tanh", "sigmoid", "clip", "gelu", "swish", None] for with_bias in [True, False]: diff --git a/src/runtime/contrib/dnnl/dnnl_json_runtime.cc b/src/runtime/contrib/dnnl/dnnl_json_runtime.cc index 93c53dda1652..dde415829d49 100644 --- a/src/runtime/contrib/dnnl/dnnl_json_runtime.cc +++ b/src/runtime/contrib/dnnl/dnnl_json_runtime.cc @@ -182,8 +182,6 @@ class DNNLJSONRuntime : public JSONRuntimeBase { if (o_scl_tr || activation[0] != "none" || sum_scl_tr || dst_zp_tr) return attr; - // parsing of name to extract attributes - auto op_name = nodes_[nid].GetOpName(); // Define RegExp. std::regex bias_add_pat(".*_bias.*"); std::regex relu_pat(".*_relu.*"); @@ -192,9 +190,16 @@ class DNNLJSONRuntime : public JSONRuntimeBase { std::regex clip_pat(".*_clip.*"); std::regex gelu_pat(".*_gelu.*"); std::regex swish_pat(".*_swish.*"); + std::regex sum_pat(".*_sum.*"); + + // parsing of name to extract attributes + auto op_name = nodes_[nid].GetOpName(); // Parsing post-ops. dnnl::post_ops ops; + if (std::regex_match(op_name, sum_pat)) { + ops.append_sum(1.f); + } if (std::regex_match(op_name, relu_pat)) { ops.append_eltwise(1.f, dnnl::algorithm::eltwise_relu, 0.f, 0.f); } @@ -280,6 +285,7 @@ class DNNLJSONRuntime : public JSONRuntimeBase { void Convolution(const size_t& nid) { auto node = nodes_[nid]; + auto op_name = nodes_[nid].GetOpName(); // Setup attributes. auto src_tr = GetInput(nid, 0); @@ -361,6 +367,10 @@ class DNNLJSONRuntime : public JSONRuntimeBase { // TODO(@apeskov): Simulation of inplace primitive. just as PoC. auto sum_in_tr = GetInputByName(nid, "sum_idx").TreatAs(dst_layout); + if (op_name.find("_sum") != std::string::npos) { + sum_in_tr = GetInput(nid, node.GetInputs().size() - 1); + sum_in_tr = sum_in_tr.TreatAs(dst_layout); + } Submit(dnnl::convolution_forward(conv_prim_desc), {{DNNL_ARG_SRC, src_tr}, diff --git a/tests/python/contrib/test_dnnl.py b/tests/python/contrib/test_dnnl.py index 1bf8068b2e40..f12ee7479b85 100755 --- a/tests/python/contrib/test_dnnl.py +++ b/tests/python/contrib/test_dnnl.py @@ -788,6 +788,48 @@ def test_conv2d_pattern(run_module, dtype="float32"): run_and_verify_func(config, run_module=run_module, dtype=dtype) +def test_conv2d_bias_sum_relu(run_module, dtype="float32"): + x_shape = (1, 32, 8, 8) + k_shape = (16, 32, 3, 3) + + def get_conv2d_bn_sum_relu(x_shape, k_shape, sum_shape, dtype="float32"): + out, dic, param_lst = get_conv2d_bias(x_shape=x_shape, k_shape=k_shape, dtype=dtype) + beta = relay.const(np.zeros(k_shape[0]).astype(dtype)) + gamma = relay.const(np.ones(k_shape[0]).astype(dtype)) + moving_mean = relay.const(np.zeros(k_shape[0]).astype(dtype)) + moving_var = relay.const(np.ones(k_shape[0]).astype(dtype)) + out, _, _ = relay.nn.batch_norm( + out, + gamma=gamma, + beta=beta, + moving_mean=moving_mean, + moving_var=moving_var, + axis=1, + center=True, + scale=True, + epsilon=1e-5, + ) + sum_data = relay.var("data1", shape=sum_shape, dtype=dtype) + out = relay.add(out, sum_data) + dic["data1"] = sum_shape + param_lst += ["data1"] + return relay.nn.relu(out), dic, param_lst + + conv2d_bn_sum_relu, dic, param_lst = get_conv2d_bn_sum_relu( + x_shape, k_shape, sum_shape=(1, 16, 6, 6), dtype=dtype + ) + conv2d_bn_sum_relu = tvm.IRModule.from_expr(conv2d_bn_sum_relu) + config = conv2d_bn_sum_relu, dic, param_lst + run_and_verify_func(config, run_module=run_module, dtype=dtype) + + conv2d_bn_sum_relu, dic, param_lst = get_conv2d_bn_sum_relu( + x_shape, k_shape, sum_shape=(1, 16, 1, 1), dtype=dtype + ) + conv2d_bn_sum_relu = tvm.IRModule.from_expr(conv2d_bn_sum_relu) + config = conv2d_bn_sum_relu, dic, param_lst + run_and_verify_func(config, run_module=run_module, dtype=dtype) + + def test_conv2d_transpose(run_module, dtype="float32"): x_shape = (1, 32, 8, 8) for k_shape, groups in [((32, 16, 3, 3), 1), ((32, 1, 3, 3), 32), ((32, 4, 3, 3), 16)]: diff --git a/tests/python/relay/test_pass_partition_graph.py b/tests/python/relay/test_pass_partition_graph.py index 4b7ac92136e9..0c702d074eaa 100644 --- a/tests/python/relay/test_pass_partition_graph.py +++ b/tests/python/relay/test_pass_partition_graph.py @@ -919,7 +919,11 @@ def expected(): def test_dnnl_fuse(): dnnl_patterns = get_pattern_table("dnnl") - dnnl_pat_dic = dict(dnnl_patterns) + valid_pats = list() + for pattern in dnnl_patterns: + if len(pattern) == 2: + valid_pats.append(pattern) + dnnl_pat_dic = dict(valid_pats) ( conv2d_bias_relu_pat, conv2d_bias_sigmoid_pat,