From 59918fae5996f850120e1652fb0ba05c92c93290 Mon Sep 17 00:00:00 2001 From: Ivy Date: Thu, 21 Jul 2022 21:06:25 +0800 Subject: [PATCH 1/5] add post_sum pattern --- python/tvm/relay/op/contrib/dnnl.py | 17 ++++++++++ src/runtime/contrib/dnnl/dnnl_json_runtime.cc | 14 +++++++-- tests/python/contrib/test_dnnl.py | 31 +++++++++++++++++++ 3 files changed, 60 insertions(+), 2 deletions(-) diff --git a/python/tvm/relay/op/contrib/dnnl.py b/python/tvm/relay/op/contrib/dnnl.py index 228619e0ef35..def227da01e7 100644 --- a/python/tvm/relay/op/contrib/dnnl.py +++ b/python/tvm/relay/op/contrib/dnnl.py @@ -166,6 +166,19 @@ def make_conv_pattern(conv_name, with_bias=True, with_eltwise=None): return append_eltwise_ops(conv_out, with_eltwise) +def make_conv_add_sum_relu_pattern(conv_type, has_relu=True): + data1 = wildcard() + weight = wildcard() + bias = wildcard() + data2 = wildcard() + out = is_op(conv_type)(data1, weight) + out = is_op("add")(out, bias) + out = is_op("add")(out, data2) + if has_relu: + out = is_op("nn.relu")(out) + return out + + def make_dense_pattern(with_bias=True, with_eltwise=None): """Create patterns related to nn.dense. @@ -305,6 +318,10 @@ def pattern_table(): dnnl_patterns = list() dnnl_patterns.append(make_qnn_conv2d_pattern()) dnnl_patterns.append(make_qnn_dense_pattern()) + dnnl_patterns.append(("dnnl.conv2d_bias_sum_relu", + make_conv_add_sum_relu_pattern("nn.conv2d"))), + dnnl_patterns.append(("dnnl.conv2d_bias_sum", + make_conv_add_sum_relu_pattern("nn.conv2d", False))), elt_list = ["nn.relu", "tanh", "sigmoid", "clip", "gelu", "swish", None] for with_bias in [True, False]: diff --git a/src/runtime/contrib/dnnl/dnnl_json_runtime.cc b/src/runtime/contrib/dnnl/dnnl_json_runtime.cc index 93c53dda1652..6627fc82a5a8 100644 --- a/src/runtime/contrib/dnnl/dnnl_json_runtime.cc +++ b/src/runtime/contrib/dnnl/dnnl_json_runtime.cc @@ -182,8 +182,6 @@ class DNNLJSONRuntime : public JSONRuntimeBase { if (o_scl_tr || activation[0] != "none" || sum_scl_tr || dst_zp_tr) return attr; - // parsing of name to extract attributes - auto op_name = nodes_[nid].GetOpName(); // Define RegExp. std::regex bias_add_pat(".*_bias.*"); std::regex relu_pat(".*_relu.*"); @@ -192,9 +190,16 @@ class DNNLJSONRuntime : public JSONRuntimeBase { std::regex clip_pat(".*_clip.*"); std::regex gelu_pat(".*_gelu.*"); std::regex swish_pat(".*_swish.*"); + std::regex sum_pat(".*_sum.*"); + + // parsing of name to extract attributes + auto op_name = nodes_[nid].GetOpName(); // Parsing post-ops. dnnl::post_ops ops; + if (std::regex_match(op_name, sum_pat)) { + ops.append_sum(1.f); + } if (std::regex_match(op_name, relu_pat)) { ops.append_eltwise(1.f, dnnl::algorithm::eltwise_relu, 0.f, 0.f); } @@ -280,6 +285,7 @@ class DNNLJSONRuntime : public JSONRuntimeBase { void Convolution(const size_t& nid) { auto node = nodes_[nid]; + auto op_name = nodes_[nid].GetOpName(); // Setup attributes. auto src_tr = GetInput(nid, 0); @@ -361,6 +367,10 @@ class DNNLJSONRuntime : public JSONRuntimeBase { // TODO(@apeskov): Simulation of inplace primitive. just as PoC. auto sum_in_tr = GetInputByName(nid, "sum_idx").TreatAs(dst_layout); + if (op_name.find("_sum") != std::string::npos) { + sum_in_tr = GetInput(nid, node.GetInputs().size()-1); + sum_in_tr = sum_in_tr.TreatAs(dst_layout); + } Submit(dnnl::convolution_forward(conv_prim_desc), {{DNNL_ARG_SRC, src_tr}, diff --git a/tests/python/contrib/test_dnnl.py b/tests/python/contrib/test_dnnl.py index 1bf8068b2e40..614e3ecfb3f9 100755 --- a/tests/python/contrib/test_dnnl.py +++ b/tests/python/contrib/test_dnnl.py @@ -788,6 +788,37 @@ def test_conv2d_pattern(run_module, dtype="float32"): run_and_verify_func(config, run_module=run_module, dtype=dtype) +def test_conv2d_bias_sum_relu(run_module, dtype="float32"): + x_shape=(1, 32, 8, 8) + k_shape=(16, 32, 3, 3) + def get_conv2d_bn_sum_relu(x_shape=(1, 32, 8, 8), k_shape=(16, 32, 3, 3), dtype="float32"): + out, dic, param_lst = get_conv2d_bias(x_shape=x_shape, k_shape=k_shape, dtype=dtype) + beta = relay.const(np.zeros(k_shape[0]).astype(dtype)) + gamma = relay.const(np.ones(k_shape[0]).astype(dtype)) + moving_mean = relay.const(np.zeros(k_shape[0]).astype(dtype)) + moving_var = relay.const(np.ones(k_shape[0]).astype(dtype)) + out, _, _ = relay.nn.batch_norm( + out, + gamma=gamma, + beta=beta, + moving_mean=moving_mean, + moving_var=moving_var, + axis=1, + center=True, + scale=True, + epsilon=1e-5, + ) + sum_data = relay.var("data1", shape=(1, 16, 6, 6), dtype=dtype) + out = relay.add(out, sum_data) + dic["data1"] = (1, 16, 6, 6) + param_lst += ["data1"] + return relay.nn.relu(out), dic, param_lst + conv2d_bn_sum_relu, dic, param_lst = get_conv2d_bn_sum_relu(x_shape, k_shape, dtype=dtype) + conv2d_bn_sum_relu = tvm.IRModule.from_expr(conv2d_bn_sum_relu) + config = conv2d_bn_sum_relu, dic, param_lst + run_and_verify_func(config, run_module=run_module, dtype=dtype) + + def test_conv2d_transpose(run_module, dtype="float32"): x_shape = (1, 32, 8, 8) for k_shape, groups in [((32, 16, 3, 3), 1), ((32, 1, 3, 3), 32), ((32, 4, 3, 3), 16)]: From 54a693dfaa5d6690dcfc27fad84bf5393aba97da Mon Sep 17 00:00:00 2001 From: Ivy Date: Fri, 22 Jul 2022 10:11:12 +0800 Subject: [PATCH 2/5] add checkers for sum pattern --- python/tvm/relay/op/contrib/dnnl.py | 86 +++++++++++++++++++++++++++-- tests/python/contrib/test_dnnl.py | 5 +- 2 files changed, 84 insertions(+), 7 deletions(-) diff --git a/python/tvm/relay/op/contrib/dnnl.py b/python/tvm/relay/op/contrib/dnnl.py index def227da01e7..112becbc06e0 100644 --- a/python/tvm/relay/op/contrib/dnnl.py +++ b/python/tvm/relay/op/contrib/dnnl.py @@ -33,8 +33,10 @@ check the attributes of the op and decide if it should be offloaded to DNNL. """ import logging +from typing import Tuple, List, Dict, Union, Optional, Any, Callable import tvm.ir +from tvm.ir import Op from tvm import relay from tvm.relay import transform from tvm.relay.expr import GlobalVar @@ -46,6 +48,7 @@ from ... import _ffi_api +from tvm.relay.expr import Call, Constant, TupleGetItem from ...dataflow_pattern import wildcard, is_op, is_constant, is_expr, rewrite, DFPatternCallback from .register import register_pattern_table @@ -166,7 +169,20 @@ def make_conv_pattern(conv_name, with_bias=True, with_eltwise=None): return append_eltwise_ops(conv_out, with_eltwise) -def make_conv_add_sum_relu_pattern(conv_type, has_relu=True): +def make_conv_bias_sum_relu_pattern(conv_type, has_relu=True): + """Create patterns with sum op. + + Parameters + ---------- + conv_type : str + Should be nn.conv1d / nn.conv2d / nn.conv3d. + has_relu : bool + Whether attach relu. + Returns + ------- + out : CallPattern + Call node sequence. + """ data1 = wildcard() weight = wildcard() bias = wildcard() @@ -179,6 +195,64 @@ def make_conv_add_sum_relu_pattern(conv_type, has_relu=True): return out +def get_op_name(expr: relay.expr.Expr) -> str: + """Get the operator name from an expression.""" + if isinstance(expr, Op): + return expr.name + if isinstance(expr, Call): + return get_op_name(expr.op) + if isinstance(expr, TupleGetItem): + return get_op_name(expr.tuple_value) + if isinstance(expr, relay.Tuple): + return get_op_name(expr.fields[0]) + return "" + + +def get_args(expr: relay.expr.Expr) -> List[relay.expr.Expr]: + """Get the arguments from an expression.""" + if isinstance(expr, Call): + return expr.args + if isinstance(expr, TupleGetItem): + return get_args(expr.tuple_value) + if isinstance(expr, relay.Tuple): + return [arg for args in map(get_args, expr.fields) for arg in args] + return [] + + +def get_attrs(expr: relay.expr.Expr) -> Any: + """Get the attributes from an expression.""" + if isinstance(expr, Call): + return expr.attrs + if isinstance(expr, TupleGetItem): + return get_attrs(expr.tuple_value) + return {} + + +def checker() -> Callable[[relay.expr.Expr], bool]: + """Check whether the conv_bias_add_sum pattern is as expected.""" + + def check_sum_pattern(expr: relay.expr.Expr) -> bool: + op_name = get_op_name(expr) + if op_name == "nn.relu": + expr = expr.args[0] + # elementwise add + args = get_args(expr) + if get_shape(args[0]) != get_shape(args[1]): + return False + # bias_add + expr = expr.args[0] + args = get_args(expr) + conv_attrs = get_attrs(expr.args[0]) + channel = dict(conv_attrs)["channels"] + const_shape = get_shape(args[1]) + from functools import reduce + if channel != reduce(lambda x, y: x * y, const_shape): + return False + return True + + return check_sum_pattern + + def make_dense_pattern(with_bias=True, with_eltwise=None): """Create patterns related to nn.dense. @@ -318,10 +392,12 @@ def pattern_table(): dnnl_patterns = list() dnnl_patterns.append(make_qnn_conv2d_pattern()) dnnl_patterns.append(make_qnn_dense_pattern()) - dnnl_patterns.append(("dnnl.conv2d_bias_sum_relu", - make_conv_add_sum_relu_pattern("nn.conv2d"))), - dnnl_patterns.append(("dnnl.conv2d_bias_sum", - make_conv_add_sum_relu_pattern("nn.conv2d", False))), + dnnl_patterns.append( + ("dnnl.conv2d_bias_sum_relu", make_conv_bias_sum_relu_pattern("nn.conv2d"), checker()) + ), + dnnl_patterns.append( + ("dnnl.conv2d_bias_sum", make_conv_bias_sum_relu_pattern("nn.conv2d", False), checker()) + ), elt_list = ["nn.relu", "tanh", "sigmoid", "clip", "gelu", "swish", None] for with_bias in [True, False]: diff --git a/tests/python/contrib/test_dnnl.py b/tests/python/contrib/test_dnnl.py index 614e3ecfb3f9..a6be6f2e26ce 100755 --- a/tests/python/contrib/test_dnnl.py +++ b/tests/python/contrib/test_dnnl.py @@ -789,8 +789,8 @@ def test_conv2d_pattern(run_module, dtype="float32"): def test_conv2d_bias_sum_relu(run_module, dtype="float32"): - x_shape=(1, 32, 8, 8) - k_shape=(16, 32, 3, 3) + x_shape = (1, 32, 8, 8) + k_shape = (16, 32, 3, 3) def get_conv2d_bn_sum_relu(x_shape=(1, 32, 8, 8), k_shape=(16, 32, 3, 3), dtype="float32"): out, dic, param_lst = get_conv2d_bias(x_shape=x_shape, k_shape=k_shape, dtype=dtype) beta = relay.const(np.zeros(k_shape[0]).astype(dtype)) @@ -813,6 +813,7 @@ def get_conv2d_bn_sum_relu(x_shape=(1, 32, 8, 8), k_shape=(16, 32, 3, 3), dtype= dic["data1"] = (1, 16, 6, 6) param_lst += ["data1"] return relay.nn.relu(out), dic, param_lst + conv2d_bn_sum_relu, dic, param_lst = get_conv2d_bn_sum_relu(x_shape, k_shape, dtype=dtype) conv2d_bn_sum_relu = tvm.IRModule.from_expr(conv2d_bn_sum_relu) config = conv2d_bn_sum_relu, dic, param_lst From 6b175d9ce826bb1782ca203c03b14333bba0e3d6 Mon Sep 17 00:00:00 2001 From: Ivy Date: Fri, 22 Jul 2022 10:57:25 +0800 Subject: [PATCH 3/5] fix lint --- python/tvm/relay/op/contrib/dnnl.py | 51 +++++++++++++++-------------- tests/python/contrib/test_dnnl.py | 18 +++++++--- 2 files changed, 40 insertions(+), 29 deletions(-) diff --git a/python/tvm/relay/op/contrib/dnnl.py b/python/tvm/relay/op/contrib/dnnl.py index 112becbc06e0..fbab9ab745f9 100644 --- a/python/tvm/relay/op/contrib/dnnl.py +++ b/python/tvm/relay/op/contrib/dnnl.py @@ -33,7 +33,7 @@ check the attributes of the op and decide if it should be offloaded to DNNL. """ import logging -from typing import Tuple, List, Dict, Union, Optional, Any, Callable +from functools import reduce import tvm.ir from tvm.ir import Op @@ -46,9 +46,8 @@ from tvm.relay.analysis import analysis as _analysis from tvm.relay import expr as _expr - +from tvm.relay.expr import Call, TupleGetItem from ... import _ffi_api -from tvm.relay.expr import Call, Constant, TupleGetItem from ...dataflow_pattern import wildcard, is_op, is_constant, is_expr, rewrite, DFPatternCallback from .register import register_pattern_table @@ -195,7 +194,7 @@ def make_conv_bias_sum_relu_pattern(conv_type, has_relu=True): return out -def get_op_name(expr: relay.expr.Expr) -> str: +def get_op_name(expr): """Get the operator name from an expression.""" if isinstance(expr, Op): return expr.name @@ -208,7 +207,7 @@ def get_op_name(expr: relay.expr.Expr) -> str: return "" -def get_args(expr: relay.expr.Expr) -> List[relay.expr.Expr]: +def get_args(expr): """Get the arguments from an expression.""" if isinstance(expr, Call): return expr.args @@ -219,7 +218,7 @@ def get_args(expr: relay.expr.Expr) -> List[relay.expr.Expr]: return [] -def get_attrs(expr: relay.expr.Expr) -> Any: +def get_attrs(expr): """Get the attributes from an expression.""" if isinstance(expr, Call): return expr.attrs @@ -228,29 +227,33 @@ def get_attrs(expr: relay.expr.Expr) -> Any: return {} -def checker() -> Callable[[relay.expr.Expr], bool]: +def make_predicate(checker): """Check whether the conv_bias_add_sum pattern is as expected.""" - def check_sum_pattern(expr: relay.expr.Expr) -> bool: - op_name = get_op_name(expr) - if op_name == "nn.relu": + def predicate(expr): + if get_op_name(expr) == "nn.relu": expr = expr.args[0] - # elementwise add - args = get_args(expr) - if get_shape(args[0]) != get_shape(args[1]): + for e, op_name in zip([expr, expr.args[0]], ["sum", "bias_add"]): + args = get_args(e) + attrs = get_attrs(e.args[0]) + if not checker(attrs, args, op_name): + return False + return True + + return predicate + + +def add_checker(attrs, args, op_name): + """Check if add is supported by DNNL.""" + if op_name == "sum": + if tuple(get_shape(args[0])) != tuple(get_shape(args[1])): return False - # bias_add - expr = expr.args[0] - args = get_args(expr) - conv_attrs = get_attrs(expr.args[0]) - channel = dict(conv_attrs)["channels"] + if op_name == "bias_add": + channel = dict(attrs)["channels"] const_shape = get_shape(args[1]) - from functools import reduce if channel != reduce(lambda x, y: x * y, const_shape): return False - return True - - return check_sum_pattern + return True def make_dense_pattern(with_bias=True, with_eltwise=None): @@ -393,10 +396,10 @@ def pattern_table(): dnnl_patterns.append(make_qnn_conv2d_pattern()) dnnl_patterns.append(make_qnn_dense_pattern()) dnnl_patterns.append( - ("dnnl.conv2d_bias_sum_relu", make_conv_bias_sum_relu_pattern("nn.conv2d"), checker()) + ("dnnl.conv2d_bias_sum_relu", make_conv_bias_sum_relu_pattern("nn.conv2d"), make_predicate(add_checker)) ), dnnl_patterns.append( - ("dnnl.conv2d_bias_sum", make_conv_bias_sum_relu_pattern("nn.conv2d", False), checker()) + ("dnnl.conv2d_bias_sum", make_conv_bias_sum_relu_pattern("nn.conv2d", False), make_predicate(add_checker)) ), elt_list = ["nn.relu", "tanh", "sigmoid", "clip", "gelu", "swish", None] diff --git a/tests/python/contrib/test_dnnl.py b/tests/python/contrib/test_dnnl.py index a6be6f2e26ce..e3ff42e4f72b 100755 --- a/tests/python/contrib/test_dnnl.py +++ b/tests/python/contrib/test_dnnl.py @@ -192,6 +192,7 @@ def run_and_verify(mod, input, params, target, run_module, subgraph_num=None, te if use_dnnl: processed_mod = partition_for_dnnl(processed_mod, params, alter_layout) check_dnnl_used(processed_mod) + print(processed_mod) with tvm.transform.PassContext(opt_level=3): func = relay.create_executor( mode, mod=processed_mod, device=dev, target=target @@ -791,7 +792,8 @@ def test_conv2d_pattern(run_module, dtype="float32"): def test_conv2d_bias_sum_relu(run_module, dtype="float32"): x_shape = (1, 32, 8, 8) k_shape = (16, 32, 3, 3) - def get_conv2d_bn_sum_relu(x_shape=(1, 32, 8, 8), k_shape=(16, 32, 3, 3), dtype="float32"): + + def get_conv2d_bn_sum_relu(x_shape, k_shape, sum_shape, dtype="float32"): out, dic, param_lst = get_conv2d_bias(x_shape=x_shape, k_shape=k_shape, dtype=dtype) beta = relay.const(np.zeros(k_shape[0]).astype(dtype)) gamma = relay.const(np.ones(k_shape[0]).astype(dtype)) @@ -808,13 +810,18 @@ def get_conv2d_bn_sum_relu(x_shape=(1, 32, 8, 8), k_shape=(16, 32, 3, 3), dtype= scale=True, epsilon=1e-5, ) - sum_data = relay.var("data1", shape=(1, 16, 6, 6), dtype=dtype) + sum_data = relay.var("data1", shape=sum_shape, dtype=dtype) out = relay.add(out, sum_data) - dic["data1"] = (1, 16, 6, 6) + dic["data1"] = sum_shape param_lst += ["data1"] return relay.nn.relu(out), dic, param_lst - conv2d_bn_sum_relu, dic, param_lst = get_conv2d_bn_sum_relu(x_shape, k_shape, dtype=dtype) + conv2d_bn_sum_relu, dic, param_lst = get_conv2d_bn_sum_relu(x_shape, k_shape, sum_shape=(1, 16, 6, 6), dtype=dtype) + conv2d_bn_sum_relu = tvm.IRModule.from_expr(conv2d_bn_sum_relu) + config = conv2d_bn_sum_relu, dic, param_lst + run_and_verify_func(config, run_module=run_module, dtype=dtype) + + conv2d_bn_sum_relu, dic, param_lst = get_conv2d_bn_sum_relu(x_shape, k_shape, sum_shape=(1, 16, 1, 1), dtype=dtype) conv2d_bn_sum_relu = tvm.IRModule.from_expr(conv2d_bn_sum_relu) config = conv2d_bn_sum_relu, dic, param_lst run_and_verify_func(config, run_module=run_module, dtype=dtype) @@ -1764,4 +1771,5 @@ def generate_model(p, c): if __name__ == "__main__": - tvm.testing.main() + # tvm.testing.main() + test_conv2d_bias_sum_relu(True) \ No newline at end of file From d89db44db03ba7a510f7197de5367aa45bf7ebfa Mon Sep 17 00:00:00 2001 From: Ivy Date: Wed, 27 Jul 2022 09:58:54 +0800 Subject: [PATCH 4/5] fix error in test_pass_partition_graph --- tests/python/relay/test_pass_partition_graph.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/tests/python/relay/test_pass_partition_graph.py b/tests/python/relay/test_pass_partition_graph.py index 4b7ac92136e9..0c702d074eaa 100644 --- a/tests/python/relay/test_pass_partition_graph.py +++ b/tests/python/relay/test_pass_partition_graph.py @@ -919,7 +919,11 @@ def expected(): def test_dnnl_fuse(): dnnl_patterns = get_pattern_table("dnnl") - dnnl_pat_dic = dict(dnnl_patterns) + valid_pats = list() + for pattern in dnnl_patterns: + if len(pattern) == 2: + valid_pats.append(pattern) + dnnl_pat_dic = dict(valid_pats) ( conv2d_bias_relu_pat, conv2d_bias_sigmoid_pat, From 001898f47e16cf103f3a9238313bc384812aaa88 Mon Sep 17 00:00:00 2001 From: Ivy Date: Thu, 28 Jul 2022 16:22:57 +0800 Subject: [PATCH 5/5] fix lint error --- python/tvm/relay/op/contrib/dnnl.py | 16 ++++++++++++---- src/runtime/contrib/dnnl/dnnl_json_runtime.cc | 2 +- tests/python/contrib/test_dnnl.py | 12 +++++++----- 3 files changed, 20 insertions(+), 10 deletions(-) diff --git a/python/tvm/relay/op/contrib/dnnl.py b/python/tvm/relay/op/contrib/dnnl.py index fbab9ab745f9..129b3de31ae4 100644 --- a/python/tvm/relay/op/contrib/dnnl.py +++ b/python/tvm/relay/op/contrib/dnnl.py @@ -396,11 +396,19 @@ def pattern_table(): dnnl_patterns.append(make_qnn_conv2d_pattern()) dnnl_patterns.append(make_qnn_dense_pattern()) dnnl_patterns.append( - ("dnnl.conv2d_bias_sum_relu", make_conv_bias_sum_relu_pattern("nn.conv2d"), make_predicate(add_checker)) - ), + ( + "dnnl.conv2d_bias_sum_relu", + make_conv_bias_sum_relu_pattern("nn.conv2d"), + make_predicate(add_checker), + ) + ) dnnl_patterns.append( - ("dnnl.conv2d_bias_sum", make_conv_bias_sum_relu_pattern("nn.conv2d", False), make_predicate(add_checker)) - ), + ( + "dnnl.conv2d_bias_sum", + make_conv_bias_sum_relu_pattern("nn.conv2d", False), + make_predicate(add_checker), + ) + ) elt_list = ["nn.relu", "tanh", "sigmoid", "clip", "gelu", "swish", None] for with_bias in [True, False]: diff --git a/src/runtime/contrib/dnnl/dnnl_json_runtime.cc b/src/runtime/contrib/dnnl/dnnl_json_runtime.cc index 6627fc82a5a8..dde415829d49 100644 --- a/src/runtime/contrib/dnnl/dnnl_json_runtime.cc +++ b/src/runtime/contrib/dnnl/dnnl_json_runtime.cc @@ -368,7 +368,7 @@ class DNNLJSONRuntime : public JSONRuntimeBase { // TODO(@apeskov): Simulation of inplace primitive. just as PoC. auto sum_in_tr = GetInputByName(nid, "sum_idx").TreatAs(dst_layout); if (op_name.find("_sum") != std::string::npos) { - sum_in_tr = GetInput(nid, node.GetInputs().size()-1); + sum_in_tr = GetInput(nid, node.GetInputs().size() - 1); sum_in_tr = sum_in_tr.TreatAs(dst_layout); } diff --git a/tests/python/contrib/test_dnnl.py b/tests/python/contrib/test_dnnl.py index e3ff42e4f72b..f12ee7479b85 100755 --- a/tests/python/contrib/test_dnnl.py +++ b/tests/python/contrib/test_dnnl.py @@ -192,7 +192,6 @@ def run_and_verify(mod, input, params, target, run_module, subgraph_num=None, te if use_dnnl: processed_mod = partition_for_dnnl(processed_mod, params, alter_layout) check_dnnl_used(processed_mod) - print(processed_mod) with tvm.transform.PassContext(opt_level=3): func = relay.create_executor( mode, mod=processed_mod, device=dev, target=target @@ -816,12 +815,16 @@ def get_conv2d_bn_sum_relu(x_shape, k_shape, sum_shape, dtype="float32"): param_lst += ["data1"] return relay.nn.relu(out), dic, param_lst - conv2d_bn_sum_relu, dic, param_lst = get_conv2d_bn_sum_relu(x_shape, k_shape, sum_shape=(1, 16, 6, 6), dtype=dtype) + conv2d_bn_sum_relu, dic, param_lst = get_conv2d_bn_sum_relu( + x_shape, k_shape, sum_shape=(1, 16, 6, 6), dtype=dtype + ) conv2d_bn_sum_relu = tvm.IRModule.from_expr(conv2d_bn_sum_relu) config = conv2d_bn_sum_relu, dic, param_lst run_and_verify_func(config, run_module=run_module, dtype=dtype) - conv2d_bn_sum_relu, dic, param_lst = get_conv2d_bn_sum_relu(x_shape, k_shape, sum_shape=(1, 16, 1, 1), dtype=dtype) + conv2d_bn_sum_relu, dic, param_lst = get_conv2d_bn_sum_relu( + x_shape, k_shape, sum_shape=(1, 16, 1, 1), dtype=dtype + ) conv2d_bn_sum_relu = tvm.IRModule.from_expr(conv2d_bn_sum_relu) config = conv2d_bn_sum_relu, dic, param_lst run_and_verify_func(config, run_module=run_module, dtype=dtype) @@ -1771,5 +1774,4 @@ def generate_model(p, c): if __name__ == "__main__": - # tvm.testing.main() - test_conv2d_bias_sum_relu(True) \ No newline at end of file + tvm.testing.main()