From d72c6a80016fa174686acccfa33541dd9a3db12f Mon Sep 17 00:00:00 2001 From: Ivy Date: Mon, 20 Jun 2022 15:48:16 +0800 Subject: [PATCH 1/6] support post-op swish --- python/tvm/relay/op/contrib/dnnl.py | 13 ++++++++++--- src/relay/backend/contrib/dnnl/codegen.cc | 1 + src/relay/backend/utils.h | 6 +++--- src/runtime/contrib/dnnl/dnnl_json_runtime.cc | 6 +++++- tests/python/contrib/test_dnnl.py | 8 +++++++- 5 files changed, 26 insertions(+), 8 deletions(-) diff --git a/python/tvm/relay/op/contrib/dnnl.py b/python/tvm/relay/op/contrib/dnnl.py index c251b66bfbc7..2dbb08dea358 100644 --- a/python/tvm/relay/op/contrib/dnnl.py +++ b/python/tvm/relay/op/contrib/dnnl.py @@ -128,8 +128,11 @@ def make_conv_pattern(conv_name, with_bias=True, with_eltwise=None): conv_out = is_op("add")(conv, bias) else: conv_out = conv - if with_eltwise: - return is_op(with_eltwise)(conv_out) + if with_eltwise == "swish": + sig_out = is_op("sigmoid")(conv_out) + conv_out = is_op("multiply")(conv_out, sig_out) + elif with_eltwise: + conv_out = is_op(with_eltwise)(conv_out) return conv_out @@ -165,6 +168,9 @@ def make_dense_pattern(with_bias=True, with_eltwise=None): added_erf_val = is_op("add")(erf_val, const2) mul_val = is_op("multiply")(dense_out, added_erf_val) dense_out = is_op("multiply")(mul_val, const3) + elif with_eltwise == "swish": + sig_out = is_op("sigmoid")(dense_out) + dense_out = is_op("multiply")(dense_out, sig_out) elif with_eltwise: dense_out = is_op(with_eltwise)(dense_out) return dense_out @@ -191,6 +197,7 @@ def make_dnnl_pattern(op_name, with_bias, with_eltwise): pat_name = "dnnl.deconv" + op_name.split("_")[0][-2::] pat_name += "_bias" if with_bias else "" pat_name += ("_" + with_eltwise.split(".")[-1]) if with_eltwise else "" + pat_name =pat_name.replace("_swish", "_sigmoid_mul") if "conv" in op_name: dnnl_pattern = (pat_name, make_conv_pattern(op_name, with_bias, with_eltwise)) elif op_name == "nn.dense": @@ -282,7 +289,7 @@ def pattern_table(): dnnl_patterns.append(make_qnn_conv2d_pattern()) dnnl_patterns.append(make_qnn_dense_pattern()) - elt_list = ["nn.relu", "tanh", "sigmoid", "gelu", None] + elt_list = ["nn.relu", "tanh", "sigmoid", "gelu", "swish", None] for with_bias in [True, False]: for elt in elt_list: if not with_bias and not elt: diff --git a/src/relay/backend/contrib/dnnl/codegen.cc b/src/relay/backend/contrib/dnnl/codegen.cc index 2f47c23a7cf9..c800329c912e 100644 --- a/src/relay/backend/contrib/dnnl/codegen.cc +++ b/src/relay/backend/contrib/dnnl/codegen.cc @@ -470,6 +470,7 @@ class DNNLJSONSerializer : public backend::contrib::JSONSerializer { {"relu", "nn.relu"}, {"tanh", "tanh"}, {"sigmoid", "sigmoid"}, + {"mul", "multiply"}, {"nn.deconv2d", "nn.conv2d_transpose"}, {"nn.deconv3d", "nn.conv3d_transpose"}, }; diff --git a/src/relay/backend/utils.h b/src/relay/backend/utils.h index d6fae8c72b5e..1a8d488b5fc1 100644 --- a/src/relay/backend/utils.h +++ b/src/relay/backend/utils.h @@ -466,9 +466,9 @@ inline const CallNode* GetRootCall(const CallNode* current_call, int depth, ICHECK_GT(current_call->args.size(), 0); size_t valid_node_idx = 0; - while (valid_node_idx < current_call->args.size() && - current_call->args[valid_node_idx].as()) { - valid_node_idx++; + while (valid_node_idx < current_call->args.size()) { + if (IsOp(current_call->args[valid_node_idx].as(), expected_op_names[depth - 1])) break; + valid_node_idx++; } const auto* next_call = current_call->args[valid_node_idx].as(); return GetRootCall(next_call, depth - 1, expected_op_names); diff --git a/src/runtime/contrib/dnnl/dnnl_json_runtime.cc b/src/runtime/contrib/dnnl/dnnl_json_runtime.cc index a46f170fea94..71d4d184c007 100644 --- a/src/runtime/contrib/dnnl/dnnl_json_runtime.cc +++ b/src/runtime/contrib/dnnl/dnnl_json_runtime.cc @@ -200,7 +200,11 @@ class DNNLJSONRuntime : public JSONRuntimeBase { ops.append_eltwise(1.f, dnnl::algorithm::eltwise_tanh, 0.f, 0.f); } if (std::regex_match(op_name, sigmoid_pat)) { - ops.append_eltwise(1.f, dnnl::algorithm::eltwise_logistic, 0.f, 0.f); + if (op_name.find("_sigmoid_mul") != std::string::npos) { + ops.append_eltwise(1.f, dnnl::algorithm::eltwise_swish, 1.f, 1.f); + } else { + ops.append_eltwise(1.f, dnnl::algorithm::eltwise_logistic, 0.f, 0.f); + } } if (std::regex_match(op_name, gelu_pat)) { ops.append_eltwise(1.f, dnnl::algorithm::eltwise_gelu_erf, 0.f, 0.f); diff --git a/tests/python/contrib/test_dnnl.py b/tests/python/contrib/test_dnnl.py index 2138eda08697..84f1c5ebb710 100755 --- a/tests/python/contrib/test_dnnl.py +++ b/tests/python/contrib/test_dnnl.py @@ -192,7 +192,6 @@ def run_and_verify(mod, input, params, target, run_module, subgraph_num=None, te if use_dnnl: processed_mod = partition_for_dnnl(processed_mod, params, alter_layout) check_dnnl_used(processed_mod) - with tvm.transform.PassContext(opt_level=3): func = relay.create_executor( mode, mod=processed_mod, device=dev, target=target @@ -819,6 +818,13 @@ def test_conv2d_pattern(run_module, dtype="float32"): config = conv2d_bias_bn_relu, dic, param_lst run_and_verify_func(config, run_module=run_module, dtype=dtype) + conv2d_bias, dic, param_lst = get_conv2d_bias(x_shape, k_shape, activation=None, dtype=dtype) + conv2d_bias_sig = relay.sigmoid(conv2d_bias) + conv2d_bias_swish = relay.multiply(conv2d_bias, conv2d_bias_sig) + conv2d_bias_swish = tvm.IRModule.from_expr(conv2d_bias_swish) + config = conv2d_bias_swish, dic, param_lst + run_and_verify_func(config, run_module=run_module, dtype=dtype) + def test_conv2d_transpose(run_module, dtype="float32"): x_shape = (1, 32, 8, 8) From dbb01b8ada8be10fbef3dddc309afddab13f8ea0 Mon Sep 17 00:00:00 2001 From: Ivy Date: Mon, 20 Jun 2022 16:08:20 +0800 Subject: [PATCH 2/6] support post-op clip --- python/tvm/relay/op/contrib/dnnl.py | 2 +- src/relay/backend/contrib/dnnl/codegen.cc | 8 ++ src/runtime/contrib/dnnl/dnnl_json_runtime.cc | 6 + tests/python/contrib/test_dnnl.py | 132 ++++-------------- 4 files changed, 46 insertions(+), 102 deletions(-) diff --git a/python/tvm/relay/op/contrib/dnnl.py b/python/tvm/relay/op/contrib/dnnl.py index 2dbb08dea358..ccb02ceda002 100644 --- a/python/tvm/relay/op/contrib/dnnl.py +++ b/python/tvm/relay/op/contrib/dnnl.py @@ -289,7 +289,7 @@ def pattern_table(): dnnl_patterns.append(make_qnn_conv2d_pattern()) dnnl_patterns.append(make_qnn_dense_pattern()) - elt_list = ["nn.relu", "tanh", "sigmoid", "gelu", "swish", None] + elt_list = ["nn.relu", "tanh", "sigmoid", "clip", "gelu", "swish", None] for with_bias in [True, False]: for elt in elt_list: if not with_bias and not elt: diff --git a/src/relay/backend/contrib/dnnl/codegen.cc b/src/relay/backend/contrib/dnnl/codegen.cc index c800329c912e..4abfc9d9b136 100644 --- a/src/relay/backend/contrib/dnnl/codegen.cc +++ b/src/relay/backend/contrib/dnnl/codegen.cc @@ -470,6 +470,7 @@ class DNNLJSONSerializer : public backend::contrib::JSONSerializer { {"relu", "nn.relu"}, {"tanh", "tanh"}, {"sigmoid", "sigmoid"}, + {"clip", "clip"}, {"mul", "multiply"}, {"nn.deconv2d", "nn.conv2d_transpose"}, {"nn.deconv3d", "nn.conv3d_transpose"}, @@ -567,6 +568,13 @@ class DNNLJSONSerializer : public backend::contrib::JSONSerializer { "kernel", /* op_type_ */ inputs, 1 /* num_outputs_ */); SetCallNodeAttribute(node, call); + // If has post-op `clip`. Assume the last op is clip, add clip's attrs to the pattern attrs. + if (name.find("_clip") != std::string::npos) { + auto clip_call = cn->op.as()->body.as(); + ICHECK(IsOp(clip_call, "clip")); + SetCallNodeAttribute(node, clip_call); + } + // For QNN. for (const auto& kvp : extra_attrs) node->SetAttr(kvp.first, kvp.second); return AddNode(node, GetRef(cn)); diff --git a/src/runtime/contrib/dnnl/dnnl_json_runtime.cc b/src/runtime/contrib/dnnl/dnnl_json_runtime.cc index 71d4d184c007..6c0fd64066e5 100644 --- a/src/runtime/contrib/dnnl/dnnl_json_runtime.cc +++ b/src/runtime/contrib/dnnl/dnnl_json_runtime.cc @@ -189,6 +189,7 @@ class DNNLJSONRuntime : public JSONRuntimeBase { std::regex relu_pat(".*_relu.*"); std::regex tanh_pat(".*_tanh.*"); std::regex sigmoid_pat(".*_sigmoid.*"); + std::regex clip_pat(".*_clip.*"); std::regex gelu_pat(".*_gelu.*"); // Parsing post-ops. @@ -199,6 +200,11 @@ class DNNLJSONRuntime : public JSONRuntimeBase { if (std::regex_match(op_name, tanh_pat)) { ops.append_eltwise(1.f, dnnl::algorithm::eltwise_tanh, 0.f, 0.f); } + if (std::regex_match(op_name, clip_pat)) { + float a_min = GetNodeAttr(nodes_[nid], "a_min"); + float a_max = GetNodeAttr(nodes_[nid], "a_max"); + ops.append_eltwise(1.f, dnnl::algorithm::eltwise_clip, a_min, a_max); + } if (std::regex_match(op_name, sigmoid_pat)) { if (op_name.find("_sigmoid_mul") != std::string::npos) { ops.append_eltwise(1.f, dnnl::algorithm::eltwise_swish, 1.f, 1.f); diff --git a/tests/python/contrib/test_dnnl.py b/tests/python/contrib/test_dnnl.py index 84f1c5ebb710..5da1101be74e 100755 --- a/tests/python/contrib/test_dnnl.py +++ b/tests/python/contrib/test_dnnl.py @@ -236,6 +236,23 @@ def run_and_verify_func( ) +def add_activation(activation, out, dic, param_lst): + if activation == "relu": + return relay.nn.relu(out), dic, param_lst + elif activation == "tanh": + return relay.tanh(out), dic, param_lst + elif activation == "sigmoid": + return relay.sigmoid(out), dic, param_lst + elif activation == "clip": + return relay.clip(out, 0.0, 6.0), dic, param_lst + elif activation == "swish": + sig_out = relay.sigmoid(out) + out = relay.multiply(out, sig_out) + return out, dic, param_lst + else: + return out, dic, param_lst + + def get_conv1d( x_shape=((1, 3, 224)), k_shape=(16, 3, 3), @@ -261,15 +278,7 @@ def get_conv1d( ) dic = {"x": x_shape, "kernel": k_shape} param_lst = ["kernel"] - - if activation == "relu": - return relay.nn.relu(out), dic, param_lst - elif activation == "tanh": - return relay.tanh(out), dic, param_lst - elif activation == "sigmoid": - return relay.sigmoid(out), dic, param_lst - else: - return out, dic, param_lst + return add_activation(activation, out, dic, param_lst) def get_conv1d_bias(x_shape=(1, 3, 224), k_shape=(10, 3, 3), activation=None, dtype="float32"): @@ -278,15 +287,7 @@ def get_conv1d_bias(x_shape=(1, 3, 224), k_shape=(10, 3, 3), activation=None, dt out = relay.nn.bias_add(conv, bias) dic["bias"] = (k_shape[0],) param_lst += ["bias"] - - if activation == "relu": - return relay.nn.relu(out), dic, param_lst - elif activation == "tanh": - return relay.tanh(out), dic, param_lst - elif activation == "sigmoid": - return relay.sigmoid(out), dic, param_lst - else: - return out, dic, param_lst + return add_activation(activation, out, dic, param_lst) def get_conv1d_bias_bn_relu(x_shape=(1, 3, 224), k_shape=(10, 3, 3), dtype="float32"): @@ -333,15 +334,7 @@ def get_conv2d( ) dic = {"x": x_shape, "kernel": k_shape} param_lst = ["kernel"] - - if activation == "relu": - return relay.nn.relu(out), dic, param_lst - elif activation == "tanh": - return relay.tanh(out), dic, param_lst - elif activation == "sigmoid": - return relay.sigmoid(out), dic, param_lst - else: - return out, dic, param_lst + return add_activation(activation, out, dic, param_lst) def get_conv2d_transpose( @@ -366,15 +359,7 @@ def get_conv2d_transpose( ) dic = {"x": x_shape, "kernel": k_shape} param_lst = ["kernel"] - - if activation == "relu": - return relay.nn.relu(out), dic, param_lst - elif activation == "tanh": - return relay.tanh(out), dic, param_lst - elif activation == "sigmoid": - return relay.sigmoid(out), dic, param_lst - else: - return out, dic, param_lst + return add_activation(activation, out, dic, param_lst) def get_conv2d_weights_const( @@ -411,15 +396,7 @@ def get_conv2d_bias( out = relay.nn.bias_add(conv, bias) dic["bias"] = (k_shape[0],) param_lst += ["bias"] - - if activation == "relu": - return relay.nn.relu(out), dic, param_lst - elif activation == "tanh": - return relay.tanh(out), dic, param_lst - elif activation == "sigmoid": - return relay.sigmoid(out), dic, param_lst - else: - return out, dic, param_lst + return add_activation(activation, out, dic, param_lst) def get_conv2d_transpose_bias( @@ -430,15 +407,7 @@ def get_conv2d_transpose_bias( out = relay.nn.bias_add(conv, bias) dic["bias"] = (k_shape[1],) param_lst += ["bias"] - - if activation == "relu": - return relay.nn.relu(out), dic, param_lst - elif activation == "tanh": - return relay.tanh(out), dic, param_lst - elif activation == "sigmoid": - return relay.sigmoid(out), dic, param_lst - else: - return out, dic, param_lst + return add_activation(activation, out, dic, param_lst) def get_conv2d_bias_bn_relu(x_shape=(1, 32, 8, 8), k_shape=(16, 32, 3, 3), dtype="float32"): @@ -502,15 +471,7 @@ def get_conv3d( ) dic = {"x": x_shape, "kernel": k_shape} param_lst = ["kernel"] - - if activation == "relu": - return relay.nn.relu(out), dic, param_lst - elif activation == "tanh": - return relay.tanh(out), dic, param_lst - elif activation == "sigmoid": - return relay.sigmoid(out), dic, param_lst - else: - return out, dic, param_lst + return add_activation(activation, out, dic, param_lst) def get_conv3d_transpose( @@ -541,15 +502,7 @@ def get_conv3d_transpose( ) dic = {"x": x_shape, "kernel": k_shape} param_lst = ["kernel"] - - if activation == "relu": - return relay.nn.relu(out), dic, param_lst - elif activation == "tanh": - return relay.tanh(out), dic, param_lst - elif activation == "sigmoid": - return relay.sigmoid(out), dic, param_lst - else: - return out, dic, param_lst + return add_activation(activation, out, dic, param_lst) def get_conv3d_bias( @@ -560,15 +513,7 @@ def get_conv3d_bias( out = relay.nn.bias_add(conv, bias) dic["bias"] = (k_shape[0],) param_lst += ["bias"] - - if activation == "relu": - return relay.nn.relu(out), dic, param_lst - elif activation == "tanh": - return relay.tanh(out), dic, param_lst - elif activation == "sigmoid": - return relay.sigmoid(out), dic, param_lst - else: - return out, dic, param_lst + return add_activation(activation, out, dic, param_lst) def get_conv3d_transpose_bias( @@ -579,15 +524,7 @@ def get_conv3d_transpose_bias( out = relay.nn.bias_add(conv, bias) dic["bias"] = (k_shape[1],) param_lst += ["bias"] - - if activation == "relu": - return relay.nn.relu(out), dic, param_lst - elif activation == "tanh": - return relay.tanh(out), dic, param_lst - elif activation == "sigmoid": - return relay.sigmoid(out), dic, param_lst - else: - return out, dic, param_lst + return add_activation(activation, out, dic, param_lst) def gelu_helper(data): @@ -796,7 +733,7 @@ def test_conv2d_weights_const(run_module, dtype="float32"): def test_conv2d_pattern(run_module, dtype="float32"): x_shape = (1, 32, 8, 8) k_shape = (16, 32, 3, 3) - activation_lst = [None, "relu", "tanh", "sigmoid"] + activation_lst = [None, "relu", "tanh", "sigmoid", "clip", "swish"] for a in activation_lst: conv2d, dic, param_lst = get_conv2d(x_shape, k_shape, activation=a, dtype=dtype) conv2d = tvm.IRModule.from_expr(conv2d) @@ -818,13 +755,6 @@ def test_conv2d_pattern(run_module, dtype="float32"): config = conv2d_bias_bn_relu, dic, param_lst run_and_verify_func(config, run_module=run_module, dtype=dtype) - conv2d_bias, dic, param_lst = get_conv2d_bias(x_shape, k_shape, activation=None, dtype=dtype) - conv2d_bias_sig = relay.sigmoid(conv2d_bias) - conv2d_bias_swish = relay.multiply(conv2d_bias, conv2d_bias_sig) - conv2d_bias_swish = tvm.IRModule.from_expr(conv2d_bias_swish) - config = conv2d_bias_swish, dic, param_lst - run_and_verify_func(config, run_module=run_module, dtype=dtype) - def test_conv2d_transpose(run_module, dtype="float32"): x_shape = (1, 32, 8, 8) @@ -845,7 +775,7 @@ def test_conv2d_transpose(run_module, dtype="float32"): def test_conv2d_transpose_pattern(run_module, dtype="float32"): - activation_lst = [None, "relu", "tanh", "sigmoid"] + activation_lst = [None, "relu", "tanh", "sigmoid", "clip", "swish"] for a in activation_lst: conv2d, dic, param_lst = get_conv2d_transpose(activation=a, dtype=dtype) conv2d = tvm.IRModule.from_expr(conv2d) @@ -878,7 +808,7 @@ def test_conv3d(run_module, dtype="float32"): def test_conv3d_pattern(run_module, dtype="float32"): - activation_lst = [None, "relu", "tanh", "sigmoid"] + activation_lst = [None, "relu", "tanh", "sigmoid", "clip", "swish"] for a in activation_lst: conv3d, dic, param_lst = get_conv3d(activation=a, dtype=dtype) conv3d = tvm.IRModule.from_expr(conv3d) @@ -911,7 +841,7 @@ def test_conv3d_transpose(run_module, dtype="float32"): def test_conv3d_transpose_pattern(run_module, dtype="float32"): - activation_lst = [None, "relu", "tanh", "sigmoid"] + activation_lst = [None, "relu", "tanh", "sigmoid", "clip", "swish"] for a in activation_lst: conv3d, dic, param_lst = get_conv3d_transpose(activation=a, dtype=dtype) conv3d = tvm.IRModule.from_expr(conv3d) From 04babf2091854e66a238b077cf68d6d007cfe3b9 Mon Sep 17 00:00:00 2001 From: Ivy Date: Mon, 20 Jun 2022 14:13:24 +0800 Subject: [PATCH 3/6] enhance get_shape and get_dtype in dnnl.py to support efficientnet --- python/tvm/relay/op/contrib/dnnl.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/python/tvm/relay/op/contrib/dnnl.py b/python/tvm/relay/op/contrib/dnnl.py index ccb02ceda002..62b369ae4dbe 100644 --- a/python/tvm/relay/op/contrib/dnnl.py +++ b/python/tvm/relay/op/contrib/dnnl.py @@ -387,6 +387,8 @@ def get_shape(tensor): if isinstance(tensor, tvm.ir.container.Array): return tensor[-1].shape if isinstance(tensor, relay.expr.Call): + if tensor.op.name=="multiply": + return tensor.type_args[0].shape return tensor.checked_type.shape raise TypeError("Unsupport data type: %s" % type(tensor)) @@ -402,6 +404,8 @@ def get_dtype(tensor): if isinstance(tensor, tvm.ir.container.Array): return tensor[-1].dtype if isinstance(tensor, relay.expr.Call): + if tensor.op.name=="multiply": + return tensor.type_args[0].dtype return tensor.checked_type.dtype raise TypeError("Unsupport data type: %s" % type(tensor)) From 8746936905d16cbf7317ea92a955d7a43ddae33c Mon Sep 17 00:00:00 2001 From: Ivy Date: Wed, 22 Jun 2022 14:00:35 +0800 Subject: [PATCH 4/6] add checks for with_eltwise whether in supported list --- python/tvm/relay/op/contrib/dnnl.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/python/tvm/relay/op/contrib/dnnl.py b/python/tvm/relay/op/contrib/dnnl.py index 62b369ae4dbe..32cb1b79fbc6 100644 --- a/python/tvm/relay/op/contrib/dnnl.py +++ b/python/tvm/relay/op/contrib/dnnl.py @@ -51,6 +51,7 @@ logger = logging.getLogger("DNNL") +supported_post_elts = ["nn.relu", "tanh", "sigmoid", "clip", "gelu", "swish", None] def _register_external_op_helper(op_name, supported=True): @@ -120,6 +121,8 @@ def make_conv_pattern(conv_name, with_bias=True, with_eltwise=None): conv_out : CallPattern Call node sequence. """ + if with_eltwise not in supported_post_elts: + raise ValueError("Unsupported eltwise post-op: %s" % with_eltwise) data = wildcard() weight = wildcard() bias = wildcard() @@ -150,6 +153,8 @@ def make_dense_pattern(with_bias=True, with_eltwise=None): dense_out : CallPattern Call node sequence. """ + if with_eltwise not in supported_post_elts: + raise ValueError("Unsupported eltwise post-op: %s" % with_eltwise) data = wildcard() weight = wildcard() bias = wildcard() From 456eb0c6caf68162e18dc462fea311271babf74c Mon Sep 17 00:00:00 2001 From: Ivy Date: Mon, 27 Jun 2022 09:53:19 +0800 Subject: [PATCH 5/6] fix lint --- python/tvm/relay/op/contrib/dnnl.py | 6 +++--- src/relay/backend/utils.h | 5 +++-- 2 files changed, 6 insertions(+), 5 deletions(-) diff --git a/python/tvm/relay/op/contrib/dnnl.py b/python/tvm/relay/op/contrib/dnnl.py index 32cb1b79fbc6..f3d53a1df609 100644 --- a/python/tvm/relay/op/contrib/dnnl.py +++ b/python/tvm/relay/op/contrib/dnnl.py @@ -202,7 +202,7 @@ def make_dnnl_pattern(op_name, with_bias, with_eltwise): pat_name = "dnnl.deconv" + op_name.split("_")[0][-2::] pat_name += "_bias" if with_bias else "" pat_name += ("_" + with_eltwise.split(".")[-1]) if with_eltwise else "" - pat_name =pat_name.replace("_swish", "_sigmoid_mul") + pat_name = pat_name.replace("_swish", "_sigmoid_mul") if "conv" in op_name: dnnl_pattern = (pat_name, make_conv_pattern(op_name, with_bias, with_eltwise)) elif op_name == "nn.dense": @@ -392,7 +392,7 @@ def get_shape(tensor): if isinstance(tensor, tvm.ir.container.Array): return tensor[-1].shape if isinstance(tensor, relay.expr.Call): - if tensor.op.name=="multiply": + if tensor.op.name == "multiply": return tensor.type_args[0].shape return tensor.checked_type.shape raise TypeError("Unsupport data type: %s" % type(tensor)) @@ -409,7 +409,7 @@ def get_dtype(tensor): if isinstance(tensor, tvm.ir.container.Array): return tensor[-1].dtype if isinstance(tensor, relay.expr.Call): - if tensor.op.name=="multiply": + if tensor.op.name == "multiply": return tensor.type_args[0].dtype return tensor.checked_type.dtype raise TypeError("Unsupport data type: %s" % type(tensor)) diff --git a/src/relay/backend/utils.h b/src/relay/backend/utils.h index 1a8d488b5fc1..f7c10c2cfee7 100644 --- a/src/relay/backend/utils.h +++ b/src/relay/backend/utils.h @@ -467,8 +467,9 @@ inline const CallNode* GetRootCall(const CallNode* current_call, int depth, ICHECK_GT(current_call->args.size(), 0); size_t valid_node_idx = 0; while (valid_node_idx < current_call->args.size()) { - if (IsOp(current_call->args[valid_node_idx].as(), expected_op_names[depth - 1])) break; - valid_node_idx++; + if (IsOp(current_call->args[valid_node_idx].as(), expected_op_names[depth - 1])) + break; + valid_node_idx++; } const auto* next_call = current_call->args[valid_node_idx].as(); return GetRootCall(next_call, depth - 1, expected_op_names); From f35ad2d14e0f6c7f6fd7d36997484ca45c137012 Mon Sep 17 00:00:00 2001 From: Ivy Date: Mon, 4 Jul 2022 14:58:32 +0800 Subject: [PATCH 6/6] fix test --- src/relay/backend/utils.h | 9 ++++--- .../python/relay/test_pass_partition_graph.py | 26 +++++++++++++++---- 2 files changed, 27 insertions(+), 8 deletions(-) diff --git a/src/relay/backend/utils.h b/src/relay/backend/utils.h index f7c10c2cfee7..57c066131181 100644 --- a/src/relay/backend/utils.h +++ b/src/relay/backend/utils.h @@ -466,9 +466,12 @@ inline const CallNode* GetRootCall(const CallNode* current_call, int depth, ICHECK_GT(current_call->args.size(), 0); size_t valid_node_idx = 0; - while (valid_node_idx < current_call->args.size()) { - if (IsOp(current_call->args[valid_node_idx].as(), expected_op_names[depth - 1])) - break; + while (valid_node_idx < current_call->args.size() && + current_call->args[valid_node_idx].as()) { + valid_node_idx++; + } + while (valid_node_idx < current_call->args.size() && + !(IsOp(current_call->args[valid_node_idx].as(), expected_op_names[depth - 1]))) { valid_node_idx++; } const auto* next_call = current_call->args[valid_node_idx].as(); diff --git a/tests/python/relay/test_pass_partition_graph.py b/tests/python/relay/test_pass_partition_graph.py index 58b41189a0f0..4b7ac92136e9 100644 --- a/tests/python/relay/test_pass_partition_graph.py +++ b/tests/python/relay/test_pass_partition_graph.py @@ -919,6 +919,7 @@ def expected(): def test_dnnl_fuse(): dnnl_patterns = get_pattern_table("dnnl") + dnnl_pat_dic = dict(dnnl_patterns) ( conv2d_bias_relu_pat, conv2d_bias_sigmoid_pat, @@ -926,11 +927,26 @@ def test_dnnl_fuse(): conv2d_relu_pat, conv2d_sigmoid_pat, ) = ( - dnnl_patterns[3], - dnnl_patterns[15], - dnnl_patterns[22], - dnnl_patterns[28], - dnnl_patterns[40], + ( + "dnnl.conv2d_bias_relu", + dnnl_pat_dic["dnnl.conv2d_bias_relu"], + ), + ( + "dnnl.conv2d_bias_sigmoid", + dnnl_pat_dic["dnnl.conv2d_bias_sigmoid"], + ), + ( + "dnnl.conv2d_bias", + dnnl_pat_dic["dnnl.conv2d_bias"], + ), + ( + "dnnl.conv2d_relu", + dnnl_pat_dic["dnnl.conv2d_relu"], + ), + ( + "dnnl.conv2d_sigmoid", + dnnl_pat_dic["dnnl.conv2d_sigmoid"], + ), ) def get_blocks(